RFR: 8351034: Add AVX-512 intrinsics for ML-DSA [v11]

Ferenc Rakoczi duke at openjdk.org
Mon Mar 31 14:28:20 UTC 2025


On Mon, 24 Mar 2025 15:16:20 GMT, Volodymyr Paprotski <vpaprotski at openjdk.org> wrote:

>> Ferenc Rakoczi has updated the pull request incrementally with two additional commits since the last revision:
>> 
>>  - Further readability improvements.
>>  - Added asserts for array sizes
>
> I still need to have a look at the sha3 changes, but I think I am done with the most complex part of the review. This was a really interesting bit of code to review!

@vpaprotsk , thanks a lot for the very thorough review!

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 270:
> 
>> 268: }
>> 269: 
>> 270: static void loadPerm(int destinationRegs[], Register perms,
> 
> `replXmm`? i.e. this function is replicating (any) Xmm register, not just perm?..

Since I am only using it for permutation describers, I thought this way it is easier to follow what is happening.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 327:
> 
>> 325: //
>> 326: //
>> 327: static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen,
> 
> Similar comments as to `generate_dilithiumAlmostInverseNtt_avx512`
> 
> - similar comment about the 'pair-wise' operation, updating `[j]` and `[j+l]` at a time.. 
> - somehow had less trouble following the flow through registers here, perhaps I am getting used to it. FYI, ended renaming some as:
> 
> // xmm16_27 = Temp1
> // xmm0_3 = Coeffs1
> // xmm4_7 = Coeffs2
> // xmm8_11 = Coeffs3
> // xmm12_15 = Coeffs4 = Temp2
> // xmm16_27 = Scratch

For me, it was easier to follow what goes where using the xmm... names (with the symbolic names you always have to remember which one overlaps with another and how much).

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 421:
> 
>> 419:   for (int i = 0; i < 8; i += 2) {
>> 420:     __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
>> 421:   }
> 
> Wish there was a more 'abstract' way to arrange this, so its obvious from the shape of the code what registers are input/outputs (i.e. and use the register arrays). Even though its just 'elementary index operations' `i/2 + 16` is still 'clever'. Couldnt think of anything myself though (same elsewhere in this function for the table permutes).

Well, this is how it is when we have three inputs, one of which also plays as output... At least the output is always the first one (so that one gets clobbered). This is why you have to replicate the permutation describer when you need both permutands later.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 509:
> 
>> 507: // coeffs (int[256]) = c_rarg0
>> 508: // zetas (int[256]) = c_rarg1
>> 509: static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen,
> 
> Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!
> 
> Some general comments first, rest inline.
> 
> - The array names for registers helped a lot. And so did the new helper functions!
> - The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
> 
> The core of the (Java) loop is this 'pair-wise' operation:
>         int a = coeffs[j];
>         int b = coeffs[j + offset];
>         coeffs[j] = (a + b);
>         coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);
> 
> There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
> At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).
> 
> To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
> montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
> the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.
> 
> Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
> 
> - I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
> 
> // xmm8_11  = Perms1
> // xmm12_15 = Perms2
> // xmm16_27 = Scratch
> // xmm0_3 = CoeffsPlus
> // xmm4_7 = CoeffsMul
> // xmm24_27 = CoeffsMinus (overlaps with Scratch)
> 
> (I made a similar comment, but I think it is now hidden after the last refactor)
> - would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and ...

I added some more comments, but I kept the xmm... names for the registers, just like with the ntt function.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 554:
> 
>> 552:   for (int i = 0; i < 8; i += 2) {
>> 553:     __ evpermi2d(xmm(i / 2 + 8), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
>> 554:     __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
> 
> Took a bit to unscramble the flow, so a comment needed? Purpose 'fairly obvious' once I got the general shape of the level/algorithm (as per my top-level comment) but something like "shuffle xmm0-7 into xmm8-15"?

I hope the comment that I added at the beginning of the function sheds some light on the purpose of these permutations.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 656:
> 
>> 654:   for (int i = 0; i < 8; i++) {
>> 655:     __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, Assembler::AVX_512bit);
>> 656:   }
> 
> Fairly clean as is, but could also be two sub_add calls, I think (you have to swap order of add/sub in the helper, to be able to clobber `xmm(i)`.. or swap register usage downstream, so perhaps not.. but would be cleaner) 
> 
>   sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm);
>   sub_add(CoeffsMul,  &Scratch[4], Perms2, CoeffsMul, _masm);
> 
> 
> If nothing else, would had prefered to see the use of the register array variables

