-
Notifications
You must be signed in to change notification settings - Fork 837
Arm backend: Support for aten.slice_scatter and slice_copy with non-unit step #17413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Set, Type | ||
|
|
||
| import torch | ||
|
|
||
| from executorch.backends.arm._passes import ArmPass | ||
| from executorch.backends.arm._passes.accumulate_index_put_pass import ( | ||
| AccumulateIndexPutPass, | ||
| ) | ||
| from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.pass_base import ExportPass | ||
|
|
||
| edge_slice_scatter_ops = (exir_ops.edge.aten.slice_scatter.default,) | ||
| aten_slice_scatter_ops = (torch.ops.aten.slice_scatter.default,) | ||
|
|
||
|
|
||
| def _get_slice_scatter_decomposition(op) -> tuple: | ||
| if op in edge_slice_scatter_ops: | ||
| return ( | ||
| exir_ops.edge.aten.arange.start_step, | ||
| exir_ops.edge.aten.slice_copy.Tensor, | ||
| exir_ops.edge.aten.cat.default, | ||
| exir_ops.edge.aten.permute_copy.default, | ||
| exir_ops.edge.aten.index_put.default, | ||
| ) | ||
| if op in aten_slice_scatter_ops: | ||
| return ( | ||
| torch.ops.aten.arange.start_step, | ||
| torch.ops.aten.slice_copy.Tensor, | ||
| torch.ops.aten.cat.default, | ||
| torch.ops.aten.permute_copy.default, | ||
| torch.ops.aten.index_put.default, | ||
| ) | ||
| raise RuntimeError(f"Can't get slice_scatter decomposition for op {op}") | ||
|
|
||
|
|
||
| def _fixup_start(start, dim_size: int) -> int: | ||
| s = 0 if start is None else int(start) | ||
| return max(0, min(s % dim_size if s < 0 else s, dim_size)) | ||
|
|
||
|
|
||
| def _fixup_end(end, dim_size: int) -> int: | ||
| e = dim_size if end is None else int(end) | ||
| return max(0, min(e % dim_size if e < 0 else e, dim_size)) | ||
|
|
||
|
|
||
| class DecomposeSliceScatterPass(ArmPass): | ||
| """ | ||
| Decompose slice_scatter into: | ||
| - Fast path (step == 1): slice_copy + cat (contiguous update), or | ||
| - General path (step > 1): arange + index_put (strided / interleaved). | ||
|
|
||
| Limitations: | ||
| - Does not broadcast src: requires src.shape to exactly match the slice | ||
| shape being updated | ||
|
|
||
| For dim != 0, permute input/src so that the updated dimension is first, | ||
| apply index_put with a single index tensor, then permute back. | ||
| """ | ||
|
|
||
| _passes_required_after: Set[Type[ExportPass]] = { | ||
| AccumulateIndexPutPass, | ||
| RewriteIndexPutPass, | ||
| } | ||
|
|
||
| def call_operator(self, op, args, kwargs, meta): | ||
| if op not in (edge_slice_scatter_ops + aten_slice_scatter_ops): | ||
| return super().call_operator(op, args, kwargs, meta) | ||
|
|
||
| ( | ||
| arange_op, | ||
| slice_copy_op, | ||
| cat_op, | ||
| permute_op, | ||
| index_put_op, | ||
| ) = _get_slice_scatter_decomposition(op) | ||
|
|
||
| input = args[0] | ||
| src = args[1] | ||
| dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) | ||
| start = args[3] if len(args) > 3 else kwargs.get("start", None) | ||
| end = args[4] if len(args) > 4 else kwargs.get("end", None) | ||
| step = args[5] if len(args) > 5 else kwargs.get("step", 1) | ||
|
|
||
| if not isinstance(dim, int) or not isinstance(step, int): | ||
| raise NotImplementedError("slice_scatter expects constant dim/step") | ||
|
|
||
| if step <= 0: | ||
| raise NotImplementedError("slice_scatter expects step > 0") | ||
|
|
||
| input_val = input.data # FakeTensor | ||
| input_shape = input_val.shape # [d0, d1, ..., d{r-1}] | ||
| input_device = input_val.device | ||
| input_rank = len(input_shape) | ||
| dim_norm = dim % input_rank | ||
| dim_size = int(input_shape[dim_norm]) | ||
|
|
||
| start_i = _fixup_start(start, dim_size) | ||
| end_i = _fixup_end(end, dim_size) | ||
| if end_i <= start_i: | ||
| return input | ||
|
|
||
| # index_positions: [W] where W = len(arange(start_i, end_i, step)) | ||
| index_positions = super().call_operator( | ||
| arange_op, | ||
| (start_i, end_i, step), | ||
| {"dtype": torch.int32, "device": input_device}, | ||
| meta, | ||
| updated=True, | ||
| ) | ||
|
|
||
| src_val = src.data | ||
| src_shape = src_val.shape | ||
| index_shape = index_positions.data.shape | ||
| # slice_shape is input_shape with dim_norm replaced by W | ||
| # input_shape: [d0, ..., D, ..., d{r-1}] | ||
| # -> slice_shape: [d0, ..., W, ..., d{r-1}] | ||
| slice_shape = list(input_shape) | ||
| slice_shape[dim_norm] = int(index_shape[0]) | ||
| # slice_scatter does not broadcast src: require exact shape match. | ||
| if tuple(src_shape) != tuple(slice_shape): | ||
| raise NotImplementedError( | ||
| "slice_scatter requires src.shape to match the slice shape; " | ||
| f"got src.shape={tuple(src_shape)}, expected={tuple(slice_shape)}" | ||
| ) | ||
|
|
||
| # ---- fast path: contiguous update (step == 1) ---- | ||
| if step == 1: | ||
|
Comment on lines
+132
to
+133
|
||
| # prefix = input[..., :start_i, ...] along dim_norm | ||
| prefix = super().call_operator( | ||
| slice_copy_op, | ||
| (input, dim_norm, 0, start_i, 1), | ||
| {}, | ||
| meta, | ||
| updated=True, | ||
| ) | ||
| # suffix = input[..., end_i:, ...] along dim_norm | ||
| suffix = super().call_operator( | ||
| slice_copy_op, | ||
| (input, dim_norm, end_i, dim_size, 1), | ||
| {}, | ||
| meta, | ||
| updated=True, | ||
| ) | ||
| # concat(prefix, src, suffix) along dim_norm | ||
| updated = super().call_operator( | ||
| cat_op, | ||
| ([prefix, src, suffix], dim_norm), | ||
| {}, | ||
| meta, | ||
| updated=True, | ||
| ) | ||
| return updated | ||
|
|
||
| # ---- general path: strided update (step > 1) ---- | ||
| # Move updated dim to front to use a single index tensor. | ||
| if dim_norm != 0: | ||
| perm = [dim_norm] + [i for i in range(input_rank) if i != dim_norm] | ||
| inv_perm = [0] * input_rank | ||
| for i, p in enumerate(perm): | ||
| inv_perm[p] = i | ||
|
|
||
| # input: [d0, ..., d{dim_norm-1}, d{dim_norm}, d{dim_norm+1}, ..., d{r-1}] | ||
| # -> input_perm: [d{dim_norm}, d0, ..., d{dim_norm-1}, d{dim_norm+1}, ..., d{r-1}] | ||
| input = super().call_operator( | ||
| permute_op, (input, perm), {}, meta, updated=True | ||
| ) | ||
| # src: [d0, ..., d{dim_norm-1}, W, d{dim_norm+1}, ..., d{r-1}] | ||
| # -> src_perm: [W, d0, ..., d{dim_norm-1}, d{dim_norm+1}, ..., d{r-1}] | ||
| src = super().call_operator(permute_op, (src, perm), {}, meta, updated=True) | ||
|
|
||
| # Puts values from src into input along the first dimension | ||
| # using a single 1D index tensor index_positions. | ||
| updated = super().call_operator( | ||
| index_put_op, | ||
| (input, (index_positions,), src, False), | ||
| {}, | ||
| meta, | ||
| updated=True, | ||
| ) | ||
|
|
||
| if dim_norm != 0: | ||
| # updated_perm: [d{dim_norm}, ...] -> updated: [d0, d1, ..., d{r-1}] | ||
| updated = super().call_operator( | ||
| permute_op, (updated, inv_perm), {}, meta, updated=True | ||
| ) | ||
|
|
||
| return updated | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,146 @@ | ||||||||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||||||||
| # | ||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||
|
|
||||||||
| from typing import Set, Type | ||||||||
|
|
||||||||
| import torch | ||||||||
| from executorch.backends.arm._passes import ArmPass | ||||||||
| from executorch.exir.dialects._ops import ops as exir_ops | ||||||||
| from executorch.exir.pass_base import ExportPass | ||||||||
|
|
||||||||
|
|
||||||||
| def _get_strided_slice_copy_decomposition(op): | ||||||||
| """Return the operator overloads used by this decomposition.""" | ||||||||
| if op == exir_ops.edge.aten.slice_copy.Tensor: | ||||||||
| return ( | ||||||||
| exir_ops.edge.aten.slice_copy.Tensor, | ||||||||
| exir_ops.edge.aten.cat.default, | ||||||||
| exir_ops.edge.aten.view_copy.default, | ||||||||
| ) | ||||||||
| raise RuntimeError(f"Can't get strided slice_copy decomposition for op {op}") | ||||||||
|
|
||||||||
|
|
||||||||
| def _fixup_start(start, dim_size): | ||||||||
| """Normalize start and clamp into [0, dim_size].""" | ||||||||
| s = 0 if start is None else start | ||||||||
| if s < 0: | ||||||||
| s = s % dim_size | ||||||||
| return max(0, min(s, dim_size)) | ||||||||
|
|
||||||||
|
|
||||||||
| def _fixup_end(end, dim_size): | ||||||||
| """Normalize end and clamp into [0, dim_size].""" | ||||||||
| if end is None: | ||||||||
| return dim_size | ||||||||
| e = end | ||||||||
| if e > dim_size: | ||||||||
| e = dim_size | ||||||||
| if e < 0: | ||||||||
| e = e % dim_size | ||||||||
| return max(0, min(e, dim_size)) | ||||||||
|
|
||||||||
|
|
||||||||
| class DecomposeStridedSliceCopyPass(ArmPass): | ||||||||
| """ | ||||||||
| Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops. | ||||||||
|
|
||||||||
| Given: | ||||||||
| out = slice_copy(x, dim, start, end, step) with step > 1 | ||||||||
|
|
||||||||
| Produce: | ||||||||
| 1) y = slice_copy(x, dim, start, end, 1) # span with unit step | ||||||||
| 2) pad y on the right to make length divisible by step (if needed) | ||||||||
| 3) y2 = view_copy(y, ..., U, step, ...) # split the sliced dim | ||||||||
| 4) y3 = slice_copy(y2, dim_i + 1, 0, 1, 1) # pick index 0 in each group | ||||||||
| 5) out = view_copy(y3, ...) # collapse the singleton dim | ||||||||
|
|
||||||||
| This implements "take every step-th element" using only unit-step slice + reshape. | ||||||||
| """ | ||||||||
|
|
||||||||
| _passes_required_after: Set[Type[ExportPass]] = set() | ||||||||
| _TARGET_OPS = {exir_ops.edge.aten.slice_copy.Tensor} | ||||||||
|
|
||||||||
| def call_operator(self, op, args, kwargs, meta): | ||||||||
| if op not in self._TARGET_OPS: | ||||||||
| return super().call_operator(op, args, kwargs, meta) | ||||||||
|
|
||||||||
| # Only handle the non-unit-step case; leave unit-step to existing lowering. | ||||||||
| if not (len(args) == 5 and args[4] != 1): | ||||||||
| return super().call_operator(op, args, kwargs, meta) | ||||||||
|
|
||||||||
| x, dim, start, end, step = args | ||||||||
| assert step > 0, "slice_copy step must be positive" | ||||||||
|
||||||||
| assert step > 0, "slice_copy step must be positive" | |
| if step <= 0: | |
| raise RuntimeError(f"slice_copy step must be positive, got {step}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arangeis always emitted even whenstep == 1(fast path), and the resultingindex_positionsis unused in that branch. If dead-code elimination doesn’t reliably remove it before partitioning/lowering, this can introduce an unnecessary (and potentially unsupported) op in the step==1 path. Moveindex_positions = arange(...)(and any dependent shape checks) inside thestep > 1branch; forstep == 1, compute the expected slice length directly fromstart_i/end_ifor validation.