[vector] Thoughts on ChaCha20 prototype/benchmark
Vladimir Ivanov
vladimir.x.ivanov at oracle.com
Sat Feb 9 00:11:29 UTC 2019
Hi Adam,
> I'm still working on crypto prototypes on the vector API, and I managed
> to get a version of ChaCha20[1] working. It's pretty fast on
> AVX2---about 2/3 the speed of openssl in my initial testing. I have a
> couple of questions for the experts on this mailing list:
>
> 1) Is there any interest in including this benchmark in the repo? The
> code is relatively simple, and it may be a good "real-world" example to
> track for purposes of API suitability and performance. If so, I'd be
> happy to submit a patch---I'll just need a little guidance on where this
> benchmark should go in the source tree. Note that we already have
> benchmarks for the scalar ChaCha20 implementation in the JDK.
Definitely! Vector benchmarks reside in [1]:
test/jdk/jdk/incubator/vector/benchmark/
> 2) Does anybody see any obvious improvements that could be made to the
> code (included at the end of this e-mail)? The code is pretty
> rough---all of my attempts to better organize the code seem to have
> eliminated inlining. I'm mostly looking for advice that improves
> performance, but tips on how to organize the code without hurting
> performance would also be appreciated. This code should work for vectors
> of 128, 256, and 512 bits, though I haven't tested on a machine with a
> 512-bit vector unit.
It looks fine to me. And haven't spotted anything suspicious in
generated code as well.
Best regards,
Vladimir Ivanov
[1]
http://hg.openjdk.java.net/panama/dev/file/b610ed069715/test/jdk/jdk/incubator/vector/benchmark/src/main/java/benchmark/jdk/incubator/vector
> [1] https://eprint.iacr.org/2013/759.pdf
>
> ------------------------------------------------------------------------------------
>
>
> import jdk.incubator.vector.*;
>
> public class ChaCha20 {
>
> private static final IntVector.IntSpecies intSpecies =
> IntVector.species(Vector.Shape.S_256_BIT);
> private static int numBlocks = intSpecies.length() / 4;
>
> private static final IntVector.Shuffle rot1 = rotate(intSpecies, 1);
> private static final IntVector.Shuffle rot2 = rotate(intSpecies, 2);
> private static final IntVector.Shuffle rot3 = rotate(intSpecies, 3);
>
> private static IntVector.Shuffle rotate(IntVector.IntSpecies spec,
> int amount) {
> int[] shuffleArr = new int[spec.length()];
>
> for (int i = 0; i < spec.length(); i ++) {
> int offset = (i / 4) * 4;
> shuffleArr[i] = offset + ((i + amount) % 4);
> }
>
> return spec.shuffleFromValues(shuffleArr);
> }
>
> private static int[] constants = new int[]{0x61707865, 0x3320646e,
> 0x79622d32, 0x6b206574};
>
> public static int[] makeState(byte[] key, byte[] nonce, long
> counter, int[] result) {
>
> // first field is constants
> for (int i = 0; i < 4; i++) {
> for (int j = 0; j < numBlocks; j++) {
> result[4*j + i] = constants[i];
> }
> }
>
> // second field is first part of key
> int fieldStart = 4 * numBlocks;
> for (int i = 0; i < 4; i++) {
> int keyInt = 0;
> for (int j = 0; j < 4; j++) {
> keyInt += (0xFF & key[4 * i + j]) << 8 * j;
> }
> for (int j = 0; j < numBlocks; j++) {
> result[fieldStart + j*4 + i] = keyInt;
> }
> }
>
> // third field is second part of key
> fieldStart = 8 * numBlocks;
> for (int i = 0; i < 4; i++) {
> int keyInt = 0;
> for (int j = 0; j < 4; j++) {
> keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
> }
>
> for (int j = 0; j < numBlocks; j++) {
> result[fieldStart + j*4 + i] = keyInt;
> }
> }
>
> // fourth field is counter and nonce
> fieldStart = 12 * numBlocks;
> for (int j = 0; j < numBlocks; j++) {
> result[fieldStart + j*4] = (int) (counter + j);
> }
>
> for (int i = 0; i < 3; i++) {
> int nonceInt = 0;
> for (int j = 0; j < 4; j++) {
> nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
> }
>
> for (int j = 0; j < numBlocks; j++) {
> result[fieldStart + j*4 + 1 + i] = nonceInt;
> }
> }
>
> return result;
> }
>
> private static int[] state = new int[numBlocks * 16];
>
> private static IntVector counterAdd = makeCounterAdd();
>
> private static IntVector makeCounterAdd() {
> int[] addArr = new int[intSpecies.length()];
> for(int i = 0; i < numBlocks; i++) {
> addArr[4 * i] = numBlocks;
> }
> return intSpecies.fromArray(addArr, 0);
> }
>
> private static IntVector.Shuffle makeRearrangeShuffle(int order) {
> int[] shuffleArr = new int[intSpecies.length()];
> int start = order * 4;
> for (int i = 0; i < shuffleArr.length; i++) {
> shuffleArr[i] = (i % 4) + start;
> }
> return intSpecies.shuffleFromArray(shuffleArr, 0);
> }
>
> private static IntVector.Mask makeRearrangeMask(int order) {
> boolean[] maskArr = new boolean[intSpecies.length()];
> int start = order * 4;
> if (start < maskArr.length) {
> for (int i = 0; i < 4; i++) {
> maskArr[i + start] = true;
> }
> }
>
> return intSpecies.maskFromValues(maskArr);
> }
>
> private static IntVector.Shuffle shuf0 = makeRearrangeShuffle(0);
> private static IntVector.Shuffle shuf1 = makeRearrangeShuffle(1);
> private static IntVector.Shuffle shuf2 = makeRearrangeShuffle(2);
> private static IntVector.Shuffle shuf3 = makeRearrangeShuffle(3);
>
> private static IntVector.Mask mask0 = makeRearrangeMask(0);
> private static IntVector.Mask mask1 = makeRearrangeMask(1);
> private static IntVector.Mask mask2 = makeRearrangeMask(2);
> private static IntVector.Mask mask3 = makeRearrangeMask(3);
>
> public static void chacha20(byte[] key, byte[] nonce, long counter,
> byte[] in, byte[] out) {
>
> makeState(key, nonce, counter, state);
>
> int specLen = intSpecies.length();
>
> IntVector sa = intSpecies.fromArray(state, 0);
> IntVector sb = intSpecies.fromArray(state, specLen);
> IntVector sc = intSpecies.fromArray(state, 2 * specLen);
> IntVector sd = intSpecies.fromArray(state, 3 * specLen);
>
> int stateLenBytes = state.length * 4;
>
> for (int j = 0; j < (in.length + stateLenBytes - 1) /
> stateLenBytes; j++){
>
> IntVector a = sa;
> IntVector b = sb;
> IntVector c = sc;
> IntVector d = sd;
>
> for (int i = 0; i < 10; i++) {
> // first round
> a = a.add(b);
> d = d.xor(a);
> d = d.rotateL(16);
>
> c = c.add(d);
> b = b.xor(c);
> b = b.rotateL(12);
>
> a = a.add(b);
> d = d.xor(a);
> d = d.rotateL(8);
>
> c = c.add(d);
> b = b.xor(c);
> b = b.rotateL(7);
>
> // rotate
> b = b.rearrange(rot1);
> c = c.rearrange(rot2);
> d = d.rearrange(rot3);
>
> // second round
> a = a.add(b);
> d = d.xor(a);
> d = d.rotateL(16);
>
> c = c.add(d);
> b = b.xor(c);
> b = b.rotateL(12);
>
> a = a.add(b);
> d = d.xor(a);
> d = d.rotateL(8);
>
> c = c.add(d);
> b = b.xor(c);
> b = b.rotateL(7);
>
> // rotate
> b = b.rearrange(rot3);
> c = c.rearrange(rot2);
> d = d.rearrange(rot1);
> }
>
> a = a.add(sa);
> b = b.add(sb);
> c = c.add(sc);
> d = d.add(sd);
>
> // rearrange the vectors
> if (intSpecies.length() == 4) {
> // no rearrange needed
> } else if (intSpecies.length() == 8) {
> IntVector a_r = a.rearrange(b, shuf0, mask1);
> IntVector b_r = c.rearrange(d, shuf0, mask1);
> IntVector c_r = a.rearrange(b, shuf1, mask1);
> IntVector d_r = c.rearrange(d, shuf1, mask1);
>
> a = a_r;
> b = b_r;
> c = c_r;
> d = d_r;
> } else if (intSpecies.length() == 16) {
> IntVector a_r = a;
> a_r = a_r.blend(b.rearrange(shuf0), mask1);
> a_r = a_r.blend(c.rearrange(shuf0), mask2);
> a_r = a_r.blend(d.rearrange(shuf0), mask3);
>
> IntVector b_r = b;
> b_r = b_r.blend(a.rearrange(shuf1), mask0);
> b_r = b_r.blend(c.rearrange(shuf1), mask2);
> b_r = b_r.blend(d.rearrange(shuf1), mask3);
>
> IntVector c_r = c;
> c_r = c_r.blend(a.rearrange(shuf2), mask0);
> c_r = c_r.blend(b.rearrange(shuf2), mask1);
> c_r = c_r.blend(d.rearrange(shuf2), mask3);
>
> IntVector d_r = d;
> d_r = d_r.blend(a.rearrange(shuf3), mask0);
> d_r = d_r.blend(b.rearrange(shuf3), mask1);
> d_r = d_r.blend(c.rearrange(shuf3), mask2);
>
> a = a_r;
> b = b_r;
> c = c_r;
> d = d_r;
> } else {
> throw new RuntimeException("not supported");
> }
>
> // xor keystream with input
> int inOff = stateLenBytes * j;
> IntVector ina = intSpecies.fromByteArray(in, inOff);
> IntVector inb = intSpecies.fromByteArray(in, inOff + 4 *
> specLen);
> IntVector inc = intSpecies.fromByteArray(in, inOff + 8 *
> specLen);
> IntVector ind = intSpecies.fromByteArray(in, inOff + 12 *
> specLen);
>
> ina.xor(a).intoByteArray(out, inOff);
> inb.xor(b).intoByteArray(out, inOff + 4 * specLen);
> inc.xor(c).intoByteArray(out, inOff + 8 * specLen);
> ind.xor(d).intoByteArray(out, inOff + 12 * specLen);
>
> // increment counter
> sd = sd.add(counterAdd);
> }
> }
> }
>
More information about the panama-dev
mailing list