Question on supporting vectorized per-byte table lookups for MXFP4 dequantization in gpt-oss.java

xu xuzh1002 at gmail.com
Mon Oct 20 05:16:41 UTC 2025


Hi all,

I’d like to seek some help from Project Panama team regarding efficient
implementation of MXFP4 dequantization using the Java Vector API.

*## Background*

I’ve been developing a pure Java implementation of OpenAI’s gpt-oss
inference program, optimized for CPU execution, github: amzn/gpt-oss.java
<https://github.com/amzn/gpt-oss.java>. The model uses MXFP4 weights, and
I’m exploring efficient ways to handle them in the computation-intensive
MLP layer. However, I run into limitations with the current Vector API (up
to JDK 24) that seem to prevent a fully vectorized implementation.

*## MXFP4 layout*
In C++, each Mixture of Expert weight byte stores two 4-bit weights
(nibbles). These 4-bit values are used as indices into a predefined
16-entry lookup table:
```
kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f,
-1.5f, -2.f, -3.f, -4.f, -6.f]
```

*## llama.cpp implementation*

Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation
(e.g., quants.c in llama.cpp
<https://github.com/ggml-org/llama.cpp/blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/x86/quants.c#L790>)
can perform 16 parallel table lookups per instruction — converting nibbles
to float values efficiently, take AVX2 for example:
```
#if defined __AVX2__

    const __m128i values128 = _mm_loadu_si128((const
__m128i*)kvalues_mxfp4);
    const __m128i m4b  = _mm_set1_epi8(0x0f);
    const __m256i mone = _mm256_set1_epi16(1);

    __m256 accum1 = _mm256_setzero_ps();
    __m256 accum2 = _mm256_setzero_ps();
    for (; ib + 1 < nb; ib += 2) {
        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib +
0].qs);
        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib +
1].qs);
        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib +
0].qs);
        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib +
1].qs);
        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128,
_mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
                                              _mm_shuffle_epi8(values128,
_mm_and_si128(q4bits_1, m4b)));
        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128,
_mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
                                              _mm_shuffle_epi8(values128,
_mm_and_si128(q4bits_2, m4b)));
        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib
+ 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
                _mm256_cvtepi32_ps(p_1), accum1);
        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib
+ 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
                _mm256_cvtepi32_ps(p_2), accum2);
    }

    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
```

*## Problem in Java Vector API*

In Java, there doesn’t seem to be an equivalent way to perform dynamic
per-byte lookups using `ByteVector`.

While `VectorShuffle` supports lane rearrangements, its indices must fall
within 0 to 15 which are elements indices of the kvalues_mxfp4, so it only
works for rearrangements of 512-bit vectors with exactly 16 float lanes.
This means that dynamic, per-lane table lookups like `_mm_shuffle_epi8`
cannot be expressed directly in Java.

The current workaround requires scalar operations:
```
for (int i = 0; i < lanes; i++) {
    int nibble = byteVector.lane(i) & 0x0F;
    weights[i] = MXFP4_VALUES[nibble];  // Scalar operation
}
```

This prevents MXFP4 dequantization from being fully vectorized — each
nibble must be individually decoded and indexed in scalar form.

Below is a experimental version for 256-bit vector:

```
private static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_128;
private static final VectorSpecies<Float> F_SPECIES =
FloatVector.SPECIES_256;

private static final float[] MXFP4_VALUES = {
        +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
        -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
};

/**
 * @param vec              input vector length == 2880 (float32)
 * @param allExpertWeights MemorySegment storing MXFP4 values to dot
product the input vector
 * @param offset           Expert weights offset in terms of the total
amount of weights, not physical byte offset
 * @param n                input dim
 * @param scales           scales for the weights
 * @return Output vector
 */
 private static float vectorizedMXFP4DotProduct(final float[] vec,
                                                final MemorySegment
allExpertWeights,
                                                final int offset,
                                                final int n,
                                                final float[] scales) {
    float acc = 0.0f;
    int i = 0;

    FloatVector accV = FloatVector.zero(F_SPECIES);
    float[] weights = new float[F_SPECIES.length()];

    // ---------- The below code is experimental to use MXFP4 precision
weights for projection !!! -----------
    // Process in blocks of 32 elements (16 bytes = 32 FP4 values)
    // Take 256-bit lane sized vector for instance which means we can
process 8 floats in instruction-parallel way
    // To build a JAR, skip test cases, just run ./gradlew shadowJar
    while (i + 32 <= n) {
        int blockIdx = (offset + i) / 32;
        float scale = scales[blockIdx];
        FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);

        ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES,
allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);
        ByteVector loNibbles = wBytes.and((byte) 0x0F);
        ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR,
4).and((byte) 0x0F);

        for (int bytePair = 0; bytePair < 4; bytePair++) {
            int byte1 = bytePair * 4;
            int byte2 = bytePair * 4 + 1;
            int byte3 = bytePair * 4 + 2;
            int byte4 = bytePair * 4 + 3;

            //!!! :( this is where we can not do parallel execution unless
it is 512-bit vector unit
            weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];
            weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];
            weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];
            weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];

            weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];
            weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];
            weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];
            weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];

            FloatVector wv = FloatVector.fromArray(F_SPECIES, weights, 0);
            int vecPos = i + (bytePair * 8);
            FloatVector xv = FloatVector.fromArray(F_SPECIES, vec, vecPos);

            accV = wv.mul(vScales).fma(xv, accV);
        }

        i += 32;
    }

    acc = accV.reduceLanes(VectorOperators.ADD);

    // Handle remaining elements
    while (i < n) {
        int blockIdx = (offset + i) / 32;
        int elemInBlock = (offset + i) % 32;
        float scale = scales[blockIdx];

        int byteIdx = elemInBlock / 2;
        int blockStart = blockIdx * 16;
        byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE,
blockStart + byteIdx);

        int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) : ((packed >>
4) & 0x0F);
        float weight = MXFP4_VALUES[fp4Idx] * scale;
        acc += weight * vec[i];
        i++;
    }

    return acc;
}
```

*## Questions*

1. Is my understanding correct that the current Java Vector API cannot
perform dynamic per-byte table lookups? If yes, since `VectorShuffle` only
works with compile-time matched constant indices (and effectively only
within 512-bit / 16-float lane vectors), is there any plan to support in
future releases?

2. As a follow-up, MXFP4 decoding also requires interleaving low and high
nibbles — are there any performant ways or patterns in the API to handle
this efficiently?

Any insights from the Project Panama developers would be greatly
appreciated!

Thanks,
Xu
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mail.openjdk.org/pipermail/panama-dev/attachments/20251020/2754c29b/attachment-0001.htm>


More information about the panama-dev mailing list