RFR: 8351034: Add AVX-512 intrinsics for ML-DSA [v7]
Volodymyr Paprotski
vpaprotski at openjdk.org
Mon Mar 17 21:49:12 UTC 2025
On Wed, 12 Mar 2025 19:19:08 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:
>> By using the AVX-512 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with one additional commit since the last revision:
>
> Made the intrinsics test separate from the pure java test.
Partial review, just didnt want to sit on comments for this long.
(Spent quite a bit of time catching up on papers and math required)
The biggest roadblock I have following the code are raw register numbers. (And more comments? perhaps I need more math knowledge, but comments would help too).
Also, 'hidden variables' (xmm30). Can't complain, because this is exactly what Vladimir Ivanov told me to do on my first PR https://github.com/openjdk/jdk/pull/10582#discussion_r1022185591 Perhaps that discussion applies here too.
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 45:
> 43: // Constants
> 44: //
> 45: ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Consts[] = {
This is really nitpicking.. but could had loaded constants inline with `movl` without requiring an ExternalAddress()?
Nice to have constants together, only complaint is we have 'magic offsets' in ASM to reach in for particular one..
This one isnt too bad, offset of 32bits is easy to inspect visually (`dilithiumAvx512ConstsAddr()` could take a parameter perhaps)
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 58:
> 56:
> 57: ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Perms[] = {
> 58: // collect montmul results into the destination register
same as `dilithiumAvx512Consts()`, 'magic offsets'; except here they are harder to count (eg. not clear visually what is the offset of `ntt inverse`).
Could be split into three constant arrays to make the compiler count for us
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 127:
> 125: for (int i = 0; i < parCnt; i++) {
> 126: __ evpsubd(xmm(i + outputReg), k0, xmm(i + scratchReg1), xmm(i + scratchReg2), false, Assembler::AVX_512bit);
> 127: }
This is such a deceptively brilliant function!!! Took me a while to understand (and map to Java `montMul` function). Perhaps needs more comments.
The comment on line 99 does provide good hints, but I still had some trouble. I ended up annotating a copy quite a bit. I do think all 'clever code' needs comments. Here is my annotated version, if you want to copy out anything:
static void montmulEven2(XMMRegister outputReg, XMMRegister inputReg1, XMMRegister inputReg2, XMMRegister scratchReg1,
XMMRegister scratchReg2, XMMRegister montQInvModR, XMMRegister dilithium_q, int parCnt, MacroAssembler* _masm) {
int output = outputReg->encoding();
int input1 = inputReg1->encoding();
int input2 = inputReg2->encoding();
int scratch1 = scratchReg1->encoding();
int scratch2 = scratchReg2->encoding();
for (int i = 0; i < parCnt; i++) {
// scratch1 = (int64)input1_even*input2_even
// Java: long a = (long) b * (long) c;
__ vpmuldq(xmm(i + scratch1), xmm(i + input1), xmm((input2 == 29) ? 29 : input2 + i), Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
// scratch2 = int32(montQInvModR*(int32)scratch1)
// Java: int aLow = (int) a;
// Java: int m = MONT_Q_INV_MOD_R * aLow; // signed low product
__ vpmulld(xmm(i + scratch2), xmm(i + scratch1), montQInvModR, Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
// scratch2 = (int64)scratch2_even*dilithium_q_even
// Java: ((long)m * MONT_Q)
__ vpmuldq(xmm(i + scratch2), xmm(i + scratch2), dilithium_q, Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
// output_odd = scratch1_odd - scratch2_odd
// Java: (aHigh - (int) (("scratch2") >> MONT_R_BITS))
__ evpsubd(xmm(i + output), k0, xmm(i + scratch1), xmm(i + scratch2), false, Assembler::AVX_512bit);
}
}
- add comment that input2 can be xmm29, treated as constants, not consecutive (i.e. zetas)
- Candidate for ascii art, even/odd columns, implicit int/long casts (or more 'math' comments on what happens)
- use XMMRegisters instead of numbers (improve callsite readability)
- can use either `inputReg1 = inputReg1->successor()`
- or get `encoding()` and keep current style
- could be static (local) function (hide from header), then pass _masm
- pass all registers used (helps seeing register allocation, confirm no overlaps)
False trails (i.e. nothing to do, but I thought about it already, so other reviewer doesnt have to?)
- (ignore: worse performance) squash into a single for loop, let cpu do out-of-order (and improve readability)
- xmm30/xmm31 (montQInvModR/dilithium_q) are constant. At a glance, it looks like they should be combined into one precomputed one. And paper 039.pdf suggests merging constants precompute the product; but.. different constants and looking at Java, there are several implicit casts
For reductions of products inside the NTT this is not a problem because one has to multiply by the roots of unity
which are compile-time constants. So one can just precompute them with an additional
factor of β mod q so that the results after Montgomery reduction are in fact congruent to the desired value a
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 140:
> 138: __ vpmuldq(xmm(scratchReg1 + 1), xmm(inputReg12), xmm(inputReg2 + 1), Assembler::AVX_512bit);
> 139: __ vpmuldq(xmm(scratchReg1 + 2), xmm(inputReg13), xmm(inputReg2 + 2), Assembler::AVX_512bit);
> 140: __ vpmuldq(xmm(scratchReg1 + 3), xmm(inputReg14), xmm(inputReg2 + 3), Assembler::AVX_512bit);
Another option for these four lines, to keep the style of rest of function
int inputReg1[] = {inputReg11, inputReg12, inputReg13, inputReg14};
for (int i = 0; i < parCnt; i++) {
__ vpmuldq(xmm(scratchReg1 + i), inputReg1[i], xmm(inputReg2 + i), Assembler::AVX_512bit);
}
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 197:
> 195:
> 196: // level 0
> 197: montmulEven(20, 8, 29, 20, 16, 4);
It would improve readability to know which parameter is a register, and which is a count.. i.e.
`montmulEven(xmm20, xmm8, xmm29, xmm20, xmm16, 4);`
(its not _that_ bad, once I remember that its always the last parameter.. but it does add to the 'mental load' one has to carry, and this code is already interesting enough)
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 980:
> 978: // Dilithium multiply polynomials in the NTT domain.
> 979: // Implements
> 980: // static int implDilithiumNttMult(
I suppose no java changes in this PR, but I notice that the inputs are all assumed to have fixed size.
Most/all intrinsics I worked with had some sort of guard (eg `Objects.checkFromIndexSize`) right before the intrinsic java call. (It usually looks like it can be optimized away). But I notice no such guard here on the java side.
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1010:
> 1008: __ vpbroadcastd(xmm31, Address(dilithiumConsts, 4), Assembler::AVX_512bit); // q
> 1009: __ vpbroadcastd(xmm29, Address(dilithiumConsts, 12), Assembler::AVX_512bit); // 2^64 mod q
> 1010: __ evmovdqul(xmm28, Address(perms, 0), Assembler::AVX_512bit);
- use of `c_rarg3` is 'clever' so probably should have a comment (ie. 'no 3rd parameter, free register')
- Alternatively, load directly into the vector with `ExternalAddress()`; you need a scratch register (use r10) but address is close enough, it actually wont be used. Here is the disassembly I got:
StubRoutines::dilithiumNttMult [0x00007f414fb68280, 0x00007f414fb68548] (712 bytes)
--------------------------------------------------------------------------------
add %al,(%rax)
0x00007f414fb68280: push %rbp
0x00007f414fb68281: mov %rsp,%rbp
0x00007f414fb68284: vpbroadcastd 0x18f9fe32(%rip),%zmm30 # 0x00007f4168b080c0
0x00007f414fb6828e: vpbroadcastd 0x18f9fe2c(%rip),%zmm31 # 0x00007f4168b080c4
0x00007f414fb68298: vpbroadcastd 0x18f9fe2a(%rip),%zmm29 # 0x00007f4168b080cc
0x00007f414fb682a2: vmovdqu32 0x18f9f8d4(%rip),%zmm28 # 0x00007f4168b07b80
```
The `ExternalAddress()` calls for above assembler
```
const Register scratch = r10;
const XMMRegister montRSquareModQ = xmm29;
const XMMRegister montQInvModR = xmm30;
const XMMRegister dilithium_q = xmm31;
const XMMRegister perms = xmm28;
__ vpbroadcastd(montQInvModR, ExternalAddress(dilithiumAvx512ConstsAddr()), Assembler::AVX_512bit, scratch); // q^-1 mod 2^32
__ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr() + 4), Assembler::AVX_512bit, scratch); // q
__ vpbroadcastd(montRSquareModQ, ExternalAddress(dilithiumAvx512ConstsAddr() + 12), Assembler::AVX_512bit, scratch); // 2^64 mod q
__ evmovdqul(perms, k0, ExternalAddress(dilithiumAvx512PermsAddr()), false, Assembler::AVX_512bit, scratch);
(and `dilithiumAvx512ConstsAddr(offset)` cound take an int parameter too)
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1012:
> 1010: __ evmovdqul(xmm28, Address(perms, 0), Assembler::AVX_512bit);
> 1011:
> 1012: __ movl(len, 4);
Compile-time constant, why not 'unroll at compile time'? i.e. wrap this loop with `for (int len=0; len<4; len++)` instead?
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1041:
> 1039: for (int i = 0; i < 4; i++) {
> 1040: __ evmovdqul(Address(result, i * 64), xmm(i), Assembler::AVX_512bit);
> 1041: }
This is nice, compact and clean. The biggest issue I have with following this code is really with all the 'raw' registers. I would much rather prefer symbolic names, but up to you to decide style.
I ended up 'annotating' this snippet, so I could understand it and confirm everything.. as with montmulEven, hope some of it can be useful to you to copy out.
XMMRegister POLY1[] = {xmm0, xmm1, xmm2, xmm3};
XMMRegister POLY2[] = {xmm4, xmm5, xmm6, xmm7};
XMMRegister SCRATCH1[] = {xmm12, xmm13, xmm14, xmm15};
XMMRegister SCRATCH2[] = {xmm16, xmm17, xmm18, xmm19};
XMMRegister SCRATCH3[] = {xmm8, xmm9, xmm10, xmm11};
for (int i = 0; i < 4; i++) {
__ evmovdqul(POLY1[i], Address(poly1, i * 64), Assembler::AVX_512bit);
__ evmovdqul(POLY2[i], Address(poly2, i * 64), Assembler::AVX_512bit);
}
// montmulEven: inputs are in even columns and output is in odd columns
// scratch3_even = poly2_even*montRSquareModQ // poly2 to montgomery domain
montmulEven2(SCRATCH3[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
// swap even/odd; 0xB1 == 2-3-0-1
__ vpshufd(SCRATCH3[i], SCRATCH3[i], 0xB1, Assembler::AVX_512bit);
}
// scratch3_odd = poly1_even*scratch3_even = poly1_even*poly2_even*montRSquareModQ
montmulEven2(SCRATCH3[0], POLY1[0], SCRATCH3[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
__ vpshufd(POLY1[i], POLY1[i], 0xB1, Assembler::AVX_512bit);
__ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
}
// poly2_even = poly2_odd*montRSquareModQ // poly2 to montgomery domain
montmulEven2(POLY2[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
__ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
}
// poly1_odd = poly1_even*poly2_even
montmulEven2(POLY1[0], POLY1[0], POLY2[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
// result is scrambled between scratch3_odd and poly1_odd; unscramble
__ evpermt2d(POLY1[i], perms, SCRATCH3[i], Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evmovdqul(Address(result, i * 64), POLY1[i], Assembler::AVX_512bit);
}
With symbolic variable names, code was much easier to follow conceptually. Also has the side benefit of making it obvious which XMM registers are used and that there is no conflicts
src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1090:
> 1088: __ evpbroadcastd(xmm29, constant, Assembler::AVX_512bit); // constant multiplier
> 1089:
> 1090: __ movl(len, 2);
Same comment here as the `generate_dilithiumNttMult_avx512`
- constants can be loaded directly into XMM
- len can be removed by unrolling at compile time
- symbolic names could be used for registers
- comments could be added
-------------
PR Review: https://git.openjdk.org/jdk/pull/23860#pullrequestreview-2665370975
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999468929
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999471763
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999625933
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1992230295
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1992235625
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999712200
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999413007
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999367607
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999683384
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r1999686631
More information about the hotspot-dev
mailing list