Question on supporting vectorized per-byte table lookups for MXFP4 dequantization in gpt-oss.java

Vladimir Ivanov vladimir.x.ivanov at oracle.com
Mon Oct 20 17:41:28 UTC 2025


Hi Xu,

I didn't look at your benchmark in details, but I see that 
ByteVector.rearrange() on Byte128Vector is lowered into pshufb [1]. Have 
you tried it?

Best regards,
Vladimir Ivanov

[1] 
https://github.com/openjdk/jdk/blob/master/src/hotspot/cpu/x86/x86.ad#L8769

On 10/19/25 22:16, xu wrote:
> Hi all,
> 
> I’d like to seek some help from Project Panama team regarding efficient 
> implementation of MXFP4 dequantization using the Java Vector API.
> 
> *## Background*
> 
> I’ve been developing a pure Java implementation of OpenAI’s gpt-oss 
> inference program, optimized for CPU execution, github: amzn/gpt- 
> oss.java <https://github.com/amzn/gpt-oss.java>. The model uses MXFP4 
> weights, and I’m exploring efficient ways to handle them in the 
> computation-intensive MLP layer. However, I run into limitations with 
> the current Vector API (up to JDK 24) that seem to prevent a fully 
> vectorized implementation.
> 
> *## MXFP4 layout*
> In C++, each Mixture of Expert weight byte stores two 4-bit weights 
> (nibbles). These 4-bit values are used as indices into a predefined 16- 
> entry lookup table:
> ```
> kvalues_mxfp4 = [0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, 
> -1.5f, -2.f, -3.f, -4.f, -6.f]
> ```
> 
> *## llama.cpp implementation*
> 
> Using SIMD intrinsics such as `_mm_shuffle_epi8`, the C++ implementation 
> (e.g., quants.c in llama.cpp <https://github.com/ggml-org/llama.cpp/ 
> blob/12bbc3fa50b6df03318a4451c9a2210200a0b28d/ggml/src/ggml-cpu/arch/ 
> x86/quants.c#L790>) can perform 16 parallel table lookups per 
> instruction — converting nibbles to float values efficiently, take AVX2 
> for example:
> ```
> #if defined __AVX2__
> 
>      const __m128i values128 = _mm_loadu_si128((const 
> __m128i*)kvalues_mxfp4);
>      const __m128i m4b  = _mm_set1_epi8(0x0f);
>      const __m256i mone = _mm256_set1_epi16(1);
> 
>      __m256 accum1 = _mm256_setzero_ps();
>      __m256 accum2 = _mm256_setzero_ps();
>      for (; ib + 1 < nb; ib += 2) {
>          const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 
> 0].qs);
>          const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 
> 1].qs);
>          const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib 
> + 0].qs);
>          const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib 
> + 1].qs);
>          const __m256i q4b_1 = 
> MM256_SET_M128I(_mm_shuffle_epi8(values128, 
> _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
>                                                
> _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
>          const __m256i q4b_2 = 
> MM256_SET_M128I(_mm_shuffle_epi8(values128, 
> _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
>                                                
> _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
>          const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
>          const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
>          const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
>          const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
>          accum1 = 
> _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 
> 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
>                  _mm256_cvtepi32_ps(p_1), accum1);
>          accum2 = 
> _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 
> 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
>                  _mm256_cvtepi32_ps(p_2), accum2);
>      }
> 
>      sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
> ```
> 
> *## Problem in Java Vector API*
> 
> In Java, there doesn’t seem to be an equivalent way to perform dynamic 
> per-byte lookups using `ByteVector`.
> 
> While `VectorShuffle` supports lane rearrangements, its indices must 
> fall within 0 to 15 which are elements indices of the kvalues_mxfp4, so 
> it only works for rearrangements of 512-bit vectors with exactly 16 
> float lanes. This means that dynamic, per-lane table lookups like 
> `_mm_shuffle_epi8` cannot be expressed directly in Java.
> 
> The current workaround requires scalar operations:
> ```
> for (int i = 0; i < lanes; i++) {
>      int nibble = byteVector.lane(i) & 0x0F;
>      weights[i] = MXFP4_VALUES[nibble];  // Scalar operation
> }
> ```
> 
> This prevents MXFP4 dequantization from being fully vectorized — each 
> nibble must be individually decoded and indexed in scalar form.
> 
> Below is a experimental version for 256-bit vector:
> 
> ```
> private static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_128;
> private static final VectorSpecies<Float> F_SPECIES = 
> FloatVector.SPECIES_256;
> 
> private static final float[] MXFP4_VALUES = {
>          +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
>          -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
> };
> 
> /**
>   * @param vec              input vector length == 2880 (float32)
>   * @param allExpertWeights MemorySegment storing MXFP4 values to dot 
> product the input vector
>   * @param offset           Expert weights offset in terms of the total 
> amount of weights, not physical byte offset
>   * @param n                input dim
>   * @param scales           scales for the weights
>   * @return Output vector
>   */
>   private static float vectorizedMXFP4DotProduct(final float[] vec,
>                                                  final MemorySegment 
> allExpertWeights,
>                                                  final int offset,
>                                                  final int n,
>                                                  final float[] scales) {
>      float acc = 0.0f;
>      int i = 0;
> 
>      FloatVector accV = FloatVector.zero(F_SPECIES);
>      float[] weights = new float[F_SPECIES.length()];
> 
>      // ---------- The below code is experimental to use MXFP4 precision 
> weights for projection !!! -----------
>      // Process in blocks of 32 elements (16 bytes = 32 FP4 values)
>      // Take 256-bit lane sized vector for instance which means we can 
> process 8 floats in instruction-parallel way
>      // To build a JAR, skip test cases, just run ./gradlew shadowJar
>      while (i + 32 <= n) {
>          int blockIdx = (offset + i) / 32;
>          float scale = scales[blockIdx];
>          FloatVector vScales = FloatVector.broadcast(F_SPECIES, scale);
> 
>          ByteVector wBytes = ByteVector.fromMemorySegment(B_SPECIES, 
> allExpertWeights, blockIdx * 16, ByteOrder.LITTLE_ENDIAN);
>          ByteVector loNibbles = wBytes.and((byte) 0x0F);
>          ByteVector hiNibbles = wBytes.lanewise(VectorOperators.LSHR, 
> 4).and((byte) 0x0F);
> 
>          for (int bytePair = 0; bytePair < 4; bytePair++) {
>              int byte1 = bytePair * 4;
>              int byte2 = bytePair * 4 + 1;
>              int byte3 = bytePair * 4 + 2;
>              int byte4 = bytePair * 4 + 3;
> 
>              //!!! :( this is where we can not do parallel execution 
> unless it is 512-bitvector unit
>              weights[0] = MXFP4_VALUES[loNibbles.lane(byte1) & 0x0F];
>              weights[1] = MXFP4_VALUES[hiNibbles.lane(byte1) & 0x0F];
>              weights[2] = MXFP4_VALUES[loNibbles.lane(byte2) & 0x0F];
>              weights[3] = MXFP4_VALUES[hiNibbles.lane(byte2) & 0x0F];
> 
>              weights[4] = MXFP4_VALUES[loNibbles.lane(byte3) & 0x0F];
>              weights[5] = MXFP4_VALUES[hiNibbles.lane(byte3) & 0x0F];
>              weights[6] = MXFP4_VALUES[loNibbles.lane(byte4) & 0x0F];
>              weights[7] = MXFP4_VALUES[hiNibbles.lane(byte4) & 0x0F];
> 
>              FloatVector wv = FloatVector.fromArray(F_SPECIES, weights, 0);
>              int vecPos = i + (bytePair * 8);
>              FloatVector xv = FloatVector.fromArray(F_SPECIES, vec, vecPos);
> 
>              accV = wv.mul(vScales).fma(xv, accV);
>          }
> 
>          i += 32;
>      }
> 
>      acc = accV.reduceLanes(VectorOperators.ADD);
> 
>      // Handle remaining elements
>      while (i < n) {
>          int blockIdx = (offset + i) / 32;
>          int elemInBlock = (offset + i) % 32;
>          float scale = scales[blockIdx];
> 
>          int byteIdx = elemInBlock / 2;
>          int blockStart = blockIdx * 16;
>          byte packed = allExpertWeights.get(ValueLayout.JAVA_BYTE, 
> blockStart + byteIdx);
> 
>          int fp4Idx = (elemInBlock % 2 == 0) ? (packed & 0x0F) : 
> ((packed >> 4) & 0x0F);
>          float weight = MXFP4_VALUES[fp4Idx] * scale;
>          acc += weight * vec[i];
>          i++;
>      }
> 
>      return acc;
> }
> ```
> 
> *## Questions*
> 
> 1. Is my understanding correct that the current Java Vector API cannot 
> perform dynamic per-byte table lookups? If yes, since `VectorShuffle` 
> only works with compile-time matched constant indices (and effectively 
> only within 512-bit / 16-float lane vectors), is there any plan to 
> supportin future releases?
> 
> 2. As a follow-up, MXFP4 decoding also requires interleaving low and 
> high nibbles — are there any performant ways or patterns in the API to 
> handle this efficiently?
> 
> Any insights from the Project Panama developers would be greatly 
> appreciated!
> 
> Thanks,
> Xu



More information about the panama-dev mailing list