RFR: 8349721: Add aarch64 intrinsics for ML-KEM [v7]

Andrew Dinn adinn at openjdk.org
Thu Apr 10 16:17:38 UTC 2025


On Thu, 10 Apr 2025 13:19:05 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:

>> By using the aarch64 vector registers the speed of the computation of the ML-KEM algorithms (key generation, encapsulation, decapsulation) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with two additional commits since the last revision:
> 
>  - Code rearrange, some renaming, fixing comments
>  - Changes suggested by Andrew Dinn.

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5933:

> 5931:     vs_ld3_post(vin, __ T16B, condensed);
> 5932: 
> 5933:     // expand groups of input bytes in vin to shorts in va and vb

I's like to expand on the data layouts here so that maintenance engineers don't have to work it out every time they look at it. So, I would like to replace this comment as follows

    // The front half of sequence vin (vin[0], vin[1] and vin[2])
    // holds 48 (16x3) contiguous bytes from memory striped
    // horizontally across each of the 16 byte lanes. Equivalently,
    // that is 16 pairs of 12-bit integers. Likewise the back half
    // holds the next 48 bytes in the same arrangement.

    // Each vector in the front half can also be viewed as a vertical
    // strip across the 16 pairs of 12 bit integers. Each byte in
    // vin[0] stores the low 8 bits of the first int in a pair. Each
    // byte in vin[1] stores the high 4 bits of the first int and the
    // low 4 bits of the second int. Each byte in vin[2] stores the
    // high 8 bits of the second int. Likewise the vectors in second
    // half.

    // Converting the data to 16-bit shorts requires first of all
    // expanding each of the 6 x 16B vectors into 6 corresponding
    // pairs of 8H vectors. Mask, shift and add operations on the
    // resulting vector pairs can be used to combine 4 and 8 bit
    // parts of related 8H vector elements.
    //
    // The middle vectors (vin[2] and vin[5]) are actually expanded
    // twice, one copy manipulated to provide the lower 4 bits
    // belonging to the first short in a pair and another copy
    // manipulated to provide the higher 4 bits belonging to the
    // second short in a pair. This is why the the vector sequences va
    // and vb used to hold the expanded 8H elements are of length 8.

    // Expand vin[0] into va[0:1], and vin[1] into va[2:3] and va[4:5]

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5941:

> 5939:     __ ushll(va[4], __ T8H, vin[1], __ T8B, 0);
> 5940:     __ ushll2(va[5], __ T8H, vin[1], __ T16B, 0);
> 5941: 

Insert here

    // Likewise expand vin[3] into vb[0:1], and vin[4] into vb[2:3]
    // and vb[4:5]

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5949:

> 5947:     __ ushll2(vb[5], __ T8H, vin[4], __ T16B, 0);
> 5948: 
> 5949:     // offset duplicated elements in va and vb by 8

To make this clearer it should say

    // shift lo byte of copy 1 of the middle stripe into the high byte

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5955:

> 5953:     __ shl(vb[3], __ T8H, vb[3], 8);
> 5954: 
> 5955:     // expand remaining input bytes in vin to shorts in va and vb

To make this clearer it should say

    // Expand vin[2] into va[6:7] and vin[5] into vb[6:7] but this
    // time pre-shifted by 4 to ensure top bits of input 12-bit int
    // are in bit positions [4..11].

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5962:

> 5960:     __ ushll2(vb[7], __ T8H, vin[5], __ T16B, 4);
> 5961: 
> 5962:     // split the duplicated 8 bit values into two distinct 4 bit

To make this clearer it should say

    // mask hi 4 bits of the 1st 12-bit int in a pair from copy1 and
    // shift lo 4 bits of the 2nd 12-bit int in a pair to the bottom of
    // copy2

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5973:

> 5971:     __ ushr(vb[5], __ T8H, vb[5], 4);
> 5972: 
> 5973:     // sum resulting short values into the front halves of va and

