Vector API performance issues with port of XXH3
Vladimir Ivanov
vladimir.x.ivanov at oracle.com
Mon Jul 15 22:14:41 UTC 2024
FTR there was a relevant discussion about how to express vectorized
32x32->64-bit multiplication in Vector API:
https://mail.openjdk.org/pipermail/panama-dev/2018-August/002440.html
Best regards,
Vladimir Ivanov
On 7/15/24 05:11, Bhateja, Jatin wrote:
> Hi Martin,
>
> Thanks for reporting this.
>
> Instruction sequence for 64x64 bit multiplier on AVX2 targets is
> agnostic to existence of zeroing of upper / lower double word, this is
> because we do not split Multiplier at IR level and depend on any
> constant folding to sweep out the redundant logic, this can however be
> handled as a point optimization.
>
> I just did a quick patch[1] to attempt that, and I can see compiler is
> now emitting “VPMULDQ”[2]
>
> Best Regards,
>
> Jatin
>
> [1]
> https://github.com/jatin-bhateja/external_staging/blob/main/NewOperationSamples/doubleWordMultQuadWordAccum/jdk_patch.diff <https://urldefense.com/v3/__https://github.com/jatin-bhateja/external_staging/blob/main/NewOperationSamples/doubleWordMultQuadWordAccum/jdk_patch.diff__;!!ACWV5N9M2RV99hQ!OWiE1YZ3jU9oN8QjwVL_GjSB3RZuRXVI_80ZRIKa4XW3Vc2XlPiPBIjVVMZOD0VyOj0tINPG3LbJIx9MZJoWveUc3A2F6A$>
>
> [2] https://www.felixcloutier.com/x86/pmuldq
> <https://urldefense.com/v3/__https://www.felixcloutier.com/x86/pmuldq__;!!ACWV5N9M2RV99hQ!OWiE1YZ3jU9oN8QjwVL_GjSB3RZuRXVI_80ZRIKa4XW3Vc2XlPiPBIjVVMZOD0VyOj0tINPG3LbJIx9MZJoWveWP2aN6CA$>
>
> *From:*panama-dev <panama-dev-retn at openjdk.org> *On Behalf Of *Martin
> Traverso
> *Sent:* Thursday, July 11, 2024 6:58 AM
> *To:* panama-dev at openjdk.org
> *Subject:* Vector API performance issues with port of XXH3
>
> Hi,
>
> Following up on my attempts to port XXH3 to Java
> (https://github.com/Cyan4973/xxHash
> <https://urldefense.com/v3/__https://github.com/Cyan4973/xxHash__;!!ACWV5N9M2RV99hQ!OWiE1YZ3jU9oN8QjwVL_GjSB3RZuRXVI_80ZRIKa4XW3Vc2XlPiPBIjVVMZOD0VyOj0tINPG3LbJIx9MZJoWveXBdNJZbg$>), I'd like to ask for some advice. The core loop of that algorithm uses SIMD, with custom implementations for NEON, AVX2, AVX512, etc. I have been unable to get performance of the Vector API-based implementation to be anywhere near the performance of the native code (~3x difference for the core loop on a CPU with AVX2).
>
>
> private static final VectorShuffle<Long> LONG_SHUFFLE_PREFERRED =
> VectorShuffle.fromOp(LongVector.SPECIES_PREFERRED, i -> i ^ 1);
>
> ...
>
> for (int block = 0; block < input.length / 1024; block++) {
> for (int stripe = 0; stripe < 16; stripe++) {
> int inputOffset = block * 1024 + stripe * 64;
> int secretOffset = stripe * 8;
>
> for (int i = 0; i < 8; i +=
> LongVector.SPECIES_PREFERRED.length()) {
> LongVector accumulatorsVector =
> LongVector.fromArray(LongVector.SPECIES_PREFERRED, accumulators, i);
> LongVector inputVector =
> ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, input, inputOffset +
> i * 8).reinterpretAsLongs();
> LongVector secretVector =
> ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, SECRET, secretOffset
> + i * 8).reinterpretAsLongs();
>
> LongVector key = inputVector
> .lanewise(XOR, secretVector)
> .reinterpretAsLongs();
>
> LongVector low = key.and(0xFFFF_FFFFL);
> LongVector high = key.lanewise(LSHR, 32);
>
> accumulatorsVector
> .add(inputVector.rearrange(LONG_SHUFFLE_PREFERRED))
> .add(high.mul(low))
> .intoArray(accumulators, i);
> }
> }
> }
>
> It generates the following assembly (loop unrolling disabled for clarity):
>
> ...
> 0x0000762f8044b730: lea r11d,[r8*8+0x0]
> 0x0000762f8044b738: movsxd r11,r11d
> 0x0000762f8044b73b: vmovdqu ymm0,YMMWORD PTR [r14+r11*1+0x10]
> 0x0000762f8044b742: vmovdqu ymm1,YMMWORD PTR [r13+r11*1+0x10]
> 0x0000762f8044b749: vpshufd ymm2,ymm1,0xb1
> 0x0000762f8044b74e: vpmulld ymm2,ymm0,ymm2
> 0x0000762f8044b753: vpshufd ymm3,ymm2,0xb1
> 0x0000762f8044b758: vpaddd ymm3,ymm3,ymm2
> 0x0000762f8044b75c: vpsllq ymm3,ymm3,0x20
> 0x0000762f8044b761: vpmuludq ymm2,ymm0,ymm1
> 0x0000762f8044b765: vpaddq ymm0,ymm2,ymm3
> 0x0000762f8044b769: vmovdqu YMMWORD PTR [rdi+r8*8+0x10],ymm0
> 0x0000762f8044b770: add r8d,0x4
> 0x0000762f8044b774: cmp r8d,0x8
> 0x0000762f8044b778: jl 0x0000762f8044b730
> ...
>
> The native implementation for AVX2 looks like this:
>
> __attribute__((aligned(32))) uint64_t accumulators[8] = {};
> __m256i* const xacc = (__m256i*) accumulators;
>
> for (size_t block = 0; block < length / 1024; block++) {
> for (size_t stripe = 0; stripe < 16; stripe++) {
> unsigned char* in = input + block * 1024 + stripe * 64;
> unsigned char* secret = SECRET + stripe * 8;
>
> const __m256i* const xinput = (const __m256i *) in;
> const __m256i* const xsecret = (const __m256i *) secret;
> for (size_t i = 0; i < 2; i++) {
> __m256i const data_vec = _mm256_loadu_si256(xinput +
> i); // data_vec = xinput[i];
> __m256i const key_vec = _mm256_loadu_si256(xsecret
> + i); // key_vec = xsecret[i];
> __m256i const data_key = _mm256_xor_si256(data_vec,
> key_vec); // data_key = data_vec ^ key_vec;
> __m256i const data_key_lo = _mm256_srli_epi64(data_key,
> 32); // data_key_lo = data_key >> 32;
> __m256i const product = _mm256_mul_epu32(data_key,
> data_key_lo); // product = (data_key & 0xffffffff) * (data_key_lo &
> 0xffffffff);
> __m256i const data_swap =
> _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2)); // xacc[i] +=
> swap(data_vec);
> __m256i const sum = _mm256_add_epi64(xacc[i],
> data_swap); // xacc[i] += product;
> xacc[i] = _mm256_add_epi64(product, sum);
> }
> }
>
> The corresponding assembly is:
>
> 1198: vmovdqu ymm4,YMMWORD PTR [rax-0x20]
> 119d: vmovdqu ymm5,YMMWORD PTR [rax]
> 11a1: add rax,0x8
> 11a5: add rdx,0x40
> 11a9: vpxor ymm0,ymm4,YMMWORD PTR [rdx-0x60]
> 11ae: vpsrlq ymm1,ymm0,0x20
> 11b3: vpmuludq ymm0,ymm0,ymm1
> 11b7: vpshufd ymm1,YMMWORD PTR [rdx-0x60],0x4e
> 11bd: vpaddq ymm0,ymm0,ymm1
> 11c1: vpaddq ymm3,ymm0,ymm3
> 11c5: vpxor ymm0,ymm5,YMMWORD PTR [rdx-0x40]
> 11ca: vpsrlq ymm1,ymm0,0x20
> 11cf: vpmuludq ymm0,ymm0,ymm1
> 11d3: vpshufd ymm1,YMMWORD PTR [rdx-0x40],0x4e
> 11d9: vpaddq ymm0,ymm0,ymm1
> 11dd: vpaddq ymm2,ymm0,ymm2
> 11e1: cmp rcx,rax
> 11e4: jne 1198
>
> As far as I can tell, the main difference is in how the multiplication
> is performed. The native code uses _mm256_mul_epu32 to perform the
> equivalent of "(v & 0xFFFF_FFFF) * (v >>> 32)", and it emits a single
> vpmuludq instruction.
>
> On the other hand, the Java implementation does not seem to understand
> that only the lower 32 bits of each lane are set and does the full 64bit
> x 64bit product (if I'm interpreting this correctly):
>
> 0x0000762f8044b749: vpshufd ymm2,ymm1,0xb1
> 0x0000762f8044b74e: vpmulld ymm2,ymm0,ymm2
> 0x0000762f8044b753: vpshufd ymm3,ymm2,0xb1
> 0x0000762f8044b758: vpaddd ymm3,ymm3,ymm2
> 0x0000762f8044b75c: vpsllq ymm3,ymm3,0x20
> 0x0000762f8044b761: vpmuludq ymm2,ymm0,ymm1
>
> Is there any way to perform a 32x32->64 bit product, or provide enough
> structure for the compiler to realize it doesn't need to consider the
> upper 32 bits when computing the product, since they are all zeros?
>
> Anything else I'm missing?
>
>
> Thanks,
> - Martin
>
More information about the panama-dev
mailing list