RFR: 8351034: Add AVX-512 intrinsics for ML-DSA [v7]

Ferenc Rakoczi duke at openjdk.org
Thu Mar 20 21:09:12 UTC 2025


On Mon, 17 Mar 2025 19:24:52 GMT, Volodymyr Paprotski <vpaprotski at openjdk.org> wrote:

>> Ferenc Rakoczi has updated the pull request incrementally with one additional commit since the last revision:
>> 
>>   Made the intrinsics test separate from the pure java test.
>
> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 58:
> 
>> 56: 
>> 57: ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Perms[] = {
>> 58:      // collect montmul results into the destination register
> 
> same as `dilithiumAvx512Consts()`, 'magic offsets'; except here they are harder to count (eg. not clear visually what is the offset of `ntt inverse`).
> 
> Could be split into three constant arrays to make the compiler count for us

Well, it is 64 bytes per  line (16 4-byte uint32_ts), not that hard :-) ...

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 140:
> 
>> 138:   __ vpmuldq(xmm(scratchReg1 + 1), xmm(inputReg12), xmm(inputReg2 + 1), Assembler::AVX_512bit);
>> 139:   __ vpmuldq(xmm(scratchReg1 + 2), xmm(inputReg13), xmm(inputReg2 + 2), Assembler::AVX_512bit);
>> 140:   __ vpmuldq(xmm(scratchReg1 + 3), xmm(inputReg14), xmm(inputReg2 + 3), Assembler::AVX_512bit);
> 
> Another option for these four lines, to keep the style of rest of function
> 
> int inputReg1[] = {inputReg11, inputReg12, inputReg13, inputReg14};
>   for (int i = 0; i < parCnt; i++) {
>     __ vpmuldq(xmm(scratchReg1 + i), inputReg1[i], xmm(inputReg2 + i), Assembler::AVX_512bit);
>   }

I have changed the whole structure instead.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 197:
> 
>> 195: 
>> 196:   // level 0
>> 197:   montmulEven(20, 8, 29, 20, 16, 4);
> 
> It would improve readability to know which parameter is a register, and which is a count.. i.e. 
> 
> `montmulEven(xmm20, xmm8, xmm29, xmm20, xmm16, 4);`
> 
> (its not _that_ bad, once I remember that its always the last parameter.. but it does add to the 'mental load' one has to carry, and this code is already interesting enough)

I have changed the structure, now it is clear(er) which parameter is what.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 980:
> 
>> 978: // Dilithium multiply polynomials in the NTT domain.
>> 979: // Implements
>> 980: // static int implDilithiumNttMult(
> 
> I suppose no java changes in this PR, but I notice that the inputs are all assumed to have fixed size.
> 
> Most/all intrinsics I worked with had some sort of guard (eg `Objects.checkFromIndexSize`) right before the intrinsic java call. (It usually looks like it can be optimized away). But I notice no such guard here on the java side.

These functions will not be used anywhere else and in ML_DSA.java all of the arrays passed to inrinsics are of the correct size.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1010:
> 
>> 1008:   __ vpbroadcastd(xmm31, Address(dilithiumConsts, 4), Assembler::AVX_512bit); // q
>> 1009:   __ vpbroadcastd(xmm29, Address(dilithiumConsts, 12), Assembler::AVX_512bit); // 2^64 mod q
>> 1010:   __ evmovdqul(xmm28, Address(perms, 0), Assembler::AVX_512bit);
> 
> - use of `c_rarg3` is 'clever' so probably should have a comment (ie. 'no 3rd parameter, free register')
> - Alternatively, load directly into the vector with `ExternalAddress()`; you need a scratch register (use r10) but address is close enough, it actually wont be used. Here is the disassembly I got:
> 
> StubRoutines::dilithiumNttMult [0x00007f414fb68280, 0x00007f414fb68548] (712 bytes)
> --------------------------------------------------------------------------------
> add    %al,(%rax)
>   0x00007f414fb68280:   push   %rbp
>   0x00007f414fb68281:   mov    %rsp,%rbp
>   0x00007f414fb68284:   vpbroadcastd 0x18f9fe32(%rip),%zmm30        # 0x00007f4168b080c0
>   0x00007f414fb6828e:   vpbroadcastd 0x18f9fe2c(%rip),%zmm31        # 0x00007f4168b080c4
>   0x00007f414fb68298:   vpbroadcastd 0x18f9fe2a(%rip),%zmm29        # 0x00007f4168b080cc
>   0x00007f414fb682a2:   vmovdqu32 0x18f9f8d4(%rip),%zmm28        # 0x00007f4168b07b80
>   ```
>   
> The `ExternalAddress()` calls for above assembler
>   ```
>   const Register scratch = r10;
>   const XMMRegister montRSquareModQ = xmm29;
>   const XMMRegister montQInvModR = xmm30;
>   const XMMRegister dilithium_q = xmm31;
>   const XMMRegister perms = xmm28;
> 
>   __ vpbroadcastd(montQInvModR, ExternalAddress(dilithiumAvx512ConstsAddr()), Assembler::AVX_512bit, scratch); // q^-1 mod 2^32
>   __ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr() + 4), Assembler::AVX_512bit, scratch); // q
>   __ vpbroadcastd(montRSquareModQ, ExternalAddress(dilithiumAvx512ConstsAddr() + 12), Assembler::AVX_512bit, scratch); // 2^64 mod q
>   __ evmovdqul(perms, k0, ExternalAddress(dilithiumAvx512PermsAddr()), false, Assembler::AVX_512bit, scratch);
> 
> (and `dilithiumAvx512ConstsAddr(offset)` cound take an int parameter too)

