<div dir="ltr">Hi all,<br><br>I’d like to seek some help from Project Panama team regarding efficient implementation of MXFP4 dequantization using the Java Vector API.<br><br><b>## Background</b><br><br>I’ve been developing a pure Java implementation of OpenAI’s gpt-oss inference program, optimized for CPU execution, github: <a href="https://github.com/amzn/gpt-oss.java">amzn/gpt-oss.java</a>. The model uses MXFP4 weights, and I’m exploring efficient ways to handle them in the computation-intensive MLP layer. However, I run into limitations with the current Vector API (up to JDK 24) that seem to prevent a fully vectorized implementation.<br><br><b>## MXFP4 layout</b><br>In C++, each Mixture of Expert weight byte stores two 4-bit weights (nibbles). These 4-bit values are used as indices into a predefined 16-entry lookup table:<br>```<br>kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f]<br>```<br><br><b>## llama.cpp implementation</b><br><br>Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation (e.g., <a href="https://github.com/ggml-org/llama.cpp/blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/x86/quants.c#L790">quants.c in llama.cpp</a>) can perform 16 parallel table lookups per instruction — converting nibbles to float values efficiently, take AVX2 for example:<br>```<br>#if defined __AVX2__<br><br>    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);<br>    const __m128i m4b  = _mm_set1_epi8(0x0f);<br>    const __m256i mone = _mm256_set1_epi16(1);<br><br>    __m256 accum1 = _mm256_setzero_ps();<br>    __m256 accum2 = _mm256_setzero_ps();<br>    for (; ib + 1 < nb; ib += 2) {<br>        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);<br>        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);<br>        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);<br>        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);<br>        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),<br>                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));<br>        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),<br>                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));<br>        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);<br>        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);<br>        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);<br>        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);<br>        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),<br>                _mm256_cvtepi32_ps(p_1), accum1);<br>        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),<br>                _mm256_cvtepi32_ps(p_2), accum2);<br>    }<br><br>    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));<br>```<br><br><b>## Problem in Java Vector API</b><br><br>In Java, there doesn’t seem to be an equivalent way to perform dynamic per-byte lookups using `ByteVector`.<br><br>While `VectorShuffle` supports lan<font face="arial, sans-serif">e rearrangements, its indices must <span class="gmail_default" style="">fall within 0 to 15 which are elements indices of the </span>kvalues_mxfp4<span class="gmail_default" style="">,</span> <span class="gmail_default" style=""></span>s<span class="gmail_default" style="">o</span> it only works for rearrangements <span class="gmail_default" style="">of </span>512-bit vectors with exactly 16 float lanes. This means that dynamic, per-lane table lookups like `_mm_shuffle_epi8` cannot be expressed directly in Java.</font><br><br>The current workaround requires scalar operations:<br>```<br>for (int i = 0; i < lanes; i++) {<br>    int nibble = byteVector.lane(i) & 0x0F;<br>    weights[i] = MXFP4_VALUES[nibble];  // Scalar operation<br>}<br>```<br><br>This prevents MXFP4 dequantization from being fully vectorized — each nibble must be individually decoded and indexed in scalar form.<br><br>Below is a experimental version for 256-bit vector:<br><br>```<br>private static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_128;<br>private static final VectorSpecies<Float> F_SPECIES = FloatVector.SPECIES_256;<br><br>private static final float[] MXFP4_VALUES = {<br>        +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f, <br>        -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f<br>};<br><br>/**<br> * @param vec              input vector length == 2880 (float32)<br> * @param allExpertWeights MemorySegment storing MXFP4 values to dot product the input vector<br> * @param offset           Expert weights offset in terms of the total amount of weights, not physical byte offset<br> * @param n                input dim<br> * @param scales           scales for the weights<br> * @return Output vector<br> */<br> private static float vectorizedMXFP4DotProduct(final float[] vec,<br>                                                final MemorySegment allExpertWeights,<br>                                                final int offset,<br>                                                final int n,<br>                                                final float[] scales) {<br>    float acc = 0.0f;<br>    int i = 0;<br><br>    FloatVector accV = FloatVector.zero(F_SPECIES);<br>    float[] weights = new float[F_SPECIES.length()];<br><br>    // ---------- The below code is experimental to use MXFP4 precision weights for projection !!! -----------<br>    // Process in blocks of 32 elements (16 bytes = 32 FP4 values)<br>    // Take 256-bit lane sized vector for instance which means we can process 8 floats in instruction-parallel way<br>    // To build a JAR, skip test cases, just run ./gradlew shadowJar<br>    while (i + 32 <= n) {<br>        int blockIdx = (offset + i) / 32;<br>        float scale = scales[blockIdx];<br>        FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);<br><br>        ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES, allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);<br>        ByteVector loNibbles = wBytes.and((byte) 0x0F);<br>        ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR, 4).and((byte) 0x0F);<br><br>        for (int bytePair = 0; bytePair < 4; bytePair++) {<br>            int byte1 = bytePair * 4;<br>            int byte2 = bytePair * 4 + 1;<br>            int byte3 = bytePair * 4 + 2;<br>            int byte4 = bytePair * 4 + 3;<br><br>            //<span class="gmail_default" style="font-family:nsimsun,monospace"></span><span class="gmail_default" style=""><font face="arial, sans-serif">!!!</font></span> :( this is where we can not do parallel execution unless it <font face="arial, sans-serif">is 512-bit<span class="gmail_default" style=""> vector</span> <span class="gmail_default" style=""></span><span class="gmail_default" style="">unit</span></font><span class="gmail_default" style="font-family:nsimsun,monospace"></span><span class="gmail_default" style="font-family:nsimsun,monospace"></span><br>            weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];<br>            weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];<br>            weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];<br>            weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];<br><br>            weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];<br>            weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];<br>            weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];<br>            weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];<br><br>            FloatVector wv = FloatVector.fromArray(F_SPECIES, weights, 0);<br>            int vecPos = i + (bytePair * 8);<br>            FloatVector xv = FloatVector.fromArray(F_SPECIES, vec, vecPos);<br><br>            accV = wv.mul(vScales).fma(xv, accV);<br>        }<br><br>        i += 32;<br>    }<br><br>    acc = accV.reduceLanes(VectorOperators.ADD);<br><br>    // Handle remaining elements<br>    while (i < n) {<br>        int blockIdx = (offset + i) / 32;<br>        int elemInBlock = (offset + i) % 32;<br>        float scale = scales[blockIdx];<br><br>        int byteIdx = elemInBlock / 2;<br>        int blockStart = blockIdx * 16;<br>        byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE, blockStart + byteIdx);<br><br>        int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) : ((packed >> 4) & 0x0F);<br>        float weight = MXFP4_VALUES[fp4Idx] * scale;<br>        acc += weight * vec[i];<br>        i++;<br>    }<br><br>    return acc;<br>}<br>```<br><br><b>## Questions</b><br><br>1. Is my understanding correct that the current Java Vector API cannot perform dynamic per-byte table lookups? If yes, since `VectorShuffle<font face="arial, sans-serif">` only works with compile-time <span class="gmail_default" style="">matched </span>constant</font> indices (and effectively only within 512-bit / 16-float lane vectors), is there any plan to suppor<span class="gmail_default" style="font-family:nsimsun,monospace"></span>t<span class="gmail_default" style="font-family:nsimsun,monospace"></span><span class="gmail_default" style=""><font face="arial, sans-serif"> </font></span>i<span class="gmail_default" style="font-family:nsimsun,monospace"></span><span class="gmail_default" style="font-family:nsimsun,monospace"></span>n future releases?<br><br>2. As a follow-up, MXFP4 decoding also requires interleaving low and high nibbles — are there any performant ways or patterns in the API to handle this efficiently?<br><br>Any insights from the Project Panama developers would be greatly appreciated!<br><br>Thanks,<br>Xu<br></div>