RFR: 8348561: Add aarch64 intrinsics for ML-DSA [v5]
Andrew Dinn
adinn at openjdk.org
Tue Feb 18 13:46:17 UTC 2025
On Thu, 6 Feb 2025 18:47:54 GMT, Ferenc Rakoczi <duke at openjdk.org> wrote:
>> By using the aarch64 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with one additional commit since the last revision:
>
> Adding comments + some code reorganization
src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 4066:
> 4064: }
> 4065:
> 4066: // Execute on round of keccak of two computations in parallel.
Suggestion:
It would be helpful to add comments that relate the register and instruction selection to the original Java source code. e.g. change the header as follows
// Performs 2 keccak round transformations using vector parallelism
//
// Two sets of 25 * 64-bit input states a0[lo:hi]...a24[lo:hi] are passed in
// the lower/upper halves of registers v0...v24 and the transformed states
// are returned in the same registers. Intermediate 64-bit pairs
// c0...c5 and d0...d5 are computed in registers v25...v30. v31 is
// loaded with the required pair of 64 bit rounding constants.
// During computation of the output states some intermediate results are
// shuffled around registers v0...v30. Comments on each line indicate
// how the values in registers correspond to variables ai, ci, di in
// the Java source code, likewise how the generated machine instructions
// correspond to Java source operations (n.b. rol means rotate left).
The annotate the generation steps as follows:
__ eor3(v29, __ T16B, v4, v9, v14); // c4 = a4 ^ a9 ^ a14
__ eor3(v26, __ T16B, v1, v6, v11); // c1 = a1 ^ a16 ^ a11
__ eor3(v28, __ T16B, v3, v8, v13); // c3 = a3 ^ a8 ^a13
__ eor3(v25, __ T16B, v0, v5, v10); // c0 = a0 ^ a5 ^ a10
__ eor3(v27, __ T16B, v2, v7, v12); // c2 = a2 ^ a7 ^ a12
__ eor3(v29, __ T16B, v29, v19, v24); // c4 ^= a19 ^ a24
__ eor3(v26, __ T16B, v26, v16, v21); // c1 ^= a16 ^ a21
__ eor3(v28, __ T16B, v28, v18, v23); // c3 ^= a18 ^ a23
__ eor3(v25, __ T16B, v25, v15, v20); // c0 ^= a15 ^ a20
__ eor3(v27, __ T16B, v27, v17, v22); // c2 ^= a17 ^ a22
__ rax1(v30, __ T2D, v29, v26); // d0 = c4 ^ rol(c1, 1)
__ rax1(v26, __ T2D, v26, v28); // d2 = c1 ^ rol(c3, 1)
__ rax1(v28, __ T2D, v28, v25); // d4 = c3 ^ rol(c0, 1)
__ rax1(v25, __ T2D, v25, v27); // d1 = c0 ^ rol(c2, 1)
__ rax1(v27, __ T2D, v27, v29); // d3 = c2 ^ rol(c4, 1)
__ eor(v0, __ T16B, v0, v30); // a0 = a0 ^ d0
__ xar(v29, __ T2D, v1, v25, (64 - 1)); // a10' = rol((a1^d1), 1)
__ xar(v1, __ T2D, v6, v25, (64 - 44)); // a1 = rol(a6^d1), 44)
__ xar(v6, __ T2D, v9, v28, (64 - 20)); // a6 = rol((a9^d4), 20)
__ xar(v9, __ T2D, v22, v26, (64 - 61)); // a9 = rol((a22^d2), 61)
__ xar(v22, __ T2D, v14, v28, (64 - 39)); // a22 = rol((a14^d4), 39)
__ xar(v14, __ T2D, v20, v30, (64 - 18)); // a14 = rol((a20^d0), 18)
__ xar(v31, __ T2D, v2, v26, (64 - 62)); // a20' = rol((a2^d2), 62)
__ xar(v2, __ T2D, v12, v26, (64 - 43)); // a2 = rol((a12^d2), 43)
__ xar(v12, __ T2D, v13, v27, (64 - 25)); // a12 = rol((a13^d3), 25)
__ xar(v13, __ T2D, v19, v28, (64 - 8)); // a13 = rol((a19^d4), 8)
__ xar(v19, __ T2D, v23, v27, (64 - 56)); // a19 = rol((a23^d3), 56)
__ xar(v23, __ T2D, v15, v30, (64 - 41)); // a23 = rol((a15^d0), 41)
__ xar(v15, __ T2D, v4, v28, (64 - 27)); // a15 = rol((a4^d4), 27)
__ xar(v28, __ T2D, v24, v28, (64 - 14)); // a4' = rol((a24^d4), 14)
__ xar(v24, __ T2D, v21, v25, (64 - 2)); // a24 = rol((a21^d1), 2)
__ xar(v8, __ T2D, v8, v27, (64 - 55)); // a21' = rol((a8^d3), 55)
__ xar(v4, __ T2D, v16, v25, (64 - 45)); // a8' = rol((a16^d1), 45)
__ xar(v16, __ T2D, v5, v30, (64 - 36)); // a16 = rol((a5^d0), 36)
__ xar(v5, __ T2D, v3, v27, (64 - 28)); // a5 = rol((a3^d3), 28)
__ xar(v27, __ T2D, v18, v27, (64 - 21)); // a3' = rol((a18^d3), 21)
__ xar(v3, __ T2D, v17, v26, (64 - 15)); // a18' = rol((a17^d2), 15)
__ xar(v25, __ T2D, v11, v25, (64 - 10)); // a17' = rol((a11^d1), 10)
__ xar(v26, __ T2D, v7, v26, (64 - 6)); // a11' = rol((a7^d2), 6)
__ xar(v30, __ T2D, v10, v30, (64 - 3)); // a7' = rol((a10^d0), 3)
__ bcax(v20, __ T16B, v31, v22, v8); // a20 = a20' ^ (~a21 & a22')
__ bcax(v21, __ T16B, v8, v23, v22); // a21 = a21' ^ (~a22 & a23)
__ bcax(v22, __ T16B, v22, v24, v23); // a22 = a22 ^ (~a23 & a24)
__ bcax(v23, __ T16B, v23, v31, v24); // a23 = a23 ^ (~a24 & a20')
__ bcax(v24, __ T16B, v24, v8, v31); // a24 = a24 ^ (~a20' & a21')
__ ld1r(v31, __ T2D, __ post(rscratch1, 8)); // rc = round_constants[i]
__ bcax(v17, __ T16B, v25, v19, v3); // a17 = a17' ^ (~a18' & a19)
__ bcax(v18, __ T16B, v3, v15, v19); // a18 = a18' ^ (~a19 & a15')
__ bcax(v19, __ T16B, v19, v16, v15); // a19 = a19 ^ (~a15 & a16)
__ bcax(v15, __ T16B, v15, v25, v16); // a15 = a15 ^ (~a16 & a17')
__ bcax(v16, __ T16B, v16, v3, v25); // a16 = a16 ^ (~a17' & a18')
__ bcax(v10, __ T16B, v29, v12, v26); // a10 = a10' ^ (~a11' & a12)
__ bcax(v11, __ T16B, v26, v13, v12); // a11 = a11' ^ (~a12 & a13)
__ bcax(v12, __ T16B, v12, v14, v13); // a12 = a12 ^ (~a13 & a14)
__ bcax(v13, __ T16B, v13, v29, v14); // a13 = a13 ^ (~a14 & a10')
__ bcax(v14, __ T16B, v14, v26, v29); // a14 = a14 ^ (~a10' & a11')
__ bcax(v7, __ T16B, v30, v9, v4); // a7 = a7' ^ (~a8' & a9)
__ bcax(v8, __ T16B, v4, v5, v9); // a8 = a8' ^ (~a9 & a5)
__ bcax(v9, __ T16B, v9, v6, v5); // a9 = a9 ^ (~a5 & a6)
__ bcax(v5, __ T16B, v5, v30, v6); // a5 = a5 ^ (~a6 & a7)
__ bcax(v6, __ T16B, v6, v4, v30); // a6 = a6 ^ (~a7 & a8')
__ bcax(v3, __ T16B, v27, v0, v28); // a3 = a3' ^ (~a4' & a0)
__ bcax(v4, __ T16B, v28, v1, v0); // a4 = a4' ^ (~a0 & a1)
__ bcax(v0, __ T16B, v0, v2, v1); // a0 = a0 ^ (~a1 & a2)
__ bcax(v1, __ T16B, v1, v27, v2); // a1 = a1 ^ (~a2 & a3)
__ bcax(v2, __ T16B, v2, v28, v27); // a2 = a2 ^ (~a3 & a4')
__ eor(v0, __ T16B, v0, v31); // a0 = a0 ^ rc
-------------
PR Review Comment: https://git.openjdk.org/jdk/pull/23300#discussion_r1959776475
More information about the hotspot-dev
mailing list