I added comments and changed the vpbroadcast loads to load directly from memory.l

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1012:
> 
>> 1010:   __ evmovdqul(xmm28, Address(perms, 0), Assembler::AVX_512bit);
>> 1011: 
>> 1012:   __ movl(len, 4);
> 
> Compile-time constant, why not 'unroll at compile time'? i.e. wrap this loop with `for (int len=0; len<4; len++)` instead?

I have found that unrolling these loops actually hurts performance (probably an I-cache effect.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1041:
> 
>> 1039:   for (int i = 0; i < 4; i++) {
>> 1040:     __ evmovdqul(Address(result, i * 64), xmm(i), Assembler::AVX_512bit);
>> 1041:   }
> 
> This is nice, compact and clean. The biggest issue I have with following this code is really with all the 'raw' registers. I would much rather prefer symbolic names, but up to you to decide style.
> 
> I ended up 'annotating' this snippet, so I could understand it and confirm everything..  as with montmulEven, hope some of it can be useful to you to copy out.
> 
> 
>   XMMRegister POLY1[] = {xmm0, xmm1, xmm2, xmm3};
>   XMMRegister POLY2[] = {xmm4, xmm5, xmm6, xmm7};
>   XMMRegister SCRATCH1[] = {xmm12, xmm13, xmm14, xmm15};
>   XMMRegister SCRATCH2[] = {xmm16, xmm17, xmm18, xmm19};
>   XMMRegister SCRATCH3[] = {xmm8, xmm9, xmm10, xmm11};
>   for (int i = 0; i < 4; i++) {
>     __ evmovdqul(POLY1[i], Address(poly1, i * 64), Assembler::AVX_512bit);
>     __ evmovdqul(POLY2[i], Address(poly2, i * 64), Assembler::AVX_512bit);
>   }
> 
>   // montmulEven: inputs are in even columns and output is in odd columns
>   // scratch3_even = poly2_even*montRSquareModQ // poly2 to montgomery domain
>   montmulEven2(SCRATCH3[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], montQInvModR, dilithium_q, 4, _masm);
>   for (int i = 0; i < 4; i++) {
>     // swap even/odd; 0xB1 == 2-3-0-1
>     __ vpshufd(SCRATCH3[i], SCRATCH3[i], 0xB1, Assembler::AVX_512bit);
>   }
> 
>   // scratch3_odd = poly1_even*scratch3_even = poly1_even*poly2_even*montRSquareModQ
>   montmulEven2(SCRATCH3[0], POLY1[0], SCRATCH3[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
>   for (int i = 0; i < 4; i++) {
>     __ vpshufd(POLY1[i], POLY1[i], 0xB1, Assembler::AVX_512bit);
>     __ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
>   }
> 
>   // poly2_even = poly2_odd*montRSquareModQ // poly2 to montgomery domain
>   montmulEven2(POLY2[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
>   for (int i = 0; i < 4; i++) {
>     __ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
>   }
> 
>   // poly1_odd = poly1_even*poly2_even
>   montmulEven2(POLY1[0], POLY1[0], POLY2[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
>   for (int i = 0; i < 4; i++) {
>     // result is scrambled between scratch3_odd and poly1_odd; unscramble
>     __ evpermt2d(POLY1[i], perms, SCRATCH3[i], Assembler::AVX_512bit);
>   }
>   for (int i = 0; i < 4; i++) {
>     __ evmovdqul(Address(result, i *...

I have rewritten it to use full montmuls (a new function) her and everywhere else. It is much easier to follow the code that way.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 1090:
> 
>> 1088:   __ evpbroadcastd(xmm29, constant, Assembler::AVX_512bit); // constant multiplier
>> 1089: 
>> 1090:   __ movl(len, 2);
> 
> Same comment here as the `generate_dilithiumNttMult_avx512`
> - constants can be loaded directly into XMM
> - len can be removed by unrolling at compile time
> - symbolic names could be used for registers
> - comments could be added

Done.

-------------

PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455445
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455814
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455732
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006454991
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455529
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455662
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455178
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2006455086


More information about the hotspot-dev mailing list