From ee2d9492591f0b7403fd9bc0e5716c9a5fe45555 Mon Sep 17 00:00:00 2001 From: Usamah Zaheer Date: Tue, 19 May 2026 22:34:34 +0100 Subject: [PATCH] Arm backend: Decompose integral float pow exponents Treat positive integral float scalar exponents like integer exponents in DecomposeIntPowPass. This avoids lowering pow(x, 2.0) to TOSA POW, whose reference model rejects negative bases even when the exponent is mathematically integral. Keep zero and negative float exponents on their existing paths so the current zero decomposition and TOSA constraint handling stay unchanged. This unblocks Swin2SR-style graphs on the TOSA reference model. Ticket: MLETORCH-2134 Test Plan: - lintrunner on changed files - pytest test_decompose_int_pow_pass.py and test_pow.py - TinySwin2SR FP/INT TOSA reference smoke Signed-off-by: Usamah Zaheer Change-Id: I650190a63fc8cfc676dbdde4ce33200d71e9aa4c --- .../arm/_passes/decompose_int_pow_pass.py | 25 ++++++++++++-- backends/arm/test/ops/test_pow.py | 13 -------- .../passes/test_decompose_int_pow_pass.py | 33 ++++++++++++++++++- 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py index bb29d34d6bf..a31a9415e23 100644 --- a/backends/arm/_passes/decompose_int_pow_pass.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Set, Type +from typing import Optional, Set, Type from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops @@ -21,6 +21,18 @@ class DecomposeIntPowPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod + def _get_decomposable_integer_exponent(exp) -> Optional[int]: + if isinstance(exp, int): + return exp + # Exported models can represent positive integer-valued exponents as + # floats, for example pow(x, 2.0). Only exact values are decomposed: + # rounding near-integer floats would change fractional pow semantics, + # especially for negative bases. + if isinstance(exp, float) and exp > 0 and exp.is_integer(): + return int(exp) + return None + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.pow.Tensor_Scalar: return super().call_operator(op, args, kwargs, meta) @@ -43,9 +55,18 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.add.Tensor, (zeros, ones), {}, meta, True ) - if not isinstance(exp, int): + exp = self._get_decomposable_integer_exponent(exp) + if exp is None: return super().call_operator(op, args, kwargs, meta) + if exp == 1: + ones = super().call_operator( + exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True + ) + return super().call_operator( + exir_ops.edge.aten.mul.Tensor, (x, ones), {}, meta, True + ) + # Handle negative exponent if exp < 0: x = super().call_operator( diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py index 6d304ce0627..28f95f2c622 100644 --- a/backends/arm/test/ops/test_pow.py +++ b/backends/arm/test/ops/test_pow.py @@ -141,22 +141,11 @@ def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t): pipeline.run() -x_fail = { - "exp_two": "TOSA constraints: If x <0 .", -} - -x_fail_FP = { - "exp_two": "TOSA constraints: If x <0 .", -} - - @common.parametrize( "test_data", Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16 | Pow_TensorScalar.test_data_bf16, - xfails=x_fail_FP, - strict=False, ) def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t): base, exp = test_data() @@ -211,8 +200,6 @@ def test_pow_tensor_scalar_u85_INT(test_data: Pow_TensorScalar.input_t): @common.parametrize( "test_data", Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16, - x_fail_FP, - strict=False, ) @common.SkipIfNoModelConverter def test_pow_tensor_scalar_vgf_no_quant(test_data: Pow_TensorScalar.input_t): diff --git a/backends/arm/test/passes/test_decompose_int_pow_pass.py b/backends/arm/test/passes/test_decompose_int_pow_pass.py index 7761c031e2c..ac6a03a68eb 100644 --- a/backends/arm/test/passes/test_decompose_int_pow_pass.py +++ b/backends/arm/test/passes/test_decompose_int_pow_pass.py @@ -35,7 +35,7 @@ def get_inputs(self) -> input_t: class Pow(torch.nn.Module): """Basic squaring.""" - def __init__(self, exponent: int) -> None: + def __init__(self, exponent: int | float) -> None: super().__init__() self.exponent = exponent @@ -48,12 +48,20 @@ def get_inputs(self) -> input_t: test_data: Dict[str, TestParam] = { "square": (Square(), 1), + "pow_1": (Pow(1), 1), + "pow_1_float": (Pow(1.0), 1), "pow_2": (Pow(2), 1), + "pow_2_float": (Pow(2.0), 1), "pow_3": (Pow(3), 2), "pow_0": (Pow(0), 0), "pow_neg_2": (Pow(-2), 1), } +non_integer_float_test_data: Dict[str, ModuleWithInputs] = { + "pow_1_999999999": Pow(1.999999999), + "pow_2_000000001": Pow(2.000000001), +} + @common.parametrize("data", test_data) def test_decompose_int_pow_tosa_FP(data: TestParam) -> None: @@ -74,3 +82,26 @@ def test_decompose_int_pow_tosa_FP(data: TestParam) -> None: pass_list=[DecomposeIntPowPass], ) pipeline.run() + + +@common.parametrize("module_with_inputs", non_integer_float_test_data) +def test_decompose_int_pow_tosa_FP_non_integer_float( + module_with_inputs: ModuleWithInputs, +) -> None: + module = cast(torch.nn.Module, module_with_inputs) + pow_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar" + pipeline = PassPipeline[input_t]( + module, + module_with_inputs.get_inputs(), + quantize=False, + ops_before_pass={ + pow_op: 1, + }, + ops_not_before_pass=[], + ops_after_pass={ + pow_op: 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 0, + }, + pass_list=[DecomposeIntPowPass], + ) + pipeline.run()