[code-reflection] RFR: Support for Onnx functions

Adam Sotona asotona at openjdk.org
Thu Apr 10 14:40:09 UTC 2025


Support for Onnx functions:
 - Onnx model is represented by `ModuleOp`.
 - Initializers are self-described by `OnnxType.Initializer` type (no more by-passing initializers values from `OnnxTransformer`).
 - `OnnxTransformer` is decomposed into small independent transformations.
- Single-use functions are inlined, multiple-use functions are declared in the model.
- Protobuf model is improved to better reflect Onnx model.

Sample Onnx model:

module ()void -> {
    func @"step" (%0 : tensor<int64>, %1 : tensor<int64>, %2 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, %3 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepEast>, %4 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionNorth>, %5 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepNorth>, %6 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionWest>, %7 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepWest>, %8 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepSouth>)tensor<int64> -> {
        %9 : tensor<bool> = Equal %1 %2;
        %10 : tensor<int64> = If %9
            ()tensor<int64> -> {
                %11 : tensor<int64> = Add %0 %3;
                return %11 @loc="102:17";
            }
            ()tensor<int64> -> {
                %12 : tensor<bool> = Equal %1 %4;
                %13 : tensor<int64> = If %12
                    ()tensor<int64> -> {
                        %14 : tensor<int64> = Add %0 %5;
                        return %14 @loc="104:25";
                    }
                    ()tensor<int64> -> {
                        %15 : tensor<bool> = Equal %1 %6;
                        %16 : tensor<int64> = If %15
                            ()tensor<int64> -> {
                                %17 : tensor<int64> = Add %0 %7;
                                return %17 @loc="106:29";
                            }
                            ()tensor<int64> -> {
                                %18 : tensor<int64> = Add %0 %8;
                                return %18 @loc="107:29";
                            };
                        return %16 @loc="105:25";
                    };
                return %13 @loc="103:17";
            };
        return %10 @loc="101:9";
    };
    func @"turnLeft" (%19 : tensor<int64>, %20 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, %21 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionNorth>, %22 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionWest>, %23 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionSouth>)tensor<int64> -> {
        %24 : tensor<bool> = Equal %19 %20;
        %25 : tensor<int64> = If %24
            ()tensor<int64> -> {
                %26 : tensor<int64> = Identity %21;
                return %26 @loc="80:17";
            }
            ()tensor<int64> -> {
                %27 : tensor<bool> = Equal %19 %21;
                %28 : tensor<int64> = If %27
                    ()tensor<int64> -> {
                        %29 : tensor<int64> = Identity %22;
                        return %29 @loc="82:25";
                    }
                    ()tensor<int64> -> {
                        %30 : tensor<bool> = Equal %19 %22;
                        %31 : tensor<int64> = If %30
                            ()tensor<int64> -> {
                                %32 : tensor<int64> = Identity %23;
                                return %32 @loc="84:29";
                            }
                            ()tensor<int64> -> {
                                %33 : tensor<int64> = Identity %20;
                                return %33 @loc="85:29";
                            };
                        return %31 @loc="83:25";
                    };
                return %28 @loc="81:17";
            };
        return %25 @loc="79:9";
    };
    func @"isWallAt" (%34 : tensor<int64>, %35 : init<tensor<uint8>, oracle.code.onnx.WalkTheMazeTest.maze>, %36 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.oneOne>, %37 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.wall>)tensor<bool> -> {
        %38 : tensor<int64> = Add %34 %36;
        %39 : tensor<uint8> = Slice %35 %34 %38;
        %40 : tensor<int64> = CastLike %39 %37;
        %41 : tensor<bool> = Equal %40 %37;
        return %41 @loc="96:9";
    };
    func @"walkAroundTheMaze" (%42 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, %43 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepEast>, %44 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionNorth>, %45 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepNorth>, %46 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionWest>, %47 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepWest>, %48 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepSouth>, %49 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionSouth>, %50 : init<tensor<uint8>, oracle.code.onnx.WalkTheMazeTest.maze>, %51 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.oneOne>, %52 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.wall>, %53 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.homePos>, %54 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.limit>, %55 : init<tensor<bool>, oracle.code.onnx.WalkT
 heMazeTest._true>, %56 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.three>, %57 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.scalarShape>)tensor<uint8> -> {
        %58 : tensor<uint8> = Cast %42 @to="2";
        %59 : Tuple<init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.homePos>, init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, tensor<uint8>> = tuple %53 %42 %58;
        %60 : Tuple<tensor<int64>, tensor<int64>, tensor<uint8>> = Loop %54 %55 %59 (%61 : tensor<int64>, %62 : tensor<bool>, %63 : Tuple<tensor<int64>, tensor<int64>, tensor<uint8>>)Tuple<tensor<bool>, Tuple<tensor<int64>, tensor<int64>, tensor<uint8>>> -> {
            %64 : tensor<int64> = tuple.load %63 @"0";
            %65 : tensor<int64> = tuple.load %63 @"1";
            %66 : tensor<int64> = func.call %64 %65 %42 %43 %44 %45 %46 %47 %48 @"step";
            %67 : tensor<int64> = tuple.load %63 @"1";
            %68 : tensor<int64> = Loop %56 %55 %67 (%69 : tensor<int64>, %70 : tensor<bool>, %71 : tensor<int64>)Tuple<tensor<bool>, tensor<int64>> -> {
                %72 : tensor<int64> = func.call %71 %42 %44 %46 %49 @"turnLeft";
                %73 : Tuple<tensor<bool>, tensor<int64>> = tuple %70 %72;
                return %73 @loc="90:46";
            };
            %74 : tensor<int64> = func.call %66 %68 %42 %43 %44 %45 %46 %47 %48 @"step";
            %75 : tensor<bool> = func.call %74 %50 %51 %52 @"isWallAt";
            %76 : tensor<bool> = Reshape %75 %57;
            %77 : tensor<int64> = Loop %54 %76 %68 (%78 : tensor<int64>, %79 : tensor<bool>, %80 : tensor<int64>)Tuple<tensor<bool>, tensor<int64>> -> {
                %81 : tensor<int64> = func.call %80 %42 %44 %46 %49 @"turnLeft";
                %82 : tensor<int64> = func.call %66 %81 %42 %43 %44 %45 %46 %47 %48 @"step";
                %83 : tensor<bool> = func.call %82 %50 %51 %52 @"isWallAt";
                %84 : Tuple<tensor<bool>, tensor<int64>> = tuple %83 %81;
                return %84 @loc="120:17";
            };
            %85 : tensor<bool> = Equal %66 %53;
            %86 : tensor<bool> = ReduceMin %85;
            %87 : tensor<bool> = Not %86;
            %88 : tensor<uint8> = tuple.load %63 @"2";
            %89 : tensor<uint8> = Cast %77 @to="2";
            %90 : tensor<uint8> = Concat %88 %89 @axis="0";
            %91 : Tuple<tensor<int64>, tensor<int64>, tensor<uint8>> = tuple %66 %77 %90;
            %92 : Tuple<tensor<bool>, Tuple<tensor<int64>, tensor<int64>, tensor<uint8>>> = tuple %87 %91;
            return %92 @loc="140:13";
        };
        %93 : tensor<uint8> = tuple.load %60 @"2";
        return %93 @loc="148:69";
    };
    unreachable;
};


