RFR: 8351034: Add AVX-512 intrinsics for ML-DSA [v11]
Volodymyr Paprotski
vpaprotski at openjdk.org
Mon Mar 24 15:19:22 UTC 2025
On Sat, 22 Mar 2025 20:02:31 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:
>> By using the AVX-512 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with two additional commits since the last revision:
>
> - Further readability improvements.
> - Added asserts for array sizes
I still need to have a look at the sha3 changes, but I think I am done with the most complex part of the review. This was a really interesting bit of code to review!
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 270:
> 268: }
> 269:
> 270: static void loadPerm(int destinationRegs[], Register perms,
`replXmm`? i.e. this function is replicating (any) Xmm register, not just perm?..
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 327:
> 325: //
> 326: //
> 327: static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen,
Similar comments as to `generate_dilithiumAlmostInverseNtt_avx512`
- similar comment about the 'pair-wise' operation, updating `[j]` and `[j+l]` at a time..
- somehow had less trouble following the flow through registers here, perhaps I am getting used to it. FYI, ended renaming some as:
// xmm16_27 = Temp1
// xmm0_3 = Coeffs1
// xmm4_7 = Coeffs2
// xmm8_11 = Coeffs3
// xmm12_15 = Coeffs4 = Temp2
// xmm16_27 = Scratch
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 421:
> 419: for (int i = 0; i < 8; i += 2) {
> 420: __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
> 421: }
Wish there was a more 'abstract' way to arrange this, so its obvious from the shape of the code what registers are input/outputs (i.e. and use the register arrays). Even though its just 'elementary index operations' `i/2 + 16` is still 'clever'. Couldnt think of anything myself though (same elsewhere in this function for the table permutes).
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 509:
> 507: // coeffs (int[256]) = c_rarg0
> 508: // zetas (int[256]) = c_rarg1
> 509: static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen,
Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!
Some general comments first, rest inline.
- The array names for registers helped a lot. And so did the new helper functions!
- The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
The core of the (Java) loop is this 'pair-wise' operation:
int a = coeffs[j];
int b = coeffs[j + offset];
coeffs[j] = (a + b);
coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);
There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).
To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.
Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
- I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
// xmm8_11 = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)
(I made a similar comment, but I think it is now hidden after the last refactor)
- would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?
Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!
Some general comments first, rest inline.
- The array names for registers helped a lot. And so did the new helper functions!
- The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
The core of the (Java) loop is this 'pair-wise' operation:
int a = coeffs[j];
int b = coeffs[j + offset];
coeffs[j] = (a + b);
coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);
There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).
To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.
Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
- I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
// xmm8_11 = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)
(I made a similar comment, but I think it is now hidden after the last refactor)
- would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 554:
> 552: for (int i = 0; i < 8; i += 2) {
> 553: __ evpermi2d(xmm(i / 2 + 8), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
> 554: __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
Took a bit to unscramble the flow, so a comment needed? Purpose 'fairly obvious' once I got the general shape of the level/algorithm (as per my top-level comment) but something like "shuffle xmm0-7 into xmm8-15"?
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 572:
> 570: load4Xmms(xmm4_7, zetas, 512, _masm);
> 571: sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm);
> 572: montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm);
>From my annotated version, levels 1-4, fairly 'straightforward':
// level 1
replXmm(Perms1, perms, nttInvL1PermsIdx, _masm);
replXmm(Perms2, perms, nttInvL1PermsIdx + 64, _masm);
for (int i = 0; i < 4; i++) {
__ evpermi2d(xmm(Perms1[i]), xmm(CoeffsPlus[i]), xmm(CoeffsMul[i]), Assembler::AVX_512bit);
__ evpermi2d(xmm(Perms2[i]), xmm(CoeffsPlus[i]), xmm(CoeffsMul[i]), Assembler::AVX_512bit);
}
load4Xmms(CoeffsMul, zetas, 512, _masm);
sub_add(CoeffsMinus, CoeffsPlus, Perms1, Perms2, _masm);
montMul64(CoeffsMul, CoeffsMinus, CoeffsMul, Scratch, _masm);
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 613:
> 611: montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm);
> 612:
> 613: // level 5
"// No shuffling for level 5 and 6; can just rearrange full registers"
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 656:
> 654: for (int i = 0; i < 8; i++) {
> 655: __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, Assembler::AVX_512bit);
> 656: }
Fairly clean as is, but could also be two sub_add calls, I think (you have to swap order of add/sub in the helper, to be able to clobber `xmm(i)`.. or swap register usage downstream, so perhaps not.. but would be cleaner)
sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm);
sub_add(CoeffsMul, &Scratch[4], Perms2, CoeffsMul, _masm);
If nothing else, would had prefered to see the use of the register array variables
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 660:
> 658: store4Xmms(coeffs, 0, xmm16_19, _masm);
> 659: store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm);
> 660: montMulByConst128(_masm);
Would prefer explicit parameters here. But I think this could also be two `montMul64` calls?
montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch, _masm);
montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch, _masm);
(I think there is one other use of `montMulByConst128` where same applies; then you could delete both `montMulByConst128` and `montmulEven`
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 871:
> 869: __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, Assembler::AVX_512bit);
> 870: __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, Assembler::AVX_512bit);
> 871: __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, Assembler::AVX_512bit);
Fairly 'straightforward' transcription of the java code.. no comments from me.
At first glance using `xmm0_3`, `xmm4_7`, etc. might had been a good idea, but you only save one line per 4x group. (Unless you have one big loop, but I suspect that give you worse performance? Is that something you tried already? Might be worth it otherwise..)
src/java.base/share/classes/sun/security/provider/ML_DSA.java line 1418:
> 1416: int twoGamma2, int multiplier) {
> 1417: assert (input.length == ML_DSA_N) && (lowPart.length == ML_DSA_N)
> 1418: && (highPart.length == ML_DSA_N);
I wrote this test to test java-to-intrinsic correspondence. Might be good to include it (and add the other 4 intrinsics). This is very similar to all my other *Fuzz* tests I've been adding for my own intrinsics (and you made this test FAR easier to write by breaking out the java implementation; need to 'copy' that pattern myself)
import java.util.Arrays;
import java.util.Random;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Constructor;
public class ML_DSA_Intrinsic_Test {
public static void main(String[] args) throws Exception {
MethodHandles.Lookup lookup = MethodHandles.lookup();
Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA");
Constructor<?> constructor = kClazz.getDeclaredConstructor(
int.class);
constructor.setAccessible(true);
Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply",
int[].class, int[].class, int[].class);
m.setAccessible(true);
MethodHandle mult = lookup.unreflect(m);
m = kClazz.getDeclaredMethod("implDilithiumNttMultJava",
int[].class, int[].class, int[].class);
m.setAccessible(true);
MethodHandle multJava = lookup.unreflect(m);
Random rnd = new Random();
long seed = rnd.nextLong();
rnd.setSeed(seed);
//Note: it might be useful to increase this number during development of new intrinsics
final int repeat = 1000000;
int[] coeffs1 = new int[ML_DSA_N];
int[] coeffs2 = new int[ML_DSA_N];
int[] prod1 = new int[ML_DSA_N];
int[] prod2 = new int[ML_DSA_N];
try {
for (int i = 0; i < repeat; i++) {
run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
}
System.out.println("Fuzz Success");
} catch (Throwable e) {
System.out.println("Fuzz Failed: " + e);
}
}
private static final int ML_DSA_N = 256;
public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2,
MethodHandle mult, MethodHandle multJava, Random rnd,
long seed, int i) throws Exception, Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
coeffs1[j] = rnd.nextInt();
coeffs2[j] = rnd.nextInt();
}
mult.invoke(prod1, coeffs1, coeffs2);
multJava.invoke(prod2, coeffs1, coeffs2);
if (!Arrays.equals(prod1, prod2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mismatch: " + Arrays.toString(prod1) + " != " + Arrays.toString(prod2));
}
}
}
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java
-------------
PR Review: https://git.openjdk.org/jdk/pull/23860#pullrequestreview-2708301954
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008921783
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009415317
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009477186
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009428310
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009433467
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009435329
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009435791
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009437669
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009438921
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009486160
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2010355575
More information about the hotspot-dev
mailing list