From 7a8bd1eb35536a3259eee7d4b5cd4a0f78d9749d Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 21 Dec 2024 10:21:35 +0100 Subject: [PATCH 1/6] Added Pow operator --- ops/binary_op.go | 5 ++++ ops/convert.go | 28 +++++++++++++++++++ ops/errors.go | 4 +++ ops/pow/pow.go | 63 +++++++++++++++++++++++++++++++++++++++++ ops/pow/pow_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++ ops/pow/versions.go | 15 ++++++++++ ops_test.go | 6 ++++ opset.go | 2 ++ 8 files changed, 191 insertions(+) create mode 100644 ops/pow/pow.go create mode 100644 ops/pow/pow_test.go create mode 100644 ops/pow/versions.go diff --git a/ops/binary_op.go b/ops/binary_op.go index 3df36e8..be0361d 100644 --- a/ops/binary_op.go +++ b/ops/binary_op.go @@ -48,6 +48,11 @@ func Mul(A, B tensor.Tensor) (tensor.Tensor, error) { return tensor.Mul(A, B) } +// Pow raises the first tensor to the power of the second tensor. +func Pow(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Pow(A, B) +} + // Sub subtracts 1 tensor from the other. func Sub(A, B tensor.Tensor) (tensor.Tensor, error) { return tensor.Sub(A, B) diff --git a/ops/convert.go b/ops/convert.go index 0637f49..e4fa67c 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -51,6 +51,34 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { return tensor.New(tensor.WithShape(t.Shape()...), tensor.WithBacking(newBacking)), nil } +func DTypeToONNXType(t tensor.Dtype) (int32, error) { + switch t { + case tensor.Float32: + return int32(onnx.TensorProto_FLOAT), nil + case tensor.Float64: + return int32(onnx.TensorProto_DOUBLE), nil + case tensor.Int8: + return int32(onnx.TensorProto_INT8), nil + case tensor.Int16: + return int32(onnx.TensorProto_INT16), nil + case tensor.Int32: + return int32(onnx.TensorProto_INT32), nil + case tensor.Int64: + return int32(onnx.TensorProto_INT64), nil + case tensor.Uint8: + return int32(onnx.TensorProto_UINT8), nil + case tensor.Uint16: + return int32(onnx.TensorProto_UINT16), nil + case tensor.Uint32: + return int32(onnx.TensorProto_UINT32), nil + case tensor.Uint64: + return int32(onnx.TensorProto_UINT64), nil + default: + return 0, ErrUnknownTensorONNXDtype(t) + } + +} + func convertBacking[B Number](backing []B, dataType int32) (any, error) { switch onnx.TensorProto_DataType(dataType) { case onnx.TensorProto_FLOAT: diff --git a/ops/errors.go b/ops/errors.go index 0518d6f..bf6fe87 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -282,6 +282,10 @@ func ErrConversionNotSupported(dType int32) error { return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType) } +func ErrUnknownTensorONNXDtype(dType tensor.Dtype) error { + return fmt.Errorf("%w: tensor with dtype %v does not have a corresponding onnx type", ErrCast, dType) +} + var ErrActivationNotImplementedBase = errors.New("the given activation function is not implemented") func ErrActivationNotImplemented(activation string) error { diff --git a/ops/pow/pow.go b/ops/pow/pow.go new file mode 100644 index 0000000..1e08a9e --- /dev/null +++ b/ops/pow/pow.go @@ -0,0 +1,63 @@ +package pow + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var pow7TypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, +} + +var powTypeConstraints = [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Pow represents the ONNX pow operator. +type Pow struct { + ops.BaseOperator +} + +// newPow creates a new pow operator. +func newPow(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Pow{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "pow", + ), + } +} + +// Init initializes the pow operator. +func (a *Pow) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the pow operator. +func (a *Pow) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + powTensor := inputs[1] + if inputs[0].Dtype() != powTensor.Dtype() { + to, err := ops.DTypeToONNXType(inputs[0].Dtype()) + if err != nil { + return nil, err + } + + powTensor, err = ops.ConvertTensorDtype(powTensor, to) + if err != nil { + return nil, err + } + } + + return ops.ApplyBinaryOperation( + inputs[0], + powTensor, + ops.Pow, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/pow/pow_test.go b/ops/pow/pow_test.go new file mode 100644 index 0000000..9bbe05c --- /dev/null +++ b/ops/pow/pow_test.go @@ -0,0 +1,68 @@ +package pow + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestPowInit(t *testing.T) { + p := &Pow{} + err := p.Init(nil) + assert.Nil(t, err) +} + +func TestPow(t *testing.T) { + tests := []struct { + version int64 + backing0 any + backing1 any + shapes [][]int + expected []float32 + }{ + { + 13, + []float32{0, 1, 2, 3}, + []float32{1, 1, 1, 1}, + [][]int{{2, 2}, {2, 2}}, + []float32{0, 1, 2, 3}, + }, + { + 13, + []float32{0, 1, 2, 3, 4, 5}, + []float32{2, 2, 2, 2, 2, 2}, + [][]int{{3, 2}, {3, 2}}, + []float32{0, 1, 4, 9, 16, 25}, + }, + { + 13, + []float32{0, 1}, + []float32{0, 1, 2, 3}, + [][]int{{2}, {2, 2}}, + []float32{1, 1, 0, 1}, + }, + { + 13, + []int32{1, 2, 3}, + []int32{4, 5, 6}, + [][]int{{3}, {3}}, + []float32{1, 1, 0, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing0, test.shapes[0]...), + ops.TensorWithBackingFixture(test.backing1, test.shapes[1]...), + } + + pow := powVersions[test.version]() + + res, err := pow.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expected, res[0].Data()) + } +} diff --git a/ops/pow/versions.go b/ops/pow/versions.go new file mode 100644 index 0000000..eb77068 --- /dev/null +++ b/ops/pow/versions.go @@ -0,0 +1,15 @@ +package pow + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var powVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newPow, 7, pow7TypeConstraints), + 12: ops.NewOperatorConstructor(newPow, 12, powTypeConstraints), + 13: ops.NewOperatorConstructor(newPow, 13, powTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return powVersions +} diff --git a/ops_test.go b/ops_test.go index 07fb258..88fedd6 100644 --- a/ops_test.go +++ b/ops_test.go @@ -172,11 +172,16 @@ func TestOps(t *testing.T) { runnedTests := []string{} for opName := range operators { + if opName != "Pow" { + continue + } + tests, err := getTestCasesForOp(opName) assert.Nil(t, err) for _, test := range tests { t.Run(test.name, func(t *testing.T) { + fmt.Println(test.name, test.inputs) outputs, err := test.model.Run(test.inputs) assert.Nil(t, err) @@ -301,6 +306,7 @@ func readTestModel(folder string) (*Model, error) { // Currently we support Opset 7-13, hence we enforce this in our tests. All // tests that fail because of this are ignored. + fmt.Println(folder, mp.OpsetImport[0].Version) if mp.OpsetImport[0].Version < MinSupportedOpset { mp.OpsetImport[0].Version = MinSupportedOpset } else if mp.OpsetImport[0].Version > MaxSupportedOpset { diff --git a/opset.go b/opset.go index 0bfcdb9..35b34b2 100644 --- a/opset.go +++ b/opset.go @@ -37,6 +37,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/mul" "github.com/advancedclimatesystems/gonnx/ops/not" "github.com/advancedclimatesystems/gonnx/ops/or" + "github.com/advancedclimatesystems/gonnx/ops/pow" "github.com/advancedclimatesystems/gonnx/ops/prelu" "github.com/advancedclimatesystems/gonnx/ops/reducemax" "github.com/advancedclimatesystems/gonnx/ops/reducemin" @@ -103,6 +104,7 @@ var operators = map[string]ops.OperatorVersions{ "Mul": mul.GetMulVersions(), "Not": not.GetNotVersions(), "Or": or.GetOrVersions(), + "Pow": pow.GetVersions(), "PRelu": prelu.GetPReluVersions(), "ReduceMax": reducemax.GetReduceMaxVersions(), "ReduceMin": reducemin.GetReduceMinVersions(), From d9afb57fd0bfa391b2d11a418983ca1a047be825 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 21 Dec 2024 14:52:05 +0100 Subject: [PATCH 2/6] Handle non-float pow csaes --- onnx/graph_proto.go | 2 +- ops/binary_op.go | 41 ++++++++++++++++++++++++++++++++++++++++- ops/pow/pow_test.go | 4 ++-- ops/types.go | 5 +++++ ops_test.go | 20 +++++++++++++------- 5 files changed, 61 insertions(+), 11 deletions(-) diff --git a/onnx/graph_proto.go b/onnx/graph_proto.go index 046b7e0..88e5d54 100644 --- a/onnx/graph_proto.go +++ b/onnx/graph_proto.go @@ -556,7 +556,7 @@ func ReadInt32ArrayFromBytes(data []byte) ([]int32, error) { // ReadUint64ArrayFromBytes reads data and parses it to an array of uint64. func ReadUint64ArrayFromBytes(data []byte) ([]uint64, error) { buffer := bytes.NewReader(data) - element := make([]byte, int32Size) + element := make([]byte, int64Size) var ( err error diff --git a/ops/binary_op.go b/ops/binary_op.go index be0361d..ba2aa19 100644 --- a/ops/binary_op.go +++ b/ops/binary_op.go @@ -1,6 +1,8 @@ package ops import ( + "slices" + "gorgonia.org/tensor" ) @@ -49,8 +51,45 @@ func Mul(A, B tensor.Tensor) (tensor.Tensor, error) { } // Pow raises the first tensor to the power of the second tensor. +// Because the gorgonia.Tensor 'Pow' operation only supports float32 and float64, +// we need to convert the tensors to float64 if they are of a different type. +// After the operation is done, we convert the result back to the original type. func Pow(A, B tensor.Tensor) (tensor.Tensor, error) { - return tensor.Pow(A, B) + needsConversion := false + if slices.Contains(IntTypes, A.Dtype()) { + needsConversion = true + } + + if !needsConversion { + return tensor.Pow(A, B) + } + + oldType, err := DTypeToONNXType(A.Dtype()) + if err != nil { + return nil, err + } + + newType, err := DTypeToONNXType(tensor.Float64) + if err != nil { + return nil, err + } + + A, err = ConvertTensorDtype(A, newType) + if err != nil { + return nil, err + } + + B, err = ConvertTensorDtype(B, newType) + if err != nil { + return nil, err + } + + out, err := tensor.Pow(A, B) + if err != nil { + return nil, err + } + + return ConvertTensorDtype(out, oldType) } // Sub subtracts 1 tensor from the other. diff --git a/ops/pow/pow_test.go b/ops/pow/pow_test.go index 9bbe05c..b85a0ec 100644 --- a/ops/pow/pow_test.go +++ b/ops/pow/pow_test.go @@ -20,7 +20,7 @@ func TestPow(t *testing.T) { backing0 any backing1 any shapes [][]int - expected []float32 + expected any }{ { 13, @@ -48,7 +48,7 @@ func TestPow(t *testing.T) { []int32{1, 2, 3}, []int32{4, 5, 6}, [][]int{{3}, {3}}, - []float32{1, 1, 0, 1}, + []int32{1, 32, 729}, }, } diff --git a/ops/types.go b/ops/types.go index edea5fb..75ade06 100644 --- a/ops/types.go +++ b/ops/types.go @@ -16,3 +16,8 @@ var AllTypes = []tensor.Dtype{ tensor.String, tensor.Bool, } + +var IntTypes = []tensor.Dtype{ + tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, + tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, +} diff --git a/ops_test.go b/ops_test.go index 88fedd6..5dcd9a7 100644 --- a/ops_test.go +++ b/ops_test.go @@ -172,16 +172,11 @@ func TestOps(t *testing.T) { runnedTests := []string{} for opName := range operators { - if opName != "Pow" { - continue - } - tests, err := getTestCasesForOp(opName) assert.Nil(t, err) for _, test := range tests { t.Run(test.name, func(t *testing.T) { - fmt.Println(test.name, test.inputs) outputs, err := test.model.Run(test.inputs) assert.Nil(t, err) @@ -192,7 +187,7 @@ func TestOps(t *testing.T) { if expectedTensor.Dtype() == tensor.Bool { assert.ElementsMatch(t, expectedTensor.Data(), actualTensor.Data()) } else { - assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001) + assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.001) } } }) @@ -306,7 +301,6 @@ func readTestModel(folder string) (*Model, error) { // Currently we support Opset 7-13, hence we enforce this in our tests. All // tests that fail because of this are ignored. - fmt.Println(folder, mp.OpsetImport[0].Version) if mp.OpsetImport[0].Version < MinSupportedOpset { mp.OpsetImport[0].Version = MinSupportedOpset } else if mp.OpsetImport[0].Version > MaxSupportedOpset { @@ -477,6 +471,18 @@ var expectedTests = []string{ "test_or_bcast4v2d", "test_or_bcast4v3d", "test_or_bcast4v4d", + "test_pow", + "test_pow_bcast_array", + "test_pow_bcast_scalar", + "test_pow_example", + "test_pow_types_float32_int32", + "test_pow_types_float32_int64", + "test_pow_types_float32_uint32", + "test_pow_types_float32_uint64", + "test_pow_types_int32_float32", + "test_pow_types_int32_int32", + "test_pow_types_int64_float32", + "test_pow_types_int64_int64", "test_prelu_broadcast", "test_prelu_example", "test_relu", From 86509570c5e6c38d81331666d2845d6e1081c980 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 21 Dec 2024 14:52:48 +0100 Subject: [PATCH 3/6] Fix lint --- ops/convert.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ops/convert.go b/ops/convert.go index e4fa67c..dd313cc 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -76,7 +76,6 @@ func DTypeToONNXType(t tensor.Dtype) (int32, error) { default: return 0, ErrUnknownTensorONNXDtype(t) } - } func convertBacking[B Number](backing []B, dataType int32) (any, error) { From ea82830e1ec0f9b09be12bd8d7d71c7387da91d0 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 21 Dec 2024 14:58:20 +0100 Subject: [PATCH 4/6] Update pipeline go versions --- .github/workflows/go.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index d53567a..0b6497d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Install linter run: make install_lint @@ -34,7 +34,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Install dependencies run: make install @@ -56,7 +56,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Build amd64 run: make build_amd64 @@ -69,7 +69,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Build arm64 run: make build_arm64 From b8c5188b80a967315029b2502846f6df81c9c8e7 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 22 Dec 2024 09:48:20 +0100 Subject: [PATCH 5/6] Use go1.23 --- .github/workflows/go.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0b6497d..ab7b1a6 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.21 + go-version: 1.23 - name: Install linter run: make install_lint @@ -34,7 +34,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.21 + go-version: 1.23 - name: Install dependencies run: make install @@ -56,7 +56,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.21 + go-version: 1.23 - name: Build amd64 run: make build_amd64 @@ -69,7 +69,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.21 + go-version: 1.23 - name: Build arm64 run: make build_arm64 From 2cdcbad9fa637c4d381d34ac45044dc40692dffa Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 22 Dec 2024 10:02:48 +0100 Subject: [PATCH 6/6] Back to go1.21 --- .github/workflows/go.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index ab7b1a6..0b6497d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.23 + go-version: 1.21 - name: Install linter run: make install_lint @@ -34,7 +34,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.23 + go-version: 1.21 - name: Install dependencies run: make install @@ -56,7 +56,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.23 + go-version: 1.21 - name: Build amd64 run: make build_amd64 @@ -69,7 +69,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.23 + go-version: 1.21 - name: Build arm64 run: make build_arm64