RFR: 8351412: Add AVX-512 intrinsics for ML-KEM [v4]

Ferenc Rakoczi duke at openjdk.org
Wed May 14 11:43:58 UTC 2025


On Tue, 13 May 2025 17:53:50 GMT, Sandhya Viswanathan <sviswanathan at openjdk.org> wrote:

>> Ferenc Rakoczi has updated the pull request incrementally with one additional commit since the last revision:
>> 
>>   Restoring copyright notice on ML_KEM.java
>
> src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 693:
> 
>> 691: // a (short[256]) = c_rarg1
>> 692: // b (short[256]) = c_rarg2
>> 693: // kyberConsts (short[40]) = c_rarg3
> 
> kyberConsts is not one of the arguments passed in.

Fixed.

> src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 696:
> 
>> 694: address generate_kyberAddPoly_2_avx512(StubGenerator *stubgen,
>> 695:                                        MacroAssembler *_masm) {
>> 696: 
> 
> The Java code for "implKyberAddPoly(short[] result, short[] a, short[] b)" does BarrettReduction but the intrinsic code here does not. Is that intentional and how is the reduction handled?

Actually, the Java version is the one that is too cautious. There is Barrett reduction after at most 4 consecutive uses of mlKemAddPoly(), so doing the reduction in implKyberAddPoly() is not necessary. Thanks for discovering this!

> src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 742:
> 
>> 740: // b (short[256]) = c_rarg2
>> 741: // c (short[256]) = c_rarg3
>> 742: // kyberConsts (short[40]) = c_rarg4
> 
> kyberConsts is not one of the arguments passed in.

Fixed.

> src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 799:
> 
>> 797: // parsedLength (int) = c_rarg3
>> 798: address generate_kyber12To16_avx512(StubGenerator *stubgen,
>> 799:                                     MacroAssembler *_masm) {
> 
> If AVX512_VBMI and AVX512_VBMI2 is available, it looks to me that the loop body of this algorithm can be implemented using more efficient instructions in simple 5 steps:
> 
> Step 1:
> Load 0-47, 48-95, 96-143, 144-191 condensed bytes into xmm0, xmm1, xmm2, xmm3 respectively using masked load.
> 
> Step 2:
> Use vpermb to arrange xmm0 such that bytes 1, 4, 7, ... are duplicated
> xmm0 before  b47, b46, ..., b0 where each b is a byte
> xmm0 after b47 b46 b46 b45, ......., b5 b4 b4 b3 b2 b1 b1 b0  
> Repeat this for xmm1, xmm2, xmm3
> 
> Step 3:
> Use vpshldvw to shift every word (16 bits) in the xmm0 appropriately with variable shift
> Shift word 31 by 4, word 30 by 0, ... word 3 by 4, word 2 by 0,  word 1 by 4, word 0 by 0
> Repeat this for xmm1, xmm2, xmm3
> 
> Step 4:
> Use vpand to "and" each word element in xmm0 by 0xfff.
> Repeat this for xmm1, xmm2, xmm3
> 
> Step 5:
> Store xmm0 into parsed
> Store xmm1 into parsed + 64
> Store xmm2 into parsed +128
> Store xmm3 into parsed + 192
> 
> If you think there is not sufficient time, we could look into it after the merge of this PR as well.

Yes, that way we can speed this up a little (well, in itself it might  be significant), but with the current intrinsics, the contribution of this function to the overall running time is about 1.5%, so it would not matter that much, while on the other hand not all AVX-512 capable processors have vbmi.
So I would rather not do it in this PR.

-------------

PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2088738946
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2088738841
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2088738704
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2088738615


More information about the graal-dev mailing list