diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 728012ad457..9fdb5f79695 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -306,6 +306,8 @@ def tosa_support_factory( negative_checks.append(EthosU55NotSupported(reporter)) negative_checks.append(EthosU55DtypeSupport(reporter)) negative_checks.append(EthosU55CastCheck(reporter)) + if not tosa_spec.support_extension("shape"): + negative_checks.append(SymbolicShapeSupportCheck(reporter)) return chain( reporter.wrap_check( @@ -316,6 +318,72 @@ def tosa_support_factory( ) +class SymbolicShapeSupportCheck(OperatorSupportBase): + """Reject symbolic tensor shapes for specs without the shape extension.""" + + def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ + self.reporter = reporter + + @staticmethod + def _has_symbolic_shape(node: fx.Node) -> bool: + val = node.meta.get("val") + vals = val if isinstance(val, (list, tuple)) else (val,) + for node_val in vals: + if isinstance(node_val, torch.SymInt): + return True + + shape = getattr(node_val, "shape", None) + if shape is not None and any( + isinstance(dim, torch.SymInt) for dim in shape + ): + return True + + return False + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + """Return False for nodes with symbolic tensor input or output shapes. + + Dynamic shapes require the TOSA shape extension. Reject nodes with + symbolic tensor dimensions before partitioning when the active spec + does not enable that extension. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: False if rejected by constraints; otherwise, True. + + """ + if node.op in ("placeholder", "output"): + return True + if node.op == "call_function" and node.target in (*Q_OPS, *DQ_OPS): + return True + + if self._has_symbolic_shape(node) or any( + self._has_symbolic_shape(input_node) for input_node in node.all_input_nodes + ): + if node.target == exir_ops.edge.aten.upsample_nearest2d.vec: + return True + + self.reporter.report_reject( + node, + "Node has symbolic shape but the TOSA spec does not support " + "the shape extension.", + ) + return False + + return True + + class TOSAProINTSupportList(OperatorSupportBase): """Provide the INT profile support list for TOSA. diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index 9f06335b7a8..3742f710494 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -9,15 +9,23 @@ import torch import torch.nn.functional as F +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + SymbolicShapeSupportCheck, +) from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, ) from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import ( TosaPipelineFP, TosaPipelineINT, VgfPipeline, ) +from executorch.exir import to_edge +from executorch.exir.backend.utils import WhyNoPartitionReporter +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim, export aten_op = "torch.ops.aten.pad.default" exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default" @@ -143,6 +151,128 @@ def forward(self, x: torch.Tensor): return F.pad(x, pad=self.pad, mode=self.mode, value=self.value) +class RawConstantPadND(torch.nn.Module): + def __init__(self, pad: Tuple, value: float = 0.0): + super().__init__() + self.pad = pad + self.value = value + + def forward(self, x: torch.Tensor): + return F.pad(x, pad=self.pad, mode="constant", value=self.value) + + +def _constant_pad_nd_node( + module: torch.nn.Module, + example_inputs: tuple[torch.Tensor, ...], + dynamic_shapes=None, +) -> torch.fx.Node: + edge = to_edge( + export(module, example_inputs, dynamic_shapes=dynamic_shapes, strict=True) + ) + return next( + n + for n in edge.exported_program().graph.nodes + if n.target == exir_ops.edge.aten.constant_pad_nd.default + ) + + +def _is_tosa_without_shape_extension_supported(node: torch.fx.Node) -> bool: + return SymbolicShapeSupportCheck(WhyNoPartitionReporter()).is_node_supported( + {}, node + ) + + +def test_constant_pad_nd_no_target_u55_symbolic_padded_axis_not_delegated(): + input_tensor = torch.rand(1, 3, 8, 8, 5) + width = Dim("width", min=4, max=10) + node = _constant_pad_nd_node( + RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)), + (input_tensor,), + dynamic_shapes={"x": {4: width}}, + ) + + assert not _is_tosa_without_shape_extension_supported(node) + + +def test_constant_pad_nd_no_target_u55_symbolic_unpadded_axis_not_delegated(): + input_tensor = torch.rand(1, 3, 8, 8, 5) + width = Dim("width", min=4, max=10) + node = _constant_pad_nd_node( + RawConstantPadND((0, 0, 1, 0, 0, 0, 0, 0)), + (input_tensor,), + dynamic_shapes={"x": {4: width}}, + ) + + assert not _is_tosa_without_shape_extension_supported(node) + + +def test_constant_pad_nd_no_target_u55_static_padded_axis_supported(): + input_tensor = torch.rand(1, 3, 8, 8, 5) + node = _constant_pad_nd_node( + RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)), + (input_tensor,), + ) + + assert _is_tosa_without_shape_extension_supported(node) + + +def test_constant_pad_nd_u55_INT_dynamic_padded_axis_not_delegated(): + input_tensor = torch.rand(1, 3, 8, 8, 5) + width = Dim("width", min=4, max=10) + tester = ArmTester( + RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)), + (input_tensor,), + common.get_u55_compile_spec(), + dynamic_shapes=({4: width},), + ) + + tester.quantize().export().to_edge().partition() + targets = { + node.target + for node in tester.stages[tester.cur].artifact.exported_program().graph.nodes + } + + assert exir_ops.edge.aten.constant_pad_nd.default in targets + assert torch.ops.higher_order.executorch_call_delegate not in targets + + +def test_constant_pad_nd_u85_INT_dynamic_padded_axis_not_delegated(): + input_tensor = torch.rand(1, 3, 8, 8, 5) + width = Dim("width", min=4, max=10) + tester = ArmTester( + RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)), + (input_tensor,), + common.get_u85_compile_spec(), + dynamic_shapes=({4: width},), + ) + + tester.quantize().export().to_edge().partition() + targets = { + node.target + for node in tester.stages[tester.cur].artifact.exported_program().graph.nodes + } + + assert exir_ops.edge.aten.constant_pad_nd.default in targets + assert torch.ops.higher_order.executorch_call_delegate not in targets + + +def test_constant_pad_nd_u55_INT_static_5d_padded_axis_delegated(): + input_tensor = torch.rand(1, 3, 8, 8, 5) + tester = ArmTester( + RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)), + (input_tensor,), + common.get_u55_compile_spec(), + ) + + tester.quantize().export().to_edge_transform_and_lower() + targets = { + node.target + for node in tester.stages[tester.cur].artifact.exported_program().graph.nodes + } + + assert torch.ops.higher_order.executorch_call_delegate in targets + + @common.parametrize( "test_data", test_data_suite | test_data_suite_bf16 | test_data_suite_fp16,