RFR: 8351034: Add AVX-512 intrinsics for ML-DSA [v11]

Volodymyr Paprotski vpaprotski at openjdk.org
Mon Mar 24 15:19:22 UTC 2025


On Sat, 22 Mar 2025 20:02:31 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:

>> By using the AVX-512 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with two additional commits since the last revision:
> 
>  - Further readability improvements.
>  - Added asserts for array sizes

I still need to have a look at the sha3 changes, but I think I am done with the most complex part of the review. This was a really interesting bit of code to review!

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 270:

> 268: }
> 269: 
> 270: static void loadPerm(int destinationRegs[], Register perms,

`replXmm`? i.e. this function is replicating (any) Xmm register, not just perm?..

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 327:

> 325: //
> 326: //
> 327: static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen,

Similar comments as to `generate_dilithiumAlmostInverseNtt_avx512`

- similar comment about the 'pair-wise' operation, updating `[j]` and `[j+l]` at a time.. 
- somehow had less trouble following the flow through registers here, perhaps I am getting used to it. FYI, ended renaming some as:

// xmm16_27 = Temp1
// xmm0_3 = Coeffs1
// xmm4_7 = Coeffs2
// xmm8_11 = Coeffs3
// xmm12_15 = Coeffs4 = Temp2
// xmm16_27 = Scratch

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 421:

> 419:   for (int i = 0; i < 8; i += 2) {
> 420:     __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
> 421:   }

Wish there was a more 'abstract' way to arrange this, so its obvious from the shape of the code what registers are input/outputs (i.e. and use the register arrays). Even though its just 'elementary index operations' `i/2 + 16` is still 'clever'. Couldnt think of anything myself though (same elsewhere in this function for the table permutes).

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 509:

> 507: // coeffs (int[256]) = c_rarg0
> 508: // zetas (int[256]) = c_rarg1
> 509: static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen,

Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!

Some general comments first, rest inline.

- The array names for registers helped a lot. And so did the new helper functions!
- The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.

The core of the (Java) loop is this 'pair-wise' operation:
        int a = coeffs[j];
        int b = coeffs[j + offset];
        coeffs[j] = (a + b);
        coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);

There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).

To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.

Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)

- I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:

// xmm8_11  = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)

(I made a similar comment, but I think it is now hidden after the last refactor)
- would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?
Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!

Some general comments first, rest inline.

- The array names for registers helped a lot. And so did the new helper functions!
- The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.

The core of the (Java) loop is this 'pair-wise' operation:
        int a = coeffs[j];
        int b = coeffs[j + offset];
        coeffs[j] = (a + b);
        coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);

There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).

To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.

Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)

- I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:

// xmm8_11  = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)

(I made a similar comment, but I think it is now hidden after the last refactor)
- would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 554:

> 552:   for (int i = 0; i < 8; i += 2) {
> 553:     __ evpermi2d(xmm(i / 2 + 8), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
> 554:     __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);

Took a bit to unscramble the flow, so a comment needed? Purpose 'fairly obvious' once I got the general shape of the level/algorithm (as per my top-level comment) but something like "shuffle xmm0-7 into xmm8-15"?

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 572:

> 570:   load4Xmms(xmm4_7, zetas, 512, _masm);
> 571:   sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm);
> 572:   montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm);

>From my annotated version, levels 1-4, fairly 'straightforward':

  // level 1
  replXmm(Perms1, perms, nttInvL1PermsIdx, _masm);
  replXmm(Perms2, perms, nttInvL1PermsIdx + 64, _masm);

  for (int i = 0; i < 4; i++) {
    __ evpermi2d(xmm(Perms1[i]), xmm(CoeffsPlus[i]), xmm(CoeffsMul[i]), Assembler::AVX_512bit);
    __ evpermi2d(xmm(Perms2[i]), xmm(CoeffsPlus[i]), xmm(CoeffsMul[i]), Assembler::AVX_512bit);
  }

  load4Xmms(CoeffsMul, zetas, 512, _masm);
  sub_add(CoeffsMinus, CoeffsPlus, Perms1, Perms2, _masm);
  montMul64(CoeffsMul, CoeffsMinus, CoeffsMul, Scratch, _masm);

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 613:

> 611:   montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm);
> 612: 
> 613:   // level 5

"// No shuffling for level 5 and 6; can just rearrange full registers"

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 656:

