[vector] Thoughts on ChaCha20 prototype/benchmark

Vladimir Ivanov vladimir.x.ivanov at oracle.com
Sat Feb 9 00:11:29 UTC 2019


Hi Adam,

> I'm still working on crypto prototypes on the vector API, and I managed 
> to get a version of ChaCha20[1] working. It's pretty fast on 
> AVX2---about 2/3 the speed of openssl in my initial testing. I have a 
> couple of questions for the experts on this mailing list:
> 
> 1) Is there any interest in including this benchmark in the repo? The 
> code is relatively simple, and it may be a good "real-world" example to 
> track for purposes of API suitability and performance. If so, I'd be 
> happy to submit a patch---I'll just need a little guidance on where this 
> benchmark should go in the source tree. Note that we already have 
> benchmarks for the scalar ChaCha20 implementation in the JDK.

Definitely! Vector benchmarks reside in [1]:
   test/jdk/jdk/incubator/vector/benchmark/

> 2) Does anybody see any obvious improvements that could be made to the 
> code (included at the end of this e-mail)? The code is pretty 
> rough---all of my attempts to better organize the code seem to have 
> eliminated inlining. I'm mostly looking for advice that improves 
> performance, but tips on how to organize the code without hurting 
> performance would also be appreciated. This code should work for vectors 
> of 128, 256, and 512 bits, though I haven't tested on a machine with a 
> 512-bit vector unit.

It looks fine to me. And haven't spotted anything suspicious in 
generated code as well.

Best regards,
Vladimir Ivanov

[1] 
http://hg.openjdk.java.net/panama/dev/file/b610ed069715/test/jdk/jdk/incubator/vector/benchmark/src/main/java/benchmark/jdk/incubator/vector

