Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 1 addition & 48 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import operator
import traceback
from inspect import isclass
from typing import cast, List, Optional, Sequence, Tuple
from typing import cast, Optional, Sequence

import torch
import torch.fx
Expand All @@ -19,10 +19,6 @@
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.graph_module import (
_get_control_flow_submodules,
get_control_flow_submodules,
)
from executorch.exir.pass_base import NodeMetadata

from torch._export.utils import (
Expand All @@ -36,7 +32,6 @@
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.export.graph_signature import InputKind
from torch.fx import GraphModule, Node


def is_submodule_node(node: torch.fx.Node):
Expand Down Expand Up @@ -364,48 +359,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
raise RuntimeError("Invalid type")


def is_nested_control_flow_graph(graph_module: GraphModule) -> bool:
"""Returns True if graph_module is a nested control-flow graph."""

# Find all top-level control-flow submodules
top_cf = get_control_flow_submodules(graph_module)
# For each submodule, see if it itself has control-flow inside
for _, submod, _ in top_cf:
if get_control_flow_submodules(submod):
return True
return False


def get_cond_while_submodules_nested(
graph_module: GraphModule,
apply_quantization: bool = False,
) -> List[Tuple[str, GraphModule, Node]]:
"""Recursively find cond/while_loop submodules in an GraphModule.

In nested control flow graphs, FX records the submodule functions
(true/false or cond/body) in reverse order compared to top-level graphs. We
must swap the indices when nested so that cond (first) and body/true_fn
(second) are consistently identified across all nesting levels.

"""

# Determine arg indices based on nesting and whether only cond branch is needed
nested = is_nested_control_flow_graph(graph_module)
# cond: [true_fn, false_fn] or swapped if nested
cond_indices = [2, 1] if nested else [1, 2]
# while_loop: [cond_fn, body_fn] or swapped if nested
while_indices = [1, 0] if nested else [0, 1]
if apply_quantization:
# only keep the cond_fn for while_loop (first index) when quantizing.
while_indices = [while_indices[0]]
mapping = {
torch.ops.higher_order.cond: cond_indices,
torch.ops.higher_order.while_loop: while_indices,
}
# collect cond/while submodules (using mapping indices)
return _get_control_flow_submodules(graph_module, mapping)


def to_2tuple(value):
"""Normalizes scalars, and 1-element sequences to a tuple of length 2."""
if isinstance(value, int):
Expand Down
8 changes: 3 additions & 5 deletions backends/arm/_passes/control_flow_const_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
get_cond_while_submodules_nested,
is_submodule_node,
)
from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node
from executorch.backends.transforms.utils import is_get_attr_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_cond_while_submodules
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule

Expand All @@ -37,7 +35,7 @@ class ControlFlowConstInlinePass(ArmPass):

def _convert_getattr(self, graph_module):
modified = False
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
for _, submodule, _ in get_cond_while_submodules(graph_module):
for submodule_node in submodule.graph.nodes:
if submodule_node.target in self._targeted_ops:
self._convert_getattr(submodule)
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,13 @@ def _rescale_submodule_inputs(
input_node = input_nodes[qargs_index]
if len(input_node.users) == 0:
continue
if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1:
out_qparams_map = input_node.meta.get("output_qparams", {})
if len(out_qparams_map) == 0:
# Nested control-flow submodules may also expose frozen captured
# values as placeholders. Those are not control-flow boundary
# inputs, so there is no qparam pair to bridge with a RESCALE.
continue
if len(out_qparams_map) != 1:
raise ValueError(
f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}"
)
Expand Down
8 changes: 3 additions & 5 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
get_cond_while_submodules_nested,
get_first_fake_tensor,
)
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.exir.graph_module import get_cond_while_submodules
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
Expand Down Expand Up @@ -98,7 +96,7 @@ def handle_control_nodes(self, graph_module: GraphModule) -> None:
"""Apply scalar argument conversion on subgraphs of control-flow
nodes.
"""
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
for _, submodule, _ in get_cond_while_submodules(graph_module):
for submodule_node in submodule.graph.nodes:
self._convert_scalar_args(submodule, submodule_node)

Expand Down
26 changes: 17 additions & 9 deletions backends/arm/operator_support/control_flow_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from torch.fx.passes.operator_support import OperatorSupportBase


