Vector API performance issues with port of XXH3

Martin Traverso mtraverso at gmail.com
Thu Jul 11 01:27:52 UTC 2024


Hi,

Following up on my attempts to port XXH3 to Java (
https://github.com/Cyan4973/xxHash), I'd like to ask for some advice. The
core loop of that algorithm uses SIMD, with custom implementations for
NEON, AVX2, AVX512, etc. I have been unable to get performance of the
Vector API-based implementation to be anywhere near the performance of the
native code (~3x difference for the core loop on a CPU with AVX2).

    private static final VectorShuffle<Long> LONG_SHUFFLE_PREFERRED =
VectorShuffle.fromOp(LongVector.SPECIES_PREFERRED, i -> i ^ 1);

    ...

    for (int block = 0; block < input.length / 1024; block++) {
        for (int stripe = 0; stripe < 16; stripe++) {
            int inputOffset = block * 1024 + stripe * 64;
            int secretOffset = stripe * 8;

            for (int i = 0; i < 8; i +=
LongVector.SPECIES_PREFERRED.length()) {
                LongVector accumulatorsVector =
LongVector.fromArray(LongVector.SPECIES_PREFERRED, accumulators, i);
                LongVector inputVector =
ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, input, inputOffset + i *
8).reinterpretAsLongs();
                LongVector secretVector =
ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, SECRET, secretOffset + i
* 8).reinterpretAsLongs();

                LongVector key = inputVector
                        .lanewise(XOR, secretVector)
                        .reinterpretAsLongs();

                LongVector low = key.and(0xFFFF_FFFFL);
                LongVector high = key.lanewise(LSHR, 32);

                accumulatorsVector
                        .add(inputVector.rearrange(LONG_SHUFFLE_PREFERRED))
                        .add(high.mul(low))
                        .intoArray(accumulators, i);
            }
        }
    }

It generates the following assembly (loop unrolling disabled for clarity):

  ...
  0x0000762f8044b730:   lea    r11d,[r8*8+0x0]
  0x0000762f8044b738:   movsxd r11,r11d
  0x0000762f8044b73b:   vmovdqu ymm0,YMMWORD PTR [r14+r11*1+0x10]
  0x0000762f8044b742:   vmovdqu ymm1,YMMWORD PTR [r13+r11*1+0x10]
  0x0000762f8044b749:   vpshufd ymm2,ymm1,0xb1
  0x0000762f8044b74e:   vpmulld ymm2,ymm0,ymm2
  0x0000762f8044b753:   vpshufd ymm3,ymm2,0xb1
  0x0000762f8044b758:   vpaddd ymm3,ymm3,ymm2
  0x0000762f8044b75c:   vpsllq ymm3,ymm3,0x20
  0x0000762f8044b761:   vpmuludq ymm2,ymm0,ymm1
  0x0000762f8044b765:   vpaddq ymm0,ymm2,ymm3
  0x0000762f8044b769:   vmovdqu YMMWORD PTR [rdi+r8*8+0x10],ymm0
  0x0000762f8044b770:   add    r8d,0x4
  0x0000762f8044b774:   cmp    r8d,0x8
  0x0000762f8044b778:   jl     0x0000762f8044b730
  ...

The native implementation for AVX2 looks like this:

    __attribute__((aligned(32))) uint64_t accumulators[8] = {};
    __m256i* const xacc = (__m256i*) accumulators;

    for (size_t block = 0; block < length / 1024; block++) {
        for (size_t stripe = 0; stripe < 16; stripe++) {
            unsigned char* in = input + block * 1024 + stripe * 64;
            unsigned char* secret = SECRET + stripe * 8;

            const __m256i* const xinput  = (const __m256i *) in;
            const __m256i* const xsecret = (const __m256i *) secret;
            for (size_t i = 0; i < 2; i++) {
                __m256i const data_vec    = _mm256_loadu_si256(xinput + i);
// data_vec = xinput[i];
                __m256i const key_vec     = _mm256_loadu_si256(xsecret +
i); // key_vec = xsecret[i];
                __m256i const data_key    = _mm256_xor_si256(data_vec,
key_vec); // data_key = data_vec ^ key_vec;
                __m256i const data_key_lo = _mm256_srli_epi64(data_key,
32); // data_key_lo = data_key >> 32;
                __m256i const product     = _mm256_mul_epu32(data_key,
data_key_lo); // product = (data_key & 0xffffffff) * (data_key_lo &
0xffffffff);
                __m256i const data_swap   = _mm256_shuffle_epi32(data_vec,
_MM_SHUFFLE(1, 0, 3, 2)); // xacc[i] += swap(data_vec);
                __m256i const sum         = _mm256_add_epi64(xacc[i],
data_swap); // xacc[i] += product;
                xacc[i]                   = _mm256_add_epi64(product, sum);
            }
    }

The corresponding assembly is:

    1198:   vmovdqu ymm4,YMMWORD PTR [rax-0x20]
    119d:   vmovdqu ymm5,YMMWORD PTR [rax]
    11a1:   add    rax,0x8
    11a5:   add    rdx,0x40
    11a9:   vpxor  ymm0,ymm4,YMMWORD PTR [rdx-0x60]
    11ae:   vpsrlq ymm1,ymm0,0x20
    11b3:   vpmuludq ymm0,ymm0,ymm1
    11b7:   vpshufd ymm1,YMMWORD PTR [rdx-0x60],0x4e
    11bd:   vpaddq ymm0,ymm0,ymm1
    11c1:   vpaddq ymm3,ymm0,ymm3
    11c5:   vpxor  ymm0,ymm5,YMMWORD PTR [rdx-0x40]
    11ca:   vpsrlq ymm1,ymm0,0x20
    11cf:   vpmuludq ymm0,ymm0,ymm1
    11d3:   vpshufd ymm1,YMMWORD PTR [rdx-0x40],0x4e
    11d9:   vpaddq ymm0,ymm0,ymm1
    11dd:   vpaddq ymm2,ymm0,ymm2
    11e1:   cmp    rcx,rax
    11e4:   jne    1198

As far as I can tell, the main difference is in how the multiplication is
performed. The native code uses _mm256_mul_epu32 to perform the equivalent
of "(v & 0xFFFF_FFFF) * (v >>> 32)", and it emits a single vpmuludq
instruction.

On the other hand, the Java implementation does not seem to understand that
only the lower 32 bits of each lane are set and does the full 64bit x 64bit
product (if I'm interpreting this correctly):

0x0000762f8044b749:   vpshufd ymm2,ymm1,0xb1
0x0000762f8044b74e:   vpmulld ymm2,ymm0,ymm2
0x0000762f8044b753:   vpshufd ymm3,ymm2,0xb1
0x0000762f8044b758:   vpaddd ymm3,ymm3,ymm2
0x0000762f8044b75c:   vpsllq ymm3,ymm3,0x20
0x0000762f8044b761:   vpmuludq ymm2,ymm0,ymm1

Is there any way to perform a 32x32->64 bit product, or provide enough
structure for the compiler to realize it doesn't need to consider the upper
32 bits when computing the product, since they are all zeros?

Anything else I'm missing?

Thanks,
- Martin
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mail.openjdk.org/pipermail/panama-dev/attachments/20240710/5837e47d/attachment-0001.htm>


More information about the panama-dev mailing list