[vector] Thoughts on ChaCha20 prototype/benchmark

Adam Petcher adam.petcher at oracle.com
Fri Feb 8 21:26:36 UTC 2019


I'm still working on crypto prototypes on the vector API, and I managed 
to get a version of ChaCha20[1] working. It's pretty fast on 
AVX2---about 2/3 the speed of openssl in my initial testing. I have a 
couple of questions for the experts on this mailing list:

1) Is there any interest in including this benchmark in the repo? The 
code is relatively simple, and it may be a good "real-world" example to 
track for purposes of API suitability and performance. If so, I'd be 
happy to submit a patch---I'll just need a little guidance on where this 
benchmark should go in the source tree. Note that we already have 
benchmarks for the scalar ChaCha20 implementation in the JDK.

2) Does anybody see any obvious improvements that could be made to the 
code (included at the end of this e-mail)? The code is pretty 
rough---all of my attempts to better organize the code seem to have 
eliminated inlining. I'm mostly looking for advice that improves 
performance, but tips on how to organize the code without hurting 
performance would also be appreciated. This code should work for vectors 
of 128, 256, and 512 bits, though I haven't tested on a machine with a 
512-bit vector unit.

[1] https://eprint.iacr.org/2013/759.pdf

------------------------------------------------------------------------------------

import jdk.incubator.vector.*;

public class ChaCha20 {

     private static final IntVector.IntSpecies intSpecies = IntVector.species(Vector.Shape.S_256_BIT);
     private static int numBlocks = intSpecies.length() / 4;

     private static final IntVector.Shuffle rot1 = rotate(intSpecies, 1);
     private static final IntVector.Shuffle rot2 = rotate(intSpecies, 2);
     private static final IntVector.Shuffle rot3 = rotate(intSpecies, 3);

     private static IntVector.Shuffle rotate(IntVector.IntSpecies spec, int amount) {
         int[] shuffleArr = new int[spec.length()];

         for (int i = 0; i < spec.length(); i ++) {
             int offset = (i / 4) * 4;
             shuffleArr[i] = offset + ((i + amount) % 4);
         }

         return spec.shuffleFromValues(shuffleArr);
     }

