Vector API performance issues with port of XXH3
Martin Traverso
mtraverso at gmail.com
Thu Jul 11 01:27:52 UTC 2024
Hi,
Following up on my attempts to port XXH3 to Java (
https://github.com/Cyan4973/xxHash), I'd like to ask for some advice. The
core loop of that algorithm uses SIMD, with custom implementations for
NEON, AVX2, AVX512, etc. I have been unable to get performance of the
Vector API-based implementation to be anywhere near the performance of the
native code (~3x difference for the core loop on a CPU with AVX2).
private static final VectorShuffle<Long> LONG_SHUFFLE_PREFERRED =
VectorShuffle.fromOp(LongVector.SPECIES_PREFERRED, i -> i ^ 1);
...
for (int block = 0; block < input.length / 1024; block++) {
for (int stripe = 0; stripe < 16; stripe++) {
int inputOffset = block * 1024 + stripe * 64;
int secretOffset = stripe * 8;
for (int i = 0; i < 8; i +=
LongVector.SPECIES_PREFERRED.length()) {
LongVector accumulatorsVector =
LongVector.fromArray(LongVector.SPECIES_PREFERRED, accumulators, i);
LongVector inputVector =
ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, input, inputOffset + i *
8).reinterpretAsLongs();
LongVector secretVector =
ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, SECRET, secretOffset + i
* 8).reinterpretAsLongs();
LongVector key = inputVector
.lanewise(XOR, secretVector)
.reinterpretAsLongs();
LongVector low = key.and(0xFFFF_FFFFL);
LongVector high = key.lanewise(LSHR, 32);
accumulatorsVector
.add(inputVector.rearrange(LONG_SHUFFLE_PREFERRED))
.add(high.mul(low))
.intoArray(accumulators, i);
}
}
}
It generates the following assembly (loop unrolling disabled for clarity):
...
0x0000762f8044b730: lea r11d,[r8*8+0x0]
0x0000762f8044b738: movsxd r11,r11d
0x0000762f8044b73b: vmovdqu ymm0,YMMWORD PTR [r14+r11*1+0x10]
0x0000762f8044b742: vmovdqu ymm1,YMMWORD PTR [r13+r11*1+0x10]
0x0000762f8044b749: vpshufd ymm2,ymm1,0xb1
0x0000762f8044b74e: vpmulld ymm2,ymm0,ymm2
0x0000762f8044b753: vpshufd ymm3,ymm2,0xb1
0x0000762f8044b758: vpaddd ymm3,ymm3,ymm2
0x0000762f8044b75c: vpsllq ymm3,ymm3,0x20
0x0000762f8044b761: vpmuludq ymm2,ymm0,ymm1
0x0000762f8044b765: vpaddq ymm0,ymm2,ymm3
0x0000762f8044b769: vmovdqu YMMWORD PTR [rdi+r8*8+0x10],ymm0
0x0000762f8044b770: add r8d,0x4
0x0000762f8044b774: cmp r8d,0x8
0x0000762f8044b778: jl 0x0000762f8044b730
...
The native implementation for AVX2 looks like this:
__attribute__((aligned(32))) uint64_t accumulators[8] = {};
__m256i* const xacc = (__m256i*) accumulators;
for (size_t block = 0; block < length / 1024; block++) {
for (size_t stripe = 0; stripe < 16; stripe++) {
unsigned char* in = input + block * 1024 + stripe * 64;
unsigned char* secret = SECRET + stripe * 8;
const __m256i* const xinput = (const __m256i *) in;
const __m256i* const xsecret = (const __m256i *) secret;
for (size_t i = 0; i < 2; i++) {
__m256i const data_vec = _mm256_loadu_si256(xinput + i);
// data_vec = xinput[i];
__m256i const key_vec = _mm256_loadu_si256(xsecret +
i); // key_vec = xsecret[i];
__m256i const data_key = _mm256_xor_si256(data_vec,
key_vec); // data_key = data_vec ^ key_vec;
__m256i const data_key_lo = _mm256_srli_epi64(data_key,
32); // data_key_lo = data_key >> 32;
__m256i const product = _mm256_mul_epu32(data_key,
data_key_lo); // product = (data_key & 0xffffffff) * (data_key_lo &
0xffffffff);
__m256i const data_swap = _mm256_shuffle_epi32(data_vec,
_MM_SHUFFLE(1, 0, 3, 2)); // xacc[i] += swap(data_vec);
__m256i const sum = _mm256_add_epi64(xacc[i],
data_swap); // xacc[i] += product;
xacc[i] = _mm256_add_epi64(product, sum);
}
}
The corresponding assembly is:
1198: vmovdqu ymm4,YMMWORD PTR [rax-0x20]
119d: vmovdqu ymm5,YMMWORD PTR [rax]
11a1: add rax,0x8
11a5: add rdx,0x40
11a9: vpxor ymm0,ymm4,YMMWORD PTR [rdx-0x60]
11ae: vpsrlq ymm1,ymm0,0x20
11b3: vpmuludq ymm0,ymm0,ymm1
11b7: vpshufd ymm1,YMMWORD PTR [rdx-0x60],0x4e
11bd: vpaddq ymm0,ymm0,ymm1
11c1: vpaddq ymm3,ymm0,ymm3
11c5: vpxor ymm0,ymm5,YMMWORD PTR [rdx-0x40]
11ca: vpsrlq ymm1,ymm0,0x20
11cf: vpmuludq ymm0,ymm0,ymm1
11d3: vpshufd ymm1,YMMWORD PTR [rdx-0x40],0x4e
11d9: vpaddq ymm0,ymm0,ymm1
11dd: vpaddq ymm2,ymm0,ymm2
11e1: cmp rcx,rax
11e4: jne 1198
As far as I can tell, the main difference is in how the multiplication is
performed. The native code uses _mm256_mul_epu32 to perform the equivalent
of "(v & 0xFFFF_FFFF) * (v >>> 32)", and it emits a single vpmuludq
instruction.
On the other hand, the Java implementation does not seem to understand that
only the lower 32 bits of each lane are set and does the full 64bit x 64bit
product (if I'm interpreting this correctly):
0x0000762f8044b749: vpshufd ymm2,ymm1,0xb1
0x0000762f8044b74e: vpmulld ymm2,ymm0,ymm2
0x0000762f8044b753: vpshufd ymm3,ymm2,0xb1
0x0000762f8044b758: vpaddd ymm3,ymm3,ymm2
0x0000762f8044b75c: vpsllq ymm3,ymm3,0x20
0x0000762f8044b761: vpmuludq ymm2,ymm0,ymm1
Is there any way to perform a 32x32->64 bit product, or provide enough
structure for the compiler to realize it doesn't need to consider the upper
32 bits when computing the product, since they are all zeros?
Anything else I'm missing?
Thanks,
- Martin
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mail.openjdk.org/pipermail/panama-dev/attachments/20240710/5837e47d/attachment-0001.htm>
More information about the panama-dev
mailing list