Protobuf model fragment of the sample Onnx model:

![WalkTheMaze](https://github.com/user-attachments/assets/ca7a8288-bd4e-42d4-967a-fc560a425344)

-------------

Commit messages:
 - OnnxProtoBuilder.Indexer improved to match computed global names by OpWriter
 - OnnxProtoBuilder support for custom domain name and named graphs + cleanup
 - model main func name calculation
 - iniling single-called functions
 - OnnxTransformer cleanup
 - Initializers encoded into OnnxType.Initializer
 - OnnxProtoBuilder support for ModuleOp - work in progress
 - OnnxProtoBuilder support for ModuleOp - work in progress
 - fixed SimpleTest::testConcat to avoid non-tensor in args
 - Onnx model functions into module - work in progress
 - ... and 1 more: https://git.openjdk.org/babylon/compare/bd0569ee...3c68af28

Changes: https://git.openjdk.org/babylon/pull/378/files
  Webrev: https://webrevs.openjdk.org/?repo=babylon&pr=378&range=00
  Stats: 487 lines in 8 files changed: 315 ins; 78 del; 94 mod
  Patch: https://git.openjdk.org/babylon/pull/378.diff
  Fetch: git fetch https://git.openjdk.org/babylon.git pull/378/head:pull/378

PR: https://git.openjdk.org/babylon/pull/378


More information about the babylon-dev mailing list