diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 4a8857be82a..20bddf17793 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -43,7 +43,7 @@ from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa from .decompose_einsum_pass import DecomposeEinsumPass # noqa -from .decompose_elu_pass import DecomposeEluPass # noqa +from .decompose_elu_pass import ConvertEluFamilyToEluPass, DecomposeEluPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_erfinv_pass import DecomposeErfinvPass # noqa from .decompose_expm1_pass import DecomposeExpm1Pass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index bf39cbe44ea..5a135696463 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -20,6 +20,7 @@ ConstantFoldingPass, ControlFlowConstInlinePass, Conv1dUnsqueezePass, + ConvertEluFamilyToEluPass, ConvertELUParamsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, @@ -403,6 +404,7 @@ def _tosa_pipeline( DecomposeLayerNormPass(), DecomposeVarPass(), DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec), + ConvertEluFamilyToEluPass(), ConvertELUParamsPass(), ControlFlowConstInlinePass(), NormalizeWhileInitialArgsPass(use_exir_clone=True), @@ -607,6 +609,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): RewriteInplaceArithmeticPass(tfa_pass=True), DecomposeAddSubAlphaPass(tfa_pass=True), DecomposeLeakyReLUPass(tfa_pass=True), + ConvertEluFamilyToEluPass(tfa_pass=True), DecomposeGroupNormPass(tfa_pass=True), DecomposeLayerNormPass(tfa_pass=True), DecomposeVarPass(tfa_pass=True), diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py index 0e4ec86f516..6c5d8b93143 100644 --- a/backends/arm/_passes/convert_elu_params.py +++ b/backends/arm/_passes/convert_elu_params.py @@ -14,14 +14,20 @@ class ConvertELUParamsPass(ArmPass): - """Pass to convert the input_scale kwarg of ELU operator from float to int. + """The int8 ELU operator crashes when the alpha, scale or input scale + parameters are not integers. - It has been set to 2 as the outputs seem to stay the same regardless of what - the value of input_scale is, as long as that value is not 1. + This pass temporarily converts quantized ELU parameters to int and stores + the original float values in the meta dict to be able to recover them later. """ - _passes_required_after: Set[Type[ExportPass]] = set() + @property + def _passes_required_after(self) -> Set[Type[ExportPass]]: + # Lazy import to avoid circular dependency between passes + from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass + + return {InsertTableOpsPass} def call(self, graph_module: torch.fx.GraphModule): modified_graph = False @@ -36,29 +42,45 @@ def call(self, graph_module: torch.fx.GraphModule): ) if not is_quantized or not self.allowed_to_transform(node.meta): continue + with graph.inserting_after(node): replace_node = create_node( graph, exir_ops.edge.aten.elu.default, from_node=node ) - old_args = list(node.args) - alpha = old_args[1] if len(old_args) > 1 else 1.0 - scale = 1.0 - input_scale = 2.0 + old_args = list(node.args) + alpha = ( + old_args[1] if len(old_args) > 1 else node.kwargs.get("alpha", 1.0) + ) + scale = ( + old_args[2] if len(old_args) > 2 else node.kwargs.get("scale", 1.0) + ) + input_scale = ( + old_args[3] + if len(old_args) > 3 + else node.kwargs.get("input_scale", 1.0) + ) replace_node.args = (old_args[0],) + # Set placeholder int values updated_kwargs = dict(node.kwargs) - updated_kwargs["alpha"] = int(alpha) - updated_kwargs["scale"] = int(scale) - updated_kwargs["input_scale"] = int(input_scale) - + updated_kwargs["alpha"] = 1 + updated_kwargs["scale"] = 1 + updated_kwargs["input_scale"] = ( + 2 # Keep input_scale away from 1 to avoid fake execution type checks. + ) replace_node.kwargs = updated_kwargs + # Save correct parameters + replace_node.meta["float_alpha"] = alpha + replace_node.meta["float_scale"] = scale + replace_node.meta["float_input_scale"] = input_scale + node.replace_all_uses_with(replace_node) graph.erase_node(node) - modified_graph = True + if modified_graph: graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index d212c3ec9cb..548a508d914 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -5,11 +5,22 @@ from typing import Set, Type +import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass edge_elu_ops = (exir_ops.edge.aten.elu.default,) +edge_selu_ops = (exir_ops.edge.aten.selu.default,) +edge_celu_ops = (exir_ops.edge.aten.celu.default,) +edge_elu_family_ops = edge_elu_ops + edge_selu_ops + edge_celu_ops +torch_selu_ops = (torch.ops.aten.selu.default,) +torch_celu_ops = (torch.ops.aten.celu.default,) +selu_ops = edge_selu_ops + torch_selu_ops +celu_ops = edge_celu_ops + torch_celu_ops + +SELU_ALPHA = 1.6732632423543772 +SELU_SCALE = 1.0507009873554805 def get_elu_decomposition(op) -> tuple: @@ -29,7 +40,7 @@ def get_elu_decomposition(op) -> tuple: """ - if op in edge_elu_ops: + if op in edge_elu_family_ops: return ( exir_ops.edge.aten.expm1.default, exir_ops.edge.aten.ge.Scalar, @@ -40,15 +51,64 @@ def get_elu_decomposition(op) -> tuple: raise RuntimeError(f"Can't get elu decomposition for op {op}") +def _get_elu_parameter(args, kwargs, index, name): + if len(args) > index: + return args[index] + + return kwargs.get(name, 1.0) + + +def _get_elu_parameters(op, args, kwargs): + if op in selu_ops: + return SELU_ALPHA, SELU_SCALE, 1.0 + if op in celu_ops: + alpha = _get_elu_parameter(args, kwargs, 1, "alpha") + return alpha, 1.0, 1.0 / alpha + + alpha = _get_elu_parameter(args, kwargs, 1, "alpha") + scale = _get_elu_parameter(args, kwargs, 2, "scale") + input_scale = _get_elu_parameter(args, kwargs, 3, "input_scale") + return alpha, scale, input_scale + + +class ConvertEluFamilyToEluPass(ArmPass): + """Convert SELU/CELU ops to equivalent parameterized ELU ops.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta): + if op not in selu_ops + celu_ops or not self.allowed_to_transform(meta): + return super().call_operator(op, args, kwargs, meta, updated=False) + + input_ = args[0] + alpha, scale, input_scale = _get_elu_parameters(op, args, kwargs) + elu_op = ( + torch.ops.aten.elu.default + if op in torch_selu_ops + torch_celu_ops + else exir_ops.edge.aten.elu.default + ) + return super().call_operator( + elu_op, + (input_, alpha, scale, input_scale), + {}, + meta, + updated=True, + ) + + class DecomposeEluPass(ArmPass): """A transformation pass that decomposes unsupported 'aten.elu' operations into a combination of supported TOSA-equivalent operations. Since TOSA does not provide a native ELU operator, this pass rewrites: - elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x) + elu(x) → scale * where( + greater_or_eq(x, 0), x, alpha * expm1(input_scale * x) + ) Supported input ops: - - exir_ops.edge.aten.elu.Tensor(x) + - exir_ops.edge.aten.elu.default + - exir_ops.edge.aten.selu.default + - exir_ops.edge.aten.celu.default These are replaced with: - exir_ops.edge.aten.expm1.default @@ -61,7 +121,7 @@ class DecomposeEluPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): - if op not in edge_elu_ops: + if op not in edge_elu_family_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): @@ -76,11 +136,11 @@ def call_operator(self, op, args, kwargs, meta): ) = get_elu_decomposition(op) input = args[0] - alpha = args[1] if len(args) > 1 else 1.0 + alpha, scale, input_scale = _get_elu_parameters(op, args, kwargs) if alpha == 0: relu_op = exir_ops.edge.aten.clamp.default - return super().call_operator( + relu_node = super().call_operator( relu_op, ( input, @@ -90,14 +150,35 @@ def call_operator(self, op, args, kwargs, meta): meta, updated=True, ) + if scale == 1: + return relu_node - expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True) + return super().call_operator( + mul_op, (relu_node, scale), {}, meta, updated=True + ) + + expm1_input = input + if input_scale != 1: + expm1_input = super().call_operator( + mul_op, (input, input_scale), {}, meta, updated=True + ) + expm1_node = super().call_operator( + expm1_op, (expm1_input,), {}, meta, updated=True + ) mul_node = super().call_operator( mul_op, (expm1_node, alpha), {}, meta, updated=True ) ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True) + positive_node = input + if scale != 1: + positive_node = super().call_operator( + mul_op, (input, scale), {}, meta, updated=True + ) + mul_node = super().call_operator( + mul_op, (mul_node, scale), {}, meta, updated=True + ) where_node = super().call_operator( - where_op, (ge_node, input, mul_node), {}, meta, updated=True + where_op, (ge_node, positive_node, mul_node), {}, meta, updated=True ) return where_node diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 7d5af66a742..10b85149dad 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -100,9 +100,11 @@ def __getitem__(self, node: Node): x, approximate=approximate ).flatten() case exir_ops.edge.aten.elu.default: - input_alpha = cast(int, node.kwargs["alpha"]) - return lambda x: torch.nn.functional.elu( - x, alpha=input_alpha + input_alpha = cast(float, node.meta["float_alpha"]) + input_scale = cast(float, node.meta.get("float_input_scale", 1.0)) + scale = cast(float, node.meta.get("float_scale", 1.0)) + return lambda x: torch.ops.aten.elu.default( + x, input_alpha, scale, input_scale ).flatten() case exir_ops.edge.aten.remainder.Scalar: divisor = cast(float | int, node.args[1]) diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index 3c3aa57774f..c96f966a2e2 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -121,6 +121,8 @@ exir_ops.edge.aten.cosh.default, exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, + exir_ops.edge.aten.selu.default, + exir_ops.edge.aten.celu.default, exir_ops.edge.aten.bitwise_not.default, exir_ops.edge.aten.copy.default, exir_ops.edge.aten.tan.default, @@ -244,6 +246,8 @@ exir_ops.edge.aten.logit.default, exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, + exir_ops.edge.aten.selu.default, + exir_ops.edge.aten.celu.default, exir_ops.edge.aten.copy.default, exir_ops.edge.aten.floor_divide.default, exir_ops.edge.aten.tan.default, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 7c2ce6b74d1..0a4c8fe1f6f 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -481,6 +481,8 @@ def _match_pattern( torch.ops.aten.exp.default, torch.ops.aten.expm1.default, torch.ops.aten.elu.default, + torch.ops.aten.selu.default, + torch.ops.aten.celu.default, torch.ops.aten.floor.default, torch.ops.aten.log.default, torch.ops.aten.reciprocal.default, diff --git a/backends/arm/scripts/docgen/ethos-u/ethos-u-getting-started-tutorial.md.in b/backends/arm/scripts/docgen/ethos-u/ethos-u-getting-started-tutorial.md.in index ecd63afd8ba..909153d7596 100644 --- a/backends/arm/scripts/docgen/ethos-u/ethos-u-getting-started-tutorial.md.in +++ b/backends/arm/scripts/docgen/ethos-u/ethos-u-getting-started-tutorial.md.in @@ -20,7 +20,7 @@ In this tutorial you will learn how to export a simple PyTorch model for the Exe ```{tip} If you are already familiar with this delegate, you may want to jump directly to the examples: * [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm) -* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py) +* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py) ``` This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on Arm® Ethos™-U targets. It is based on `ethos_u_minimal_example.ipynb`, provided in Arm’s examples folder. @@ -69,9 +69,10 @@ The example below shows how to quantize a model consisting of a single addition, $MINIMAL_EXAMPLE ```{tip} -For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte. +For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte. To produce a pte file equivalent to the one above, run -`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte` +`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`. +For production use, you should instead use the stable Python API shown above. ``` ### Runtime: diff --git a/backends/arm/scripts/docgen/vgf/vgf-getting-started-tutorial.md.in b/backends/arm/scripts/docgen/vgf/vgf-getting-started-tutorial.md.in index 531dea14b37..1fea93e2f86 100644 --- a/backends/arm/scripts/docgen/vgf/vgf-getting-started-tutorial.md.in +++ b/backends/arm/scripts/docgen/vgf/vgf-getting-started-tutorial.md.in @@ -26,7 +26,7 @@ You may encounter some rough edges and features which may be documented or plann ```{tip} If you are already familiar with this delegate, you may want to jump directly to the examples: * [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm) -* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py) +* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py) ``` This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on VGF targets. The tutorial is based on `vgf_minimal_example.ipyb`, provided in Arm's example folder. @@ -78,9 +78,10 @@ The example below shows how to quantize a model consisting of a single addition, $MINIMAL_EXAMPLE ```{tip} -For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte. +For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte. To produce a pte file equivalent to the one above, run -`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf` +`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf`. +For production use, you should instead use the stable Python API shown above. ``` ## Runtime diff --git a/backends/arm/test/ops/test_elu.py b/backends/arm/test/ops/test_elu.py index c748f8385dc..5fbb3feac48 100644 --- a/backends/arm/test/ops/test_elu.py +++ b/backends/arm/test/ops/test_elu.py @@ -1,9 +1,10 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from dataclasses import dataclass +from typing import Callable, Tuple import torch import torch.nn as nn @@ -17,117 +18,224 @@ VgfPipeline, ) -test_data_suite = { - # (test_name, test_data) - "zeros_default": lambda: (1.0, torch.zeros(1, 10, 10, 10)), - "ones_default": lambda: (1.0, torch.ones(10, 10, 10)), - "rand_default": lambda: (1.0, torch.rand(10, 10) - 0.5), - "randn_pos_default": lambda: (1.0, torch.randn(1, 2, 3, 3) + 10), - "randn_neg_default": lambda: (1.0, torch.randn(2, 4, 3) - 10), - "ramp_default": lambda: (1.0, torch.arange(-16, 16, 0.2)), - "large_pos_default": lambda: (1.0, torch.randn(3, 3) * 1e6 + 1e7), - "large_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e5, 1e8)), - "small_pos_default": lambda: (1.0, torch.empty(5).uniform_(1e-8, 1e-5)), - "small_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e-8, 1e-5)), - "zeros_custom": lambda: (2.0, torch.zeros(1, 10, 10, 10)), - "ones_custom": lambda: (2.0, torch.ones(10, 10, 10)), - "rand_custom": lambda: (2.0, torch.rand(10, 10) - 0.5), - "randn_pos_custom": lambda: (2.0, torch.randn(1, 3, 3) + 10), - "randn_neg_custom": lambda: (2.0, torch.randn(1, 2, 4, 3) - 10), - "ramp_custom": lambda: (2.0, torch.arange(-16, 16, 0.2)), - "large_pos_custom": lambda: (2.0, torch.randn(3, 3) * 1e6 + 1e7), - "large_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e5, 1e8)), - "small_pos_custom": lambda: (2.0, torch.empty(5).uniform_(1e-8, 1e-5)), - "small_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e-8, 1e-5)), - "zeros_zero": lambda: (0.0, torch.zeros(1, 10, 10, 10)), - "ones_zero": lambda: (0.0, torch.ones(10, 10, 10)), - "rand_zero": lambda: (0.0, torch.rand(10, 10) - 0.5), - "randn_pos_zero": lambda: (0.0, torch.randn(1, 3, 3) + 10), - "randn_neg_zero": lambda: (0.0, torch.randn(1, 2, 4, 3) - 10), - "ramp_zero": lambda: (0.0, torch.arange(-16, 16, 0.2)), - "large_pos_zero": lambda: (0.0, torch.randn(3, 3) * 1e6 + 1e7), - "large_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e5, 1e8)), - "small_pos_zero": lambda: (0.0, torch.empty(5).uniform_(1e-8, 1e-5)), - "small_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e-8, 1e-5)), -} - class Elu(nn.Module): aten_op = "torch.ops.aten.elu.default" exir_op = "executorch_exir_dialects_edge__ops_aten__elu_default" + quantized_aten_op = aten_op + quantized_exir_op = exir_op + + def __init__( + self, + input_alpha: float = 1.0, + scale: float = 1.0, + input_scale: float = 1.0, + ): + super().__init__() + self.input_alpha = input_alpha + self.scale = scale + self.input_scale = input_scale + + def forward(self, input_: torch.Tensor): + return torch.ops.aten.elu.default( + input_, self.input_alpha, self.scale, self.input_scale + ) + + +class Selu(nn.Module): + aten_op = "torch.ops.aten.selu.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_selu_default" + quantized_aten_op = Elu.aten_op + quantized_exir_op = Elu.exir_op + + def __init__(self): + super().__init__() + self.selu = torch.nn.SELU() + + def forward(self, input_: torch.Tensor): + return self.selu(input_) + - def __init__(self, input_alpha: float = 1.0): +class Celu(nn.Module): + aten_op = "torch.ops.aten.celu.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_celu_default" + quantized_aten_op = Elu.aten_op + quantized_exir_op = Elu.exir_op + + def __init__(self, alpha: float = 1.0): super().__init__() - self.elu = torch.nn.ELU(alpha=input_alpha) + self.celu = torch.nn.CELU(alpha=alpha) def forward(self, input_: torch.Tensor): - return self.elu(input_) + return self.celu(input_) input_t1 = Tuple[torch.Tensor] -@common.parametrize("test_module", test_data_suite) -def test_elu_tosa_FP(test_module: input_t1): - alpha, test_data = test_module() +@dataclass +class EluTestCase: + model: nn.Module + example_inputs: input_t1 | Callable[[], input_t1] + + def get_example_inputs(self) -> input_t1: + if callable(self.example_inputs): + return self.example_inputs() + return self.example_inputs + + def aten_op(self, quantized: bool = False) -> str: + attr = "quantized_aten_op" if quantized else "aten_op" + return getattr(self.model, attr) + + def exir_op(self, quantized: bool = False) -> str: + attr = "quantized_exir_op" if quantized else "exir_op" + return getattr(self.model, attr) + + +def _input(input_: torch.Tensor) -> input_t1: + return (input_,) + + +test_suite = { + "elu_rand_default": lambda: EluTestCase(Elu(), _input(torch.rand(10, 10) - 0.5)), + "elu_randn_pos_default": lambda: EluTestCase( + Elu(), _input(torch.randn(1, 2, 3, 3) + 10) + ), + "elu_randn_neg_default": lambda: EluTestCase( + Elu(1.0), _input(torch.randn(2, 4, 3) - 10) + ), + "elu_large_pos_default": lambda: EluTestCase( + Elu(1.0), _input(torch.randn(3, 3) * 1e6 + 1e7) + ), + "elu_large_neg_default": lambda: EluTestCase( + Elu(1.0), _input(-torch.empty(5).uniform_(1e5, 1e8)) + ), + "elu_small_pos_default": lambda: EluTestCase( + Elu(1), _input(torch.empty(5).uniform_(1e-8, 1e-5)) + ), + "elu_small_neg_default": lambda: EluTestCase( + Elu(1), _input(-torch.empty(5).uniform_(1e-8, 1e-5)) + ), + "elu_rand_custom": lambda: EluTestCase(Elu(2.5), _input(torch.rand(10, 10) - 0.5)), + "elu_randn_pos_custom": lambda: EluTestCase( + Elu(2.0), _input(torch.randn(1, 3, 3) + 10) + ), + "elu_ramp_custom": lambda: EluTestCase( + Elu(10.0), _input(torch.arange(-16, 16, 0.2)) + ), + "elu_large_pos_custom": lambda: EluTestCase( + Elu(2.0), _input(torch.randn(3, 3) * 1e6 + 1e7) + ), + "elu_large_neg_custom": lambda: EluTestCase( + Elu(2.0), _input(-torch.empty(5).uniform_(1e5, 1e8)) + ), + "elu_small_pos_custom": lambda: EluTestCase( + Elu(2.0), _input(torch.empty(5).uniform_(1e-8, 1e-5)) + ), + "elu_small_neg_custom": lambda: EluTestCase( + Elu(2.0), _input(-torch.empty(5).uniform_(1e-8, 1e-5)) + ), + "elu_rand_zero": lambda: EluTestCase(Elu(0.0), _input(torch.rand(10, 10) - 0.5)), + "elu_ramp_zero": lambda: EluTestCase(Elu(0.0), _input(torch.arange(-16, 16, 0.2))), + "elu_large_pos_zero": lambda: EluTestCase( + Elu(0.0), _input(torch.randn(3, 3) * 1e6 + 1e7) + ), + "elu_large_neg_zero": lambda: EluTestCase( + Elu(0.0), _input(-torch.empty(5).uniform_(1e5, 1e8)) + ), + "elu_selu_params_ramp": lambda: EluTestCase( + Elu(1.6732632423543772, 1.0507009873554805, 1.0), + _input(torch.arange(-16, 16, 0.2)), + ), + "elu_celu_alpha_0_5_params_rand": lambda: EluTestCase( + Elu(0.5, 1.0, 2.0), _input(torch.rand(10, 10) - 0.5) + ), + "elu_celu_alpha_2_params_ramp": lambda: EluTestCase( + Elu(2.0, 1.0, 0.5), _input(torch.arange(-16, 16, 0.2)) + ), + "nn_selu_ramp": lambda: EluTestCase(Selu(), _input(torch.arange(-16, 16, 0.2))), + "nn_celu_alpha_0_5_rand": lambda: EluTestCase( + Celu(0.5), _input(torch.rand(10, 10) - 0.5) + ), + "nn_celu_alpha_2_ramp": lambda: EluTestCase( + Celu(2.0), _input(torch.arange(-16, 16, 0.2)) + ), +} + + +@common.parametrize("test_case", test_suite) +def test_elu_tosa_FP(test_case: Callable[[], EluTestCase]): + test_case = test_case() pipeline = TosaPipelineFP[input_t1]( - Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op + test_case.model, + test_case.get_example_inputs(), + aten_op=test_case.aten_op(), + exir_op=test_case.exir_op(), ) pipeline.run() -@common.parametrize("test_module", test_data_suite) -def test_elu_tosa_INT(test_module: input_t1): - alpha, test_data = test_module() +@common.parametrize("test_case", test_suite) +def test_elu_tosa_INT(test_case: Callable[[], EluTestCase]): + test_case = test_case() pipeline = TosaPipelineINT[input_t1]( - Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op + test_case.model, + test_case.get_example_inputs(), + aten_op=test_case.aten_op(quantized=True), + exir_op=test_case.exir_op(quantized=True), ) pipeline.run() @common.XfailIfNoCorstone300 -@common.parametrize("test_module", test_data_suite) -def test_elu_u55_INT(test_module: input_t1): - alpha, test_data = test_module() +@common.parametrize("test_case", test_suite) +def test_elu_u55_INT(test_case: Callable[[], EluTestCase]): + test_case = test_case() pipeline = EthosU55PipelineINT[input_t1]( - Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op + test_case.model, + test_case.get_example_inputs(), + aten_ops=test_case.aten_op(quantized=True), + exir_ops=test_case.exir_op(quantized=True), ) pipeline.run() @common.XfailIfNoCorstone320 -@common.parametrize("test_module", test_data_suite) -def test_elu_u85_INT(test_module: input_t1): - alpha, test_data = test_module() +@common.parametrize("test_case", test_suite) +def test_elu_u85_INT(test_case: Callable[[], EluTestCase]): + test_case = test_case() pipeline = EthosU85PipelineINT[input_t1]( - Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op + test_case.model, + test_case.get_example_inputs(), + aten_ops=test_case.aten_op(quantized=True), + exir_ops=test_case.exir_op(quantized=True), ) pipeline.run() @common.SkipIfNoModelConverter -@common.parametrize("test_module", test_data_suite) -def test_elu_vgf_no_quant(test_module: input_t1): - alpha, test_data = test_module() +@common.parametrize("test_case", test_suite) +def test_elu_vgf_no_quant(test_case: Callable[[], EluTestCase]): + test_case = test_case() pipeline = VgfPipeline[input_t1]( - Elu(alpha), - (test_data,), - aten_op=Elu.aten_op, - exir_op=Elu.exir_op, + test_case.model, + test_case.get_example_inputs(), + aten_op=test_case.aten_op(), + exir_op=test_case.exir_op(), quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -@common.parametrize("test_module", test_data_suite) -def test_elu_vgf_quant(test_module: input_t1): - alpha, test_data = test_module() +@common.parametrize("test_case", test_suite) +def test_elu_vgf_quant(test_case: Callable[[], EluTestCase]): + test_case = test_case() pipeline = VgfPipeline[input_t1]( - Elu(alpha), - (test_data,), - aten_op=Elu.aten_op, - exir_op=Elu.exir_op, + test_case.model, + test_case.get_example_inputs(), + aten_op=test_case.aten_op(quantized=True), + exir_op=test_case.exir_op(quantized=True), quantize=True, ) pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 07548eb5d69..12c2c11c692 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -485,11 +485,13 @@ def ops_to_not_decompose( # noqa: C901 torch.ops.aten.pad.default, } ops_to_not_decompose_if_fp = { + torch.ops.aten.celu.default, torch.ops.aten.eye.default, torch.ops.aten.logit.default, torch.ops.aten.linear.default, torch.ops.aten.linspace.default, torch.ops.aten.pad.default, + torch.ops.aten.selu.default, } ops_to_not_decompose_always = { torch.ops.aten.logit.default,