[code-reflection] RFR: Support for Onnx functions
Adam Sotona
asotona at openjdk.org
Thu Apr 10 14:40:09 UTC 2025
Support for Onnx functions:
- Onnx model is represented by `ModuleOp`.
- Initializers are self-described by `OnnxType.Initializer` type (no more by-passing initializers values from `OnnxTransformer`).
- `OnnxTransformer` is decomposed into small independent transformations.
- Single-use functions are inlined, multiple-use functions are declared in the model.
- Protobuf model is improved to better reflect Onnx model.
Sample Onnx model:
module ()void -> {
func @"step" (%0 : tensor<int64>, %1 : tensor<int64>, %2 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, %3 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepEast>, %4 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionNorth>, %5 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepNorth>, %6 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionWest>, %7 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepWest>, %8 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepSouth>)tensor<int64> -> {
%9 : tensor<bool> = Equal %1 %2;
%10 : tensor<int64> = If %9
()tensor<int64> -> {
%11 : tensor<int64> = Add %0 %3;
return %11 @loc="102:17";
}
()tensor<int64> -> {
%12 : tensor<bool> = Equal %1 %4;
%13 : tensor<int64> = If %12
()tensor<int64> -> {
%14 : tensor<int64> = Add %0 %5;
return %14 @loc="104:25";
}
()tensor<int64> -> {
%15 : tensor<bool> = Equal %1 %6;
%16 : tensor<int64> = If %15
()tensor<int64> -> {
%17 : tensor<int64> = Add %0 %7;
return %17 @loc="106:29";
}
()tensor<int64> -> {
%18 : tensor<int64> = Add %0 %8;
return %18 @loc="107:29";
};
return %16 @loc="105:25";
};
return %13 @loc="103:17";
};
return %10 @loc="101:9";
};
func @"turnLeft" (%19 : tensor<int64>, %20 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, %21 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionNorth>, %22 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionWest>, %23 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionSouth>)tensor<int64> -> {
%24 : tensor<bool> = Equal %19 %20;
%25 : tensor<int64> = If %24
()tensor<int64> -> {
%26 : tensor<int64> = Identity %21;
return %26 @loc="80:17";
}
()tensor<int64> -> {
%27 : tensor<bool> = Equal %19 %21;
%28 : tensor<int64> = If %27
()tensor<int64> -> {
%29 : tensor<int64> = Identity %22;
return %29 @loc="82:25";
}
()tensor<int64> -> {
%30 : tensor<bool> = Equal %19 %22;
%31 : tensor<int64> = If %30
()tensor<int64> -> {
%32 : tensor<int64> = Identity %23;
return %32 @loc="84:29";
}
()tensor<int64> -> {
%33 : tensor<int64> = Identity %20;
return %33 @loc="85:29";
};
return %31 @loc="83:25";
};
return %28 @loc="81:17";
};
return %25 @loc="79:9";
};
func @"isWallAt" (%34 : tensor<int64>, %35 : init<tensor<uint8>, oracle.code.onnx.WalkTheMazeTest.maze>, %36 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.oneOne>, %37 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.wall>)tensor<bool> -> {
%38 : tensor<int64> = Add %34 %36;
%39 : tensor<uint8> = Slice %35 %34 %38;
%40 : tensor<int64> = CastLike %39 %37;
%41 : tensor<bool> = Equal %40 %37;
return %41 @loc="96:9";
};
func @"walkAroundTheMaze" (%42 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, %43 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepEast>, %44 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionNorth>, %45 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepNorth>, %46 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionWest>, %47 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepWest>, %48 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.stepSouth>, %49 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionSouth>, %50 : init<tensor<uint8>, oracle.code.onnx.WalkTheMazeTest.maze>, %51 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.oneOne>, %52 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.wall>, %53 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.homePos>, %54 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.limit>, %55 : init<tensor<bool>, oracle.code.onnx.WalkT
heMazeTest._true>, %56 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.three>, %57 : init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.scalarShape>)tensor<uint8> -> {
%58 : tensor<uint8> = Cast %42 @to="2";
%59 : Tuple<init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.homePos>, init<tensor<int64>, oracle.code.onnx.WalkTheMazeTest.directionEast>, tensor<uint8>> = tuple %53 %42 %58;
%60 : Tuple<tensor<int64>, tensor<int64>, tensor<uint8>> = Loop %54 %55 %59 (%61 : tensor<int64>, %62 : tensor<bool>, %63 : Tuple<tensor<int64>, tensor<int64>, tensor<uint8>>)Tuple<tensor<bool>, Tuple<tensor<int64>, tensor<int64>, tensor<uint8>>> -> {
%64 : tensor<int64> = tuple.load %63 @"0";
%65 : tensor<int64> = tuple.load %63 @"1";
%66 : tensor<int64> = func.call %64 %65 %42 %43 %44 %45 %46 %47 %48 @"step";
%67 : tensor<int64> = tuple.load %63 @"1";
%68 : tensor<int64> = Loop %56 %55 %67 (%69 : tensor<int64>, %70 : tensor<bool>, %71 : tensor<int64>)Tuple<tensor<bool>, tensor<int64>> -> {
%72 : tensor<int64> = func.call %71 %42 %44 %46 %49 @"turnLeft";
%73 : Tuple<tensor<bool>, tensor<int64>> = tuple %70 %72;
return %73 @loc="90:46";
};
%74 : tensor<int64> = func.call %66 %68 %42 %43 %44 %45 %46 %47 %48 @"step";
%75 : tensor<bool> = func.call %74 %50 %51 %52 @"isWallAt";
%76 : tensor<bool> = Reshape %75 %57;
%77 : tensor<int64> = Loop %54 %76 %68 (%78 : tensor<int64>, %79 : tensor<bool>, %80 : tensor<int64>)Tuple<tensor<bool>, tensor<int64>> -> {
%81 : tensor<int64> = func.call %80 %42 %44 %46 %49 @"turnLeft";
%82 : tensor<int64> = func.call %66 %81 %42 %43 %44 %45 %46 %47 %48 @"step";
%83 : tensor<bool> = func.call %82 %50 %51 %52 @"isWallAt";
%84 : Tuple<tensor<bool>, tensor<int64>> = tuple %83 %81;
return %84 @loc="120:17";
};
%85 : tensor<bool> = Equal %66 %53;
%86 : tensor<bool> = ReduceMin %85;
%87 : tensor<bool> = Not %86;
%88 : tensor<uint8> = tuple.load %63 @"2";
%89 : tensor<uint8> = Cast %77 @to="2";
%90 : tensor<uint8> = Concat %88 %89 @axis="0";
%91 : Tuple<tensor<int64>, tensor<int64>, tensor<uint8>> = tuple %66 %77 %90;
%92 : Tuple<tensor<bool>, Tuple<tensor<int64>, tensor<int64>, tensor<uint8>>> = tuple %87 %91;
return %92 @loc="140:13";
};
%93 : tensor<uint8> = tuple.load %60 @"2";
return %93 @loc="148:69";
};
unreachable;
};
Protobuf model fragment of the sample Onnx model:

-------------
Commit messages:
- OnnxProtoBuilder.Indexer improved to match computed global names by OpWriter
- OnnxProtoBuilder support for custom domain name and named graphs + cleanup
- model main func name calculation
- iniling single-called functions
- OnnxTransformer cleanup
- Initializers encoded into OnnxType.Initializer
- OnnxProtoBuilder support for ModuleOp - work in progress
- OnnxProtoBuilder support for ModuleOp - work in progress
- fixed SimpleTest::testConcat to avoid non-tensor in args
- Onnx model functions into module - work in progress
- ... and 1 more: https://git.openjdk.org/babylon/compare/bd0569ee...3c68af28
Changes: https://git.openjdk.org/babylon/pull/378/files
Webrev: https://webrevs.openjdk.org/?repo=babylon&pr=378&range=00
Stats: 487 lines in 8 files changed: 315 ins; 78 del; 94 mod
Patch: https://git.openjdk.org/babylon/pull/378.diff
Fetch: git fetch https://git.openjdk.org/babylon.git pull/378/head:pull/378
PR: https://git.openjdk.org/babylon/pull/378
More information about the babylon-dev
mailing list