RFR: 8342103: C2 compiler support for Float16 type and associated operations [v2]

Emanuel Peter epeter at openjdk.org
Tue Nov 26 07:36:46 UTC 2024


On Mon, 25 Nov 2024 19:55:27 GMT, Jatin Bhateja <jbhateja at openjdk.org> wrote:

>> I heard no argument about why you did not split this up. Please do that in the future. It is hard to review well when there is this much code. If it is really necessary, then sure. Here it does not seem necessary to deliver all at once.
>> 
>>> The patch includes IR framework-based scalar constant folding test points.
>> You mention this IR test:
>> https://github.com/openjdk/jdk/pull/21490/files#diff-3f8786f9f62662eda4b4a5c76c01fa04534c94d870d496501bfc20434ad45579R169-R174
>> 
>> Here I only see the use of very trivial values. I think we need more complicated cases.
>> 
>> What about these:
>> - Add/Sub/Mul/Div/Min/Max ... with NaN and infinity.
>> - Same where it would overflow the FP16 range.
>> - Negative zero tests.
>> - Division by powers of 2.
>> 
>> It would for example be nice if you could iterate over all inputs. FP16 with 2 inputs is only 32bits, that can be iterated in just a few seconds. Then you can run the computation with constants in the interpreter, and compare to the results in compiled code.
>
>> I heard no argument about why you did not split this up. Please do that in the future. It is hard to review well when there is this much code. If it is really necessary, then sure. Here it does not seem necessary to deliver all at once.
>> 
>> > The patch includes IR framework-based scalar constant folding test points.
>> > You mention this IR test:
>> > https://github.com/openjdk/jdk/pull/21490/files#diff-3f8786f9f62662eda4b4a5c76c01fa04534c94d870d496501bfc20434ad45579R169-R174
>> 
>> Here I only see the use of very trivial values. I think we need more complicated cases.
>> 
>> What about these:
>> 
>> * Add/Sub/Mul/Div/Min/Max ... with NaN and infinity.
>> * Same where it would overflow the FP16 range.
>> * Negative zero tests.
>> * Division by powers of 2.
>> 
>> It would for example be nice if you could iterate over all inputs. FP16 with 2 inputs is only 32bits, that can be iterated in just a few seconds. Then you can run the computation with constants in the interpreter, and compare to the results in compiled code.
> 
> [ScalarFloat16OperationsTest.java](https://github.com/openjdk/jdk/pull/21490/files#diff-6afb7e66ce0fcdac61df60af0231010b20cf16489ec7e4d5b0b41852db8796a0) 
> Adds has a specialized data provider that generates test vectors with special values, our functional validation is covering the entire Float16 value range.

@jatin-bhateja 

> [ScalarFloat16OperationsTest.java](https://github.com/openjdk/jdk/pull/21490/files#diff-6afb7e66ce0fcdac61df60af0231010b20cf16489ec7e4d5b0b41852db8796a0)
Adds has a specialized data provider that generates test vectors with special values, our functional validation is covering the entire Float16 value range.

Maybe I'm not making myself clear here. The test vectors will never constant fold - the values you read from an array load will always be the full range of their type, and not a constant. And you added constant folding IGVN optimizations.

So we should test both:
- Compile-time variables: for this you can use array element loads. You have to generate the values randomly beforehand, spanning the whole Float16 value range. This I think is covered somewhat adequately.
- Compile-time constants: for this you cannot use array element loads - they will not be constants. You have to use literals, or you can set `static final int val = RANDOM.nextInt();`, which will constant fold during compilation, or you can use `MethodHandles.constant(int.class, 1);` to get compile-time constants, that you can change and trigger recompilation with the new "constant".

It starts with something as simple as your constant folding of addition:

// Supplied function returns the sum of the inputs.
// This also type-checks the inputs for sanity.  Guaranteed never to
// be passed a TOP or BOTTOM type, these are filtered out by pre-check.
const Type* AddHFNode::add_ring(const Type* t0, const Type* t1) const {
  if (!t0->isa_half_float_constant() || !t1->isa_half_float_constant()) {
    return bottom_type();
  }
  return TypeH::make(t0->getf() + t1->getf());
}


Which uses this code:

const TypeH *TypeH::make(float f) {
  assert( StubRoutines::f2hf_adr() != nullptr, "");
  short hf = StubRoutines::f2hf(f);
  return (TypeH*)(new TypeH(hf))->hashcons();
}


You are doing the addition in `float`, and then casting back to `half_float`. Probably correct. But does it do the rounding correctly? Does it deal with `infty` and `NaN` correctly? Probably, but I would like to see tests for that.

This is the simple stuff. Then there are more complex cases.

const Type* MinHFNode::add_ring(const Type* t0, const Type* t1) const {
  const TypeH* r0 = t0->isa_half_float_constant();
  const TypeH* r1 = t1->isa_half_float_constant();
  if (r0 == nullptr || r1 == nullptr) {
    return bottom_type();
  }

  if (r0->is_nan()) {
    return r0;
  }
  if (r1->is_nan()) {
    return r1;
  }

  float f0 = r0->getf();
  float f1 = r1->getf();
  if (f0 != 0.0f || f1 != 0.0f) {
    return f0 < f1 ? r0 : r1;
  }

  // As per IEEE 754 specification, floating point comparison consider +ve and -ve
  // zeros as equals. Thus, performing signed integral comparison for max value
  // detection.
  return (jint_cast(f0) < jint_cast(f1)) ? r0 : r1;
}

Is this adequately tested over the whole range of inputs? Of course the inputs have to be **constant**, otherwise if you only do array loads, the values are obviously variable, i.e. they would fail at the `isa_half_float_constant` check.

You do have some constant folding tests like this:

    @Test
    @IR(counts = {IRNode.MIN_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
        applyIfCPUFeature = {"avx512_fp16", "true"})
    public void testMinConstantFolding() {
        assertResult(min(valueOf(1.0f), valueOf(2.0f)).floatValue(), 1.0f, "testMinConstantFolding");
        assertResult(min(valueOf(0.0f), valueOf(-0.0f)).floatValue(), -0.0f, "testMinConstantFolding");
    }

But this is **only 2 examples for min**. It does not cover all cases by a long shot. It covers 2 "nice" cases.

I do not think that is sufficient. Often the bugs are hiding in special cases.

Testing is really important to me. I've made the experience myself where I did not test optimizations well and later it can turn into a bug.

Comments like these do not give me much confidence:
> functional validation is covering the entire Float16 value range.

Then I review the tests, and see: not all cases are covered. Now what am I supposed to do as a reviewer? It does not make me trust what you say in the future. Maybe this is all a misunderstanding - if so I hope my lengthy explanation clarifies what I mean.

What do you think @Bhavana-Kilambi @PaulSandoz ?

-------------

PR Comment: https://git.openjdk.org/jdk/pull/21490#issuecomment-2499876085


More information about the hotspot-compiler-dev mailing list