<html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:w="urn:schemas-microsoft-com:office:word" xmlns:m="http://schemas.microsoft.com/office/2004/12/omml" xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<meta name="Generator" content="Microsoft Word 15 (filtered medium)">
<style><!--
/* Font Definitions */
@font-face
{font-family:"Cambria Math";
panose-1:2 4 5 3 5 4 6 3 2 4;}
@font-face
{font-family:Calibri;
panose-1:2 15 5 2 2 2 4 3 2 4;}
@font-face
{font-family:Aptos;}
/* Style Definitions */
p.MsoNormal, li.MsoNormal, div.MsoNormal
{margin:0cm;
font-size:12.0pt;
font-family:"Aptos",sans-serif;}
a:link, span.MsoHyperlink
{mso-style-priority:99;
color:blue;
text-decoration:underline;}
span.EmailStyle18
{mso-style-type:personal-reply;
font-family:"Aptos",sans-serif;
color:windowtext;}
.MsoChpDefault
{mso-style-type:export-only;
mso-fareast-language:EN-US;}
@page WordSection1
{size:612.0pt 792.0pt;
margin:72.0pt 72.0pt 72.0pt 72.0pt;}
div.WordSection1
{page:WordSection1;}
--></style><!--[if gte mso 9]><xml>
<o:shapedefaults v:ext="edit" spidmax="1026" />
</xml><![endif]--><!--[if gte mso 9]><xml>
<o:shapelayout v:ext="edit">
<o:idmap v:ext="edit" data="1" />
</o:shapelayout></xml><![endif]-->
</head>
<body lang="EN-IN" link="blue" vlink="purple" style="word-wrap:break-word">
<div class="WordSection1">
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">Hi Martin,<o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US"><o:p> </o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">Thanks for reporting this.<o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US"><o:p> </o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">Instruction sequence for 64x64 bit multiplier on AVX2 targets is agnostic to existence of zeroing of upper / lower double word, this is because we do not split Multiplier at IR level and depend on
any constant folding to sweep out the redundant logic, this can however be handled as a point optimization.<o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US"><o:p> </o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">I just did a quick patch[1] to attempt that, and I can see compiler is now emitting “VPMULDQ”[2]<o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US"><o:p> </o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">Best Regards,<o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">Jatin<o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">[1] <a href="https://github.com/jatin-bhateja/external_staging/blob/main/NewOperationSamples/doubleWordMultQuadWordAccum/jdk_patch.diff">
https://github.com/jatin-bhateja/external_staging/blob/main/NewOperationSamples/doubleWordMultQuadWordAccum/jdk_patch.diff</a><o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US">[2] <a href="https://www.felixcloutier.com/x86/pmuldq">
https://www.felixcloutier.com/x86/pmuldq</a><o:p></o:p></span></p>
<p class="MsoNormal"><span style="mso-fareast-language:EN-US"><o:p> </o:p></span></p>
<div style="border:none;border-left:solid blue 1.5pt;padding:0cm 0cm 0cm 4.0pt">
<div>
<div style="border:none;border-top:solid #E1E1E1 1.0pt;padding:3.0pt 0cm 0cm 0cm">
<p class="MsoNormal"><b><span lang="EN-US" style="font-size:11.0pt;font-family:"Calibri",sans-serif">From:</span></b><span lang="EN-US" style="font-size:11.0pt;font-family:"Calibri",sans-serif"> panama-dev <panama-dev-retn@openjdk.org>
<b>On Behalf Of </b>Martin Traverso<br>
<b>Sent:</b> Thursday, July 11, 2024 6:58 AM<br>
<b>To:</b> panama-dev@openjdk.org<br>
<b>Subject:</b> Vector API performance issues with port of XXH3<o:p></o:p></span></p>
</div>
</div>
<p class="MsoNormal"><o:p> </o:p></p>
<div>
<div>
<p class="MsoNormal">Hi,<o:p></o:p></p>
</div>
<div>
<p class="MsoNormal"><o:p> </o:p></p>
</div>
<div>
<p class="MsoNormal">Following up on my attempts to port XXH3 to Java (<a href="https://github.com/Cyan4973/xxHash">https://github.com/Cyan4973/xxHash</a>), I'd like to ask for some advice. The core loop of that algorithm uses SIMD, with custom implementations
for NEON, AVX2, AVX512, etc. I have been unable to get performance of the Vector API-based implementation to be anywhere near the performance of the native code (~3x difference for the core loop on a CPU with AVX2).<o:p></o:p></p>
</div>
<p class="MsoNormal" style="margin-bottom:12.0pt"><br>
private static final VectorShuffle<Long> LONG_SHUFFLE_PREFERRED = VectorShuffle.fromOp(LongVector.SPECIES_PREFERRED, i -> i ^ 1);<br>
<br>
...<br>
<br>
for (int block = 0; block < input.length / 1024; block++) {<br>
for (int stripe = 0; stripe < 16; stripe++) {<br>
int inputOffset = block * 1024 + stripe * 64;<br>
int secretOffset = stripe * 8;<br>
<br>
for (int i = 0; i < 8; i += LongVector.SPECIES_PREFERRED.length()) {<br>
LongVector accumulatorsVector = LongVector.fromArray(LongVector.SPECIES_PREFERRED, accumulators, i);<br>
LongVector inputVector = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, input, inputOffset + i * 8).reinterpretAsLongs();<br>
LongVector secretVector = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, SECRET, secretOffset + i * 8).reinterpretAsLongs();<br>
<br>
LongVector key = inputVector<br>
.lanewise(XOR, secretVector)<br>
.reinterpretAsLongs();<br>
<br>
LongVector low = key.and(0xFFFF_FFFFL);<br>
LongVector high = key.lanewise(LSHR, 32);<br>
<br>
accumulatorsVector<br>
.add(inputVector.rearrange(LONG_SHUFFLE_PREFERRED))<br>
.add(high.mul(low))<br>
.intoArray(accumulators, i);<br>
}<br>
}<br>
}<br>
<br>
It generates the following assembly (loop unrolling disabled for clarity):<br>
<br>
...<br>
0x0000762f8044b730: lea r11d,[r8*8+0x0]<br>
0x0000762f8044b738: movsxd r11,r11d<br>
0x0000762f8044b73b: vmovdqu ymm0,YMMWORD PTR [r14+r11*1+0x10]<br>
0x0000762f8044b742: vmovdqu ymm1,YMMWORD PTR [r13+r11*1+0x10]<br>
0x0000762f8044b749: vpshufd ymm2,ymm1,0xb1<br>
0x0000762f8044b74e: vpmulld ymm2,ymm0,ymm2<br>
0x0000762f8044b753: vpshufd ymm3,ymm2,0xb1<br>
0x0000762f8044b758: vpaddd ymm3,ymm3,ymm2<br>
0x0000762f8044b75c: vpsllq ymm3,ymm3,0x20<br>
0x0000762f8044b761: vpmuludq ymm2,ymm0,ymm1<br>
0x0000762f8044b765: vpaddq ymm0,ymm2,ymm3<br>
0x0000762f8044b769: vmovdqu YMMWORD PTR [rdi+r8*8+0x10],ymm0<br>
0x0000762f8044b770: add r8d,0x4<br>
0x0000762f8044b774: cmp r8d,0x8<br>
0x0000762f8044b778: jl 0x0000762f8044b730<br>
...<br>
<br>
The native implementation for AVX2 looks like this:<br>
<br>
__attribute__((aligned(32))) uint64_t accumulators[8] = {};<br>
__m256i* const xacc = (__m256i*) accumulators;<br>
<br>
for (size_t block = 0; block < length / 1024; block++) {<br>
for (size_t stripe = 0; stripe < 16; stripe++) {<br>
unsigned char* in = input + block * 1024 + stripe * 64;<br>
unsigned char* secret = SECRET + stripe * 8;<br>
<br>
const __m256i* const xinput = (const __m256i *) in;<br>
const __m256i* const xsecret = (const __m256i *) secret;<br>
for (size_t i = 0; i < 2; i++) {<br>
__m256i const data_vec = _mm256_loadu_si256(xinput + i); // data_vec = xinput[i];<br>
__m256i const key_vec = _mm256_loadu_si256(xsecret + i); // key_vec = xsecret[i];<br>
__m256i const data_key = _mm256_xor_si256(data_vec, key_vec); // data_key = data_vec ^ key_vec;<br>
__m256i const data_key_lo = _mm256_srli_epi64(data_key, 32); // data_key_lo = data_key >> 32;<br>
__m256i const product = _mm256_mul_epu32(data_key, data_key_lo); // product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff);<br>
__m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2)); // xacc[i] += swap(data_vec);<br>
__m256i const sum = _mm256_add_epi64(xacc[i], data_swap); // xacc[i] += product;<br>
xacc[i] = _mm256_add_epi64(product, sum);<br>
}<br>
}<br>
<br>
The corresponding assembly is:<br>
<br>
1198: vmovdqu ymm4,YMMWORD PTR [rax-0x20]<br>
119d: vmovdqu ymm5,YMMWORD PTR [rax]<br>
11a1: add rax,0x8<br>
11a5: add rdx,0x40<br>
11a9: vpxor ymm0,ymm4,YMMWORD PTR [rdx-0x60]<br>
11ae: vpsrlq ymm1,ymm0,0x20<br>
11b3: vpmuludq ymm0,ymm0,ymm1<br>
11b7: vpshufd ymm1,YMMWORD PTR [rdx-0x60],0x4e<br>
11bd: vpaddq ymm0,ymm0,ymm1<br>
11c1: vpaddq ymm3,ymm0,ymm3<br>
11c5: vpxor ymm0,ymm5,YMMWORD PTR [rdx-0x40]<br>
11ca: vpsrlq ymm1,ymm0,0x20<br>
11cf: vpmuludq ymm0,ymm0,ymm1<br>
11d3: vpshufd ymm1,YMMWORD PTR [rdx-0x40],0x4e<br>
11d9: vpaddq ymm0,ymm0,ymm1<br>
11dd: vpaddq ymm2,ymm0,ymm2<br>
11e1: cmp rcx,rax<br>
11e4: jne 1198<br>
<br>
As far as I can tell, the main difference is in how the multiplication is performed. The native code uses _mm256_mul_epu32 to perform the equivalent of "(v & 0xFFFF_FFFF) * (v >>> 32)", and it emits a single vpmuludq instruction.<br>
<br>
On the other hand, the Java implementation does not seem to understand that only the lower 32 bits of each lane are set and does the full 64bit x 64bit product (if I'm interpreting this correctly):<br>
<br>
0x0000762f8044b749: vpshufd ymm2,ymm1,0xb1<br>
0x0000762f8044b74e: vpmulld ymm2,ymm0,ymm2<br>
0x0000762f8044b753: vpshufd ymm3,ymm2,0xb1<br>
0x0000762f8044b758: vpaddd ymm3,ymm3,ymm2<br>
0x0000762f8044b75c: vpsllq ymm3,ymm3,0x20<br>
0x0000762f8044b761: vpmuludq ymm2,ymm0,ymm1<o:p></o:p></p>
<div>
<p class="MsoNormal">Is there any way to perform a 32x32->64 bit product, or provide enough structure for the compiler to realize it doesn't need to consider the upper 32 bits when computing the product, since they are all zeros?<o:p></o:p></p>
</div>
<div>
<p class="MsoNormal"><o:p> </o:p></p>
</div>
<div>
<p class="MsoNormal">Anything else I'm missing?<o:p></o:p></p>
</div>
<p class="MsoNormal"><br>
Thanks,<br>
- Martin<o:p></o:p></p>
</div>
</div>
</div>
</body>
</html>