diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py index bb29d34d6bf..a31a9415e23 100644 --- a/backends/arm/_passes/decompose_int_pow_pass.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Set, Type +from typing import Optional, Set, Type from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops @@ -21,6 +21,18 @@ class DecomposeIntPowPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod + def _get_decomposable_integer_exponent(exp) -> Optional[int]: + if isinstance(exp, int): + return exp + # Exported models can represent positive integer-valued exponents as + # floats, for example pow(x, 2.0). Only exact values are decomposed: + # rounding near-integer floats would change fractional pow semantics, + # especially for negative bases. + if isinstance(exp, float) and exp > 0 and exp.is_integer(): + return int(exp) + return None + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.pow.Tensor_Scalar: return super().call_operator(op, args, kwargs, meta) @@ -43,9 +55,18 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.add.Tensor, (zeros, ones), {}, meta, True ) - if not isinstance(exp, int): + exp = self._get_decomposable_integer_exponent(exp) + if exp is None: return super().call_operator(op, args, kwargs, meta) + if exp == 1: + ones = super().call_operator( + exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True + ) + return super().call_operator( + exir_ops.edge.aten.mul.Tensor, (x, ones), {}, meta, True + ) + # Handle negative exponent if exp < 0: x = super().call_operator( diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py index 6d304ce0627..28f95f2c622 100644 --- a/backends/arm/test/ops/test_pow.py +++ b/backends/arm/test/ops/test_pow.py @@ -141,22 +141,11 @@ def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t): pipeline.run() -x_fail = { - "exp_two": "TOSA constraints: If x <0 .", -} - -x_fail_FP = { - "exp_two": "TOSA constraints: If x <0 .", -} - - @common.parametrize( "test_data", Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16 | Pow_TensorScalar.test_data_bf16, - xfails=x_fail_FP, - strict=False, ) def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t): base, exp = test_data() @@ -211,8 +200,6 @@ def test_pow_tensor_scalar_u85_INT(test_data: Pow_TensorScalar.input_t): @common.parametrize( "test_data", Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16, - x_fail_FP, - strict=False, ) @common.SkipIfNoModelConverter def test_pow_tensor_scalar_vgf_no_quant(test_data: Pow_TensorScalar.input_t): diff --git a/backends/arm/test/passes/test_decompose_int_pow_pass.py b/backends/arm/test/passes/test_decompose_int_pow_pass.py index 7761c031e2c..ac6a03a68eb 100644 --- a/backends/arm/test/passes/test_decompose_int_pow_pass.py +++ b/backends/arm/test/passes/test_decompose_int_pow_pass.py @@ -35,7 +35,7 @@ def get_inputs(self) -> input_t: class Pow(torch.nn.Module): """Basic squaring.""" - def __init__(self, exponent: int) -> None: + def __init__(self, exponent: int | float) -> None: super().__init__() self.exponent = exponent @@ -48,12 +48,20 @@ def get_inputs(self) -> input_t: test_data: Dict[str, TestParam] = { "square": (Square(), 1), + "pow_1": (Pow(1), 1), + "pow_1_float": (Pow(1.0), 1), "pow_2": (Pow(2), 1), + "pow_2_float": (Pow(2.0), 1), "pow_3": (Pow(3), 2), "pow_0": (Pow(0), 0), "pow_neg_2": (Pow(-2), 1), } +non_integer_float_test_data: Dict[str, ModuleWithInputs] = { + "pow_1_999999999": Pow(1.999999999), + "pow_2_000000001": Pow(2.000000001), +} + @common.parametrize("data", test_data) def test_decompose_int_pow_tosa_FP(data: TestParam) -> None: @@ -74,3 +82,26 @@ def test_decompose_int_pow_tosa_FP(data: TestParam) -> None: pass_list=[DecomposeIntPowPass], ) pipeline.run() + + +@common.parametrize("module_with_inputs", non_integer_float_test_data) +def test_decompose_int_pow_tosa_FP_non_integer_float( + module_with_inputs: ModuleWithInputs, +) -> None: + module = cast(torch.nn.Module, module_with_inputs) + pow_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar" + pipeline = PassPipeline[input_t]( + module, + module_with_inputs.get_inputs(), + quantize=False, + ops_before_pass={ + pow_op: 1, + }, + ops_not_before_pass=[], + ops_after_pass={ + pow_op: 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 0, + }, + pass_list=[DecomposeIntPowPass], + ) + pipeline.run()