RFR: 8350589: Investigate cleaner implementation of AArch64 ML-DSA intrinsic introduced in JDK-8348561
Andrew Dinn
adinn at openjdk.org
Thu Mar 13 10:22:16 UTC 2025
On Thu, 13 Mar 2025 08:57:18 GMT, Andrew Dinn <adinn at openjdk.org> wrote:
> This PR reworks the existing AArch64 ML_DSA intrinsic code generator to make it clearer to read and easier to maintain.
@ferakocz I have modified your generator code to employ vector sequences and auxiliaries that handle iterative loads, stores and math/logic operations over vector sequences. It would be useful to have a review of the code from you and also for you to test it (see comments below re testing)
The rewrite has allowed much of the generator logic to be condensed into calls to simple auxiliaries which provides a better mid-level view of how the code is structured. It has also clarified the register use. I think this will be a lot easier for maintainers to understand.
A few further comments:
1. I have added some asserts to the montmul operations to ensure that input and output register sequences are either disjoint or overlapping. There may be further opportunities to add asserts in a follow-up.
2. One thing I noted (commented on in code) after switching to passing vector sequences rather than relying on fixed mappings is that some reloading of q and qinv inside loops is unnecessary as the code in the loop does not write the relevant vectors. I left the code as is so that I could check that the generated code is identical to the original but I will move the relevant load outside the loop before pushing.
3. I compared before and after dissasemblies for the generated code ans it is unchanged modulo routine dilithiumDecomposePoly. For that intrinsic your generator code wrote successive, intermediate results into the next unused set of 4 vectors, which are in most cases used later to hold a non-temporary result needed for a later computation. My code always writes intermediate results into the last set of 4 vectors (which are declared as `VSeq<4> vtmp(20)`). As a result the generated code has the same structure but a slightly different register mapping. I don't believe this affects performance but he change makes it clearer how the computed values are being used.
4. As well as comparing disassemblies for the generated code I verified the patch by running test `jdk/sun/security/provider/acvp/ML_DSA_Test.java`. However, I noted a problem with relying on the test as currently implemented since it dids not appear to capture some errors in my code. I ran the test under the debugger and confirmed that only one of the intrinsics was being exercised (dilithiumAlmostNtt). I confirmed this by adding -XX:+PrintCompilation to the test command line. It seems that all the other calls to intrinsic candidates occurred in the interpreter, not running often enough to trigger compilation of the caller. Instead some of the `impl*Java` methods were being compiled.
The last point needs some thinking about. I worked around the limitations of the current test by adding the following compile exclusions to the test on the command line:
-XX:CompileCommand=exclude,sun.security.provider.ML_DSA::implDilithiumNttMultJava \
-XX:CompileCommand=exclude,sun.security.provider.ML_DSA::implDilithiumAlmostInverseNttJava \
-XX:CompileCommand=exclude,sun.security.provider.ML_DSA::implDilithiumMontMulByConstantJava \
-XX:CompileCommand=exclude,sun.security.provider.ML_DSA::implDilithiumAlmostNttJava \
-XX:CompileCommand=exclude,sun.security.provider.ML_DSA::decomposePolyJava`dilithiumAlmostNtt
This fixes the problem, ensuring that all the intrinsics get exercised. However, there are two problems with modifying the test to pass these options automatically. Firstly, it will slow down the test on ports that don't implement the intrinsics. Secondly, it is only a partial fix -- it won't stop the same problem arising with the other tests launched by the associated test class `Launcher` (e.g. `ML_KEM` tests).
Am I simply failing to spot some other test that you ran to verify correctness of the code? If that is not the case then we need to fix the current tests so we can guarantee to exercise the intrinsics. We could supply a new test that exercises the callers often enough to trigger compilation or we could modify the current test to exclude compilation of the Java implementations as appropriate to the architecture. A 3rd, more complex solution would be to modify the interpreter to call out to the generated code when dilithium intrinsics are switched off (as happens with some of the other crypto routines). That might be a useful thing to do anyway, ensuring that callers cannot gain info from a change in timing when we switch from interpreted to compiled code. I'm not sure how significant any timing disparity might be. Perhaps you can advise?
-------------
PR Comment: https://git.openjdk.org/jdk/pull/24026#issuecomment-2720710691
More information about the hotspot-compiler-dev
mailing list