Question on supporting vectorized per-byte table lookups for MXFP4 dequantization in gpt-oss.java

xu xuzh1002 at gmail.com
Tue Oct 21 11:05:14 UTC 2025


Hi Vladimir,

Thank you for the quick response and the valuable point.

Your suggestion helps me to double down and think deeper on using
ByteVector.rearrange(). Initially, I was trying to use FloatVector for the
lookup table:

```
private static final VectorSpecies<Float> FLOAT_SPECIES =
FloatVector.SPECIES_128;

private static final float[] MXFP4_VALUES = {
        +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
        -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
};

private static final FloatVector MXFP4_TABLE =
FloatVector.fromArray(FLOAT_SPECIES, MXFP4_VALUES, 0);
ByteVector wBytes = ByteVector.fromArray(ByteVector.SPECIES_128, blocks,
blockStart);

ByteVector loNibbles = wBytes.and((byte) 0x0F);
ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR, 4);

VectorShuffle<Float> loShuffle = loNibbles.castShape(FLOAT_SPECIES,
j).toShuffle();
VectorShuffle<Float> hiShuffle = hiNibbles.castShape(FLOAT_SPECIES,
j).toShuffle();

FloatVector loW = MXFP4_TABLE.rearrange(loShuffle);
FloatVector hiW = MXFP4_TABLE.rearrange(hiShuffle);
```

This turned out to be not working as expected.

Thanks to your insight about ByteVector.rearrange(), I’ve had a great idea,
I can multiply and scale the MXFP4_VALUES by 10x, so that I can store them
as ByteVector, and use later. During inference time, I can multiply them by
0.1f back. My new code would be something like below:

```
/**
 * MXFP4_BYTES contains values that are 10x larger than the actual {@link
#MXFP4_VALUES}.
 *
 * <p>Since Java Vector API <code>ByteVector.rearrange()</code> can only
use byte values as lookup indices, but
 * MXFP4 values like 0.5f, 1.5f cannot be directly stored as bytes, so
here, we store 5, 15 (10x scaled) as bytes,
 * then multiply by 0.1 during runtime lookup.
 *
 * <p>In addition, in order to eliminates the need for 0.1 multiplication
during every MXFP4 dot product
 * computation, the 0.1 factor is applied once during model loading rather
than repeatedly during inference.
 * You can check out {@link
ModelBinLoader#loadU8TensorsAsFloatArrays(String, long, int)} to see each
scale is
 * multiplied by 0.1 and store the pre-scaled values in mlp1Scales[] and
mlp2Scales[] arrays.
 */
private static final byte[] MXFP4_BYTES = {
        0, 5, 10, 15, 20, 30, 40, 60,
        0, -5, -10, -15, -20, -30, -40, -60
};
private static final ByteVector MXFP4_BYTE_TABLE =
ByteVector.fromArray(B_SPECIES, MXFP4_BYTES, 0);

ByteVector weights = ByteVector.fromMemorySegment(B_SPECIES,
allExpertWeights, (long) blockIdx * 16, ByteOrder.LITTLE_ENDIAN);

ByteVector loNibbles = weights.and((byte) 0x0F);
ByteVector hiNibbles = weights.lanewise(VectorOperators.LSHR, 4);

ByteVector loWeights = MXFP4_BYTE_TABLE.rearrange(loNibbles.toShuffle());
ByteVector hiWeights = MXFP4_BYTE_TABLE.rearrange(hiNibbles.toShuffle());
```

Since I still find it hard to figure out how to write nibble interleaving
code, I took a different approach by rearranging the input vector to match
the MXFP4 unpacking pattern. Since MXFP4 weights are stored as 4-bit values
packed 2 per byte. After unpacking, by using Java Vector API's
ByteVector.rearrange(), they naturally separate into even/odd indexed
weight vectors:

```
loWeights: w0, w2, w4, w6, ... (even-indexed weights)
hiWeights: w1, w3, w5, w7, ... (odd-indexed weights)
```

So I reshape the input vector, take 128-bit vector size for example, I
organize 4 floats per vector to match the weights layout.
```
Input vector : [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13,
x14, x15, ...]
Reshaped     : [x0, x2, x4, x6, x1, x3, x5, x7, x8, x10, x12, x14, x9, x11,
x13, x15, ...]
                 └─── even ────┘ └─── odd ────┘ └──── even ─────┘ └──── odd
─────┘
```

