diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index da15044d938..eb0bd9f2938 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -76,9 +76,11 @@ from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa from .decompose_sign_pass import DecomposeSignPass # noqa from .decompose_sinh_pass import DecomposeSinhPass # noqa +from .decompose_slice_scatter_pass import DecomposeSliceScatterPass # noqa from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa +from .decompose_strided_slice_copy_pass import DecomposeStridedSliceCopyPass # noqa from .decompose_sum_pass import DecomposeSumPass # noqa from .decompose_tan_pass import DecomposeTanPass # noqa from .decompose_tosa_unsupported_clamp_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 649952f314c..564b1c5111e 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -77,9 +77,11 @@ DecomposeSelectScatterPass, DecomposeSignPass, DecomposeSinhPass, + DecomposeSliceScatterPass, DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, + DecomposeStridedSliceCopyPass, DecomposeSumPass, DecomposeTanPass, DecomposeTOSAUnsupportedClampPass, @@ -292,6 +294,7 @@ def _tosa_pipeline( DecomposeUnfoldToGatherPass(), DecomposeEmbeddingPass(), DecomposeIndexSelectToGatherPass(), + DecomposeStridedSliceCopyPass(), Conv1dUnsqueezePass(), ] ) @@ -313,6 +316,7 @@ def _tosa_pipeline( # Node transformation passes (post scalar-removal) self.add_passes( [ + DecomposeSliceScatterPass(), AccumulateIndexPutPass(), RewriteIndexPutPass(), DecomposeRemainderPass(), @@ -394,6 +398,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_passes( [ DecomposeSelectScatterPass(tfa_pass=True), + DecomposeSliceScatterPass(tfa_pass=True), ConvertInt64ConstOpsToInt32Pass(tfa_pass=True), ConvertInt64OutputOpsToInt32Pass(tfa_pass=True), InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True), diff --git a/backends/arm/_passes/decompose_slice_scatter_pass.py b/backends/arm/_passes/decompose_slice_scatter_pass.py new file mode 100644 index 00000000000..55a9e228b75 --- /dev/null +++ b/backends/arm/_passes/decompose_slice_scatter_pass.py @@ -0,0 +1,193 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.accumulate_index_put_pass import ( + AccumulateIndexPutPass, +) +from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_slice_scatter_ops = (exir_ops.edge.aten.slice_scatter.default,) +aten_slice_scatter_ops = (torch.ops.aten.slice_scatter.default,) + + +def _get_slice_scatter_decomposition(op) -> tuple: + if op in edge_slice_scatter_ops: + return ( + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.index_put.default, + ) + if op in aten_slice_scatter_ops: + return ( + torch.ops.aten.arange.start_step, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.cat.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.index_put.default, + ) + raise RuntimeError(f"Can't get slice_scatter decomposition for op {op}") + + +def _fixup_start(start, dim_size: int) -> int: + s = 0 if start is None else int(start) + return max(0, min(s % dim_size if s < 0 else s, dim_size)) + + +def _fixup_end(end, dim_size: int) -> int: + e = dim_size if end is None else int(end) + return max(0, min(e % dim_size if e < 0 else e, dim_size)) + + +class DecomposeSliceScatterPass(ArmPass): + """ + Decompose slice_scatter into: + - Fast path (step == 1): slice_copy + cat (contiguous update), or + - General path (step > 1): arange + index_put (strided / interleaved). + + Limitations: + - Does not broadcast src: requires src.shape to exactly match the slice + shape being updated + + For dim != 0, permute input/src so that the updated dimension is first, + apply index_put with a single index tensor, then permute back. + """ + + _passes_required_after: Set[Type[ExportPass]] = { + AccumulateIndexPutPass, + RewriteIndexPutPass, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_slice_scatter_ops + aten_slice_scatter_ops): + return super().call_operator(op, args, kwargs, meta) + + ( + arange_op, + slice_copy_op, + cat_op, + permute_op, + index_put_op, + ) = _get_slice_scatter_decomposition(op) + + input = args[0] + src = args[1] + dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) + start = args[3] if len(args) > 3 else kwargs.get("start", None) + end = args[4] if len(args) > 4 else kwargs.get("end", None) + step = args[5] if len(args) > 5 else kwargs.get("step", 1) + + if not isinstance(dim, int) or not isinstance(step, int): + raise NotImplementedError("slice_scatter expects constant dim/step") + + if step <= 0: + raise NotImplementedError("slice_scatter expects step > 0") + + input_val = input.data # FakeTensor + input_shape = input_val.shape # [d0, d1, ..., d{r-1}] + input_device = input_val.device + input_rank = len(input_shape) + dim_norm = dim % input_rank + dim_size = int(input_shape[dim_norm]) + + start_i = _fixup_start(start, dim_size) + end_i = _fixup_end(end, dim_size) + if end_i <= start_i: + return input + + # index_positions: [W] where W = len(arange(start_i, end_i, step)) + index_positions = super().call_operator( + arange_op, + (start_i, end_i, step), + {"dtype": torch.int32, "device": input_device}, + meta, + updated=True, + ) + + src_val = src.data + src_shape = src_val.shape + index_shape = index_positions.data.shape + # slice_shape is input_shape with dim_norm replaced by W + # input_shape: [d0, ..., D, ..., d{r-1}] + # -> slice_shape: [d0, ..., W, ..., d{r-1}] + slice_shape = list(input_shape) + slice_shape[dim_norm] = int(index_shape[0]) + # slice_scatter does not broadcast src: require exact shape match. + if tuple(src_shape) != tuple(slice_shape): + raise NotImplementedError( + "slice_scatter requires src.shape to match the slice shape; " + f"got src.shape={tuple(src_shape)}, expected={tuple(slice_shape)}" + ) + + # ---- fast path: contiguous update (step == 1) ---- + if step == 1: + # prefix = input[..., :start_i, ...] along dim_norm + prefix = super().call_operator( + slice_copy_op, + (input, dim_norm, 0, start_i, 1), + {}, + meta, + updated=True, + ) + # suffix = input[..., end_i:, ...] along dim_norm + suffix = super().call_operator( + slice_copy_op, + (input, dim_norm, end_i, dim_size, 1), + {}, + meta, + updated=True, + ) + # concat(prefix, src, suffix) along dim_norm + updated = super().call_operator( + cat_op, + ([prefix, src, suffix], dim_norm), + {}, + meta, + updated=True, + ) + return updated + + # ---- general path: strided update (step > 1) ---- + # Move updated dim to front to use a single index tensor. + if dim_norm != 0: + perm = [dim_norm] + [i for i in range(input_rank) if i != dim_norm] + inv_perm = [0] * input_rank + for i, p in enumerate(perm): + inv_perm[p] = i + + # input: [d0, ..., d{dim_norm-1}, d{dim_norm}, d{dim_norm+1}, ..., d{r-1}] + # -> input_perm: [d{dim_norm}, d0, ..., d{dim_norm-1}, d{dim_norm+1}, ..., d{r-1}] + input = super().call_operator( + permute_op, (input, perm), {}, meta, updated=True + ) + # src: [d0, ..., d{dim_norm-1}, W, d{dim_norm+1}, ..., d{r-1}] + # -> src_perm: [W, d0, ..., d{dim_norm-1}, d{dim_norm+1}, ..., d{r-1}] + src = super().call_operator(permute_op, (src, perm), {}, meta, updated=True) + + # Puts values from src into input along the first dimension + # using a single 1D index tensor index_positions. + updated = super().call_operator( + index_put_op, + (input, (index_positions,), src, False), + {}, + meta, + updated=True, + ) + + if dim_norm != 0: + # updated_perm: [d{dim_norm}, ...] -> updated: [d0, d1, ..., d{r-1}] + updated = super().call_operator( + permute_op, (updated, inv_perm), {}, meta, updated=True + ) + + return updated diff --git a/backends/arm/_passes/decompose_strided_slice_copy_pass.py b/backends/arm/_passes/decompose_strided_slice_copy_pass.py new file mode 100644 index 00000000000..1ddf12e66e0 --- /dev/null +++ b/backends/arm/_passes/decompose_strided_slice_copy_pass.py @@ -0,0 +1,146 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +def _get_strided_slice_copy_decomposition(op): + """Return the operator overloads used by this decomposition.""" + if op == exir_ops.edge.aten.slice_copy.Tensor: + return ( + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.view_copy.default, + ) + raise RuntimeError(f"Can't get strided slice_copy decomposition for op {op}") + + +def _fixup_start(start, dim_size): + """Normalize start and clamp into [0, dim_size].""" + s = 0 if start is None else start + if s < 0: + s = s % dim_size + return max(0, min(s, dim_size)) + + +def _fixup_end(end, dim_size): + """Normalize end and clamp into [0, dim_size].""" + if end is None: + return dim_size + e = end + if e > dim_size: + e = dim_size + if e < 0: + e = e % dim_size + return max(0, min(e, dim_size)) + + +class DecomposeStridedSliceCopyPass(ArmPass): + """ + Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops. + + Given: + out = slice_copy(x, dim, start, end, step) with step > 1 + + Produce: + 1) y = slice_copy(x, dim, start, end, 1) # span with unit step + 2) pad y on the right to make length divisible by step (if needed) + 3) y2 = view_copy(y, ..., U, step, ...) # split the sliced dim + 4) y3 = slice_copy(y2, dim_i + 1, 0, 1, 1) # pick index 0 in each group + 5) out = view_copy(y3, ...) # collapse the singleton dim + + This implements "take every step-th element" using only unit-step slice + reshape. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + _TARGET_OPS = {exir_ops.edge.aten.slice_copy.Tensor} + + def call_operator(self, op, args, kwargs, meta): + if op not in self._TARGET_OPS: + return super().call_operator(op, args, kwargs, meta) + + # Only handle the non-unit-step case; leave unit-step to existing lowering. + if not (len(args) == 5 and args[4] != 1): + return super().call_operator(op, args, kwargs, meta) + + x, dim, start, end, step = args + assert step > 0, "slice_copy step must be positive" + + shape = x.data.shape + rank = len(shape) + + # Normalize dim into [0, rank). + dim_i = dim % rank + dim_size = shape[dim_i] + + # Normalize/clamp start/end into valid bounds. + start_i = _fixup_start(start, dim_size) + end_i = _fixup_end(end, dim_size) + + L = end_i - start_i + if L <= 0: + # slice_copy would return empty; keep default behavior. + return super().call_operator(op, args, kwargs, meta) + + slice_op, cat_op, view_op = _get_strided_slice_copy_decomposition(op) + + # 1) Unit-step slice of the requested span: + # y = x[..., start_i:end_i, ...] + y = super().call_operator( + slice_op, (x, dim_i, start_i, end_i, 1), {}, meta, updated=True + ) + + # 2) Compute: + # U = ceil(L / step) (# of output elements along dim_i) + # pad_right = U*step - L (so that padded length becomes U*step) + U = (L + step - 1) // step + pad_right = U * step - L + + # 3) If needed, right-pad along dim_i so that: + # after padding, y.shape[dim_i] == U*step + if pad_right > 0: + y_data = y.data + pad_shape = list(y_data.shape) + pad_shape[dim_i] = pad_right + + # z: zeros with same dtype/device as y, shape matches y except + # z.shape[dim_i] = pad_right. + fill_value = False if y_data.dtype == torch.bool else 0 + z = super().call_operator( + op=exir_ops.edge.aten.full.default, + args=(pad_shape, fill_value), + kwargs={"dtype": y_data.dtype, "device": y_data.device}, + meta=meta, + updated=True, + ) + + # Concatenate on the right: + # y.shape[dim_i] : L -> L + pad_right == U*step + y = super().call_operator(cat_op, ([y, z], dim_i), {}, meta, updated=True) + + # 4) Split the sliced dim: (U*step) -> (U, step) + y_t2 = y.data + split_shape = list(y_t2.shape) + split_shape[dim_i] = U + split_shape.insert(dim_i + 1, step) + + y2 = super().call_operator(view_op, (y, split_shape), {}, meta, updated=True) + + # 5) Take index 0 in the inserted "step" dimension: + # [..., U, step, ...] -> [..., U, 1, ...] + y3 = super().call_operator( + slice_op, (y2, dim_i + 1, 0, 1, 1), {}, meta, updated=True + ) + + # 6) Collapse y3's singleton step dim: [..., U, 1, ...] -> [..., U, ...]. + out_shape = list(y_t2.shape) # y_t2: [..., U*step, ...] + out_shape[dim_i] = U # out_shape: [..., U, ...] + + return super().call_operator(view_op, (y3, out_shape), {}, meta, updated=True) diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index 847514a2f5b..6e4b1b93b21 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -1,15 +1,21 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Declare operator support for ``aten.slice_copy`` in TOSA. -Support slicing with unit step only; emit a warning and reject otherwise. +Rely on preprocessing (e.g. DecomposeStridedSliceCopyPass) to rewrite any +non-unit-step slicing into supported ops. Assume static shapes and constant +slicing parameters. -""" +Check: +- args length is 4 or 5 +- If present, require step > 0. +- Require dtype compatible with the selected TOSA profile (allow bool in both). -import logging +""" +import torch import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( register_tosa_support_check, @@ -18,8 +24,6 @@ from executorch.backends.arm.tosa import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops -logger = logging.getLogger(__name__) - @register_tosa_support_check class SliceCopySupported(SupportedTOSAOperatorCheck): @@ -30,14 +34,56 @@ class SliceCopySupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - """Return True if the node is supported by TOSA. + if len(node.args) not in (4, 5): + self.reporter.report_reject( + node, + f"{node.target}: expected 4 or 5 args, got {len(node.args)}.", + ) + return False + + if len(node.args) == 5: + step = node.args[4] + if step <= 0: # type: ignore[operator] + self.reporter.report_reject( + node, + f"{node.target}: step must be > 0, got {step}.", + ) + return False - Accept slice_copy when the step is 1 (or unspecified). Warn and reject - non-unit step sizes. + values_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr] - """ - args = node.args - if len(args) == 5 and (step := args[4]) != 1: - logger.warning(f"{node.target} with step size of {step} not supported.") + SUPPORTED_INT_DTYPES = (torch.int8, torch.int16, torch.int32) + SUPPORTED_FLOAT_DTYPES = (torch.float16, torch.float32) + SUPPORTED_DTYPES = (torch.bool,) + SUPPORTED_INT_DTYPES + SUPPORTED_FLOAT_DTYPES + + # bool is supported in both INT and FP profiles + if values_dtype == torch.bool: + return True + # ints require INT profile + elif values_dtype in SUPPORTED_INT_DTYPES: + if not tosa_spec.support_integer(): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires INT profile.", + ) + return False + + # fp16/fp32: either FP profile, or INT profile (via quantization) + elif values_dtype in SUPPORTED_FLOAT_DTYPES: + if not (tosa_spec.support_float() or tosa_spec.support_integer()): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires FP profile or " + "INT profile (with quantization).", + ) + return False + + else: + self.reporter.report_reject( + node, + f"{node.target}: unsupported values dtype {values_dtype}; " + f"expected one of {SUPPORTED_DTYPES}.", + ) return False + return True diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index d636197ff4b..2fef371a68f 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -203,6 +203,7 @@ exir_ops.edge.aten.round.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index e73ca6f3640..74e2afe47f0 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -83,13 +83,14 @@ def define_node( self.tosa_spec, ) - # See slice_copy_support.py + # TOSA.SLICE has no stride parameter. Any non-unit-step slice_copy must have been + # rewritten earlier (e.g. by DecomposeStridedSliceCopyPass), so only step=1 is legal here. if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): raise ValueError("Unsupported combination of inputs") # aten.slice_copy supports slicing in 1d at a time. # The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride. - input_node, dim, start, end = inputs + input_node, dim, start, end = inputs[:4] # Translate and check parameters in Pytorch dim order. shape = input_node.shape diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index b68ad3979f1..f11310eb24c 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -6,6 +6,8 @@ from typing import Tuple +import pytest + import torch from executorch.backends.arm._passes import ( ConvertInt64ConstOpsToInt32Pass, @@ -76,6 +78,9 @@ def prepare_model_and_inputs(self): return text_encoder_model, text_encoder_model_inputs +@pytest.mark.xfail( + reason="MLETORCH-1601: Delegate output order mismatch from TOSA reference model." +) def test_clip_text_with_projection_tosa_FP(): text_encoder_model, text_encoder_model_inputs = ( TestCLIPTextModelWithProjection().prepare_model_and_inputs() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 6af389a5a68..239698980d8 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -187,3 +187,175 @@ def test_slice_tensor_16a8w_u85_INT(test_data: torch.Tensor): get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() + + +#################################### +## Non-unit step / stride slicing ## +#################################### + +input_t_step = Tuple[torch.Tensor, int, int, int, int] # (x, dim, start, end, step) + + +test_data_step_fp = { + # x[0:10:2] == x[::2] + "arange_fp32_1d_step2": lambda: ( + torch.arange(10, dtype=torch.float32), + 0, + 0, + 10, + 2, + ), + # x[:, 1:10:4] + "arange_fp32_2d_step4": lambda: ( + torch.arange(40, dtype=torch.float32).reshape(4, 10), + 1, + 1, + 10, + 4, + ), + # x[:, 0:4:2, :] + "arange_fp32_3d_dim1_step2": lambda: ( + torch.arange(2 * 4 * 17, dtype=torch.float32).reshape(2, 4, 17), + 1, + 0, + 4, + 2, + ), + # x[:, :, :, 0:17:4] + "arange_fp32_4d_dim3_step4": lambda: ( + torch.arange(2 * 3 * 5 * 17, dtype=torch.float32).reshape(2, 3, 5, 17), + 3, + 0, + 17, + 4, + ), + # x[:, 0:12:4] + "bool_2d_step4": lambda: ( + (torch.rand((2, 12)) < 0.5), # [2,12], dtype=bool + 1, + 0, + 12, + 4, + ), +} + +test_data_step_int = { + # x[:, 0:9:3] + "rand_int8_2d_step3": lambda: ( + torch.randint(-8, 8, size=(3, 9), dtype=torch.int8), + 1, + 0, + 9, + 3, + ), + # x[:, 0:6:2, :] + "arange_int32_3d_step2_dim1": lambda: ( + torch.arange(2 * 6 * 4, dtype=torch.int32).reshape(2, 6, 4), + 1, + 0, + 6, + 2, + ), + # x[:, :, :, 0:19:4] + "arange_int8_4d_dim3_step4": lambda: ( + torch.arange(2 * 2 * 4 * 19, dtype=torch.int8).reshape(2, 2, 4, 19), + 3, + 0, + 19, + 4, + ), + # x[:, 0:12:4] + "bool_2d_step4": lambda: ( + (torch.rand((2, 12)) < 0.5), # [2,12], dtype=bool + 1, + 0, + 12, + 4, + ), +} + + +class SliceWithStep(torch.nn.Module): + def forward( + self, x: torch.Tensor, dim_: int, start_: int, end_: int, step_: int + ) -> torch.Tensor: + # Use aten.slice to generate a slice_copy in Edge for lowering. + return torch.ops.aten.slice.Tensor(x, dim_, start_, end_, step_) + + +@common.parametrize("test_data", test_data_step_fp) +def test_slice_tensor_tosa_FP_step(test_data: Tuple): + pipeline = TosaPipelineFP[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +def test_slice_tensor_tosa_INT_step(test_data: Tuple): + pipeline = TosaPipelineINT[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_step_int | test_data_step_fp, + xfails={ + "bool_2d_step4": "MLETORCH-1744: bool test fails", + }, +) +@common.XfailIfNoCorstone300 +def test_slice_tensor_u55_INT_step(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + SliceWithStep(), + test_data(), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +@common.XfailIfNoCorstone320 +def test_slice_tensor_u85_INT_step(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + SliceWithStep(), + test_data(), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +@common.SkipIfNoModelConverter +def test_slice_tensor_vgf_no_quant_step(test_data: Tuple): + pipeline = VgfPipeline[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +@common.SkipIfNoModelConverter +def test_slice_tensor_vgf_quant_step(test_data: Tuple): + pipeline = VgfPipeline[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_slice_scatter.py b/backends/arm/test/ops/test_slice_scatter.py new file mode 100644 index 00000000000..ed151cd8d20 --- /dev/null +++ b/backends/arm/test/ops/test_slice_scatter.py @@ -0,0 +1,261 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + OpNotSupportedPipeline, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_fp_step1 = { + "rank2_step1": lambda: ( + torch.rand((5, 9), dtype=torch.float32), + torch.rand((5, 5), dtype=torch.float32), + 1, + 2, + 7, + 1, + ), + "rank4_negative": lambda: ( + torch.rand((1, 2, 4, 5), dtype=torch.float32), + torch.rand((1, 2, 2, 5), dtype=torch.float32), + 2, + -3, + -1, + 1, + ), +} + +test_data_fp_stepN = { + "rank3_step2": lambda: ( + torch.rand((2, 4, 6), dtype=torch.float32), + torch.rand((2, 4, 2), dtype=torch.float32), + 2, + 1, + 5, + 2, + ), + "rank3_end_none": lambda: ( + torch.rand((3, 5, 4), dtype=torch.float32), + torch.rand((3, 2, 4), dtype=torch.float32), + 1, + 1, + None, + 2, + ), +} + +test_data_int_step1 = { + "rank2_step1_int8": lambda: ( + torch.randint(-5, 5, (5, 9), dtype=torch.int8), + torch.randint(-5, 5, (5, 5), dtype=torch.int8), + 1, + 2, + 7, + 1, + ), +} + +test_data_int_stepN = { + "rank3_step2_int32": lambda: ( + torch.randint(-50, 50, (2, 4, 6), dtype=torch.int32), + torch.randint(-50, 50, (2, 4, 2), dtype=torch.int32), + 2, + 1, + 5, + 2, + ), +} + +test_data_bf16 = { + "rank2_step1_bf16": lambda: ( + torch.rand((4, 8), dtype=torch.bfloat16), + torch.rand((4, 3), dtype=torch.bfloat16), + 1, + 2, + 5, + 1, + ), +} + + +class SliceScatter(torch.nn.Module): + fp_aten_op = "torch.ops.aten.slice_scatter.default" + fp_exir_op = ["executorch_exir_dialects_edge__ops_aten_slice_scatter_default"] + int_aten_ops_step1 = [ + "torch.ops.aten.slice_copy.Tensor", + "torch.ops.aten.cat.default", + ] + int_aten_ops_stepN = [ + "torch.ops.aten.arange.start_step", + "torch.ops.aten.permute_copy.default", + "torch.ops.aten.index_put.default", + ] + int_exir_ops_step1 = [ + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", + "executorch_exir_dialects_edge__ops_aten_cat_default", + ] + int_exir_ops_stepN = [ + "executorch_exir_dialects_edge__ops_aten_arange_start_step", + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + "executorch_exir_dialects_edge__ops_aten_index_put_default", + ] + u55_not_supported = { + "executorch_exir_dialects_edge__ops_aten_index_put_default": 1, + } + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + dim: int, + start: int | None, + end: int | None, + step: int, + ): + return x.slice_scatter(y, dim=dim, start=start, end=end, step=step) + + +input_t = Tuple[torch.Tensor, torch.Tensor, int, int | None, int | None, int] + + +@common.parametrize( + "test_module", test_data_fp_step1 | test_data_fp_stepN | test_data_bf16 +) +def test_slice_scatter_tosa_FP(test_module: input_t): + pipeline = TosaPipelineFP[input_t]( + SliceScatter(), + test_module(), + aten_op=SliceScatter.fp_aten_op, + exir_op=SliceScatter.fp_exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_fp_step1 | test_data_int_step1) +def test_slice_scatter_tosa_INT_step1(test_module: input_t): + pipeline = TosaPipelineINT[input_t]( + SliceScatter(), + test_module(), + aten_op=SliceScatter.int_aten_ops_step1, + exir_op=SliceScatter.int_exir_ops_step1, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_fp_stepN | test_data_int_stepN) +def test_slice_scatter_tosa_INT_stepN(test_module: input_t): + pipeline = TosaPipelineINT[input_t]( + SliceScatter(), + test_module(), + aten_op=SliceScatter.int_aten_ops_stepN, + exir_op=SliceScatter.int_exir_ops_stepN, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_fp_step1 | test_data_int_step1) +def test_slice_scatter_u55_INT_step1(test_module: input_t): + # slice_scatter with unit-step is supported on U55 + pipeline = EthosU55PipelineINT[input_t]( + SliceScatter(), + test_module(), + aten_ops=SliceScatter.int_aten_ops_step1, + exir_ops=SliceScatter.int_exir_ops_step1, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_fp_stepN | test_data_int_stepN) +def test_slice_scatter_u55_INT_stepN(test_module: input_t): + # slice_scatter with non unit-step is not supported on U55 + pipeline = OpNotSupportedPipeline[input_t]( + SliceScatter(), + test_module(), + SliceScatter.u55_not_supported, + quantize=True, + u55_subset=True, + n_expected_delegates=2, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_module", test_data_fp_step1 | test_data_int_step1) +def test_slice_scatter_u85_INT_step1(test_module: input_t): + pipeline = EthosU85PipelineINT[input_t]( + SliceScatter(), + test_module(), + aten_ops=SliceScatter.int_aten_ops_step1, + exir_ops=SliceScatter.int_exir_ops_step1, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_module", test_data_fp_stepN | test_data_int_stepN) +def test_slice_scatter_u85_INT_stepN(test_module: input_t): + pipeline = EthosU85PipelineINT[input_t]( + SliceScatter(), + test_module(), + aten_ops=SliceScatter.int_aten_ops_stepN, + exir_ops=SliceScatter.int_exir_ops_stepN, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize( + "test_module", + test_data_int_step1 | test_data_int_stepN | test_data_fp_step1 | test_data_fp_stepN, + xfails={ + "rank2_step1_int8": "MLETORCH-1823: Fix quantized-node detection", + "rank3_step2_int32": "MLETORCH-1823: Fix quantized-node detection", + }, +) +def test_slice_scatter_vgf_no_quant(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SliceScatter(), + test_module(), + aten_op=SliceScatter.fp_aten_op, + exir_op=SliceScatter.fp_exir_op, + quantize=False, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_fp_step1 | test_data_int_step1) +def test_slice_scatter_vgf_quant_step1(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SliceScatter(), + test_module(), + aten_op=SliceScatter.int_aten_ops_step1, + exir_op=SliceScatter.int_exir_ops_step1, + quantize=True, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_fp_stepN | test_data_int_stepN) +def test_slice_scatter_vgf_quant_stepN(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SliceScatter(), + test_module(), + aten_op=SliceScatter.int_aten_ops_stepN, + exir_op=SliceScatter.int_exir_ops_stepN, + quantize=True, + ) + pipeline.run()