Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@
from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa
from .decompose_sign_pass import DecomposeSignPass # noqa
from .decompose_sinh_pass import DecomposeSinhPass # noqa
from .decompose_slice_scatter_pass import DecomposeSliceScatterPass # noqa
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
from .decompose_strided_slice_copy_pass import DecomposeStridedSliceCopyPass # noqa
from .decompose_sum_pass import DecomposeSumPass # noqa
from .decompose_tan_pass import DecomposeTanPass # noqa
from .decompose_tosa_unsupported_clamp_pass import ( # noqa
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@
DecomposeSelectScatterPass,
DecomposeSignPass,
DecomposeSinhPass,
DecomposeSliceScatterPass,
DecomposeSoftmaxPass,
DecomposeSoftmaxUnstablePass,
DecomposeSqrtPass,
DecomposeStridedSliceCopyPass,
DecomposeSumPass,
DecomposeTanPass,
DecomposeTOSAUnsupportedClampPass,
Expand Down Expand Up @@ -292,6 +294,7 @@ def _tosa_pipeline(
DecomposeUnfoldToGatherPass(),
DecomposeEmbeddingPass(),
DecomposeIndexSelectToGatherPass(),
DecomposeStridedSliceCopyPass(),
Conv1dUnsqueezePass(),
]
)
Expand All @@ -313,6 +316,7 @@ def _tosa_pipeline(
# Node transformation passes (post scalar-removal)
self.add_passes(
[
DecomposeSliceScatterPass(),
AccumulateIndexPutPass(),
RewriteIndexPutPass(),
DecomposeRemainderPass(),
Expand Down Expand Up @@ -394,6 +398,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_passes(
[
DecomposeSelectScatterPass(tfa_pass=True),
DecomposeSliceScatterPass(tfa_pass=True),
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
ConvertInt64OutputOpsToInt32Pass(tfa_pass=True),
InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True),
Expand Down
193 changes: 193 additions & 0 deletions backends/arm/_passes/decompose_slice_scatter_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.accumulate_index_put_pass import (
AccumulateIndexPutPass,
)
from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_slice_scatter_ops = (exir_ops.edge.aten.slice_scatter.default,)
aten_slice_scatter_ops = (torch.ops.aten.slice_scatter.default,)


def _get_slice_scatter_decomposition(op) -> tuple:
if op in edge_slice_scatter_ops:
return (
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.index_put.default,
)
if op in aten_slice_scatter_ops:
return (
torch.ops.aten.arange.start_step,
torch.ops.aten.slice_copy.Tensor,
torch.ops.aten.cat.default,
torch.ops.aten.permute_copy.default,
torch.ops.aten.index_put.default,
)
raise RuntimeError(f"Can't get slice_scatter decomposition for op {op}")


def _fixup_start(start, dim_size: int) -> int:
s = 0 if start is None else int(start)
return max(0, min(s % dim_size if s < 0 else s, dim_size))


def _fixup_end(end, dim_size: int) -> int:
e = dim_size if end is None else int(end)
return max(0, min(e % dim_size if e < 0 else e, dim_size))


class DecomposeSliceScatterPass(ArmPass):
"""
Decompose slice_scatter into:
- Fast path (step == 1): slice_copy + cat (contiguous update), or
- General path (step > 1): arange + index_put (strided / interleaved).

Limitations:
- Does not broadcast src: requires src.shape to exactly match the slice
shape being updated

For dim != 0, permute input/src so that the updated dimension is first,
apply index_put with a single index tensor, then permute back.
"""

_passes_required_after: Set[Type[ExportPass]] = {
AccumulateIndexPutPass,
RewriteIndexPutPass,
}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_slice_scatter_ops + aten_slice_scatter_ops):
return super().call_operator(op, args, kwargs, meta)

(
arange_op,
slice_copy_op,
cat_op,
permute_op,
index_put_op,
) = _get_slice_scatter_decomposition(op)

input = args[0]
src = args[1]
dim = args[2] if len(args) > 2 else kwargs.get("dim", 0)
start = args[3] if len(args) > 3 else kwargs.get("start", None)
end = args[4] if len(args) > 4 else kwargs.get("end", None)
step = args[5] if len(args) > 5 else kwargs.get("step", 1)

if not isinstance(dim, int) or not isinstance(step, int):
raise NotImplementedError("slice_scatter expects constant dim/step")

if step <= 0:
raise NotImplementedError("slice_scatter expects step > 0")

input_val = input.data # FakeTensor
input_shape = input_val.shape # [d0, d1, ..., d{r-1}]
input_device = input_val.device
input_rank = len(input_shape)
dim_norm = dim % input_rank
dim_size = int(input_shape[dim_norm])

start_i = _fixup_start(start, dim_size)
end_i = _fixup_end(end, dim_size)
if end_i <= start_i:
return input

# index_positions: [W] where W = len(arange(start_i, end_i, step))
index_positions = super().call_operator(
arange_op,
(start_i, end_i, step),
{"dtype": torch.int32, "device": input_device},
meta,
updated=True,
)

src_val = src.data
src_shape = src_val.shape
index_shape = index_positions.data.shape
# slice_shape is input_shape with dim_norm replaced by W
# input_shape: [d0, ..., D, ..., d{r-1}]
# -> slice_shape: [d0, ..., W, ..., d{r-1}]
slice_shape = list(input_shape)
slice_shape[dim_norm] = int(index_shape[0])
Comment on lines +108 to +124
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arange is always emitted even when step == 1 (fast path), and the resulting index_positions is unused in that branch. If dead-code elimination doesn’t reliably remove it before partitioning/lowering, this can introduce an unnecessary (and potentially unsupported) op in the step==1 path. Move index_positions = arange(...) (and any dependent shape checks) inside the step > 1 branch; for step == 1, compute the expected slice length directly from start_i/end_i for validation.

Suggested change
# index_positions: [W] where W = len(arange(start_i, end_i, step))
index_positions = super().call_operator(
arange_op,
(start_i, end_i, step),
{"dtype": torch.int32, "device": input_device},
meta,
updated=True,
)
src_val = src.data
src_shape = src_val.shape
index_shape = index_positions.data.shape
# slice_shape is input_shape with dim_norm replaced by W
# input_shape: [d0, ..., D, ..., d{r-1}]
# -> slice_shape: [d0, ..., W, ..., d{r-1}]
slice_shape = list(input_shape)
slice_shape[dim_norm] = int(index_shape[0])
src_val = src.data
src_shape = src_val.shape
if step == 1:
# Fast path: contiguous slice, avoid materializing index_positions via arange.
# slice_shape is input_shape with dim_norm replaced by W = end_i - start_i.
# input_shape: [d0, ..., D, ..., d{r-1}]
# -> slice_shape: [d0, ..., W, ..., d{r-1}]
slice_shape = list(input_shape)
slice_shape[dim_norm] = int(end_i - start_i)
else:
# index_positions: [W] where W = len(arange(start_i, end_i, step))
index_positions = super().call_operator(
arange_op,
(start_i, end_i, step),
{"dtype": torch.int32, "device": input_device},
meta,
updated=True,
)
index_shape = index_positions.data.shape
# slice_shape is input_shape with dim_norm replaced by W
# input_shape: [d0, ..., D, ..., d{r-1}]
# -> slice_shape: [d0, ..., W, ..., d{r-1}]
slice_shape = list(input_shape)
slice_shape[dim_norm] = int(index_shape[0])

Copilot uses AI. Check for mistakes.
# slice_scatter does not broadcast src: require exact shape match.
if tuple(src_shape) != tuple(slice_shape):
raise NotImplementedError(
"slice_scatter requires src.shape to match the slice shape; "
f"got src.shape={tuple(src_shape)}, expected={tuple(slice_shape)}"
)

# ---- fast path: contiguous update (step == 1) ----
if step == 1:
Comment on lines +132 to +133
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arange is always emitted even when step == 1 (fast path), and the resulting index_positions is unused in that branch. If dead-code elimination doesn’t reliably remove it before partitioning/lowering, this can introduce an unnecessary (and potentially unsupported) op in the step==1 path. Move index_positions = arange(...) (and any dependent shape checks) inside the step > 1 branch; for step == 1, compute the expected slice length directly from start_i/end_i for validation.

Copilot uses AI. Check for mistakes.
# prefix = input[..., :start_i, ...] along dim_norm
prefix = super().call_operator(
slice_copy_op,
(input, dim_norm, 0, start_i, 1),
{},
meta,
updated=True,
)
# suffix = input[..., end_i:, ...] along dim_norm
suffix = super().call_operator(
slice_copy_op,
(input, dim_norm, end_i, dim_size, 1),
{},
meta,
updated=True,
)
# concat(prefix, src, suffix) along dim_norm
updated = super().call_operator(
cat_op,
([prefix, src, suffix], dim_norm),
{},
meta,
updated=True,
)
return updated

# ---- general path: strided update (step > 1) ----
# Move updated dim to front to use a single index tensor.
if dim_norm != 0:
perm = [dim_norm] + [i for i in range(input_rank) if i != dim_norm]
inv_perm = [0] * input_rank
for i, p in enumerate(perm):
inv_perm[p] = i

# input: [d0, ..., d{dim_norm-1}, d{dim_norm}, d{dim_norm+1}, ..., d{r-1}]
# -> input_perm: [d{dim_norm}, d0, ..., d{dim_norm-1}, d{dim_norm+1}, ..., d{r-1}]
input = super().call_operator(
permute_op, (input, perm), {}, meta, updated=True
)
# src: [d0, ..., d{dim_norm-1}, W, d{dim_norm+1}, ..., d{r-1}]
# -> src_perm: [W, d0, ..., d{dim_norm-1}, d{dim_norm+1}, ..., d{r-1}]
src = super().call_operator(permute_op, (src, perm), {}, meta, updated=True)

# Puts values from src into input along the first dimension
# using a single 1D index tensor index_positions.
updated = super().call_operator(
index_put_op,
(input, (index_positions,), src, False),
{},
meta,
updated=True,
)

if dim_norm != 0:
# updated_perm: [d{dim_norm}, ...] -> updated: [d0, d1, ..., d{r-1}]
updated = super().call_operator(
permute_op, (updated, inv_perm), {}, meta, updated=True
)

return updated
146 changes: 146 additions & 0 deletions backends/arm/_passes/decompose_strided_slice_copy_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def _get_strided_slice_copy_decomposition(op):
"""Return the operator overloads used by this decomposition."""
if op == exir_ops.edge.aten.slice_copy.Tensor:
return (
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.view_copy.default,
)
raise RuntimeError(f"Can't get strided slice_copy decomposition for op {op}")


def _fixup_start(start, dim_size):
"""Normalize start and clamp into [0, dim_size]."""
s = 0 if start is None else start
if s < 0:
s = s % dim_size
return max(0, min(s, dim_size))


def _fixup_end(end, dim_size):
"""Normalize end and clamp into [0, dim_size]."""
if end is None:
return dim_size
e = end
if e > dim_size:
e = dim_size
if e < 0:
e = e % dim_size
return max(0, min(e, dim_size))


class DecomposeStridedSliceCopyPass(ArmPass):
"""
Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops.

Given:
out = slice_copy(x, dim, start, end, step) with step > 1

Produce:
1) y = slice_copy(x, dim, start, end, 1) # span with unit step
2) pad y on the right to make length divisible by step (if needed)
3) y2 = view_copy(y, ..., U, step, ...) # split the sliced dim
4) y3 = slice_copy(y2, dim_i + 1, 0, 1, 1) # pick index 0 in each group
5) out = view_copy(y3, ...) # collapse the singleton dim

This implements "take every step-th element" using only unit-step slice + reshape.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_TARGET_OPS = {exir_ops.edge.aten.slice_copy.Tensor}

def call_operator(self, op, args, kwargs, meta):
if op not in self._TARGET_OPS:
return super().call_operator(op, args, kwargs, meta)

# Only handle the non-unit-step case; leave unit-step to existing lowering.
if not (len(args) == 5 and args[4] != 1):
return super().call_operator(op, args, kwargs, meta)

x, dim, start, end, step = args
assert step > 0, "slice_copy step must be positive"
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using assert for input/graph validation in production passes since assertions can be stripped with optimization flags and produce less actionable failures. Prefer raising NotImplementedError or RuntimeError with a message consistent with other pass validations (e.g., the way DecomposeSliceScatterPass handles step <= 0).

Suggested change
assert step > 0, "slice_copy step must be positive"
if step <= 0:
raise RuntimeError(f"slice_copy step must be positive, got {step}")

Copilot uses AI. Check for mistakes.

shape = x.data.shape
rank = len(shape)

# Normalize dim into [0, rank).
dim_i = dim % rank
dim_size = shape[dim_i]

# Normalize/clamp start/end into valid bounds.
start_i = _fixup_start(start, dim_size)
end_i = _fixup_end(end, dim_size)

L = end_i - start_i
if L <= 0:
# slice_copy would return empty; keep default behavior.
return super().call_operator(op, args, kwargs, meta)

slice_op, cat_op, view_op = _get_strided_slice_copy_decomposition(op)

# 1) Unit-step slice of the requested span:
# y = x[..., start_i:end_i, ...]
y = super().call_operator(
slice_op, (x, dim_i, start_i, end_i, 1), {}, meta, updated=True
)

# 2) Compute:
# U = ceil(L / step) (# of output elements along dim_i)
# pad_right = U*step - L (so that padded length becomes U*step)
U = (L + step - 1) // step
pad_right = U * step - L

# 3) If needed, right-pad along dim_i so that:
# after padding, y.shape[dim_i] == U*step
if pad_right > 0:
y_data = y.data
pad_shape = list(y_data.shape)
pad_shape[dim_i] = pad_right

# z: zeros with same dtype/device as y, shape matches y except
# z.shape[dim_i] = pad_right.
fill_value = False if y_data.dtype == torch.bool else 0
z = super().call_operator(
op=exir_ops.edge.aten.full.default,
args=(pad_shape, fill_value),
kwargs={"dtype": y_data.dtype, "device": y_data.device},
meta=meta,
updated=True,
)

# Concatenate on the right:
# y.shape[dim_i] : L -> L + pad_right == U*step
y = super().call_operator(cat_op, ([y, z], dim_i), {}, meta, updated=True)

# 4) Split the sliced dim: (U*step) -> (U, step)
y_t2 = y.data
split_shape = list(y_t2.shape)
split_shape[dim_i] = U
split_shape.insert(dim_i + 1, step)

y2 = super().call_operator(view_op, (y, split_shape), {}, meta, updated=True)

# 5) Take index 0 in the inserted "step" dimension:
# [..., U, step, ...] -> [..., U, 1, ...]
y3 = super().call_operator(
slice_op, (y2, dim_i + 1, 0, 1, 1), {}, meta, updated=True
)

# 6) Collapse y3's singleton step dim: [..., U, 1, ...] -> [..., U, ...].
out_shape = list(y_t2.shape) # y_t2: [..., U*step, ...]
out_shape[dim_i] = U # out_shape: [..., U, ...]

return super().call_operator(view_op, (y3, out_shape), {}, meta, updated=True)
Loading
Loading