> [1] https://eprint.iacr.org/2013/759.pdf
> 
> ------------------------------------------------------------------------------------ 
> 
> 
> import jdk.incubator.vector.*;
> 
> public class ChaCha20 {
> 
>      private static final IntVector.IntSpecies intSpecies = 
> IntVector.species(Vector.Shape.S_256_BIT);
>      private static int numBlocks = intSpecies.length() / 4;
> 
>      private static final IntVector.Shuffle rot1 = rotate(intSpecies, 1);
>      private static final IntVector.Shuffle rot2 = rotate(intSpecies, 2);
>      private static final IntVector.Shuffle rot3 = rotate(intSpecies, 3);
> 
>      private static IntVector.Shuffle rotate(IntVector.IntSpecies spec, 
> int amount) {
>          int[] shuffleArr = new int[spec.length()];
> 
>          for (int i = 0; i < spec.length(); i ++) {
>              int offset = (i / 4) * 4;
>              shuffleArr[i] = offset + ((i + amount) % 4);
>          }
> 
>          return spec.shuffleFromValues(shuffleArr);
>      }
> 
>      private static int[] constants = new int[]{0x61707865, 0x3320646e, 
> 0x79622d32, 0x6b206574};
> 
>      public static int[] makeState(byte[] key, byte[] nonce, long 
> counter, int[] result) {
> 
>          // first field is constants
>          for (int i = 0; i < 4; i++) {
>              for (int j = 0; j < numBlocks; j++) {
>                  result[4*j + i] = constants[i];
>              }
>          }
> 
>          // second field is first part of key
>          int fieldStart = 4 * numBlocks;
>          for (int i = 0; i < 4; i++) {
>              int keyInt = 0;
>              for (int j = 0; j < 4; j++) {
>                  keyInt += (0xFF & key[4 * i + j]) << 8 * j;
>              }
>              for (int j = 0; j < numBlocks; j++) {
>                  result[fieldStart + j*4 + i] = keyInt;
>              }
>          }
> 
>          // third field is second part of key
>          fieldStart = 8 * numBlocks;
>          for (int i = 0; i < 4; i++) {
>              int keyInt = 0;
>              for (int j = 0; j < 4; j++) {
>                  keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
>              }
> 
>              for (int j = 0; j < numBlocks; j++) {
>                  result[fieldStart + j*4 + i] = keyInt;
>              }
>          }
> 
>          // fourth field is counter and nonce
>          fieldStart = 12 * numBlocks;
>          for (int j = 0; j < numBlocks; j++) {
>              result[fieldStart + j*4] = (int) (counter + j);
>          }
> 
>          for (int i = 0; i < 3; i++) {
>              int nonceInt = 0;
>              for (int j = 0; j < 4; j++) {
>                  nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
>              }
> 
>              for (int j = 0; j < numBlocks; j++) {
>                  result[fieldStart + j*4 + 1 + i] = nonceInt;
>              }
>          }
> 
>          return result;
>      }
> 
>      private static int[] state = new int[numBlocks * 16];
> 
>      private static IntVector counterAdd = makeCounterAdd();
> 
>      private static IntVector makeCounterAdd() {
>          int[] addArr = new int[intSpecies.length()];
>          for(int i = 0; i < numBlocks; i++) {
>              addArr[4 * i] = numBlocks;
>          }
>          return intSpecies.fromArray(addArr, 0);
>      }
> 
>      private static IntVector.Shuffle makeRearrangeShuffle(int order) {
>          int[] shuffleArr = new int[intSpecies.length()];
>          int start = order * 4;
>          for (int i = 0; i < shuffleArr.length; i++) {
>              shuffleArr[i] = (i % 4) + start;
>          }
>          return intSpecies.shuffleFromArray(shuffleArr, 0);
>      }
> 
>      private static IntVector.Mask makeRearrangeMask(int order) {
>          boolean[] maskArr = new boolean[intSpecies.length()];
>          int start = order * 4;
>          if (start < maskArr.length) {
>              for (int i = 0; i < 4; i++) {
>                  maskArr[i + start] = true;
>              }
>          }
> 
>          return intSpecies.maskFromValues(maskArr);
>      }
> 
>      private static IntVector.Shuffle shuf0 = makeRearrangeShuffle(0);
>      private static IntVector.Shuffle shuf1 = makeRearrangeShuffle(1);
>      private static IntVector.Shuffle shuf2 = makeRearrangeShuffle(2);
>      private static IntVector.Shuffle shuf3 = makeRearrangeShuffle(3);
> 
>      private static IntVector.Mask mask0 = makeRearrangeMask(0);
>      private static IntVector.Mask mask1 = makeRearrangeMask(1);
>      private static IntVector.Mask mask2 = makeRearrangeMask(2);
>      private static IntVector.Mask mask3 = makeRearrangeMask(3);
> 
>      public static void chacha20(byte[] key, byte[] nonce, long counter, 
> byte[] in, byte[] out) {
> 
>          makeState(key, nonce, counter, state);
> 
>          int specLen = intSpecies.length();
> 
>          IntVector sa = intSpecies.fromArray(state, 0);
>          IntVector sb = intSpecies.fromArray(state, specLen);
>          IntVector sc = intSpecies.fromArray(state, 2 * specLen);
>          IntVector sd = intSpecies.fromArray(state, 3 * specLen);
> 
>          int stateLenBytes = state.length * 4;
> 
>          for (int j = 0; j < (in.length + stateLenBytes - 1) / 
> stateLenBytes; j++){
> 
>              IntVector a = sa;
>              IntVector b = sb;
>              IntVector c = sc;
>              IntVector d = sd;
> 
>              for (int i = 0; i < 10; i++) {
>                  // first round
>                  a = a.add(b);
>                  d = d.xor(a);
>                  d = d.rotateL(16);
> 
>                  c = c.add(d);
>                  b = b.xor(c);
>                  b = b.rotateL(12);
> 
>                  a = a.add(b);
>                  d = d.xor(a);
>                  d = d.rotateL(8);
> 
>                  c = c.add(d);
>                  b = b.xor(c);
>                  b = b.rotateL(7);
> 
>                  // rotate
>                  b = b.rearrange(rot1);
>                  c = c.rearrange(rot2);
>                  d = d.rearrange(rot3);
> 
>                  // second round
>                  a = a.add(b);
>                  d = d.xor(a);
>                  d = d.rotateL(16);
> 
>                  c = c.add(d);
>                  b = b.xor(c);
>                  b = b.rotateL(12);
> 
>                  a = a.add(b);
>                  d = d.xor(a);
>                  d = d.rotateL(8);
> 
>                  c = c.add(d);
>                  b = b.xor(c);
>                  b = b.rotateL(7);
> 
>                  // rotate
>                  b = b.rearrange(rot3);
>                  c = c.rearrange(rot2);
>                  d = d.rearrange(rot1);
>              }
> 
>              a = a.add(sa);
>              b = b.add(sb);
>              c = c.add(sc);
>              d = d.add(sd);
> 
>              // rearrange the vectors
>              if (intSpecies.length() == 4) {
>                  // no rearrange needed
>              } else if (intSpecies.length() == 8) {
>                  IntVector a_r = a.rearrange(b, shuf0, mask1);
>                  IntVector b_r = c.rearrange(d, shuf0, mask1);
>                  IntVector c_r = a.rearrange(b, shuf1, mask1);
>                  IntVector d_r = c.rearrange(d, shuf1, mask1);
> 
>                  a = a_r;
>                  b = b_r;
>                  c = c_r;
>                  d = d_r;
>              } else if (intSpecies.length() == 16) {
>                  IntVector a_r = a;
>                  a_r = a_r.blend(b.rearrange(shuf0), mask1);
>                  a_r = a_r.blend(c.rearrange(shuf0), mask2);
>                  a_r = a_r.blend(d.rearrange(shuf0), mask3);
> 
>                  IntVector b_r = b;
>                  b_r = b_r.blend(a.rearrange(shuf1), mask0);
>                  b_r = b_r.blend(c.rearrange(shuf1), mask2);
>                  b_r = b_r.blend(d.rearrange(shuf1), mask3);
> 
>                  IntVector c_r = c;
>                  c_r = c_r.blend(a.rearrange(shuf2), mask0);
>                  c_r = c_r.blend(b.rearrange(shuf2), mask1);
>                  c_r = c_r.blend(d.rearrange(shuf2), mask3);
> 
>                  IntVector d_r = d;
>                  d_r = d_r.blend(a.rearrange(shuf3), mask0);
>                  d_r = d_r.blend(b.rearrange(shuf3), mask1);
>                  d_r = d_r.blend(c.rearrange(shuf3), mask2);
> 
>                  a = a_r;
>                  b = b_r;
>                  c = c_r;
>                  d = d_r;
>              } else {
>                  throw new RuntimeException("not supported");
>              }
> 
>              // xor keystream with input
>              int inOff = stateLenBytes * j;
>              IntVector ina = intSpecies.fromByteArray(in, inOff);
>              IntVector inb = intSpecies.fromByteArray(in, inOff + 4 * 
> specLen);
>              IntVector inc = intSpecies.fromByteArray(in, inOff + 8 * 
> specLen);
>              IntVector ind = intSpecies.fromByteArray(in, inOff + 12 * 
> specLen);
> 
>              ina.xor(a).intoByteArray(out, inOff);
>              inb.xor(b).intoByteArray(out, inOff + 4 * specLen);
>              inc.xor(c).intoByteArray(out, inOff + 8 * specLen);
>              ind.xor(d).intoByteArray(out, inOff + 12 * specLen);
> 
>              // increment counter
>              sd = sd.add(counterAdd);
>          }
>      }
> }
> 


More information about the panama-dev mailing list