     private static int[] constants = new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574};

     public static int[] makeState(byte[] key, byte[] nonce, long counter, int[] result) {

         // first field is constants
         for (int i = 0; i < 4; i++) {
             for (int j = 0; j < numBlocks; j++) {
                 result[4*j + i] = constants[i];
             }
         }

         // second field is first part of key
         int fieldStart = 4 * numBlocks;
         for (int i = 0; i < 4; i++) {
             int keyInt = 0;
             for (int j = 0; j < 4; j++) {
                 keyInt += (0xFF & key[4 * i + j]) << 8 * j;
             }
             for (int j = 0; j < numBlocks; j++) {
                 result[fieldStart + j*4 + i] = keyInt;
             }
         }

         // third field is second part of key
         fieldStart = 8 * numBlocks;
         for (int i = 0; i < 4; i++) {
             int keyInt = 0;
             for (int j = 0; j < 4; j++) {
                 keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
             }

             for (int j = 0; j < numBlocks; j++) {
                 result[fieldStart + j*4 + i] = keyInt;
             }
         }

         // fourth field is counter and nonce
         fieldStart = 12 * numBlocks;
         for (int j = 0; j < numBlocks; j++) {
             result[fieldStart + j*4] = (int) (counter + j);
         }

         for (int i = 0; i < 3; i++) {
             int nonceInt = 0;
             for (int j = 0; j < 4; j++) {
                 nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
             }

             for (int j = 0; j < numBlocks; j++) {
                 result[fieldStart + j*4 + 1 + i] = nonceInt;
             }
         }

         return result;
     }

     private static int[] state = new int[numBlocks * 16];

     private static IntVector counterAdd = makeCounterAdd();

     private static IntVector makeCounterAdd() {
         int[] addArr = new int[intSpecies.length()];
         for(int i = 0; i < numBlocks; i++) {
             addArr[4 * i] = numBlocks;
         }
         return intSpecies.fromArray(addArr, 0);
     }

     private static IntVector.Shuffle makeRearrangeShuffle(int order) {
         int[] shuffleArr = new int[intSpecies.length()];
         int start = order * 4;
         for (int i = 0; i < shuffleArr.length; i++) {
             shuffleArr[i] = (i % 4) + start;
         }
         return intSpecies.shuffleFromArray(shuffleArr, 0);
     }

     private static IntVector.Mask makeRearrangeMask(int order) {
         boolean[] maskArr = new boolean[intSpecies.length()];
         int start = order * 4;
         if (start < maskArr.length) {
             for (int i = 0; i < 4; i++) {
                 maskArr[i + start] = true;
             }
         }

         return intSpecies.maskFromValues(maskArr);
     }

     private static IntVector.Shuffle shuf0 = makeRearrangeShuffle(0);
     private static IntVector.Shuffle shuf1 = makeRearrangeShuffle(1);
     private static IntVector.Shuffle shuf2 = makeRearrangeShuffle(2);
     private static IntVector.Shuffle shuf3 = makeRearrangeShuffle(3);

     private static IntVector.Mask mask0 = makeRearrangeMask(0);
     private static IntVector.Mask mask1 = makeRearrangeMask(1);
     private static IntVector.Mask mask2 = makeRearrangeMask(2);
     private static IntVector.Mask mask3 = makeRearrangeMask(3);

     public static void chacha20(byte[] key, byte[] nonce, long counter, byte[] in, byte[] out) {

         makeState(key, nonce, counter, state);

         int specLen = intSpecies.length();

         IntVector sa = intSpecies.fromArray(state, 0);
         IntVector sb = intSpecies.fromArray(state, specLen);
         IntVector sc = intSpecies.fromArray(state, 2 * specLen);
         IntVector sd = intSpecies.fromArray(state, 3 * specLen);

         int stateLenBytes = state.length * 4;

         for (int j = 0; j < (in.length + stateLenBytes - 1) / stateLenBytes; j++){

             IntVector a = sa;
             IntVector b = sb;
             IntVector c = sc;
             IntVector d = sd;

             for (int i = 0; i < 10; i++) {
                 // first round
                 a = a.add(b);
                 d = d.xor(a);
                 d = d.rotateL(16);

                 c = c.add(d);
                 b = b.xor(c);
                 b = b.rotateL(12);

                 a = a.add(b);
                 d = d.xor(a);
                 d = d.rotateL(8);

                 c = c.add(d);
                 b = b.xor(c);
                 b = b.rotateL(7);

                 // rotate
                 b = b.rearrange(rot1);
                 c = c.rearrange(rot2);
                 d = d.rearrange(rot3);

                 // second round
                 a = a.add(b);
                 d = d.xor(a);
                 d = d.rotateL(16);

                 c = c.add(d);
                 b = b.xor(c);
                 b = b.rotateL(12);

                 a = a.add(b);
                 d = d.xor(a);
                 d = d.rotateL(8);

                 c = c.add(d);
                 b = b.xor(c);
                 b = b.rotateL(7);

                 // rotate
                 b = b.rearrange(rot3);
                 c = c.rearrange(rot2);
                 d = d.rearrange(rot1);
             }

             a = a.add(sa);
             b = b.add(sb);
             c = c.add(sc);
             d = d.add(sd);

             // rearrange the vectors
             if (intSpecies.length() == 4) {
                 // no rearrange needed
             } else if (intSpecies.length() == 8) {
                 IntVector a_r = a.rearrange(b, shuf0, mask1);
                 IntVector b_r = c.rearrange(d, shuf0, mask1);
                 IntVector c_r = a.rearrange(b, shuf1, mask1);
                 IntVector d_r = c.rearrange(d, shuf1, mask1);

                 a = a_r;
                 b = b_r;
                 c = c_r;
                 d = d_r;
             } else if (intSpecies.length() == 16) {
                 IntVector a_r = a;
                 a_r = a_r.blend(b.rearrange(shuf0), mask1);
                 a_r = a_r.blend(c.rearrange(shuf0), mask2);
                 a_r = a_r.blend(d.rearrange(shuf0), mask3);

                 IntVector b_r = b;
                 b_r = b_r.blend(a.rearrange(shuf1), mask0);
                 b_r = b_r.blend(c.rearrange(shuf1), mask2);
                 b_r = b_r.blend(d.rearrange(shuf1), mask3);

                 IntVector c_r = c;
                 c_r = c_r.blend(a.rearrange(shuf2), mask0);
                 c_r = c_r.blend(b.rearrange(shuf2), mask1);
                 c_r = c_r.blend(d.rearrange(shuf2), mask3);

                 IntVector d_r = d;
                 d_r = d_r.blend(a.rearrange(shuf3), mask0);
                 d_r = d_r.blend(b.rearrange(shuf3), mask1);
                 d_r = d_r.blend(c.rearrange(shuf3), mask2);

                 a = a_r;
                 b = b_r;
                 c = c_r;
                 d = d_r;
             } else {
                 throw new RuntimeException("not supported");
             }

             // xor keystream with input
             int inOff = stateLenBytes * j;
             IntVector ina = intSpecies.fromByteArray(in, inOff);
             IntVector inb = intSpecies.fromByteArray(in, inOff + 4 * specLen);
             IntVector inc = intSpecies.fromByteArray(in, inOff + 8 * specLen);
             IntVector ind = intSpecies.fromByteArray(in, inOff + 12 * specLen);

             ina.xor(a).intoByteArray(out, inOff);
             inb.xor(b).intoByteArray(out, inOff + 4 * specLen);
             inc.xor(c).intoByteArray(out, inOff + 8 * specLen);
             ind.xor(d).intoByteArray(out, inOff + 12 * specLen);

             // increment counter
             sd = sd.add(counterAdd);
         }
     }
}



More information about the panama-dev mailing list