RFR: 8348561: Add aarch64 intrinsics for ML-DSA [v5]

Andrew Dinn adinn at openjdk.org
Tue Feb 18 13:46:17 UTC 2025


On Thu, 6 Feb 2025 18:47:54 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:

>> By using the aarch64 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with one additional commit since the last revision:
> 
>   Adding comments + some code reorganization

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 4066:

> 4064:   }
> 4065: 
> 4066:   // Execute on round of keccak of two computations in parallel.

Suggestion:

It would be helpful to add comments that relate the register and instruction selection to the original Java source code. e.g. change the header as follows

    // Performs 2 keccak round transformations using vector parallelism
    // 
    // Two sets of 25 * 64-bit input states a0[lo:hi]...a24[lo:hi] are passed in
    // the lower/upper halves of registers v0...v24 and the transformed states
    // are returned in the same registers. Intermediate 64-bit pairs
    // c0...c5 and d0...d5 are computed in registers v25...v30. v31 is
    // loaded with the required pair of 64 bit rounding constants.
    // During computation of the output states some intermediate results are
    // shuffled around registers v0...v30. Comments on each line indicate
    // how the values in registers correspond to variables ai, ci, di in
    // the Java source code, likewise how the generated machine instructions
    // correspond to Java source operations (n.b. rol means rotate left).

