diff --git a/exir/program/_program.py b/exir/program/_program.py index baacd5eaec4..dd0c1b0e5da 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1122,6 +1122,16 @@ def keep(op): ) return False + # Fallback: torchgen does not detect alias annotations on ops + # returning lists of aliased tensors (e.g. split.Tensor returns + # Tensor(a)[]). Check op._schema.returns directly. + for ret in schema.returns: + if ret.alias_info is not None: + log_warning( + f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output." + ) + return False + # Explicit block list of ops that don't work if asked for # preservation if op in [ diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index f683384f8f9..8a084ba491a 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -940,6 +940,43 @@ def body(i, h, h_accum): torch.allclose(prog.exported_program().module()(inp), model(inp)) ) + def test_remove_invalid_ops_filters_aliased_list_returns(self) -> None: + """Verify _remove_invalid_ops_for_not_decompose filters ops that return + aliased tensor lists (e.g. split, chunk) even when torchgen's + aliased_return_names() fails to detect them. Regression test for + https://github.com/pytorch/executorch/issues/11723 + """ + from executorch.exir.program._program import ( + _remove_invalid_ops_for_not_decompose, + ) + + # These ops return Tensor(a)[] — a list of aliased views. + # torchgen's aliased_return_names() misses the alias annotation on + # list returns, so the fallback check on op._schema.returns is needed. + aliased_list_ops = [ + torch.ops.aten.split.Tensor, + torch.ops.aten.chunk.default, + torch.ops.aten.tensor_split.sections, + torch.ops.aten.split_with_sizes.default, + ] + for op in aliased_list_ops: + result = _remove_invalid_ops_for_not_decompose([op]) + self.assertNotIn( + op, + result, + f"{op} should be filtered out because it returns aliased tensors", + ) + + # Non-aliased ops should be preserved. + preserved_ops = [torch.ops.aten.linear.default] + for op in preserved_ops: + result = _remove_invalid_ops_for_not_decompose([op]) + self.assertIn( + op, + result, + f"{op} should be preserved because it has no aliased returns", + ) + def test_convert_symb_ops(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: