Question on supporting vectorized per-byte table lookups for MXFP4 dequantization in gpt-oss.java
xu
xuzh1002 at gmail.com
Mon Oct 20 05:16:41 UTC 2025
Hi all,
I’d like to seek some help from Project Panama team regarding efficient
implementation of MXFP4 dequantization using the Java Vector API.
*## Background*
I’ve been developing a pure Java implementation of OpenAI’s gpt-oss
inference program, optimized for CPU execution, github: amzn/gpt-oss.java
<https://github.com/amzn/gpt-oss.java>. The model uses MXFP4 weights, and
I’m exploring efficient ways to handle them in the computation-intensive
MLP layer. However, I run into limitations with the current Vector API (up
to JDK 24) that seem to prevent a fully vectorized implementation.
*## MXFP4 layout*
In C++, each Mixture of Expert weight byte stores two 4-bit weights
(nibbles). These 4-bit values are used as indices into a predefined
16-entry lookup table:
```
kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f,
-1.5f, -2.f, -3.f, -4.f, -6.f]
```
*## llama.cpp implementation*
Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation
(e.g., quants.c in llama.cpp
<https://github.com/ggml-org/llama.cpp/blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/x86/quants.c#L790>)
can perform 16 parallel table lookups per instruction — converting nibbles
to float values efficiently, take AVX2 for example:
```
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const
__m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib +
0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib +
1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib +
0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib +
1].qs);
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128,
_mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128,
_mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128,
_mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128,
_mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib
+ 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
_mm256_cvtepi32_ps(p_1), accum1);
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib
+ 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
_mm256_cvtepi32_ps(p_2), accum2);
}
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
```
*## Problem in Java Vector API*
In Java, there doesn’t seem to be an equivalent way to perform dynamic
per-byte lookups using `ByteVector`.
While `VectorShuffle` supports lane rearrangements, its indices must fall
within 0 to 15 which are elements indices of the kvalues_mxfp4, so it only
works for rearrangements of 512-bit vectors with exactly 16 float lanes.
This means that dynamic, per-lane table lookups like `_mm_shuffle_epi8`
cannot be expressed directly in Java.
The current workaround requires scalar operations:
```
for (int i = 0; i < lanes; i++) {
int nibble = byteVector.lane(i) & 0x0F;
weights[i] = MXFP4_VALUES[nibble]; // Scalar operation
}
```
This prevents MXFP4 dequantization from being fully vectorized — each
nibble must be individually decoded and indexed in scalar form.
Below is a experimental version for 256-bit vector:
```
private static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_128;
private static final VectorSpecies<Float> F_SPECIES =
FloatVector.SPECIES_256;
private static final float[] MXFP4_VALUES = {
+0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
};
/**
* @param vec input vector length == 2880 (float32)
* @param allExpertWeights MemorySegment storing MXFP4 values to dot
product the input vector
* @param offset Expert weights offset in terms of the total
amount of weights, not physical byte offset
* @param n input dim
* @param scales scales for the weights
* @return Output vector
*/
private static float vectorizedMXFP4DotProduct(final float[] vec,
final MemorySegment
allExpertWeights,
final int offset,
final int n,
final float[] scales) {
float acc = 0.0f;
int i = 0;
FloatVector accV = FloatVector.zero(F_SPECIES);
float[] weights = new float[F_SPECIES.length()];
// ---------- The below code is experimental to use MXFP4 precision
weights for projection !!! -----------
// Process in blocks of 32 elements (16 bytes = 32 FP4 values)
// Take 256-bit lane sized vector for instance which means we can
process 8 floats in instruction-parallel way
// To build a JAR, skip test cases, just run ./gradlew shadowJar
while (i + 32 <= n) {
int blockIdx = (offset + i) / 32;
float scale = scales[blockIdx];
FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);
ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES,
allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);
ByteVector loNibbles = wBytes.and((byte) 0x0F);
ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR,
4).and((byte) 0x0F);
for (int bytePair = 0; bytePair < 4; bytePair++) {
int byte1 = bytePair * 4;
int byte2 = bytePair * 4 + 1;
int byte3 = bytePair * 4 + 2;
int byte4 = bytePair * 4 + 3;
//!!! :( this is where we can not do parallel execution unless
it is 512-bit vector unit
weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];
weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];
weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];
weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];
weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];
weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];
weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];
weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];
FloatVector wv = FloatVector.fromArray(F_SPECIES, weights, 0);
int vecPos = i + (bytePair * 8);
FloatVector xv = FloatVector.fromArray(F_SPECIES, vec, vecPos);
accV = wv.mul(vScales).fma(xv, accV);
}
i += 32;
}
acc = accV.reduceLanes(VectorOperators.ADD);
// Handle remaining elements
while (i < n) {
int blockIdx = (offset + i) / 32;
int elemInBlock = (offset + i) % 32;
float scale = scales[blockIdx];
int byteIdx = elemInBlock / 2;
int blockStart = blockIdx * 16;
byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE,
blockStart + byteIdx);
int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) : ((packed >>
4) & 0x0F);
float weight = MXFP4_VALUES[fp4Idx] * scale;
acc += weight * vec[i];
i++;
}
return acc;
}
```
*## Questions*
1. Is my understanding correct that the current Java Vector API cannot
perform dynamic per-byte table lookups? If yes, since `VectorShuffle` only
works with compile-time matched constant indices (and effectively only
within 512-bit / 16-float lane vectors), is there any plan to support in
future releases?
2. As a follow-up, MXFP4 decoding also requires interleaving low and high
nibbles — are there any performant ways or patterns in the API to handle
this efficiently?
Any insights from the Project Panama developers would be greatly
appreciated!
Thanks,
Xu
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mail.openjdk.org/pipermail/panama-dev/attachments/20251020/2754c29b/attachment-0001.htm>
More information about the panama-dev
mailing list