<div dir="ltr"><div dir="ltr"><div class="gmail_default" style=""><font face="arial, sans-serif">Hi Vladimir,</font><br><br><font face="arial, sans-serif">Thank you for the quick response and the valuable point.</font><br><br><font face="arial, sans-serif">Your suggestion helps me to double down and think deeper on using ByteVector.rearrange(). Initially, I was trying to use FloatVector for the lookup table:</font><br><br><font face="arial, sans-serif">```</font><br><font face="arial, sans-serif">private static final VectorSpecies<Float> FLOAT_SPECIES = FloatVector.SPECIES_128;</font><br><br><font face="arial, sans-serif">private static final float[] MXFP4_VALUES = {</font><br><font face="arial, sans-serif"> +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,</font><br><font face="arial, sans-serif"> -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f</font><br><font face="arial, sans-serif">};</font><br><br><font face="arial, sans-serif">private static final FloatVector MXFP4_TABLE = FloatVector.fromArray(FLOAT_SPECIES, MXFP4_VALUES, 0);</font><br><font face="arial, sans-serif">ByteVector wBytes = ByteVector.fromArray(ByteVector.SPECIES_128, blocks, blockStart);</font><br><br><font face="arial, sans-serif">ByteVector loNibbles = wBytes.and((byte) 0x0F);</font><br><font face="arial, sans-serif">ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR, 4);</font><br><br><font face="arial, sans-serif">VectorShuffle<Float> loShuffle = loNibbles.castShape(FLOAT_SPECIES, j).toShuffle();</font><br><font face="arial, sans-serif">VectorShuffle<Float> hiShuffle = hiNibbles.castShape(FLOAT_SPECIES, j).toShuffle();</font><br><br><font face="arial, sans-serif">FloatVector loW = MXFP4_TABLE.rearrange(loShuffle);</font><br><font face="arial, sans-serif">FloatVector hiW = MXFP4_TABLE.rearrange(hiShuffle);</font><br><font face="arial, sans-serif">```</font><br><br><font face="arial, sans-serif">This turned out to be not working as expected. </font><br><br><font face="arial, sans-serif">Thanks to your insight about ByteVector.rearrange(), I’ve had a great idea, I can multiply and scale the MXFP4_VALUES by 10x, so that I can store them as ByteVector, and use later. During inference time, I can multiply them by 0.1f back. My new code would be something like below:</font><br><br><font face="arial, sans-serif">```</font><br><font face="arial, sans-serif">/**</font><br><font face="arial, sans-serif"> * MXFP4_BYTES contains values that are 10x larger than the actual {@link #MXFP4_VALUES}.</font><br><font face="arial, sans-serif"> *</font><br><font face="arial, sans-serif"> * <p>Since Java Vector API <code>ByteVector.rearrange()</code> can only use byte values as lookup indices, but</font><br><font face="arial, sans-serif"> * MXFP4 values like 0.5f, 1.5f cannot be directly stored as bytes, so here, we store 5, 15 (10x scaled) as bytes,</font><br><font face="arial, sans-serif"> * then multiply by 0.1 during runtime lookup.</font><br><font face="arial, sans-serif"> *</font><br><font face="arial, sans-serif"> * <p>In addition, in order to eliminates the need for 0.1 multiplication during every MXFP4 dot product</font><br><font face="arial, sans-serif"> * computation, the 0.1 factor is applied once during model loading rather than repeatedly during inference.</font><br><font face="arial, sans-serif"> * You can check out {@link ModelBinLoader#loadU8TensorsAsFloatArrays(String, long, int)} to see each scale is</font><br><font face="arial, sans-serif"> * multiplied by 0.1 and store the pre-scaled values in mlp1Scales[] and mlp2Scales[] arrays.</font><br><font face="arial, sans-serif"> */</font><br><font face="arial, sans-serif">private static final byte[] MXFP4_BYTES = {</font><br><font face="arial, sans-serif"> 0, 5, 10, 15, 20, 30, 40, 60,</font><br><font face="arial, sans-serif"> 0, -5, -10, -15, -20, -30, -40, -60</font><br><font face="arial, sans-serif">};</font><br><font face="arial, sans-serif">private static final ByteVector MXFP4_BYTE_TABLE = ByteVector.fromArray(B_SPECIES, MXFP4_BYTES, 0); </font><br><br><font face="arial, sans-serif">ByteVector weights = ByteVector.fromMemorySegment(B_SPECIES, allExpertWeights, (long) blockIdx * 16, ByteOrder.LITTLE_ENDIAN);</font><br><br><font face="arial, sans-serif">ByteVector loNibbles = weights.and((byte) 0x0F);</font><br><font face="arial, sans-serif">ByteVector hiNibbles = weights.lanewise(VectorOperators.LSHR, 4);</font><br><br><font face="arial, sans-serif">ByteVector loWeights = MXFP4_BYTE_TABLE.rearrange(loNibbles.toShuffle());</font><br><font face="arial, sans-serif">ByteVector hiWeights = MXFP4_BYTE_TABLE.rearrange(hiNibbles.toShuffle());</font><br><font face="arial, sans-serif">```</font><br><br><font face="arial, sans-serif">Since I still find it hard to figure out how to write nibble interleaving code, I took a different approach by rearranging the input vector to match the MXFP4 unpacking pattern. Since MXFP4 weights are stored as 4-bit values packed 2 per byte. After unpacking, by using Java Vector API's ByteVector.rearrange(), they naturally separate into even/odd indexed weight vectors: </font><br><br><font face="arial, sans-serif">```</font><br><font face="arial, sans-serif">loWeights: w0, w2, w4, w6, ... (even-indexed weights)</font><br><font face="arial, sans-serif">hiWeights: w1, w3, w5, w7, ... (odd-indexed weights)</font><br><font face="arial, sans-serif">```</font><br><br><font face="arial, sans-serif">So I reshape the input vector, take 128-bit vector size for example, I organize 4 floats per vector to match the weights layout.</font><br><font face="arial, sans-serif">```</font><br><font face="nsimsun, monospace">Input vector : [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, ...]<br>Reshaped : [x0, x2, x4, x6, x1, x3, x5, x7, x8, x10, x12, x14, x9, x11, x13, x15, ...]<br> └─── even ────┘ └─── odd ────┘ └──── even ─────┘ └──── odd ─────┘</font><br><font face="arial, sans-serif">```</font><br><br><font face="arial, sans-serif">The change is now working with MXFP4 weights in my gpt-oss.java. Performance is very close to BF16 but with a slight slowdown (a bit shock..). I'll continue the work and profile, will consult again and share my experience of the usage of Vector API in a modern Java-based LLM inference program later.</font><br><br><font face="arial, sans-serif">Thank you again for pointing me toward a solution!</font><br><br><font face="arial, sans-serif">Thanks,</font><br><font face="arial, sans-serif">Xu</font><br></div></div><br><div class="gmail_quote gmail_quote_container"><div dir="ltr" class="gmail_attr">Vladimir Ivanov <<a href="mailto:vladimir.x.ivanov@oracle.com">vladimir.x.ivanov@oracle.com</a>> 于2025年10月21日周二 01:42写道:<br></div><blockquote class="gmail_quote" style="margin:0px 0px 0px 0.8ex;border-left:1px solid rgb(204,204,204);padding-left:1ex">Hi Xu,<br>
<br>
I didn't look at your benchmark in details, but I see that <br>
ByteVector.rearrange() on Byte128Vector is lowered into pshufb [1]. Have <br>
you tried it?<br>
<br>
Best regards,<br>
Vladimir Ivanov<br>
<br>
[1] <br>
<a href="https://github.com/openjdk/jdk/blob/master/src/hotspot/cpu/x86/x86.ad#L8769" rel="noreferrer" target="_blank">https://github.com/openjdk/jdk/blob/master/src/hotspot/cpu/x86/x86.ad#L8769</a><br>
<br>
On 10/19/25 22:16, xu wrote:<br>
> Hi all,<br>
> <br>
> I’d like to seek some help from Project Panama team regarding efficient <br>
> implementation of MXFP4 dequantization using the Java Vector API.<br>
> <br>
> *## Background*<br>
> <br>
> I’ve been developing a pure Java implementation of OpenAI’s gpt-oss <br>
> inference program, optimized for CPU execution, github: amzn/gpt- <br>
> oss.java <<a href="https://github.com/amzn/gpt-oss.java" rel="noreferrer" target="_blank">https://github.com/amzn/gpt-oss.java</a>>. The model uses MXFP4 <br>
> weights, and I’m exploring efficient ways to handle them in the <br>
> computation-intensive MLP layer. However, I run into limitations with <br>
> the current Vector API (up to JDK 24) that seem to prevent a fully <br>
> vectorized implementation.<br>
> <br>
> *## MXFP4 layout*<br>
> In C++, each Mixture of Expert weight byte stores two 4-bit weights <br>
> (nibbles). These 4-bit values are used as indices into a predefined 16- <br>
> entry lookup table:<br>
> ```<br>
> kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, <br>
> -1.5f, -2.f, -3.f, -4.f, -6.f]<br>
> ```<br>
> <br>
> *## llama.cpp implementation*<br>
> <br>
> Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation <br>
> (e.g., quants.c in llama.cpp <<a href="https://github.com/ggml-org/llama.cpp/" rel="noreferrer" target="_blank">https://github.com/ggml-org/llama.cpp/</a> <br>
> blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/ <br>
> x86/quants.c#L790>) can perform 16 parallel table lookups per <br>
> instruction — converting nibbles to float values efficiently, take AVX2 <br>
> for example:<br>
> ```<br>
> #if defined __AVX2__<br>
> <br>
> const __m128i values128 = _mm_loadu_si128((const <br>
> __m128i*)kvalues_mxfp4);<br>
> const __m128i m4b = _mm_set1_epi8(0x0f);<br>
> const __m256i mone = _mm256_set1_epi16(1);<br>
> <br>
> __m256 accum1 = _mm256_setzero_ps();<br>
> __m256 accum2 = _mm256_setzero_ps();<br>
> for (; ib + 1 < nb; ib += 2) {<br>
> const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + <br>
> 0].qs);<br>
> const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + <br>
> 1].qs);<br>
> const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib <br>
> + 0].qs);<br>
> const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib <br>
> + 1].qs);<br>
> const __m256i q4b_1 = <br>
> MM256_SET_M128I(_mm_shuffle_epi8(values128, <br>
> _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),<br>
> <br>
> _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));<br>
> const __m256i q4b_2 = <br>
> MM256_SET_M128I(_mm_shuffle_epi8(values128, <br>
> _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),<br>
> <br>
> _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));<br>
> const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);<br>
> const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);<br>
> const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);<br>
> const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);<br>
> accum1 = <br>
> _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + <br>
> 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),<br>
> _mm256_cvtepi32_ps(p_1), accum1);<br>
> accum2 = <br>
> _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + <br>
> 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),<br>
> _mm256_cvtepi32_ps(p_2), accum2);<br>
> }<br>
> <br>
> sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));<br>
> ```<br>
> <br>
> *## Problem in Java Vector API*<br>
> <br>
> In Java, there doesn’t seem to be an equivalent way to perform dynamic <br>
> per-byte lookups using `ByteVector`.<br>
> <br>
> While `VectorShuffle` supports lane rearrangements, its indices must <br>
> fall within 0 to 15 which are elements indices of the kvalues_mxfp4, so <br>
> it only works for rearrangements of 512-bit vectors with exactly 16 <br>
> float lanes. This means that dynamic, per-lane table lookups like <br>
> `_mm_shuffle_epi8` cannot be expressed directly in Java.<br>
> <br>
> The current workaround requires scalar operations:<br>
> ```<br>
> for (int i = 0; i < lanes; i++) {<br>
> int nibble = byteVector.lane(i) & 0x0F;<br>
> weights[i] = MXFP4_VALUES[nibble]; // Scalar operation<br>
> }<br>
> ```<br>
> <br>
> This prevents MXFP4 dequantization from being fully vectorized — each <br>
> nibble must be individually decoded and indexed in scalar form.<br>
> <br>
> Below is a experimental version for 256-bit vector:<br>
> <br>
> ```<br>
> private static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_128;<br>
> private static final VectorSpecies<Float> F_SPECIES = <br>
> FloatVector.SPECIES_256;<br>
> <br>
> private static final float[] MXFP4_VALUES = {<br>
> +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,<br>
> -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f<br>
> };<br>
> <br>
> /**<br>
> * @param vec input vector length == 2880 (float32)<br>
> * @param allExpertWeights MemorySegment storing MXFP4 values to dot <br>
> product the input vector<br>
> * @param offset Expert weights offset in terms of the total <br>
> amount of weights, not physical byte offset<br>
> * @param n input dim<br>
> * @param scales scales for the weights<br>
> * @return Output vector<br>
> */<br>
> private static float vectorizedMXFP4DotProduct(final float[] vec,<br>
> final MemorySegment <br>
> allExpertWeights,<br>
> final int offset,<br>
> final int n,<br>
> final float[] scales) {<br>
> float acc = 0.0f;<br>
> int i = 0;<br>
> <br>
> FloatVector accV = FloatVector.zero(F_SPECIES);<br>
> float[] weights = new float[F_SPECIES.length()];<br>
> <br>
> // ---------- The below code is experimental to use MXFP4 precision <br>
> weights for projection !!! -----------<br>
> // Process in blocks of 32 elements (16 bytes = 32 FP4 values)<br>
> // Take 256-bit lane sized vector for instance which means we can <br>
> process 8 floats in instruction-parallel way<br>
> // To build a JAR, skip test cases, just run ./gradlew shadowJar<br>
> while (i + 32 <= n) {<br>
> int blockIdx = (offset + i) / 32;<br>
> float scale = scales[blockIdx];<br>
> FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);<br>
> <br>
> ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES, <br>
> allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);<br>
> ByteVector loNibbles = wBytes.and((byte) 0x0F);<br>
> ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR, <br>
> 4).and((byte) 0x0F);<br>
> <br>
> for (int bytePair = 0; bytePair < 4; bytePair++) {<br>
> int byte1 = bytePair * 4;<br>
> int byte2 = bytePair * 4 + 1;<br>
> int byte3 = bytePair * 4 + 2;<br>
> int byte4 = bytePair * 4 + 3;<br>
> <br>
> //!!! :( this is where we can not do parallel execution <br>
> unless it is 512-bitvector unit<br>
> weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];<br>
> weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];<br>
> weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];<br>
> weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];<br>
> <br>
> weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];<br>
> weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];<br>
> weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];<br>
> weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];<br>
> <br>
> FloatVector wv = FloatVector.fromArray(F_SPECIES, weights, 0);<br>
> int vecPos = i + (bytePair * 8);<br>
> FloatVector xv = FloatVector.fromArray(F_SPECIES, vec, vecPos);<br>
> <br>
> accV = wv.mul(vScales).fma(xv, accV);<br>
> }<br>
> <br>
> i += 32;<br>
> }<br>
> <br>
> acc = accV.reduceLanes(VectorOperators.ADD);<br>
> <br>
> // Handle remaining elements<br>
> while (i < n) {<br>
> int blockIdx = (offset + i) / 32;<br>
> int elemInBlock = (offset + i) % 32;<br>
> float scale = scales[blockIdx];<br>
> <br>
> int byteIdx = elemInBlock / 2;<br>
> int blockStart = blockIdx * 16;<br>
> byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE, <br>
> blockStart + byteIdx);<br>
> <br>
> int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) : <br>
> ((packed >> 4) & 0x0F);<br>
> float weight = MXFP4_VALUES[fp4Idx] * scale;<br>
> acc += weight * vec[i];<br>
> i++;<br>
> }<br>
> <br>
> return acc;<br>
> }<br>
> ```<br>
> <br>
> *## Questions*<br>
> <br>
> 1. Is my understanding correct that the current Java Vector API cannot <br>
> perform dynamic per-byte table lookups? If yes, since `VectorShuffle` <br>
> only works with compile-time matched constant indices (and effectively <br>
> only within 512-bit / 16-float lane vectors), is there any plan to <br>
> supportin future releases?<br>
> <br>
> 2. As a follow-up, MXFP4 decoding also requires interleaving low and <br>
> high nibbles — are there any performant ways or patterns in the API to <br>
> handle this efficiently?<br>
> <br>
> Any insights from the Project Panama developers would be greatly <br>
> appreciated!<br>
> <br>
> Thanks,<br>
> Xu<br>
<br>
</blockquote></div><div><br clear="all"></div><div><br></div><span class="gmail_signature_prefix">-- </span><br><div dir="ltr" class="gmail_signature"><div dir="ltr"><div><div dir="ltr">
<p><span>Best regards,</span></p>
<p><span>Xu</span></p></div></div></div></div></div>