Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions backends/nxp/aten_passes/convert_unsqueeze_to_view.py

This file was deleted.

4 changes: 0 additions & 4 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

import torch

from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import (
ConvertUnsqueezeToViewPass,
)
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
DecomposeSplitToSlicesPass,
)
Expand Down Expand Up @@ -50,7 +47,6 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
ConvertUnsqueezeToViewPass(),
]

if not qat_mode:
Expand Down
103 changes: 103 additions & 0 deletions backends/nxp/edge_passes/convert_reshaping_nodes_to_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassResult


class ConvertReshapingNodesToViewPass(NeutronEdgePass):
"""Replaces:
- 'aten.squeeze.default', 'aten.squeeze.dims' and 'aten.squeeze.dim' with 'aten.view_copy.default'.

x x
│ │
┌──────────────▼──────────────┐ replace with ┌───────────────▼────────────────┐
│ aten.[un]squeeze(x, dim) │ ──────────────► │ aten.view_copy.default(x, S) │
└──────────────┬──────────────┘ └───────────────┬────────────────┘
│ │
▼ ▼
out out

- 'aten.unsqueeze.default' with 'aten.view_copy.default'.

x x
│ │
┌─────────────▼─────────────┐ replace with ┌───────────────▼────────────────┐
│ aten.unsqueeze(x, dim) │ ──────────────► │ aten.view_copy.default(x, S) │
└─────────────┬─────────────┘ └───────────────┬────────────────┘
│ │
▼ ▼
out out
"""

graph_module: GraphModule

@staticmethod
def _is_squeeze(node_: Node) -> bool:
return node_.op == "call_function" and (
node_.target == exir_ops.edge.aten.squeeze_copy.dim
or node_.target == exir_ops.edge.aten.squeeze_copy.dims
or node_.target == exir_ops.edge.aten.squeeze_copy.default
)

@staticmethod
def _is_unsqueeze(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target == exir_ops.edge.aten.unsqueeze_copy.default
)

def _create_view_copy_node(self, *view_args) -> Node:
view_target = exir_ops.edge.aten.view_copy.default
view_node = self.graph_module.graph.call_function(view_target, view_args)

view_node.meta["source_fn_stack"] = [
(view_node.name, exir_ops.edge.aten.view_copy.default)
]

x_val = view_args[0].meta["val"]
with FakeTensorMode() as mode:
fake_input = FakeTensor.from_tensor(
torch.empty(x_val.shape, dtype=x_val.dtype), mode
)
output_shape = view_target(fake_input, *view_args[1:]).shape
view_node.meta["val"] = FakeTensor.from_tensor(
torch.empty(output_shape, dtype=x_val.dtype), mode
)

return view_node

def run(self, graph_module: GraphModule) -> Optional[PassResult]:
self.graph_module = graph_module

for node in list(graph_module.graph.nodes):
if not (self._is_squeeze(node) or self._is_unsqueeze(node)):
continue

input_node = node.all_input_nodes[0]
target_shape = node.meta["val"].shape

with self.graph_module.graph.inserting_after(node):
view_copy_node = self._create_view_copy_node(input_node, target_shape)

node.replace_all_uses_with(view_copy_node)
self.graph_module.graph.erase_node(node)

self.graph_module.recompile()
self.graph_module.graph.eliminate_dead_code()

# Return immediately to avoid traversing a modified graph.
# The parent class will call this pass again.
return PassResult(graph_module, True)

# The graph was not modified.
return PassResult(graph_module, False)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 NXP
# Copyright 2025-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -13,15 +13,18 @@

# Operator aliases for better readability.
AddMM = exir_ops.edge.aten.addmm.default
ViewCopy = exir_ops.edge.aten.view_copy.default
MM = exir_ops.edge.aten.mm.default
AvgPool2D = exir_ops.edge.aten.avg_pool2d.default
Conv = exir_ops.edge.aten.convolution.default
Clone = exir_ops.edge.aten.clone.default
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default
HardTanh = exir_ops.edge.aten.hardtanh.default
MM = exir_ops.edge.aten.mm.default
Relu = exir_ops.edge.aten.relu.default
Sigmoid = exir_ops.edge.aten.sigmoid.default
SqueezeCopy = exir_ops.edge.aten.squeeze_copy.dims
Tanh = exir_ops.edge.aten.tanh.default
Clone = exir_ops.edge.aten.clone.default
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default
UnsqueezeCopy = exir_ops.edge.aten.unsqueeze_copy.default
ViewCopy = exir_ops.edge.aten.view_copy.default


def insert_qdq_pair_after_node(
Expand Down Expand Up @@ -105,6 +108,13 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
ViewCopy,
],
ViewCopy: [Clone, CloneDimOrder],
Conv: [
ViewCopy, # For 1D conv
],
AvgPool2D: [
ViewCopy, # For 1D AvgPool
UnsqueezeCopy,
],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
Expand Down Expand Up @@ -200,8 +210,13 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
Relu,
Sigmoid,
Tanh,
ViewCopy, # For 1D conv.
],
ViewCopy: [Clone, CloneDimOrder],
AvgPool2D: [
ViewCopy,
SqueezeCopy,
],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
Expand Down
4 changes: 4 additions & 0 deletions backends/nxp/edge_passes/neutron_edge_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.edge_passes.convert_reshaping_nodes_to_view import (
ConvertReshapingNodesToViewPass,
)
from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
Expand All @@ -21,6 +24,7 @@ def __init__(self, passes: list[NeutronEdgePass] = None):
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
RemoveUselessAsStridedCopyNodes(),
ConvertReshapingNodesToViewPass(),
]

super().__init__(
Expand Down
8 changes: 5 additions & 3 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ class QDQCluster:

AUXILIARY_OPS = [
operator.getitem,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
]

Expand Down
14 changes: 12 additions & 2 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
AdaptiveAvgPoolPattern,
AddmmPattern,
AddTensorPattern,
AvgPoolPattern,
AvgPool1DPattern,
AvgPool2DPattern,
BatchNormPattern,
CatPattern,
Conv1dPattern,
Expand All @@ -43,10 +44,14 @@
SigmoidPattern,
SliceTensorPattern,
SoftMaxPattern,
SqueezeDimPattern,
SqueezeDimsPattern,
SqueezePattern,
SubTensorPattern,
TanhInPlacePattern,
TanhPattern,
TransposeIntPattern,
UnsqueezePattern,
UpsampleBilinear2DPattern,
UpsampleNearest2DPattern,
ViewPattern,
Expand Down Expand Up @@ -248,7 +253,8 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AvgPool1DPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AvgPool2DPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(BatchNormPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
Expand All @@ -271,10 +277,14 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SliceTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SqueezeDimPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SqueezeDimsPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SqueezePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TransposeIntPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(UnsqueezePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(UpsampleBilinear2DPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(UpsampleNearest2DPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),
Expand Down
Loading
Loading