diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index f82157d3cf0..7391c3bacc4 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -31,6 +31,7 @@ from .decompose_reciprocal import DecomposeReciprocal from .decompose_remainder import DecomposeRemainder from .decompose_roll import DecomposeRoll +from .decompose_select_scatter import DecomposeSelectScatter from .decompose_silu import DecomposeSilu from .decompose_tan import DecomposeTan from .decompose_threshold import DecomposeThreshold @@ -88,6 +89,7 @@ DecomposeReciprocal, DecomposeRemainder, DecomposeRoll, + DecomposeSelectScatter, DecomposeSilu, DecomposeTan, DecomposeThreshold, diff --git a/backends/qualcomm/_passes/decompose_select_scatter.py b/backends/qualcomm/_passes/decompose_select_scatter.py new file mode 100644 index 00000000000..46a63b3e761 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_select_scatter.py @@ -0,0 +1,96 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + + +class DecomposeSelectScatter(ExportPass): + """ + Decompose select_scatter into unsqueeze + slice_scatter. + + select_scatter(input, src, dim, index) replaces a single index along the given dimension. + If input has shape [m, n, p] and dim=1, then src must have shape [m, p] (the selected dimension is removed). + slice_scatter operates on a sliced view where the dimension is preserved. + When slicing a single index, the target region has shape [m, 1, p]. + + Therefore, src must be unsqueezed along dim (from [m, p] to [m, 1, p]) to match the slice shape. + So, the equivalence is: + select_scatter(input, src, dim, index) == slice_scatter(input, src.unsqueeze(dim), dim, index, index+1, 1) + """ + + def __init__(self): + super(DecomposeSelectScatter, self).__init__() + self.select_scatter_targets = { + torch.ops.aten.select_scatter.default, + exir_ops.edge.aten.select_scatter.default, + } + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + + for node in list(graph.nodes): + if ( + node.op == "call_function" + and node.target in self.select_scatter_targets + ): + input_node = node.args[0] + src_node = node.args[1] + dim = node.args[2] + index = node.args[3] + + # Normalize negative index + if index < 0: + size = input_node.meta["val"].shape[dim] + index = index + size + + is_edge = isinstance(node.target, EdgeOpOverload) + meta = node.meta + + unsqueeze_op = ( + exir_ops.edge.aten.unsqueeze_copy.default + if is_edge + else torch.ops.aten.unsqueeze.default + ) + slice_scatter_op = ( + exir_ops.edge.aten.slice_scatter.default + if is_edge + else torch.ops.aten.slice_scatter.default + ) + + with graph.inserting_before(node): + # unsqueeze src along dim to restore the missing dimension + unsqueeze_node = graph.create_node( + "call_function", unsqueeze_op, (src_node, dim) + ) + # Compute unsqueeze output shape for meta + src_val = src_node.meta.get("val", None) + if src_val is not None: + unsqueeze_val = src_val.unsqueeze(dim) + unsqueeze_node.meta = copy_meta( + meta, + callback=lambda m, val=unsqueeze_val: {**m, "val": val}, + ) + else: + unsqueeze_node.meta = copy_meta(meta) + + slice_scatter_node = graph.create_node( + "call_function", + slice_scatter_op, + (input_node, unsqueeze_node, dim, index, index + 1, 1), + ) + slice_scatter_node.meta = copy_meta(meta) + + for user in node.users.copy(): + user.replace_input_with(node, slice_scatter_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index b0913bbefd9..a31b6a1f42f 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -36,6 +36,7 @@ DecomposeReciprocal, DecomposeRemainder, DecomposeRoll, + DecomposeSelectScatter, DecomposeSilu, DecomposeTan, DecomposeThreshold, @@ -251,6 +252,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) self.add_pass(DecomposeReciprocal()) self.add_pass(DecomposeRemainder()) + self.add_pass(DecomposeSelectScatter()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeLogVariants()) self.add_pass(ReplaceInfValues()) @@ -266,6 +268,7 @@ def transform_for_export_pipeline( self.add_pass(DecomposePad()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) + self.add_pass(DecomposeSelectScatter()) self.add_pass(DecomposeThreshold()) self.add_pass(DecomposeTriu()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index a443ed0905c..89115a0150c 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -517,6 +517,7 @@ The following PyTorch operators are supported through decomposition or annotatio | `aten.reflection_pad2d` | `DecomposePad` | | `aten.remainder.Scalar`, `aten.remainder.Tensor` | `DecomposeRemainder` | | `aten.roll` | `DecomposeRoll` | +| `aten.select_scatter` | `DecomposeSelectScatter` | | `aten.silu` | `DecomposeSilu` | | `aten.tan` | `DecomposeTan` | | `aten.threshold` | `DecomposeThreshold` | diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index e190828f06a..cb9305b65a3 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -2319,6 +2319,16 @@ def forward(self, x, y): ) +class SelectScatter(torch.nn.Module): + def __init__(self, dim, index): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, x, y): + return x.select_scatter(y, dim=self.dim, index=self.index) + + class SliceScatter(torch.nn.Module): def __init__(self, dim, start, end, step): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 940c54c2f8d..6d5b44d7a35 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2014,6 +2014,41 @@ def test_qnn_backend_select_copy(self): sample_input = (torch.randn([1, 3, 3, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_select_scatter(self): + test_comb = [ + { + QCOM_MODULE: [ + SelectScatter(dim=0, index=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.randn(4, 8), + torch.randn(8), + ) + ], + }, + { + QCOM_MODULE: [ + SelectScatter(dim=1, index=0), # noqa: F405 + SelectScatter(dim=1, index=-1), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.randn(3, 4, 5), + torch.randn(3, 5), + ) + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_slice_copy(self): modules = [ SliceCopyDefaultParameter(), # noqa: F405 @@ -4834,6 +4869,42 @@ def test_qnn_backend_select_copy(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_select_scatter(self): + test_comb = [ + { + QCOM_MODULE: [ + SelectScatter(dim=0, index=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.randn(4, 8), + torch.randn(8), + ) + ], + }, + { + QCOM_MODULE: [ + SelectScatter(dim=1, index=0), # noqa: F405 + SelectScatter(dim=1, index=-1), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.randn(3, 4, 5), + torch.randn(3, 5), + ) + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_sigmoid(self): module = Sigmoid() # noqa: F405 sample_input = (torch.randn([1, 3, 3, 3]),)