This should be replaced to clarify details of the ordering for summing and grouping

    // sum hi 4 bits and lo 8 bits of the 1st 12-bit int in each pair and
    // hi 8 bits plus lo 4 bits of the 2nd 12-bit int in each pair

    // n.b. the ordering ensures: i) inputs are consumed before they
    // are overwritten ii) the order of 16-bit results across successive
    // pairs of vectors in va and then vb reflects the order of the
    // corresponding 12-bit inputs

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5984:

> 5982:     __ addv(vb[3], __ T8H, vb[5], vb[7]);
> 5983: 
> 5984:     // store results interleaved as shorts

Change to

    // store 64 results interleaved as shorts

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5993:

> 5991:     __ cbz(parsedLength, L_end);
> 5992: 
> 5993:     // if anything is left it should be a final 72 bytes. so we

Clarify as follows

    // if anything is left it should be a final 72 bytes of input
    // i.e. a final 48 12-bit values. so we handle this by loading
    // load 48 bytes into all 16B lanes of front(vin) and only 24
    // bytes into the lower 8B lane of back(vin)

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 5999:

> 5997:     vs_ld3(vs_back(vin), __ T8B, condensed);
> 5998: 
> 5999:     // expand groups of input bytes in vin to shorts in va and vb

Modify as above

    // Expand vin[0] into va[0:1], and vin[1] into va[2:3] and va[4:5]

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 6009:

> 6007:     __ ushll2(va[5], __ T8H, vin[1], __ T16B, 0);
> 6008: 
> 6009:     __ ushll(vb[0], __ T8H, vin[3], __ T8B, 0);

Add a comment

    // This time expand just the lower 8 lanes

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 6013:

> 6011:     __ ushll(vb[4], __ T8H, vin[4], __ T8B, 0);
> 6012: 
> 6013:     // offset duplicated elements in va and vb by 8

As before clarify as follows

    // shift lo byte of copy 1 of the middle stripe into the high byte

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 6018:

> 6016:     __ shl(vb[2], __ T8H, vb[2], 8);
> 6017: 
> 6018:     // expand remaining input bytes in vin to shorts in va and vb

Again improve this comment

    // expand vin[2] into va[6:7] and lower 8 lanes of vin[5] into
    // vb[6] pre-shifted by 4 to ensure top bits of the input 12-bit
    // int are in bit positions [4..11].

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 6024:

> 6022:     __ ushll(vb[6], __ T8H, vin[5], __ T8B, 4);
> 6023: 
> 6024:     // split the duplicated 8 bit values into two distinct 4 bit

Once again update

    // mask hi 4 bits of each 1st 12-bit int in pair from copy1 and
    // shift lo 4 bits of each 2nd 12-bit int in pair to bottom of
    // copy2

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 6033:

> 6031:     __ ushr(vb[4], __ T8H, vb[4], 4);
> 6032: 
> 6033:     // sum resulting short values into the front halves of va and

Again update to provide more detail

    // sum hi 4 bits and lo 8 bits of each 1st 12-bit int in pair and
    // hi 8 bits plus lo 4 bits of each 2nd 12-bit int in pair

    // n.b. ordering ensures: i) inputs are consumed before they are
    // overwritten ii) order of 16-bit results across succsessive
    // pairs of vectors in va and then lower half of vb reflects order
    // of corresponding 12-bit inputs

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 6042:

> 6040:     __ addv(vb[1], __ T8H, vb[4], vb[6]);
> 6041: 
> 6042:     // store results interleaved as shorts

Change to

    // store 48 results interleaved as shorts

-------------

PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037755555
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037758589
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037760493
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037762375
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037764723
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037767700
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037783970
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037771521
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037769694
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037774404
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037776668
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037779831
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037780704
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037781617
PR Review Comment: https://git.openjdk.org/jdk/pull/23663#discussion_r2037782757


More information about the graal-dev mailing list