Question on supporting vectorized per-byte table lookups for MXFP4 dequantization in gpt-oss.java
Vladimir Ivanov
vladimir.x.ivanov at oracle.com
Mon Oct 20 17:41:28 UTC 2025
Hi Xu,
I didn't look at your benchmark in details, but I see that
ByteVector.rearrange() on Byte128Vector is lowered into pshufb [1]. Have
you tried it?
Best regards,
Vladimir Ivanov
[1]
https://github.com/openjdk/jdk/blob/master/src/hotspot/cpu/x86/x86.ad#L8769
On 10/19/25 22:16, xu wrote:
> Hi all,
>
> I’d like to seek some help from Project Panama team regarding efficient
> implementation of MXFP4 dequantization using the Java Vector API.
>
> *## Background*
>
> I’ve been developing a pure Java implementation of OpenAI’s gpt-oss
> inference program, optimized for CPU execution, github: amzn/gpt-
> oss.java <https://github.com/amzn/gpt-oss.java>. The model uses MXFP4
> weights, and I’m exploring efficient ways to handle them in the
> computation-intensive MLP layer. However, I run into limitations with
> the current Vector API (up to JDK 24) that seem to prevent a fully
> vectorized implementation.
>
> *## MXFP4 layout*
> In C++, each Mixture of Expert weight byte stores two 4-bit weights
> (nibbles). These 4-bit values are used as indices into a predefined 16-
> entry lookup table:
> ```
> kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f,
> -1.5f, -2.f, -3.f, -4.f, -6.f]
> ```
>
> *## llama.cpp implementation*
>
> Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation
> (e.g., quants.c in llama.cpp <https://github.com/ggml-org/llama.cpp/
> blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/
> x86/quants.c#L790>) can perform 16 parallel table lookups per
> instruction — converting nibbles to float values efficiently, take AVX2
> for example:
> ```
> #if defined __AVX2__
>
> const __m128i values128 = _mm_loadu_si128((const
> __m128i*)kvalues_mxfp4);
> const __m128i m4b = _mm_set1_epi8(0x0f);
> const __m256i mone = _mm256_set1_epi16(1);
>
> __m256 accum1 = _mm256_setzero_ps();
> __m256 accum2 = _mm256_setzero_ps();
> for (; ib + 1 < nb; ib += 2) {
> const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib +
> 0].qs);
> const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib +
> 1].qs);
> const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib
> + 0].qs);
> const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib
> + 1].qs);
> const __m256i q4b_1 =
> MM256_SET_M128I(_mm_shuffle_epi8(values128,
> _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
>
> _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
> const __m256i q4b_2 =
> MM256_SET_M128I(_mm_shuffle_epi8(values128,
> _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
>
> _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
> const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
> const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
> const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
> const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
> accum1 =
> _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib +
> 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
> _mm256_cvtepi32_ps(p_1), accum1);
> accum2 =
> _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib +
> 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
> _mm256_cvtepi32_ps(p_2), accum2);
> }
>
> sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
> ```
>
> *## Problem in Java Vector API*
>
> In Java, there doesn’t seem to be an equivalent way to perform dynamic
> per-byte lookups using `ByteVector`.
>
> While `VectorShuffle` supports lane rearrangements, its indices must
> fall within 0 to 15 which are elements indices of the kvalues_mxfp4, so
> it only works for rearrangements of 512-bit vectors with exactly 16
> float lanes. This means that dynamic, per-lane table lookups like
> `_mm_shuffle_epi8` cannot be expressed directly in Java.
>
> The current workaround requires scalar operations:
> ```
> for (int i = 0; i < lanes; i++) {
> int nibble = byteVector.lane(i) & 0x0F;
> weights[i] = MXFP4_VALUES[nibble]; // Scalar operation
> }
> ```
>
> This prevents MXFP4 dequantization from being fully vectorized — each
> nibble must be individually decoded and indexed in scalar form.
>
> Below is a experimental version for 256-bit vector:
>
> ```
> private static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_128;
> private static final VectorSpecies<Float> F_SPECIES =
> FloatVector.SPECIES_256;
>
> private static final float[] MXFP4_VALUES = {
> +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
> -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
> };
>
> /**
> * @param vec input vector length == 2880 (float32)
> * @param allExpertWeights MemorySegment storing MXFP4 values to dot
> product the input vector
> * @param offset Expert weights offset in terms of the total
> amount of weights, not physical byte offset
> * @param n input dim
> * @param scales scales for the weights
> * @return Output vector
> */
> private static float vectorizedMXFP4DotProduct(final float[] vec,
> final MemorySegment
> allExpertWeights,
> final int offset,
> final int n,
> final float[] scales) {
> float acc = 0.0f;
> int i = 0;
>
> FloatVector accV = FloatVector.zero(F_SPECIES);
> float[] weights = new float[F_SPECIES.length()];
>
> // ---------- The below code is experimental to use MXFP4 precision
> weights for projection !!! -----------
> // Process in blocks of 32 elements (16 bytes = 32 FP4 values)
> // Take 256-bit lane sized vector for instance which means we can
> process 8 floats in instruction-parallel way
> // To build a JAR, skip test cases, just run ./gradlew shadowJar
> while (i + 32 <= n) {
> int blockIdx = (offset + i) / 32;
> float scale = scales[blockIdx];
> FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);
>
> ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES,
> allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);
> ByteVector loNibbles = wBytes.and((byte) 0x0F);
> ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR,
> 4).and((byte) 0x0F);
>
> for (int bytePair = 0; bytePair < 4; bytePair++) {
> int byte1 = bytePair * 4;
> int byte2 = bytePair * 4 + 1;
> int byte3 = bytePair * 4 + 2;
> int byte4 = bytePair * 4 + 3;
>
> //!!! :( this is where we can not do parallel execution
> unless it is 512-bitvector unit
> weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];
> weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];
> weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];
> weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];
>
> weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];
> weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];
> weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];
> weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];
>
> FloatVector wv = FloatVector.fromArray(F_SPECIES, weights, 0);
> int vecPos = i + (bytePair * 8);
> FloatVector xv = FloatVector.fromArray(F_SPECIES, vec, vecPos);
>
> accV = wv.mul(vScales).fma(xv, accV);
> }
>
> i += 32;
> }
>
> acc = accV.reduceLanes(VectorOperators.ADD);
>
> // Handle remaining elements
> while (i < n) {
> int blockIdx = (offset + i) / 32;
> int elemInBlock = (offset + i) % 32;
> float scale = scales[blockIdx];
>
> int byteIdx = elemInBlock / 2;
> int blockStart = blockIdx * 16;
> byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE,
> blockStart + byteIdx);
>
> int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) :
> ((packed >> 4) & 0x0F);
> float weight = MXFP4_VALUES[fp4Idx] * scale;
> acc += weight * vec[i];
> i++;
> }
>
> return acc;
> }
> ```
>
> *## Questions*
>
> 1. Is my understanding correct that the current Java Vector API cannot
> perform dynamic per-byte table lookups? If yes, since `VectorShuffle`
> only works with compile-time matched constant indices (and effectively
> only within 512-bit / 16-float lane vectors), is there any plan to
> supportin future releases?
>
> 2. As a follow-up, MXFP4 decoding also requires interleaving low and
> high nibbles — are there any performant ways or patterns in the API to
> handle this efficiently?
>
> Any insights from the Project Panama developers would be greatly
> appreciated!
>
> Thanks,
> Xu
More information about the panama-dev
mailing list