The annotate the generation steps as follows:

    __ eor3(v29, __ T16B, v4, v9, v14);       // c4 = a4 ^ a9 ^ a14
    __ eor3(v26, __ T16B, v1, v6, v11);       // c1 = a1 ^ a16 ^ a11
    __ eor3(v28, __ T16B, v3, v8, v13);       // c3 = a3 ^ a8 ^a13
    __ eor3(v25, __ T16B, v0, v5, v10);       // c0 = a0 ^ a5 ^ a10
    __ eor3(v27, __ T16B, v2, v7, v12);       // c2 = a2 ^ a7 ^ a12
    __ eor3(v29, __ T16B, v29, v19, v24);     // c4 ^= a19 ^ a24
    __ eor3(v26, __ T16B, v26, v16, v21);     // c1 ^= a16 ^ a21
    __ eor3(v28, __ T16B, v28, v18, v23);     // c3 ^= a18 ^ a23
    __ eor3(v25, __ T16B, v25, v15, v20);     // c0 ^= a15 ^ a20
    __ eor3(v27, __ T16B, v27, v17, v22);     // c2 ^= a17 ^ a22

    __ rax1(v30, __ T2D, v29, v26);           // d0 = c4 ^ rol(c1, 1)
    __ rax1(v26, __ T2D, v26, v28);           // d2 = c1 ^ rol(c3, 1)
    __ rax1(v28, __ T2D, v28, v25);           // d4 = c3 ^ rol(c0, 1)
    __ rax1(v25, __ T2D, v25, v27);           // d1 = c0 ^ rol(c2, 1)
    __ rax1(v27, __ T2D, v27, v29);           // d3 = c2 ^ rol(c4, 1)

    __ eor(v0, __ T16B, v0, v30);             // a0 = a0 ^ d0
    __ xar(v29, __ T2D, v1,  v25, (64 - 1));  // a10' = rol((a1^d1), 1)
    __ xar(v1,  __ T2D, v6,  v25, (64 - 44)); // a1 = rol(a6^d1), 44)
    __ xar(v6,  __ T2D, v9,  v28, (64 - 20)); // a6 = rol((a9^d4), 20)
    __ xar(v9,  __ T2D, v22, v26, (64 - 61)); // a9 = rol((a22^d2), 61)
    __ xar(v22, __ T2D, v14, v28, (64 - 39)); // a22 = rol((a14^d4), 39)
    __ xar(v14, __ T2D, v20, v30, (64 - 18)); // a14 = rol((a20^d0), 18)
    __ xar(v31, __ T2D, v2,  v26, (64 - 62)); // a20' = rol((a2^d2), 62)
    __ xar(v2,  __ T2D, v12, v26, (64 - 43)); // a2 = rol((a12^d2), 43)
    __ xar(v12, __ T2D, v13, v27, (64 - 25)); // a12 = rol((a13^d3), 25)
    __ xar(v13, __ T2D, v19, v28, (64 - 8));  // a13 = rol((a19^d4), 8)
    __ xar(v19, __ T2D, v23, v27, (64 - 56)); // a19 = rol((a23^d3), 56)
    __ xar(v23, __ T2D, v15, v30, (64 - 41)); // a23 = rol((a15^d0), 41)
    __ xar(v15, __ T2D, v4,  v28, (64 - 27)); // a15 = rol((a4^d4), 27)
    __ xar(v28, __ T2D, v24, v28, (64 - 14)); // a4' = rol((a24^d4), 14)
    __ xar(v24, __ T2D, v21, v25, (64 - 2));  // a24 = rol((a21^d1), 2)
    __ xar(v8,  __ T2D, v8,  v27, (64 - 55)); // a21' = rol((a8^d3), 55)
    __ xar(v4,  __ T2D, v16, v25, (64 - 45)); // a8' = rol((a16^d1), 45)
    __ xar(v16, __ T2D, v5,  v30, (64 - 36)); // a16 = rol((a5^d0), 36)
    __ xar(v5,  __ T2D, v3,  v27, (64 - 28)); // a5 = rol((a3^d3), 28)
    __ xar(v27, __ T2D, v18, v27, (64 - 21)); // a3' = rol((a18^d3), 21)
    __ xar(v3,  __ T2D, v17, v26, (64 - 15)); // a18' = rol((a17^d2), 15)
    __ xar(v25, __ T2D, v11, v25, (64 - 10)); // a17' = rol((a11^d1), 10)
    __ xar(v26, __ T2D, v7,  v26, (64 - 6));  // a11' = rol((a7^d2), 6)
    __ xar(v30, __ T2D, v10, v30, (64 - 3));  // a7' = rol((a10^d0), 3)

    __ bcax(v20, __ T16B, v31, v22, v8);      // a20 = a20' ^ (~a21 & a22')
    __ bcax(v21, __ T16B, v8,  v23, v22);     // a21 = a21' ^ (~a22 & a23)
    __ bcax(v22, __ T16B, v22, v24, v23);     // a22 = a22 ^ (~a23 & a24)
    __ bcax(v23, __ T16B, v23, v31, v24);     // a23 = a23 ^ (~a24 & a20')
    __ bcax(v24, __ T16B, v24, v8,  v31);     // a24 = a24 ^ (~a20' & a21')

    __ ld1r(v31, __ T2D, __ post(rscratch1, 8)); // rc = round_constants[i]

    __ bcax(v17, __ T16B, v25, v19, v3);      // a17 = a17' ^ (~a18' & a19)
    __ bcax(v18, __ T16B, v3,  v15, v19);     // a18 = a18' ^ (~a19 & a15')
    __ bcax(v19, __ T16B, v19, v16, v15);     // a19 = a19 ^ (~a15 & a16)
    __ bcax(v15, __ T16B, v15, v25, v16);     // a15 = a15 ^ (~a16 & a17')
    __ bcax(v16, __ T16B, v16, v3,  v25);     // a16 = a16 ^ (~a17' & a18')

    __ bcax(v10, __ T16B, v29, v12, v26);     // a10 = a10' ^ (~a11' & a12)
    __ bcax(v11, __ T16B, v26, v13, v12);     // a11 = a11' ^ (~a12 & a13)
    __ bcax(v12, __ T16B, v12, v14, v13);     // a12 = a12 ^ (~a13 & a14)
    __ bcax(v13, __ T16B, v13, v29, v14);     // a13 = a13 ^ (~a14 & a10')
    __ bcax(v14, __ T16B, v14, v26, v29);     // a14 = a14 ^ (~a10' & a11')

    __ bcax(v7, __ T16B, v30, v9,  v4);       // a7 = a7' ^ (~a8' & a9)
    __ bcax(v8, __ T16B, v4,  v5,  v9);       // a8 = a8' ^ (~a9 & a5)
    __ bcax(v9, __ T16B, v9,  v6,  v5);       // a9 = a9 ^ (~a5 & a6)
    __ bcax(v5, __ T16B, v5,  v30, v6);       // a5 = a5 ^ (~a6 & a7)
    __ bcax(v6, __ T16B, v6,  v4,  v30);      // a6 = a6 ^ (~a7 & a8')

    __ bcax(v3, __ T16B, v27, v0,  v28);      // a3 = a3' ^ (~a4' & a0)
    __ bcax(v4, __ T16B, v28, v1,  v0);       // a4 = a4' ^ (~a0 & a1)
    __ bcax(v0, __ T16B, v0,  v2,  v1);       // a0 = a0 ^ (~a1 & a2)
    __ bcax(v1, __ T16B, v1,  v27, v2);       // a1 = a1 ^ (~a2 & a3)
    __ bcax(v2, __ T16B, v2,  v28, v27);      // a2 = a2 ^ (~a3 & a4')

    __ eor(v0, __ T16B, v0, v31);             // a0 = a0 ^ rc

-------------

PR Review Comment: https://git.openjdk.org/jdk/pull/23300#discussion_r1959776475


More information about the hotspot-dev mailing list