RFR: 8351034: Add AVX-512 intrinsics for ML-DSA [v10]
Volodymyr Paprotski
vpaprotski at openjdk.org
Sat Mar 22 20:05:11 UTC 2025
On Thu, 20 Mar 2025 20:37:25 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:
>> By using the AVX-512 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with one additional commit since the last revision:
>
> Fix windows build
was going to finish the rest of the functions.. but I see you pushed an update so I better rebase! here are the pending comments I had that perhaps are no longer applicable..
(working through the ntt math..)
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 121:
> 119: static void montmulEven(int outputReg, int inputReg1, int inputReg2,
> 120: int scratchReg1, int scratchReg2,
> 121: int parCnt, MacroAssembler *_masm) {
nitpick.. this could be made to look more like `montMul64()` by also taking in an array of registers.
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 160:
> 158: for (int i = 0; i < 4; i++) {
> 159: __ vpmuldq(xmm(scratchRegs[i]), xmm(inputRegs1[i]), xmm(inputRegs2[i]),
> 160: Assembler::AVX_512bit);
using an array of registers, instead of array of ints would read somewhat more compact and fewer 'indirections' . i.e.
static void montMul64(XMMRegister outputRegs*, XMMRegister inputRegs1*, XMMRegister inputRegs2*,
...
__ vpmuldq(scratchRegs[i], inputRegs1[i], inputRegs2[i], Assembler::AVX_512bit);
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 216:
> 214: // Zmm8-Zmm23 used as scratch registers
> 215: // result goes to Zmm0-Zmm7
> 216: static void montMulByConst128(MacroAssembler *_masm) {
wish the inputs and output register arrays were explicit.. easier to follow that way
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 230:
> 228: }
> 229:
> 230: static void sub_add(int subResult[], int addResult[],
Big fan of all these helper functions! Makes reading the top level functions way easier, thanks for refactoring!
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 279:
> 277: static int xmm4_20_24[] = {4, 5, 6, 7, 20, 21, 22, 23, 24, 25, 26, 27};
> 278: static int xmm16_27[] = {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27};
> 279: static int xmm29_29[] = {29, 29, 29, 29};
I very much like the new refactor, waaaay clearer now. Some 'Could Do' comments..
- I probably would have preferred 'even more symbolic' variable names (i.e. its ideal when you can match the java variable names!). Conversely, if 'forced to defend this style', these names are MUCH much easier to debug from GDB, its clear what the matching instruction is.
- Not sure about it being global. It works currently, but less 'future proof'.
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 645:
> 643: // poly1 (int[256]) = c_rarg1
> 644: // poly2 (int[256]) = c_rarg2
> 645: static address generate_dilithiumNttMult_avx512(StubGenerator *stubgen,
This would be 'nice to have', something 'lost' with the refactor..
As I was reviewing this (original) function, I was thinking, "there is nothing here _that_ specific to AVX512, mostly columnar&independent operations... This function could be made 'vector-length-independent'..."
- double the loop length:
int iter = vector_len==Assembler::AVX_512bit?4:8;
__ movl(len, 4); -> __ movl(len, iter);
- halve the register arrays.. (or keep them the same but shuffle them to make SURE the first half are in xmm0-xmm15 range)
XMMRegister POLY1[] = {xmm0, xmm1, xmm12, xmm13};
XMMRegister POLY2[] = {xmm4, xmm5, xmm16, xmm17};
XMMRegister SCRATCH1[] = {xmm2, xmm3, xmm14, xmm15}; <<< here
XMMRegister SCRATCH2[] = {xmm6, xmm7, xmm18, xmm19}; <<< and here
XMMRegister SCRATCH3[] = {xmm8, xmm9, xmm10, xmm11};
- couple of other int constants (like the memory 'step' and such)
- for assembler calls, like `evmovdqul` and `evpsubd`, need a few small new MacroAssembler helpers to instead generate VEX encoded versions (plenty of instructions already do that).
- I think only the perm instruction was unique to evex (didnt really think of an alternative for AVX2.. but can be abstracted away with another helper)
Anyway; not suggesting its something you do here.. but it would be convenient to leave breadcrumbs/hooks for a future update so one of us can revisit this code and add AVX2 support. e.g. `parCnt` variable was very convenient before for exactly this, now its gone... it probably could be derived in each function from vector_len but..; Its now cleaner, but also harder to 'upgrade'?
Why AVX2? many of the newer (Atom/Ecore-based/EnableX86ECoreOpts) processors do not have AVX512 support, so its something I've been prioritizing recently
The alternative would be to write a completely separate AVX2 implementation, but that would be a shame, not to 'just' reuse this code.
"For fun", I had even gone and parametrized the mult function with the `vector_len` to see how it would look (almost identical... to the original version):
static void montmulEven2(XMMRegister* outputReg, XMMRegister* inputReg1, XMMRegister* inputReg2, XMMRegister* scratchReg1,
XMMRegister* scratchReg2, XMMRegister montQInvModR, XMMRegister dilithium_q, int parCnt, int vector_len, MacroAssembler* _masm) {
for (int i = 0; i < parCnt; i++) {
// scratch1 = (int64)input1_even*input2_even
// Java: long a = (long) b * (long) c;
__ vpmuldq(scratchReg1[i], inputReg1[i], inputReg2[i], vector_len);
}
for (int i = 0; i < parCnt; i++) {
// scratch2 = int32(montQInvModR*(int32)scratch1)
// Java: int aLow = (int) a;
// Java: int m = MONT_Q_INV_MOD_R * aLow; // signed low product
__ vpmulld(scratchReg2[i], scratchReg1[i], montQInvModR, vector_len);
}
for (int i = 0; i < parCnt; i++) {
// scratch2 = (int64)scratch2_even*dilithium_q_even
// Java: ((long)m * MONT_Q)
__ vpmuldq(scratchReg2[i], scratchReg2[i], dilithium_q, vector_len);
}
for (int i = 0; i < parCnt; i++) {
// output_odd = scratch1_odd - scratch2_odd
// Java: (aHigh - (int) (("scratch2") >> MONT_R_BITS))
__ vpsubd(outputReg[i], scratchReg1[i], scratchReg2[i], vector_len);
}
}
-------------
PR Review: https://git.openjdk.org/jdk/pull/23860#pullrequestreview-2708079853
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008809855
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008811046
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008811541
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008811704
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008808110
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008824304
More information about the hotspot-dev
mailing list