RFR: 8342103: C2 compiler support for Float16 type and associated operations

Jatin Bhateja jbhateja at openjdk.org
Tue Nov 19 19:57:14 UTC 2024


On Mon, 14 Oct 2024 11:40:01 GMT, Jatin Bhateja <jbhateja at openjdk.org> wrote:

> Hi All,
> 
> This patch adds C2 compiler support for various Float16 operations added by [PR#22128](https://github.com/openjdk/jdk/pull/22128)
> 
> Following is the summary of changes included with this patch:-
> 
> 1. Detection of various Float16 operations through inline expansion or pattern folding idealizations.
> 2. Float16 operations like add, sub, mul, div, max, and min are inferred through pattern folding idealization.
> 3. Float16 SQRT and FMA operation are inferred through inline expansion and their corresponding entry points are defined in the newly added Float16Math class.
>       -    These intrinsics receive unwrapped short arguments encoding IEEE 754 binary16 values.
> 5. New specialized IR nodes for Float16 operations, associated idealizations, and constant folding routines.
> 6. New Ideal type for constant and non-constant Float16 IR nodes. Please refer to [FAQs ](https://github.com/openjdk/jdk/pull/21490#issuecomment-2482867818)for more details.
> 7. Since Float16 uses short as its storage type, hence raw FP16 values are always loaded into general purpose register, but FP16 ISA instructions generally operate over floating point registers, therefore compiler injectes reinterpretation IR before and after Float16 operation nodes to move short value to floating point register and vice versa.
> 8. New idealization routines to optimize redundant reinterpretation chains. HF2S + S2HF = HF
> 6. Auto-vectorization of newly supported scalar operations.
> 7. X86 and AARCH64 backend implementation for all supported intrinsics.
> 9. Functional and Performance validation tests.
> 
> **Missing Pieces:-**
> **-  AARCH64 Backend.**
> 
> Kindly review and share your feedback.
> 
> Best Regards,
> Jatin

Extending on John's thoughts. 
![image](https://github.com/user-attachments/assets/c795e79f-a857-4991-9b8a-c36d8525ba73)

![image](https://github.com/user-attachments/assets/264eeeea-86a0-43ed-a365-88b91e85d9cc)

There are two possibilities of a pattern match here, one rooted at node **A** and other at **B**

With pattern match rooted at **A**,  we will need to inject additional ConvHF2F after replacing AddF with AddHF to preserve the type semantics of IR graph,  [significand bit preservation constraints](https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/lang/Float.java#L1103) for NaN value imposed by Float.float16ToFloat API  makes the idealization toward the end infeasible, thereby reducing the operating vector size for FP16 operation to half of what can be possible, as depicted by following Ideal graph fragment. 

![image](https://github.com/user-attachments/assets/0094e613-2c11-40db-b2bb-84ddf6b251f2)

Thus only feasible match is the one rooted at node **B** 
 
![image](https://github.com/user-attachments/assets/22576617-9533-40e2-94f0-dd6048e295dd)


Please consider Java side implimentation of Float16.sqrt

Float16 sqrt(Float16 radicand) {
        return valueOf(Math.sqrt(radicand.doubleValue()));
}


Here, radicand is first upcasted to doubelValue, following 2P+2 rule of IEEE 754,  square root computed at double precision is not subjected to double rounding penalty when final results is down casted to Float16 value.

Following is  the C2 IR for above Java implementation.


 T0 = Param0 (TypeInt::SHORT)

 T1 = CastHF2F T0 
 T2 = CastF2D   T1
 T3 = SqrtD T2

 T4 = ConvD2F T3
 T5 = CastF2HF T4


To replace SqrtD with SqrtHF,  we need following IR modifications. 


 T0 = Param0 (TypeInt::SHORT)
 // Replacing IR T1-T3  in original fragment with following IR T1-T6.  
 T1 = ReinterpretS2HF T0
 T3 = SqrtHF T1
 T4 = ReinterpretHF2S T3
 T5 = ConvHF2F  T4
 T6 = ConvF2D T5
 
T7 = ConvD2F T6
T5 = CastF2HF T4

  
Simplified IR after applying Identity rules ,  


 T0 = Param0 (TypeInt::SHORT)
 // Replacing IR T1-T3  in original fragment with following IR T1-T6.  
 T1 = ReinterpretS2HF T0
 T3 = SqrtHF T1
 T4 = ReinterpretHF2S T3

  
While above transformation are valid replacements for current intrinsic approach which uses explicit entry points in newly defined Float16Math helper class, they deviate from implementation of several j.l intrinsified methods which could be replaced by pattern matches e.g. 
https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/lang/Math.java#L2022
https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/lang/Math.java#L2116

I think we need to carefully pick pattern match over intrinsification if former handles more general cases.

If our intention is to capture various Float16 operation patterns in user's code which does not directly uses Float16 API then pattern matching looks appealing, but APIs like SQRT and FMA are very carefully drafted keeping in view rounding impact, and such patterns will be hard to find, thus it should be ok to take intrinsic route for them, simpler cases like add / sub / mul /div / max / min can be handled through a pattern matching approach.

There are also some issues around VM symbol creations for intrinsic entries defined in non-java.base modules which did not surface with then Float16 and Float16Math were part of java.base module.

For this PR taking hybrid approach comprising of both pattern match and intensification looks reasonable to me.

Please let me know if you have any comments.

Some FAQs on the newly added ideal type for half-float IR nodes:-

Q. Why do we not use existing TypeInt::SHORT instead of creating a new TypeH type?
A. Newly defined half float type named TypeH is special as its basictype is T_SHORT while its ideal type is  RegF. Thus, the C2 type system views its associated IR node as a 16-bit short value while the register allocator assigns it a floating point register.

Q. Problem with ConF?
A. During Auto-Vectorization, ConF replication constrains the operational vector lane count to half of what can otherwise be used for regular Float16 operation i.e. only 16 floats can be accommodated into a 512-bit vector thereby limiting the lane count of vectors in its use-def chain, one possible way to address it is through a kludge in auto-vectorizer to cast them to a 16 bits constant by analyzing its context. Newly defined Float16 constant nodes 'ConH' are inherently 16-bit encoded IEEE 754 FP16 values and can be efficiently packed to leverage full target vector width.

All Float16 IR nodes now carry newly defined Type::HALF_FLOAT type instead of Type::FLOAT, thus we no longer need special handling in auto-vectorizer to prune their container type to short.

-------------

PR Comment: https://git.openjdk.org/jdk/pull/21490#issuecomment-2425873278
PR Comment: https://git.openjdk.org/jdk/pull/21490#issuecomment-2482867818


More information about the hotspot-compiler-dev mailing list