I would rather leave this alone, too. I was considering the same, but decided that this is fairly easy to follow, it would be more complicated to either add a new helper function or follow where there are overlaps in the symbolically named register sets.

> src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 871:
> 
>> 869:   __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, Assembler::AVX_512bit);
>> 870:   __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, Assembler::AVX_512bit);
>> 871:   __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, Assembler::AVX_512bit);
> 
> Fairly 'straightforward' transcription of the java code.. no comments from me.
> 
> At first glance using `xmm0_3`, `xmm4_7`, etc. might had been a good idea, but you only save one line per 4x group. (Unless you have one big loop, but I suspect that give you worse performance? Is that something you tried already? Might be worth it otherwise..)

I have considered this but decided to leave it alone (for the reason that you mentioned).

> src/java.base/share/classes/sun/security/provider/ML_DSA.java line 1418:
> 
>> 1416:                                          int twoGamma2, int multiplier) {
>> 1417:         assert (input.length == ML_DSA_N) && (lowPart.length == ML_DSA_N)
>> 1418:                 && (highPart.length == ML_DSA_N);
> 
> I wrote this test to test java-to-intrinsic correspondence. Might be good to include it (and add the other 4 intrinsics). This is very similar to all my other *Fuzz* tests I've been adding for my own intrinsics (and you made this test FAR easier to write by breaking out the java implementation; need to 'copy' that pattern myself)
> 
> import java.util.Arrays;
> import java.util.Random;
> 
> import java.lang.invoke.MethodHandle;
> import java.lang.invoke.MethodHandles;
> import java.lang.reflect.Field;
> import java.lang.reflect.Method;
> import java.lang.reflect.Constructor;
> 
> public class ML_DSA_Intrinsic_Test {
> 
>     public static void main(String[] args) throws Exception {
>         MethodHandles.Lookup lookup = MethodHandles.lookup();
>         Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA");
>         Constructor<?> constructor = kClazz.getDeclaredConstructor(
>                 int.class);
>         constructor.setAccessible(true);
>         
>         Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply",
>                 int[].class, int[].class, int[].class);
>         m.setAccessible(true);
>         MethodHandle mult = lookup.unreflect(m);
> 
>         m = kClazz.getDeclaredMethod("implDilithiumNttMultJava",
>                 int[].class, int[].class, int[].class);
>         m.setAccessible(true);
>         MethodHandle multJava = lookup.unreflect(m);
> 
>         Random rnd = new Random();
>         long seed = rnd.nextLong();
>         rnd.setSeed(seed);
>         //Note: it might be useful to increase this number during development of new intrinsics
>         final int repeat = 1000000;
>         int[] coeffs1 = new int[ML_DSA_N];
>         int[] coeffs2 = new int[ML_DSA_N];
>         int[] prod1 = new int[ML_DSA_N];
>         int[] prod2 = new int[ML_DSA_N];
>         try {
>             for (int i = 0; i < repeat; i++) {
>                 run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
>             }
>             System.out.println("Fuzz Success");
>         } catch (Throwable e) {
>             System.out.println("Fuzz Failed: " + e);
>         }
>     }
> 
>     private static final int ML_DSA_N = 256;
>     public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, 
>         MethodH...

We will consider it for a follow-up PR.

-------------

PR Comment: https://git.openjdk.org/jdk/pull/23860#issuecomment-2766414076
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021150966
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021151152
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021151361
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021151680
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021152095
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021152962
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021154571
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021156249


More information about the hotspot-dev mailing list