> 654:   for (int i = 0; i < 8; i++) {
> 655:     __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, Assembler::AVX_512bit);
> 656:   }

Fairly clean as is, but could also be two sub_add calls, I think (you have to swap order of add/sub in the helper, to be able to clobber `xmm(i)`.. or swap register usage downstream, so perhaps not.. but would be cleaner) 

  sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm);
  sub_add(CoeffsMul,  &Scratch[4], Perms2, CoeffsMul, _masm);


If nothing else, would had prefered to see the use of the register array variables

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 660:

> 658:   store4Xmms(coeffs, 0, xmm16_19, _masm);
> 659:   store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm);
> 660:   montMulByConst128(_masm);

Would prefer explicit parameters here. But I think this could also be two `montMul64` calls?

  montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch, _masm);
  montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch, _masm);

(I think there is one other use of `montMulByConst128` where same applies; then you could delete both `montMulByConst128` and `montmulEven`

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 871:

> 869:   __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, Assembler::AVX_512bit);
> 870:   __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, Assembler::AVX_512bit);
> 871:   __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, Assembler::AVX_512bit);

Fairly 'straightforward' transcription of the java code.. no comments from me.

At first glance using `xmm0_3`, `xmm4_7`, etc. might had been a good idea, but you only save one line per 4x group. (Unless you have one big loop, but I suspect that give you worse performance? Is that something you tried already? Might be worth it otherwise..)

src/java.base/share/classes/sun/security/provider/ML_DSA.java line 1418:

> 1416:                                          int twoGamma2, int multiplier) {
> 1417:         assert (input.length == ML_DSA_N) && (lowPart.length == ML_DSA_N)
> 1418:                 && (highPart.length == ML_DSA_N);

I wrote this test to test java-to-intrinsic correspondence. Might be good to include it (and add the other 4 intrinsics). This is very similar to all my other *Fuzz* tests I've been adding for my own intrinsics (and you made this test FAR easier to write by breaking out the java implementation; need to 'copy' that pattern myself)

import java.util.Arrays;
import java.util.Random;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Constructor;

public class ML_DSA_Intrinsic_Test {

    public static void main(String[] args) throws Exception {
        MethodHandles.Lookup lookup = MethodHandles.lookup();
        Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA");
        Constructor<?> constructor = kClazz.getDeclaredConstructor(
                int.class);
        constructor.setAccessible(true);
        
        Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply",
                int[].class, int[].class, int[].class);
        m.setAccessible(true);
        MethodHandle mult = lookup.unreflect(m);

        m = kClazz.getDeclaredMethod("implDilithiumNttMultJava",
                int[].class, int[].class, int[].class);
        m.setAccessible(true);
        MethodHandle multJava = lookup.unreflect(m);

        Random rnd = new Random();
        long seed = rnd.nextLong();
        rnd.setSeed(seed);
        //Note: it might be useful to increase this number during development of new intrinsics
        final int repeat = 1000000;
        int[] coeffs1 = new int[ML_DSA_N];
        int[] coeffs2 = new int[ML_DSA_N];
        int[] prod1 = new int[ML_DSA_N];
        int[] prod2 = new int[ML_DSA_N];
        try {
            for (int i = 0; i < repeat; i++) {
                run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
            }
            System.out.println("Fuzz Success");
        } catch (Throwable e) {
            System.out.println("Fuzz Failed: " + e);
        }
    }

    private static final int ML_DSA_N = 256;
    public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, 
        MethodHandle mult, MethodHandle multJava, Random rnd, 
        long seed, int i) throws Exception, Throwable {
        for (int j = 0; j<ML_DSA_N; j++) {
            coeffs1[j] = rnd.nextInt();
            coeffs2[j] = rnd.nextInt();
        }

        mult.invoke(prod1, coeffs1, coeffs2);
        multJava.invoke(prod2, coeffs1, coeffs2);

        if (!Arrays.equals(prod1, prod2)) {
                throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mismatch: " + Arrays.toString(prod1) + " != " + Arrays.toString(prod2));
        }
    }
}
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED  -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java

-------------

PR Review: https://git.openjdk.org/jdk/pull/23860#pullrequestreview-2708301954
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008921783
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009415317
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009477186
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009428310
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009433467
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009435329
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009435791
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009437669
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009438921
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009486160
PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2010355575


More information about the hotspot-dev mailing list