From 1afcbab59204fe34db305ed6370dc6df8e7f8079 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Wed, 29 Apr 2026 13:04:15 +0100 Subject: [PATCH] Arm backend: Fix nested control-flow partition checks - Updates so that the outer cond graph is picked up. - Removes need for increased threshold. Signed-off-by: Saoirse Stewart --- backends/arm/_passes/arm_pass_utils.py | 49 +------- .../arm/_passes/control_flow_const_inline.py | 8 +- backends/arm/_passes/insert_rescales_pass.py | 8 +- .../arm/_passes/scalars_to_attribute_pass.py | 8 +- .../operator_support/control_flow_support.py | 26 +++-- backends/arm/operators/op_cond_if.py | 19 +++- backends/arm/operators/op_while.py | 19 +++- backends/arm/quantizer/arm_quantizer.py | 105 ++++++++++++------ backends/arm/test/ops/test_cond.py | 2 - backends/arm/tosa/backend.py | 61 +++++++++- backends/arm/tosa/mapping.py | 1 + backends/arm/tosa/partitioner.py | 8 +- 12 files changed, 193 insertions(+), 121 deletions(-) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 000f92135eb..f66b17b9da2 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -9,7 +9,7 @@ import operator import traceback from inspect import isclass -from typing import cast, List, Optional, Sequence, Tuple +from typing import cast, Optional, Sequence import torch import torch.fx @@ -19,10 +19,6 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.graph_module import ( - _get_control_flow_submodules, - get_control_flow_submodules, -) from executorch.exir.pass_base import NodeMetadata from torch._export.utils import ( @@ -36,7 +32,6 @@ from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor from torch.export.graph_signature import InputKind -from torch.fx import GraphModule, Node def is_submodule_node(node: torch.fx.Node): @@ -364,48 +359,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value): raise RuntimeError("Invalid type") -def is_nested_control_flow_graph(graph_module: GraphModule) -> bool: - """Returns True if graph_module is a nested control-flow graph.""" - - # Find all top-level control-flow submodules - top_cf = get_control_flow_submodules(graph_module) - # For each submodule, see if it itself has control-flow inside - for _, submod, _ in top_cf: - if get_control_flow_submodules(submod): - return True - return False - - -def get_cond_while_submodules_nested( - graph_module: GraphModule, - apply_quantization: bool = False, -) -> List[Tuple[str, GraphModule, Node]]: - """Recursively find cond/while_loop submodules in an GraphModule. - - In nested control flow graphs, FX records the submodule functions - (true/false or cond/body) in reverse order compared to top-level graphs. We - must swap the indices when nested so that cond (first) and body/true_fn - (second) are consistently identified across all nesting levels. - - """ - - # Determine arg indices based on nesting and whether only cond branch is needed - nested = is_nested_control_flow_graph(graph_module) - # cond: [true_fn, false_fn] or swapped if nested - cond_indices = [2, 1] if nested else [1, 2] - # while_loop: [cond_fn, body_fn] or swapped if nested - while_indices = [1, 0] if nested else [0, 1] - if apply_quantization: - # only keep the cond_fn for while_loop (first index) when quantizing. - while_indices = [while_indices[0]] - mapping = { - torch.ops.higher_order.cond: cond_indices, - torch.ops.higher_order.while_loop: while_indices, - } - # collect cond/while submodules (using mapping indices) - return _get_control_flow_submodules(graph_module, mapping) - - def to_2tuple(value): """Normalizes scalars, and 1-element sequences to a tuple of length 2.""" if isinstance(value, int): diff --git a/backends/arm/_passes/control_flow_const_inline.py b/backends/arm/_passes/control_flow_const_inline.py index cc76e5d9957..177ad30754e 100644 --- a/backends/arm/_passes/control_flow_const_inline.py +++ b/backends/arm/_passes/control_flow_const_inline.py @@ -7,12 +7,10 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - is_submodule_node, -) +from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node from executorch.backends.transforms.utils import is_get_attr_node from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule @@ -37,7 +35,7 @@ class ControlFlowConstInlinePass(ArmPass): def _convert_getattr(self, graph_module): modified = False - for _, submodule, _ in get_cond_while_submodules_nested(graph_module): + for _, submodule, _ in get_cond_while_submodules(graph_module): for submodule_node in submodule.graph.nodes: if submodule_node.target in self._targeted_ops: self._convert_getattr(submodule) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 06c27005440..45374c12c3b 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -509,7 +509,13 @@ def _rescale_submodule_inputs( input_node = input_nodes[qargs_index] if len(input_node.users) == 0: continue - if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1: + out_qparams_map = input_node.meta.get("output_qparams", {}) + if len(out_qparams_map) == 0: + # Nested control-flow submodules may also expose frozen captured + # values as placeholders. Those are not control-flow boundary + # inputs, so there is no qparam pair to bridge with a RESCALE. + continue + if len(out_qparams_map) != 1: raise ValueError( f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}" ) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 0473caf91e7..63a38b8cb2f 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -8,11 +8,9 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - get_first_fake_tensor, -) +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.exir.graph_module import get_cond_while_submodules from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix @@ -98,7 +96,7 @@ def handle_control_nodes(self, graph_module: GraphModule) -> None: """Apply scalar argument conversion on subgraphs of control-flow nodes. """ - for _, submodule, _ in get_cond_while_submodules_nested(graph_module): + for _, submodule, _ in get_cond_while_submodules(graph_module): for submodule_node in submodule.graph.nodes: self._convert_scalar_args(submodule, submodule_node) diff --git a/backends/arm/operator_support/control_flow_support.py b/backends/arm/operator_support/control_flow_support.py index b34ebeaece0..f5251357cd3 100644 --- a/backends/arm/operator_support/control_flow_support.py +++ b/backends/arm/operator_support/control_flow_support.py @@ -19,6 +19,13 @@ from torch.fx.passes.operator_support import OperatorSupportBase +def _owning_graph_module(node: fx.Node) -> fx.GraphModule: + graph_module = getattr(node.graph, "owning_module", None) + if not isinstance(graph_module, fx.GraphModule): + raise RuntimeError(f"Could not resolve owning GraphModule for node {node}") + return graph_module + + def _fully_partitioned(submodule: fx.GraphModule) -> bool: """Check that all nested control-flow ops within this submodule are also fully partitioned. @@ -27,8 +34,8 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool: for submodule_node in submodule.graph.nodes: if submodule_node.target in ControlFlowOpSupported._targeted_ops: - if _submodules_fully_partitioned(submodule_node, submodule): - return True + if not _submodules_fully_partitioned(submodule_node, submodule): + return False if submodule_node.op != "call_function": continue @@ -56,13 +63,18 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool: return True -def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool: +def _submodules_fully_partitioned( + node: fx.Node, graph_module: fx.GraphModule | None = None +) -> bool: """Returns whether the submodule arguments to a cond node were fully partitioned. Updates "val" meta of the submodules if they are. """ + if graph_module is None: + graph_module = _owning_graph_module(node) + match node.target: case torch.ops.higher_order.cond: submodule_args = node.args[1:3] @@ -129,9 +141,7 @@ def is_node_supported( node, f"Submodule had unsupported user {user}" ) return False - if not _submodules_fully_partitioned( - user, self.exported_program.graph_module - ): + if not _submodules_fully_partitioned(user): self.reporter.report_reject( node, "One submodule was not fully partitioned" ) @@ -174,9 +184,7 @@ def is_node_supported( ) return False - if not _submodules_fully_partitioned( - node, self.exported_program.graph_module - ): + if not _submodules_fully_partitioned(node): self.reporter.report_reject( node, "Submodule was not fully partitioned." ) diff --git a/backends/arm/operators/op_cond_if.py b/backends/arm/operators/op_cond_if.py index 05d38e2a1f0..513100c2b15 100644 --- a/backends/arm/operators/op_cond_if.py +++ b/backends/arm/operators/op_cond_if.py @@ -17,7 +17,11 @@ validate_num_inputs, validate_valid_dtype, ) -from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore +from executorch.backends.arm.tosa.mapping import ( # type: ignore + TOSA_CONTROL_FLOW_REGION_NAME_META, + TOSA_TENSOR_NAME_META, + TosaArg, +) from torch.fx import Node @@ -38,7 +42,12 @@ def define_node( validate_cf_extension(self.target, self.tosa_spec) attr = ts.TosaSerializerAttribute() - if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3]) + if_graph, else_graph = ( + cast(Node, arg).meta.get( + TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target) + ) + for arg in node.args[1:3] + ) attr.CondIfAttribute(if_graph, else_graph) self._serialize_operator( @@ -47,7 +56,11 @@ def define_node( ts.Op.COND_IF, [ inputs[0].name, - *(subgraph_input.name for subgraph_input in inputs[-1].special), + *( + subgraph_input.name + + subgraph_input.meta.get(TOSA_TENSOR_NAME_META, "") + for subgraph_input in inputs[-1].special + ), ], output.multiple_output_names, attr, diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py index 2b6314d3454..58501dd3ba0 100644 --- a/backends/arm/operators/op_while.py +++ b/backends/arm/operators/op_while.py @@ -15,8 +15,14 @@ validate_cf_extension, validate_num_inputs, ) -from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa.mapping import ( + map_dtype, + TOSA_CONTROL_FLOW_REGION_NAME_META, + TOSA_TENSOR_NAME_META, + TosaArg, +) from executorch.backends.arm.tosa.utils import normalize_symint + from torch.fx import Node @@ -46,7 +52,12 @@ def define_node( ) attr = ts.TosaSerializerAttribute() - cond_graph, body_graph = (str(cast(Node, arg).target) for arg in node.args[:2]) + cond_graph, body_graph = ( + cast(Node, arg).meta.get( + TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target) + ) + for arg in node.args[:2] + ) attr.WhileLoopAttribute(cond_graph, body_graph) input_names: list[str] = [] @@ -55,7 +66,9 @@ def define_node( raise ValueError( f"{self.target}: Unsupported carried input type {type(loop_input)}." ) - input_names.append(loop_input.name) + input_names.append( + loop_input.name + loop_input.meta.get(TOSA_TENSOR_NAME_META, "") + ) num_inputs = len(input_names) num_outputs = len(output.multiple_output_names) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index f1dfb5f1323..3508410509c 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -40,6 +40,10 @@ from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporter +from executorch.exir.graph_module import ( + _get_control_flow_submodules, + get_cond_while_submodules, +) from torch._ops import OpOverload @@ -52,10 +56,6 @@ from executorch.backends.arm.common.arm_compile_spec import ( ArmCompileSpec, ) # isort: skip -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - is_submodule_node, -) from executorch.backends.arm.quantizer.arm_quantizer_utils import ( _get_int32_bias_qspec, @@ -107,6 +107,29 @@ logger = logging.getLogger(__name__) +def get_cond_while_submodules_ao( + graph_module: GraphModule, + apply_quantization: bool = False, +) -> list[tuple[str, GraphModule, Node]]: + """Return cond/while submodules for the current graph module. + + Quantization handles ``while_loop`` body functions natively in torchao, so + only the ``while_loop`` cond function is processed explicitly there. + + """ + + if not apply_quantization: + return get_cond_while_submodules(graph_module) + + return _get_control_flow_submodules( + graph_module, + { + torch.ops.higher_order.cond: [1, 2], + torch.ops.higher_order.while_loop: [0], + }, + ) + + @functools.lru_cache def get_symmetric_quantization_config( is_per_channel: bool = True, @@ -810,42 +833,56 @@ def _quantize_with_submodules( prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e prepared = prepare_fn(model, self) - # Prepare conditional submodules (e.g., if/while bodies) - # prepare only cond branches and while_loop cond_fn - for name, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True - ): - prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - prepared.set_submodule( - nested_name, prepare_fn(nested_sub, self), strict=True - ) + + def _prepare_control_flow_submodules( + source_graph_module: GraphModule, prefix: str = "" + ) -> None: + for name, submodule, _ in get_cond_while_submodules_ao( + source_graph_module, apply_quantization=True + ): + qualified_name = f"{prefix}.{name}" if prefix else name + prepared.set_submodule( + qualified_name, prepare_fn(submodule, self), strict=True + ) + _prepare_control_flow_submodules(submodule, qualified_name) + + _prepare_control_flow_submodules(prepared) for inp in calibration_samples: prepared(*inp) - # Prepare conditional submodules (e.g., if/while bodies) - # convert only cond branches and while_loop cond_fn - for _, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True + def _convert_control_flow_submodule( + graph_module: GraphModule, + ) -> GraphModule: + converted_submodules: list[tuple[str, GraphModule]] = [] + for name, submodule, _ in get_cond_while_submodules_ao( + graph_module, apply_quantization=True + ): + converted_submodules.append( + (name, _convert_control_flow_submodule(submodule)) + ) + converted_graph_module = convert_pt2e( + graph_module, fold_quantize=fold_quantize + ) + for name, converted_submodule in converted_submodules: + converted_graph_module.set_submodule( + name, converted_submodule, strict=True + ) + return converted_graph_module + + converted_top_level_submodules: list[tuple[str, GraphModule]] = [] + for name, submodule, _ in list( + get_cond_while_submodules_ao(prepared, apply_quantization=True) ): - converted = convert_pt2e(submodule, fold_quantize=fold_quantize) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - converted.set_submodule( - nested_name, - convert_pt2e(nested_sub, fold_quantize=fold_quantize), - strict=True, - ) + converted_top_level_submodules.append( + (name, _convert_control_flow_submodule(submodule)) + ) + + converted = convert_pt2e(prepared, fold_quantize=fold_quantize) + for name, converted_submodule in converted_top_level_submodules: + converted.set_submodule(name, converted_submodule, strict=True) - return convert_pt2e(prepared, fold_quantize=fold_quantize) + return converted class _TOSAQuantizerV1(Quantizer): diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index 8c6d9ef329c..6f489f0ab01 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -250,8 +250,6 @@ def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): example_inputs, aten_op, tosa_extensions=["cf"], - frobenius_threshold=0.8, - cosine_threshold=0.8, # MLETORCH-1808 ) _set_branch_calibration_samples(pipeline, module, example_inputs) # Make sure no cond ops are left after partitioning. diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 6b864e284b1..b0cae15022d 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -23,9 +23,6 @@ import tosa_serializer as ts -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, -) from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook @@ -35,9 +32,13 @@ process_placeholder, ) from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec -from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META +from executorch.backends.arm.tosa.mapping import ( + TOSA_CONTROL_FLOW_REGION_NAME_META, + TOSA_TENSOR_NAME_META, +) from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -45,6 +46,15 @@ logger = logging.getLogger(__name__) +def _qualify_control_flow_region_name( + parent_region_name: str | None, child_region_name: str +) -> str: + """Return a globally unique TOSA region name for nested control flow.""" + if parent_region_name is None: + return child_region_name + return f"{parent_region_name}__{child_region_name}" + + def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: """Assign deterministic output IDs to leaf outputs. @@ -325,6 +335,43 @@ def _preprocess_module( # noqa: C901 RuntimeError: If an FX node with an unsupported op kind is found. """ + + def _annotate_control_flow_region_names( + graph_module: GraphModule, parent_region_name: str | None + ) -> None: + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + match node.target: + case torch.ops.higher_order.cond: + arg_indices = [1, 2] + case torch.ops.higher_order.while_loop: + arg_indices = [0, 1] + case _: + continue + + for arg_index in arg_indices: + submodule_node = node.args[arg_index] + if not isinstance(submodule_node, Node): + raise RuntimeError( + f"Expected control flow submodule arg {arg_index} to be a Node." + ) + if submodule_node.op != "get_attr": + raise RuntimeError( + f"Expected control flow submodule arg {arg_index} to be a get_attr node." + ) + if not isinstance(submodule_node.target, str): + raise RuntimeError( + "Expected control flow submodule target to be a string." + ) + + submodule_node.meta[TOSA_CONTROL_FLOW_REGION_NAME_META] = ( + _qualify_control_flow_region_name( + parent_region_name, submodule_node.target + ) + ) + tosa_spec = compile_spec.tosa_spec node_to_id_map = _annotate_external_ids(graph_module.graph) artifact_path = compile_spec._get_intermediate_path() @@ -348,6 +395,8 @@ def _preprocess_module( # noqa: C901 else: logger.debug("No re-sorting outputs (workaround) during TOSA lowering.") + _annotate_control_flow_region_names(graph_module, submodule_name) + if submodule_name is not None: tosa_graph.startRegion(submodule_name) tosa_graph.currRegion.addBasicBlock(submodule_name) @@ -396,7 +445,7 @@ def _preprocess_module( # noqa: C901 raise # Recursively preprocess controlflow submodules. - for name, submodule, control_flow_node in get_cond_while_submodules_nested( + for name, submodule, control_flow_node in get_cond_while_submodules( graph_module ): TOSABackend._regularize_submodule(submodule, control_flow_node) @@ -406,7 +455,7 @@ def _preprocess_module( # noqa: C901 compile_spec, tosa_graph, debug_hook, - submodule_name=name, + submodule_name=_qualify_control_flow_region_name(submodule_name, name), containing_graph_module=graph_module, ) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 9741ca167c0..40ed99838d1 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -17,6 +17,7 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification +TOSA_CONTROL_FLOW_REGION_NAME_META = "tosa_control_flow_region_name" TOSA_TENSOR_NAME_META = "tosa_tensor_name" UNSUPPORTED_DTYPES = ( diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 07548eb5d69..35c5cee7657 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -20,10 +20,7 @@ from typing import Callable, cast, List, Optional, Sequence, Tuple import torch -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - get_first_fake_tensor, -) +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) @@ -42,6 +39,7 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition @@ -328,7 +326,7 @@ def _tag_module( # noqa tags: set[str] = set() if tag_iterator is None: tag_iterator = count(0) - for _, submodule, _ in get_cond_while_submodules_nested(module): + for _, submodule, _ in get_cond_while_submodules(module): submodule_tags = self._tag_module( submodule, containing_program, reporter, tag_iterator )