[vector] Thoughts on ChaCha20 prototype/benchmark
Adam Petcher
adam.petcher at oracle.com
Fri Feb 8 21:26:36 UTC 2019
I'm still working on crypto prototypes on the vector API, and I managed
to get a version of ChaCha20[1] working. It's pretty fast on
AVX2---about 2/3 the speed of openssl in my initial testing. I have a
couple of questions for the experts on this mailing list:
1) Is there any interest in including this benchmark in the repo? The
code is relatively simple, and it may be a good "real-world" example to
track for purposes of API suitability and performance. If so, I'd be
happy to submit a patch---I'll just need a little guidance on where this
benchmark should go in the source tree. Note that we already have
benchmarks for the scalar ChaCha20 implementation in the JDK.
2) Does anybody see any obvious improvements that could be made to the
code (included at the end of this e-mail)? The code is pretty
rough---all of my attempts to better organize the code seem to have
eliminated inlining. I'm mostly looking for advice that improves
performance, but tips on how to organize the code without hurting
performance would also be appreciated. This code should work for vectors
of 128, 256, and 512 bits, though I haven't tested on a machine with a
512-bit vector unit.
[1] https://eprint.iacr.org/2013/759.pdf
------------------------------------------------------------------------------------
import jdk.incubator.vector.*;
public class ChaCha20 {
private static final IntVector.IntSpecies intSpecies = IntVector.species(Vector.Shape.S_256_BIT);
private static int numBlocks = intSpecies.length() / 4;
private static final IntVector.Shuffle rot1 = rotate(intSpecies, 1);
private static final IntVector.Shuffle rot2 = rotate(intSpecies, 2);
private static final IntVector.Shuffle rot3 = rotate(intSpecies, 3);
private static IntVector.Shuffle rotate(IntVector.IntSpecies spec, int amount) {
int[] shuffleArr = new int[spec.length()];
for (int i = 0; i < spec.length(); i ++) {
int offset = (i / 4) * 4;
shuffleArr[i] = offset + ((i + amount) % 4);
}
return spec.shuffleFromValues(shuffleArr);
}
private static int[] constants = new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574};
public static int[] makeState(byte[] key, byte[] nonce, long counter, int[] result) {
// first field is constants
for (int i = 0; i < 4; i++) {
for (int j = 0; j < numBlocks; j++) {
result[4*j + i] = constants[i];
}
}
// second field is first part of key
int fieldStart = 4 * numBlocks;
for (int i = 0; i < 4; i++) {
int keyInt = 0;
for (int j = 0; j < 4; j++) {
keyInt += (0xFF & key[4 * i + j]) << 8 * j;
}
for (int j = 0; j < numBlocks; j++) {
result[fieldStart + j*4 + i] = keyInt;
}
}
// third field is second part of key
fieldStart = 8 * numBlocks;
for (int i = 0; i < 4; i++) {
int keyInt = 0;
for (int j = 0; j < 4; j++) {
keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
}
for (int j = 0; j < numBlocks; j++) {
result[fieldStart + j*4 + i] = keyInt;
}
}
// fourth field is counter and nonce
fieldStart = 12 * numBlocks;
for (int j = 0; j < numBlocks; j++) {
result[fieldStart + j*4] = (int) (counter + j);
}
for (int i = 0; i < 3; i++) {
int nonceInt = 0;
for (int j = 0; j < 4; j++) {
nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
}
for (int j = 0; j < numBlocks; j++) {
result[fieldStart + j*4 + 1 + i] = nonceInt;
}
}
return result;
}
private static int[] state = new int[numBlocks * 16];
private static IntVector counterAdd = makeCounterAdd();
private static IntVector makeCounterAdd() {
int[] addArr = new int[intSpecies.length()];
for(int i = 0; i < numBlocks; i++) {
addArr[4 * i] = numBlocks;
}
return intSpecies.fromArray(addArr, 0);
}
private static IntVector.Shuffle makeRearrangeShuffle(int order) {
int[] shuffleArr = new int[intSpecies.length()];
int start = order * 4;
for (int i = 0; i < shuffleArr.length; i++) {
shuffleArr[i] = (i % 4) + start;
}
return intSpecies.shuffleFromArray(shuffleArr, 0);
}
private static IntVector.Mask makeRearrangeMask(int order) {
boolean[] maskArr = new boolean[intSpecies.length()];
int start = order * 4;
if (start < maskArr.length) {
for (int i = 0; i < 4; i++) {
maskArr[i + start] = true;
}
}
return intSpecies.maskFromValues(maskArr);
}
private static IntVector.Shuffle shuf0 = makeRearrangeShuffle(0);
private static IntVector.Shuffle shuf1 = makeRearrangeShuffle(1);
private static IntVector.Shuffle shuf2 = makeRearrangeShuffle(2);
private static IntVector.Shuffle shuf3 = makeRearrangeShuffle(3);
private static IntVector.Mask mask0 = makeRearrangeMask(0);
private static IntVector.Mask mask1 = makeRearrangeMask(1);
private static IntVector.Mask mask2 = makeRearrangeMask(2);
private static IntVector.Mask mask3 = makeRearrangeMask(3);
public static void chacha20(byte[] key, byte[] nonce, long counter, byte[] in, byte[] out) {
makeState(key, nonce, counter, state);
int specLen = intSpecies.length();
IntVector sa = intSpecies.fromArray(state, 0);
IntVector sb = intSpecies.fromArray(state, specLen);
IntVector sc = intSpecies.fromArray(state, 2 * specLen);
IntVector sd = intSpecies.fromArray(state, 3 * specLen);
int stateLenBytes = state.length * 4;
for (int j = 0; j < (in.length + stateLenBytes - 1) / stateLenBytes; j++){
IntVector a = sa;
IntVector b = sb;
IntVector c = sc;
IntVector d = sd;
for (int i = 0; i < 10; i++) {
// first round
a = a.add(b);
d = d.xor(a);
d = d.rotateL(16);
c = c.add(d);
b = b.xor(c);
b = b.rotateL(12);
a = a.add(b);
d = d.xor(a);
d = d.rotateL(8);
c = c.add(d);
b = b.xor(c);
b = b.rotateL(7);
// rotate
b = b.rearrange(rot1);
c = c.rearrange(rot2);
d = d.rearrange(rot3);
// second round
a = a.add(b);
d = d.xor(a);
d = d.rotateL(16);
c = c.add(d);
b = b.xor(c);
b = b.rotateL(12);
a = a.add(b);
d = d.xor(a);
d = d.rotateL(8);
c = c.add(d);
b = b.xor(c);
b = b.rotateL(7);
// rotate
b = b.rearrange(rot3);
c = c.rearrange(rot2);
d = d.rearrange(rot1);
}
a = a.add(sa);
b = b.add(sb);
c = c.add(sc);
d = d.add(sd);
// rearrange the vectors
if (intSpecies.length() == 4) {
// no rearrange needed
} else if (intSpecies.length() == 8) {
IntVector a_r = a.rearrange(b, shuf0, mask1);
IntVector b_r = c.rearrange(d, shuf0, mask1);
IntVector c_r = a.rearrange(b, shuf1, mask1);
IntVector d_r = c.rearrange(d, shuf1, mask1);
a = a_r;
b = b_r;
c = c_r;
d = d_r;
} else if (intSpecies.length() == 16) {
IntVector a_r = a;
a_r = a_r.blend(b.rearrange(shuf0), mask1);
a_r = a_r.blend(c.rearrange(shuf0), mask2);
a_r = a_r.blend(d.rearrange(shuf0), mask3);
IntVector b_r = b;
b_r = b_r.blend(a.rearrange(shuf1), mask0);
b_r = b_r.blend(c.rearrange(shuf1), mask2);
b_r = b_r.blend(d.rearrange(shuf1), mask3);
IntVector c_r = c;
c_r = c_r.blend(a.rearrange(shuf2), mask0);
c_r = c_r.blend(b.rearrange(shuf2), mask1);
c_r = c_r.blend(d.rearrange(shuf2), mask3);
IntVector d_r = d;
d_r = d_r.blend(a.rearrange(shuf3), mask0);
d_r = d_r.blend(b.rearrange(shuf3), mask1);
d_r = d_r.blend(c.rearrange(shuf3), mask2);
a = a_r;
b = b_r;
c = c_r;
d = d_r;
} else {
throw new RuntimeException("not supported");
}
// xor keystream with input
int inOff = stateLenBytes * j;
IntVector ina = intSpecies.fromByteArray(in, inOff);
IntVector inb = intSpecies.fromByteArray(in, inOff + 4 * specLen);
IntVector inc = intSpecies.fromByteArray(in, inOff + 8 * specLen);
IntVector ind = intSpecies.fromByteArray(in, inOff + 12 * specLen);
ina.xor(a).intoByteArray(out, inOff);
inb.xor(b).intoByteArray(out, inOff + 4 * specLen);
inc.xor(c).intoByteArray(out, inOff + 8 * specLen);
ind.xor(d).intoByteArray(out, inOff + 12 * specLen);
// increment counter
sd = sd.add(counterAdd);
}
}
}
More information about the panama-dev
mailing list