[code-reflection] RFR: Integrate Java Triton example with Intel Triton Backend

Paul Sandoz psandoz at openjdk.org
Tue Nov 5 22:51:45 UTC 2024


On Wed, 25 Sep 2024 15:32:53 GMT, hanklo6 <duke at openjdk.org> wrote:

> Babylon Java Triton example translates Java source code with Java Triton API into code model by code reflection.
>  
> In this PR, we traverse the given code model and output Triton MLIR dialect in the generic form, and then inject generated MLIR dialect into the Intel Triton backend. We then utilize Intel Triton backend to compile the Triton MLIR dialect into a SPIR-V module. Use `Jextract` to create Java binding of Intel Level Zero runtime and launch the given kernel function with it on Intel GPUs.
> 
> ## Usage
> Navigate to the `cr-example/triton` directory and execute `mvn clean test`. This will generate multiple MLIR files in the `result` directory ready to be processed by the Triton backend. 
> 
> Next, modify the `compiler.py` file within the `intel-xpu-triton-backend` project by applying the patch `git apply add-mlir-insertion.patch`. Then run the Triton backend by running `python3 translate.py`.
> 
> The Triton backend will generate SPIR-V files, which will be located under `~/.triton/cache/{hash_value}/{kernel_name}/{kernel_name}.spv`.
> 
> To create a binding for Level Zero, execute the below commands:
> 
> $JEXTRACT_DIR/bin/jextract --output src/gen/java -I /usr/include -t oneapi.levelzero level-zero/include/ze_api.h
> $JAVA_HOME/bin/javac -cp target/classes -d target/classes src/gen/java/oneapi/levelzero/*.java
> $JAVA_HOME/bin/jar cf levelzero.jar -C target/classes/ .
> 
> The will generate `levelzero.jar` in the current directory.
> 
> After getting JAR files for Level Zero and `JSON-java`, proceed to compile and run the launcher `LevelZero.java` with the following commands:
> 
> babylon/build/linux-x86_64-server-release/jdk/bin/javac -cp .:levelzero.jar:json-java.jar LevelZero.java
> babylon/build/linux-x86_64-server-release/jdk/bin/java -ea -cp .:levelzero.jar:json-java.jar LevelZero
> 
> 
> Ensure the hash values in`~/.triton/cache` match those used in the `LevelZero.java`.
> 
> ## Dependencies
> - [intel-xpu-backend-for-triton](https://github.com/intel/intel-xpu-backend-for-triton)
> - [intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch)
> - [Intel oneAPI base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html)
> - [Jextract](https://github.com/openjdk/jextract)
> - [Level Zero loader](https://github.com/oneapi-src/level-zero)
> - [compute-runtime](https://github.com/intel/compute-runtime/releases)
> - [JSON-java](https://github.com/stleary/JSON-java)

Thank you for this effort. It's not easy to hook this up and you did it without any help!

Initially i have just focused on the code in the triton package.

cr-examples/triton/src/main/java/oracle/code/triton/ArithMathOps.java line 398:

> 396:     public static class CompareOp extends ArithMathOp implements Op.Pure {
> 397:         public static final String NAME = "arith.cmp";
> 398:         public static final String ATTRIBUTE_CONSTANT_VALUE = "predicate";

Suggestion:

        public static final String ATTRIBUTE_PREDICATE = "predicate";

cr-examples/triton/src/main/java/oracle/code/triton/ArithMathOps.java line 401:

> 399: 
> 400:         public enum CompareKind {
> 401:             eq,

Can you link to https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate ? and also state in the comment the enum ordinal corresponds to the MLIR symbol's value. Further comment that we would need to refine when considering comparisons of floating point numbers which is in a different namespace.

cr-examples/triton/src/main/java/oracle/code/triton/TritonTransformer.java line 1054:

> 1052:             a = block.context().getValue(a);
> 1053:             b = block.context().getValue(b);
> 1054:             Object zero;

Add a comment such as

// Computed result is tensor of floats, regardless of inputs

like when we compute the type. Since the result is hard coded we don't need to use reflection and can directly use the constant expression "0.0".

cr-examples/triton/src/test/java/oracle/code/triton/TestMatrix.java line 264:

> 262:                     LessThan);
> 263:             var b = load(b_ptrs, broadcast(offs_k_m_3, b_ptrs.type()));
> 264:             // We accumulate along the K dimension.

Were these necessary changes because you encountered a bug?

cr-examples/triton/src/test/java/oracle/code/triton/TestMatrix.java line 471:

> 469:             int stride_am, @Constant int stride_ak,
> 470:             int stride_bk, @Constant int stride_bn,
> 471:             int stride_cm, @Constant int stride_cn,

Why the marking as constants?

-------------

PR Review: https://git.openjdk.org/babylon/pull/241#pullrequestreview-2416747104
PR Review Comment: https://git.openjdk.org/babylon/pull/241#discussion_r1830036689
PR Review Comment: https://git.openjdk.org/babylon/pull/241#discussion_r1830035724
PR Review Comment: https://git.openjdk.org/babylon/pull/241#discussion_r1830100308
PR Review Comment: https://git.openjdk.org/babylon/pull/241#discussion_r1830086221
PR Review Comment: https://git.openjdk.org/babylon/pull/241#discussion_r1830095437


More information about the babylon-dev mailing list