Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
from .decompose_einsum_pass import DecomposeEinsumPass # noqa
from .decompose_elu_pass import DecomposeEluPass # noqa
from .decompose_elu_pass import ConvertEluFamilyToEluPass, DecomposeEluPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_erfinv_pass import DecomposeErfinvPass # noqa
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ConstantFoldingPass,
ControlFlowConstInlinePass,
Conv1dUnsqueezePass,
ConvertEluFamilyToEluPass,
ConvertELUParamsPass,
ConvertExpandCopyToRepeatPass,
ConvertFullLikeToFullPass,
Expand Down Expand Up @@ -403,6 +404,7 @@ def _tosa_pipeline(
DecomposeLayerNormPass(),
DecomposeVarPass(),
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
ConvertEluFamilyToEluPass(),
ConvertELUParamsPass(),
ControlFlowConstInlinePass(),
NormalizeWhileInitialArgsPass(use_exir_clone=True),
Expand Down Expand Up @@ -607,6 +609,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
RewriteInplaceArithmeticPass(tfa_pass=True),
DecomposeAddSubAlphaPass(tfa_pass=True),
DecomposeLeakyReLUPass(tfa_pass=True),
ConvertEluFamilyToEluPass(tfa_pass=True),
DecomposeGroupNormPass(tfa_pass=True),
DecomposeLayerNormPass(tfa_pass=True),
DecomposeVarPass(tfa_pass=True),
Expand Down
48 changes: 35 additions & 13 deletions backends/arm/_passes/convert_elu_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@


class ConvertELUParamsPass(ArmPass):
"""Pass to convert the input_scale kwarg of ELU operator from float to int.
"""The int8 ELU operator crashes when the alpha, scale or input scale
parameters are not integers.

It has been set to 2 as the outputs seem to stay the same regardless of what
the value of input_scale is, as long as that value is not 1.
This pass temporarily converts quantized ELU parameters to int and stores
the original float values in the meta dict to be able to recover them later.

"""

_passes_required_after: Set[Type[ExportPass]] = set()
@property
def _passes_required_after(self) -> Set[Type[ExportPass]]:
# Lazy import to avoid circular dependency between passes
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass

return {InsertTableOpsPass}

def call(self, graph_module: torch.fx.GraphModule):
modified_graph = False
Expand All @@ -36,29 +42,45 @@ def call(self, graph_module: torch.fx.GraphModule):
)
if not is_quantized or not self.allowed_to_transform(node.meta):
continue

with graph.inserting_after(node):
replace_node = create_node(
graph, exir_ops.edge.aten.elu.default, from_node=node
)
old_args = list(node.args)

alpha = old_args[1] if len(old_args) > 1 else 1.0
scale = 1.0
input_scale = 2.0
old_args = list(node.args)
alpha = (
old_args[1] if len(old_args) > 1 else node.kwargs.get("alpha", 1.0)
)
scale = (
old_args[2] if len(old_args) > 2 else node.kwargs.get("scale", 1.0)
)
input_scale = (
old_args[3]
if len(old_args) > 3
else node.kwargs.get("input_scale", 1.0)
)

replace_node.args = (old_args[0],)

# Set placeholder int values
updated_kwargs = dict(node.kwargs)
updated_kwargs["alpha"] = int(alpha)
updated_kwargs["scale"] = int(scale)
updated_kwargs["input_scale"] = int(input_scale)

updated_kwargs["alpha"] = 1
updated_kwargs["scale"] = 1
updated_kwargs["input_scale"] = (
2 # Keep input_scale away from 1 to avoid fake execution type checks.
)
replace_node.kwargs = updated_kwargs

# Save correct parameters
replace_node.meta["float_alpha"] = alpha
replace_node.meta["float_scale"] = scale
replace_node.meta["float_input_scale"] = input_scale

node.replace_all_uses_with(replace_node)
graph.erase_node(node)

modified_graph = True

if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
97 changes: 89 additions & 8 deletions backends/arm/_passes/decompose_elu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_elu_ops = (exir_ops.edge.aten.elu.default,)
edge_selu_ops = (exir_ops.edge.aten.selu.default,)
edge_celu_ops = (exir_ops.edge.aten.celu.default,)
edge_elu_family_ops = edge_elu_ops + edge_selu_ops + edge_celu_ops
torch_selu_ops = (torch.ops.aten.selu.default,)
torch_celu_ops = (torch.ops.aten.celu.default,)
selu_ops = edge_selu_ops + torch_selu_ops
celu_ops = edge_celu_ops + torch_celu_ops

SELU_ALPHA = 1.6732632423543772
SELU_SCALE = 1.0507009873554805


def get_elu_decomposition(op) -> tuple:
Expand All @@ -29,7 +40,7 @@ def get_elu_decomposition(op) -> tuple:

"""

if op in edge_elu_ops:
if op in edge_elu_family_ops:
return (
exir_ops.edge.aten.expm1.default,
exir_ops.edge.aten.ge.Scalar,
Expand All @@ -40,15 +51,64 @@ def get_elu_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get elu decomposition for op {op}")


def _get_elu_parameter(args, kwargs, index, name):
if len(args) > index:
return args[index]

return kwargs.get(name, 1.0)


def _get_elu_parameters(op, args, kwargs):
if op in selu_ops:
return SELU_ALPHA, SELU_SCALE, 1.0
if op in celu_ops:
alpha = _get_elu_parameter(args, kwargs, 1, "alpha")
Comment thread
AdrianLundell marked this conversation as resolved.
return alpha, 1.0, 1.0 / alpha

alpha = _get_elu_parameter(args, kwargs, 1, "alpha")
scale = _get_elu_parameter(args, kwargs, 2, "scale")
input_scale = _get_elu_parameter(args, kwargs, 3, "input_scale")
return alpha, scale, input_scale


class ConvertEluFamilyToEluPass(ArmPass):
"""Convert SELU/CELU ops to equivalent parameterized ELU ops."""

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op not in selu_ops + celu_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta, updated=False)

input_ = args[0]
alpha, scale, input_scale = _get_elu_parameters(op, args, kwargs)
elu_op = (
torch.ops.aten.elu.default
if op in torch_selu_ops + torch_celu_ops
else exir_ops.edge.aten.elu.default
)
return super().call_operator(
elu_op,
(input_, alpha, scale, input_scale),
{},
meta,
updated=True,
)


class DecomposeEluPass(ArmPass):
"""A transformation pass that decomposes unsupported 'aten.elu' operations
into a combination of supported TOSA-equivalent operations.

Since TOSA does not provide a native ELU operator, this pass rewrites:
elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x)
elu(x) → scale * where(
greater_or_eq(x, 0), x, alpha * expm1(input_scale * x)
)

Supported input ops:
- exir_ops.edge.aten.elu.Tensor(x)
- exir_ops.edge.aten.elu.default
- exir_ops.edge.aten.selu.default
- exir_ops.edge.aten.celu.default

These are replaced with:
- exir_ops.edge.aten.expm1.default
Expand All @@ -61,7 +121,7 @@ class DecomposeEluPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op not in edge_elu_ops:
if op not in edge_elu_family_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)

if self._is_quantized_meta(meta):
Expand All @@ -76,11 +136,11 @@ def call_operator(self, op, args, kwargs, meta):
) = get_elu_decomposition(op)

input = args[0]
alpha = args[1] if len(args) > 1 else 1.0
alpha, scale, input_scale = _get_elu_parameters(op, args, kwargs)

if alpha == 0:
relu_op = exir_ops.edge.aten.clamp.default
return super().call_operator(
relu_node = super().call_operator(
relu_op,
(
input,
Expand All @@ -90,14 +150,35 @@ def call_operator(self, op, args, kwargs, meta):
meta,
updated=True,
)
if scale == 1:
return relu_node

expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True)
return super().call_operator(
mul_op, (relu_node, scale), {}, meta, updated=True
)

expm1_input = input
if input_scale != 1:
expm1_input = super().call_operator(
mul_op, (input, input_scale), {}, meta, updated=True
)
expm1_node = super().call_operator(
expm1_op, (expm1_input,), {}, meta, updated=True
)
mul_node = super().call_operator(
mul_op, (expm1_node, alpha), {}, meta, updated=True
)
ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True)
positive_node = input
if scale != 1:
positive_node = super().call_operator(
mul_op, (input, scale), {}, meta, updated=True
)
mul_node = super().call_operator(
mul_op, (mul_node, scale), {}, meta, updated=True
)
where_node = super().call_operator(
where_op, (ge_node, input, mul_node), {}, meta, updated=True
where_op, (ge_node, positive_node, mul_node), {}, meta, updated=True
)

return where_node
8 changes: 5 additions & 3 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def __getitem__(self, node: Node):
x, approximate=approximate
).flatten()
case exir_ops.edge.aten.elu.default:
input_alpha = cast(int, node.kwargs["alpha"])
return lambda x: torch.nn.functional.elu(
x, alpha=input_alpha
input_alpha = cast(float, node.meta["float_alpha"])
input_scale = cast(float, node.meta.get("float_input_scale", 1.0))
scale = cast(float, node.meta.get("float_scale", 1.0))
Comment thread
AdrianLundell marked this conversation as resolved.
return lambda x: torch.ops.aten.elu.default(
x, input_alpha, scale, input_scale
).flatten()
case exir_ops.edge.aten.remainder.Scalar:
divisor = cast(float | int, node.args[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
exir_ops.edge.aten.cosh.default,
exir_ops.edge.aten.acos.default,
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.selu.default,
exir_ops.edge.aten.celu.default,
exir_ops.edge.aten.bitwise_not.default,
exir_ops.edge.aten.copy.default,
exir_ops.edge.aten.tan.default,
Expand Down Expand Up @@ -244,6 +246,8 @@
exir_ops.edge.aten.logit.default,
exir_ops.edge.aten.acos.default,
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.selu.default,
exir_ops.edge.aten.celu.default,
exir_ops.edge.aten.copy.default,
exir_ops.edge.aten.floor_divide.default,
exir_ops.edge.aten.tan.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ def _match_pattern(
torch.ops.aten.exp.default,
torch.ops.aten.expm1.default,
torch.ops.aten.elu.default,
torch.ops.aten.selu.default,
torch.ops.aten.celu.default,
torch.ops.aten.floor.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ In this tutorial you will learn how to export a simple PyTorch model for the Exe
```{tip}
If you are already familiar with this delegate, you may want to jump directly to the examples:
* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm)
* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
```

This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on Arm® Ethos™-U targets. It is based on `ethos_u_minimal_example.ipynb`, provided in Arm’s examples folder.
Expand Down Expand Up @@ -69,9 +69,10 @@ The example below shows how to quantize a model consisting of a single addition,
$MINIMAL_EXAMPLE

```{tip}
For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
To produce a pte file equivalent to the one above, run
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`.
For production use, you should instead use the stable Python API shown above.
```

### Runtime:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ You may encounter some rough edges and features which may be documented or plann
```{tip}
If you are already familiar with this delegate, you may want to jump directly to the examples:
* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm)
* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
```

This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on VGF targets. The tutorial is based on `vgf_minimal_example.ipyb`, provided in Arm's example folder.
Expand Down Expand Up @@ -78,9 +78,10 @@ The example below shows how to quantize a model consisting of a single addition,
$MINIMAL_EXAMPLE

```{tip}
For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
To produce a pte file equivalent to the one above, run
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf`
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf`.
For production use, you should instead use the stable Python API shown above.
```

## Runtime
Expand Down
Loading
Loading