From 8b56e75dd9abfc00ce201405e0e4f3ba8e7233af Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 20 May 2026 22:35:10 -0700 Subject: [PATCH] Fix IndexError in RemovePermutesAroundElementwiseOps for rank-mismatched views (#19713) Summary: Fix an IndexError crash in the RemovePermutesAroundElementwiseOps compiler pass that occurs when a squeeze/unsqueeze view_copy node is included in a subgraph via upstream traversal with a permutation whose rank does not match the view input tensor rank. The bug manifests as `IndexError: tuple index out of range` in `update_view_copy` at the line: ``` unpermuted_in = [in_shape[inverse_permute[i]] for i in range(len(in_shape))] ``` because `inverse_permute` contains indices from a rank-N permutation but `in_shape` has fewer than N dimensions. The fix adds a validation guard at the top of `permute_subgraph()` that checks all view_copy nodes for rank consistency between their `node_start_permute` and their input tensor shape. If a mismatch is found, the entire subgraph optimisation is skipped, preserving graph correctness. The return type of `permute_subgraph` is changed from `None` to `bool` so the caller only marks the graph as modified when the subgraph was actually transformed. Also adds a regression test that constructs the exact graph pattern (upstream unsqueeze view_copy with a 2D constant feeding into a 3D permuted add) that triggered the crash. Reviewed By: ethansfng Differential Revision: D105787161 --- ...ve_permutes_around_elementwise_tosa_ops.py | 4 +- .../remove_permutes_around_elementwise_ops.py | 21 +++++- .../test/test_permute_optimization_passes.py | 64 +++++++++++++++++++ 3 files changed, 84 insertions(+), 5 deletions(-) diff --git a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py index fa6f6f7988c..72688d17ef2 100644 --- a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py +++ b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py @@ -21,7 +21,7 @@ def __init__(self) -> None: } ) - def permute_subgraph(self, subgraph): + def permute_subgraph(self, subgraph) -> bool: # Original function will always permute constant nodes which is wrong for table ops # Remove constant tosa.TABLE edges before running full function new_constant_edges_in = set() @@ -32,4 +32,4 @@ def permute_subgraph(self, subgraph): new_constant_edges_in.add((const_node, user_node)) subgraph.constant_edges_in = new_constant_edges_in - super().permute_subgraph(subgraph) + return super().permute_subgraph(subgraph) diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py index eec6bdc4e08..b992afaeb53 100644 --- a/backends/transforms/remove_permutes_around_elementwise_ops.py +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -240,8 +240,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False for subgraph in subgraphs_found: - self.permute_subgraph(subgraph) - modified = True + if self.permute_subgraph(subgraph): + modified = True if modified: graph_module.graph.eliminate_dead_code() @@ -399,7 +399,20 @@ def is_node_permutable(self, node: torch.fx.Node) -> bool: return True return self._is_pointwise(node.target) - def permute_subgraph(self, subgraph: Subgraph) -> None: # noqa: C901 + def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901 + # Validate: every view_copy node's permutation rank must match its + # input tensor rank. A mismatch can occur when a squeeze/unsqueeze + # view is reached via upstream traversal with a permutation that was + # already adapted to a different rank. Applying the optimisation in + # this case would produce an invalid graph, so skip the subgraph. + for node in subgraph.nodes: + if node.target in self._VIEW_OPS: + perm = subgraph.node_start_permute.get(node, subgraph.start_permute) + inp = node.args[0] + if isinstance(inp, torch.fx.Node) and inp.meta.get("val") is not None: + if len(perm) != len(inp.meta["val"].shape): + return False + # Handle dimension related node arguments FIRST, before # bypassing permutes (which changes node inputs/metadata). for node in subgraph.nodes: @@ -480,6 +493,8 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: # noqa: C901 assert out.target == exir_ops.edge.aten.permute_copy.default out.replace_all_uses_with(inp) + return True + def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: dim = get_arg(node, "dim", int) set_arg(node, "dim", start_permute[dim]) diff --git a/backends/transforms/test/test_permute_optimization_passes.py b/backends/transforms/test/test_permute_optimization_passes.py index 808a599f81f..dd356aad8a2 100644 --- a/backends/transforms/test/test_permute_optimization_passes.py +++ b/backends/transforms/test/test_permute_optimization_passes.py @@ -998,3 +998,67 @@ def test_permute_unsqueeze_copy_neg_dim_mul_squeeze_copy_permute(self) -> None: [x_data], "permute_unsqueeze_copy_neg_dim_mul_squeeze_copy_permute", ) + + def test_upstream_view_rank_mismatch_no_crash(self) -> None: + """Regression test for IndexError when a squeeze/unsqueeze view_copy + is reached via upstream traversal with a permutation whose rank does + not match the view's input tensor rank. + + Graph: + full([16, 128], 1.0) x [1, 128, 16] + | | + view_copy (unsqueeze 2D→3D) permute [0, 2, 1] + [1, 16, 128] [1, 16, 128] + \\ / + ---- add (3D) ----------- + | + permute [0, 2, 1] + | + output + + The view_copy (unsqueeze) is reached as an upstream input to `add`. + Its node_start_permute gets the 3D permutation [0, 2, 1], but its + input (the full op) is 2D. Before the fix, update_view_copy would + crash with IndexError: tuple index out of range.""" + builder = GraphBuilder() + x_data = torch.randn(1, 128, 16) + x = builder.placeholder("x", x_data) + # 2D constant — treated as compile-time constant by _is_constant + const_2d = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=([16, 128], 1.0) + ) + # Unsqueeze via view_copy: [16, 128] → [1, 16, 128] + view_unsq = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(const_2d, [1, 16, 128]) + ) + # Start permute: [1, 128, 16] → [1, 16, 128] + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 1]) + ) + # Add the permuted input with the unsqueezed constant + add = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(p1, view_unsq) + ) + # End permute: [1, 16, 128] → [1, 128, 16] + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(add, [0, 2, 1]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + # Should not crash, and should skip the subgraph due to rank mismatch + p = RemovePermutesAroundElementwiseOps() + result = cast(PassResult, p(original)) + # The subgraph is skipped, so the graph should be unmodified + self.assertFalse(result.modified) + # Both permutes are preserved + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 2 + ) + validate_numerics( + gm_before, + result.graph_module, + [x_data], + "upstream_view_rank_mismatch_no_crash", + )