The change is now working with MXFP4 weights in my gpt-oss.java.
Performance is very close to BF16 but with a slight slowdown (a bit
shock..). I'll continue the work and profile, will consult again and share
my experience of the usage of Vector API in a modern Java-based LLM
inference program later.

Thank you again for pointing me toward a solution!

Thanks,
Xu

Vladimir Ivanov <vladimir.x.ivanov at oracle.com> 于2025年10月21日周二 01:42写道:

> Hi Xu,
>
> I didn't look at your benchmark in details, but I see that
> ByteVector.rearrange() on Byte128Vector is lowered into pshufb [1]. Have
> you tried it?
>
> Best regards,
> Vladimir Ivanov
>
> [1]
> https://github.com/openjdk/jdk/blob/master/src/hotspot/cpu/x86/x86.ad#L8769
>
> On 10/19/25 22:16, xu wrote:
> > Hi all,
> >
> > I’d like to seek some help from Project Panama team regarding efficient
> > implementation of MXFP4 dequantization using the Java Vector API.
> >
> > *## Background*
> >
> > I’ve been developing a pure Java implementation of OpenAI’s gpt-oss
> > inference program, optimized for CPU execution, github: amzn/gpt-
> > oss.java <https://github.com/amzn/gpt-oss.java>. The model uses MXFP4
> > weights, and I’m exploring efficient ways to handle them in the
> > computation-intensive MLP layer. However, I run into limitations with
> > the current Vector API (up to JDK 24) that seem to prevent a fully
> > vectorized implementation.
> >
> > *## MXFP4 layout*
> > In C++, each Mixture of Expert weight byte stores two 4-bit weights
> > (nibbles). These 4-bit values are used as indices into a predefined 16-
> > entry lookup table:
> > ```
> > kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f,
> > -1.5f, -2.f, -3.f, -4.f, -6.f]
> > ```
> >
> > *## llama.cpp implementation*
> >
> > Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation
> > (e.g., quants.c in llama.cpp <https://github.com/ggml-org/llama.cpp/
> > blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/
> > x86/quants.c#L790>) can perform 16 parallel table lookups per
> > instruction — converting nibbles to float values efficiently, take AVX2
> > for example:
> > ```
> > #if defined __AVX2__
> >
> >      const __m128i values128 = _mm_loadu_si128((const
> > __m128i*)kvalues_mxfp4);
> >      const __m128i m4b  = _mm_set1_epi8(0x0f);
> >      const __m256i mone = _mm256_set1_epi16(1);
> >
> >      __m256 accum1 = _mm256_setzero_ps();
> >      __m256 accum2 = _mm256_setzero_ps();
> >      for (; ib + 1 < nb; ib += 2) {
> >          const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib +
> > 0].qs);
> >          const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib +
> > 1].qs);
> >          const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib
> > + 0].qs);
> >          const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib
> > + 1].qs);
> >          const __m256i q4b_1 =
> > MM256_SET_M128I(_mm_shuffle_epi8(values128,
> > _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
> >
> > _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
> >          const __m256i q4b_2 =
> > MM256_SET_M128I(_mm_shuffle_epi8(values128,
> > _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
> >
> > _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
> >          const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
> >          const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
> >          const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
> >          const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
> >          accum1 =
> > _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib +
> > 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
> >                  _mm256_cvtepi32_ps(p_1), accum1);
> >          accum2 =
> > _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib +
> > 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
> >                  _mm256_cvtepi32_ps(p_2), accum2);
> >      }
> >
> >      sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
> > ```
> >
> > *## Problem in Java Vector API*
> >
> > In Java, there doesn’t seem to be an equivalent way to perform dynamic
> > per-byte lookups using `ByteVector`.
> >
> > While `VectorShuffle` supports lane rearrangements, its indices must
> > fall within 0 to 15 which are elements indices of the kvalues_mxfp4, so
> > it only works for rearrangements of 512-bit vectors with exactly 16
> > float lanes. This means that dynamic, per-lane table lookups like
> > `_mm_shuffle_epi8` cannot be expressed directly in Java.
> >
> > The current workaround requires scalar operations:
> > ```
> > for (int i = 0; i < lanes; i++) {
> >      int nibble = byteVector.lane(i) & 0x0F;
> >      weights[i] = MXFP4_VALUES[nibble];  // Scalar operation
> > }
> > ```
> >
> > This prevents MXFP4 dequantization from being fully vectorized — each
> > nibble must be individually decoded and indexed in scalar form.
> >
> > Below is a experimental version for 256-bit vector:
> >
> > ```
> > private static final VectorSpecies<Byte> B_SPECIES =
> ByteVector.SPECIES_128;
> > private static final VectorSpecies<Float> F_SPECIES =
> > FloatVector.SPECIES_256;
> >
> > private static final float[] MXFP4_VALUES = {
> >          +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
> >          -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
> > };
> >
> > /**
> >   * @param vec              input vector length == 2880 (float32)
> >   * @param allExpertWeights MemorySegment storing MXFP4 values to dot
> > product the input vector
> >   * @param offset           Expert weights offset in terms of the total
> > amount of weights, not physical byte offset
> >   * @param n                input dim
> >   * @param scales           scales for the weights
> >   * @return Output vector
> >   */
> >   private static float vectorizedMXFP4DotProduct(final float[] vec,
> >                                                  final MemorySegment
> > allExpertWeights,
> >                                                  final int offset,
> >                                                  final int n,
> >                                                  final float[] scales) {
> >      float acc = 0.0f;
> >      int i = 0;
> >
> >      FloatVector accV = FloatVector.zero(F_SPECIES);
> >      float[] weights = new float[F_SPECIES.length()];
> >
> >      // ---------- The below code is experimental to use MXFP4 precision
> > weights for projection !!! -----------
> >      // Process in blocks of 32 elements (16 bytes = 32 FP4 values)
> >      // Take 256-bit lane sized vector for instance which means we can
> > process 8 floats in instruction-parallel way
> >      // To build a JAR, skip test cases, just run ./gradlew shadowJar
> >      while (i + 32 <= n) {
> >          int blockIdx = (offset + i) / 32;
> >          float scale = scales[blockIdx];
> >          FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);
> >
> >          ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES,
> > allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);
> >          ByteVector loNibbles = wBytes.and((byte) 0x0F);
> >          ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR,
> > 4).and((byte) 0x0F);
> >
> >          for (int bytePair = 0; bytePair < 4; bytePair++) {
> >              int byte1 = bytePair * 4;
> >              int byte2 = bytePair * 4 + 1;
> >              int byte3 = bytePair * 4 + 2;
> >              int byte4 = bytePair * 4 + 3;
> >
> >              //!!! :( this is where we can not do parallel execution
> > unless it is 512-bitvector unit
> >              weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];
> >              weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];
> >              weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];
> >              weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];
> >
> >              weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];
> >              weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];
> >              weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];
> >              weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];
> >
> >              FloatVector wv = FloatVector.fromArray(F_SPECIES, weights,
> 0);
> >              int vecPos = i + (bytePair * 8);
> >              FloatVector xv = FloatVector.fromArray(F_SPECIES, vec,
> vecPos);
> >
> >              accV = wv.mul(vScales).fma(xv, accV);
> >          }
> >
> >          i += 32;
> >      }
> >
> >      acc = accV.reduceLanes(VectorOperators.ADD);
> >
> >      // Handle remaining elements
> >      while (i < n) {
> >          int blockIdx = (offset + i) / 32;
> >          int elemInBlock = (offset + i) % 32;
> >          float scale = scales[blockIdx];
> >
> >          int byteIdx = elemInBlock / 2;
> >          int blockStart = blockIdx * 16;
> >          byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE,
> > blockStart + byteIdx);
> >
> >          int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) :
> > ((packed >> 4) & 0x0F);
> >          float weight = MXFP4_VALUES[fp4Idx] * scale;
> >          acc += weight * vec[i];
> >          i++;
> >      }
> >
> >      return acc;
> > }
> > ```
> >
> > *## Questions*
> >
> > 1. Is my understanding correct that the current Java Vector API cannot
> > perform dynamic per-byte table lookups? If yes, since `VectorShuffle`
> > only works with compile-time matched constant indices (and effectively
> > only within 512-bit / 16-float lane vectors), is there any plan to
> > supportin future releases?
> >
> > 2. As a follow-up, MXFP4 decoding also requires interleaving low and
> > high nibbles — are there any performant ways or patterns in the API to
> > handle this efficiently?
> >
> > Any insights from the Project Panama developers would be greatly
> > appreciated!
> >
> > Thanks,
> > Xu
>
>

-- 

Best regards,

Xu
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mail.openjdk.org/pipermail/panama-dev/attachments/20251021/19de81d8/attachment-0001.htm>


More information about the panama-dev mailing list