def _owning_graph_module(node: fx.Node) -> fx.GraphModule:
graph_module = getattr(node.graph, "owning_module", None)
if not isinstance(graph_module, fx.GraphModule):
raise RuntimeError(f"Could not resolve owning GraphModule for node {node}")
return graph_module


def _fully_partitioned(submodule: fx.GraphModule) -> bool:
"""Check that all nested control-flow ops within this submodule are also
fully partitioned.
Expand All @@ -27,8 +34,8 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool:

for submodule_node in submodule.graph.nodes:
if submodule_node.target in ControlFlowOpSupported._targeted_ops:
if _submodules_fully_partitioned(submodule_node, submodule):
return True
if not _submodules_fully_partitioned(submodule_node, submodule):
return False

if submodule_node.op != "call_function":
continue
Expand Down Expand Up @@ -56,13 +63,18 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool:
return True


def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool:
def _submodules_fully_partitioned(
node: fx.Node, graph_module: fx.GraphModule | None = None
) -> bool:
"""Returns whether the submodule arguments to a cond node were fully
partitioned.

Updates "val" meta of the submodules if they are.

"""
if graph_module is None:
graph_module = _owning_graph_module(node)

match node.target:
case torch.ops.higher_order.cond:
submodule_args = node.args[1:3]
Expand Down Expand Up @@ -129,9 +141,7 @@ def is_node_supported(
node, f"Submodule had unsupported user {user}"
)
return False
if not _submodules_fully_partitioned(
user, self.exported_program.graph_module
):
if not _submodules_fully_partitioned(user):
self.reporter.report_reject(
node, "One submodule was not fully partitioned"
)
Expand Down Expand Up @@ -174,9 +184,7 @@ def is_node_supported(
)
return False

if not _submodules_fully_partitioned(
node, self.exported_program.graph_module
):
if not _submodules_fully_partitioned(node):
self.reporter.report_reject(
node, "Submodule was not fully partitioned."
)
Expand Down
19 changes: 16 additions & 3 deletions backends/arm/operators/op_cond_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
validate_num_inputs,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore
from executorch.backends.arm.tosa.mapping import ( # type: ignore
TOSA_CONTROL_FLOW_REGION_NAME_META,
TOSA_TENSOR_NAME_META,
TosaArg,
)
from torch.fx import Node


Expand All @@ -38,7 +42,12 @@ def define_node(
validate_cf_extension(self.target, self.tosa_spec)

attr = ts.TosaSerializerAttribute()
if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3])
if_graph, else_graph = (
cast(Node, arg).meta.get(
TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target)
)
for arg in node.args[1:3]
)
attr.CondIfAttribute(if_graph, else_graph)

self._serialize_operator(
Expand All @@ -47,7 +56,11 @@ def define_node(
ts.Op.COND_IF,
[
inputs[0].name,
*(subgraph_input.name for subgraph_input in inputs[-1].special),
*(
subgraph_input.name
+ subgraph_input.meta.get(TOSA_TENSOR_NAME_META, "")
for subgraph_input in inputs[-1].special
),
],
output.multiple_output_names,
attr,
Expand Down
19 changes: 16 additions & 3 deletions backends/arm/operators/op_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
validate_cf_extension,
validate_num_inputs,
)
from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa.mapping import (
map_dtype,
TOSA_CONTROL_FLOW_REGION_NAME_META,
TOSA_TENSOR_NAME_META,
TosaArg,
)
from executorch.backends.arm.tosa.utils import normalize_symint

from torch.fx import Node


Expand Down Expand Up @@ -46,7 +52,12 @@ def define_node(
)

attr = ts.TosaSerializerAttribute()
cond_graph, body_graph = (str(cast(Node, arg).target) for arg in node.args[:2])
cond_graph, body_graph = (
cast(Node, arg).meta.get(
TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target)
)
for arg in node.args[:2]
)
attr.WhileLoopAttribute(cond_graph, body_graph)

input_names: list[str] = []
Expand All @@ -55,7 +66,9 @@ def define_node(
raise ValueError(
f"{self.target}: Unsupported carried input type {type(loop_input)}."
)
input_names.append(loop_input.name)
input_names.append(
loop_input.name + loop_input.meta.get(TOSA_TENSOR_NAME_META, "")
)

num_inputs = len(input_names)
num_outputs = len(output.multiple_output_names)
Expand Down
Loading
Loading