Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions backends/samsung/_passes/compose_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Qualcomm Innovation Center, Inc
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class RecomposeRmsNorm(ExportPass):
"""
Merge decomposed operators back to one super node.
"""

def __init__(self):
super().__init__()

def _get_eps_node(self, nodes):
# eps: one of inputs of add node
add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0]
for a in add_node.args:
if isinstance(a, float) or a.op != "call_function":
return a

def _get_gamma_node(self, output_node):
# gamma: one of inputs of output node
for a in output_node.args:
if a.op != "call_function":
return a

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
partitions = get_source_partitions(graph, [torch.nn.RMSNorm])

for _, src_partitions in partitions.items():
for src_partition in src_partitions:
input_len = len(src_partition.input_nodes)
if input_len == 1:
input_node = src_partition.input_nodes[0]
elif input_len == 2:
inp_0, inp_1 = src_partition.input_nodes
input_node = inp_0 if len(inp_0.users) == 2 else inp_1
else:
raise RuntimeError(
f"Found unsupported case of rms_node {src_partition}, "
f"which has {input_len} inputs"
)

output_node = src_partition.output_nodes[0]
eps_node = self._get_eps_node(src_partition.nodes)
gamma_node = self._get_gamma_node(output_node)

with graph.inserting_before(output_node):
# args schema
# (Tensor input, int[] normalized_shape, Tensor?
# weight=None, float? eps=None) -> Tensor
rms_node = graph.create_node(
"call_function",
exir_ops.edge.aten.rms_norm.default,
(
input_node,
list(gamma_node.meta["val"].shape),
gamma_node,
eps_node,
),
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, rms_node)
# copy metadata
rms_node.meta = output_node.meta

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
4 changes: 4 additions & 0 deletions backends/samsung/aot/PyEnnWrapperAdaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class PyEnnWrapper {
return py::array_t<char>();
}

auto perf_mode = option_buf_->perf_mode();
graphgen_set_perf_mode(
graphgen_instance_, static_cast<PerformanceMode>(perf_mode));

auto m_buf_info = model_buffer.request();
auto* model_buf_ptr = reinterpret_cast<uint8_t*>(m_buf_info.ptr);
NNCBuffer* nnc_buffer = nullptr;
Expand Down
24 changes: 24 additions & 0 deletions backends/samsung/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@
op_clamp,
op_constant_pad_nd,
op_conv2d,
op_cos,
op_dequantize,
op_div,
op_embedding,
op_expand_copy,
op_gelu,
op_getitem,
op_group_norm,
op_hardsigmoid,
op_hardswish,
op_hardtanh,
op_index,
op_layer_norm,
op_leaky_relu,
op_linear,
op_log,
op_log_softmax,
op_max_pool2d,
op_maximum,
Expand All @@ -34,17 +38,25 @@
op_mul,
op_permute,
op_pixel_shuffle,
op_pow,
op_quantize,
op_relu,
op_reshape,
op_rms_norm,
op_rsqrt,
op_select,
op_sigmoid,
op_sin,
op_slice_copy,
op_softmax,
op_split_with_sizes_copy,
op_sqrt,
op_squeeze,
op_sub,
op_sum_int_list,
op_tanh,
op_to_copy,
op_topk,
op_unsqueeze,
op_upsample_bilinear2d,
op_upsample_nearest2d,
Expand All @@ -60,37 +72,49 @@
op_clamp,
op_conv2d,
op_constant_pad_nd,
op_cos,
op_dequantize,
op_div,
op_embedding,
op_expand_copy,
op_gelu,
op_getitem,
op_group_norm,
op_hardswish,
op_hardtanh,
op_hardsigmoid,
op_index,
op_layer_norm,
op_leaky_relu,
op_linear,
op_log_softmax,
op_log,
op_max_pool2d,
op_maximum,
op_mean_dim,
op_minimum,
op_mul,
op_permute,
op_pixel_shuffle,
op_pow,
op_quantize,
op_relu,
op_reshape,
op_rms_norm,
op_rsqrt,
op_select,
op_sigmoid,
op_sin,
op_slice_copy,
op_softmax,
op_split_with_sizes_copy,
op_sqrt,
op_squeeze,
op_sub,
op_sum_int_list,
op_tanh,
op_to_copy,
op_topk,
op_unsqueeze,
op_upsample_bilinear2d,
op_upsample_nearest2d,
Expand Down
31 changes: 31 additions & 0 deletions backends/samsung/builders/op_cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.samsung.builders.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph


@register_node_visitor
class CosVisitor(NodeVisitor):
target = "aten.cos.default"

def define_node(
self,
node: torch.fx.Node,
enn_graph: EnnGraph,
vals_to_ids: Dict[torch.Tensor, int],
) -> None:
input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids)

output_id = self.define_tensor(node, enn_graph, vals_to_ids)

enn_graph.define_op(node.name, "Cos", [input_id], [output_id])
46 changes: 46 additions & 0 deletions backends/samsung/builders/op_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Dict

import torch
from executorch.backends.samsung.builders.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph


@register_node_visitor
class GroupNormVisitor(NodeVisitor):
target = "aten.native_group_norm.default"

def define_node(
self,
node: torch.fx.Node,
enn_graph: EnnGraph,
vals_to_ids: Dict[torch.Tensor, int],
) -> None:
all_input_tensors = []
input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids)
all_input_tensors.append(input_id)

weight_node = node.args[1]
weight_id = self.define_tensor(weight_node, enn_graph, vals_to_ids)
all_input_tensors.append(weight_id)
bias_node = node.args[2]
bias_id = self.define_tensor(bias_node, enn_graph, vals_to_ids)
all_input_tensors.append(bias_id)

num_groups = cast(int, node.args[6])
epsilon = node.args[7]

params = {"num_groups": num_groups, "epsilon": epsilon}

output_id = self.define_tensor(node, enn_graph, vals_to_ids, output_idx=0)
enn_graph.define_op(
node.name, "GROUPNORM", all_input_tensors, [output_id], params
)
49 changes: 49 additions & 0 deletions backends/samsung/builders/op_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.samsung.builders.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph


@register_node_visitor
class IndexVisitor(NodeVisitor):
target = "aten.index.Tensor"

def define_node(
self,
node: torch.fx.Node,
enn_graph: EnnGraph,
vals_to_ids: Dict[torch.Tensor, int],
) -> None:
input = node.args[0]
input_id = self.define_tensor(input, enn_graph, vals_to_ids)

axis = 0
valid_indices_node_count = 0
target_indices_node = None
for indices_node in node.args[1]:
if indices_node is not None:
target_indices_node = indices_node
valid_indices_node_count += 1
if valid_indices_node_count > 1:
raise NotImplementedError("Not support multi indices node.")
if target_indices_node is None:
axis += 1

indices_id = self.define_tensor(target_indices_node, enn_graph, vals_to_ids)

output_id = self.define_tensor(node, enn_graph, vals_to_ids)

params = {"axis": axis, "input_type": "params"}
enn_graph.define_op(
node.name, "GATHER", [input_id, indices_id], [output_id], params
)
32 changes: 32 additions & 0 deletions backends/samsung/builders/op_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.samsung.builders.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph


@register_node_visitor
class LogVisitor(NodeVisitor):
target = "aten.log.default"

def define_node(
self,
node: torch.fx.Node,
enn_graph: EnnGraph,
vals_to_ids: Dict[torch.Tensor, int],
) -> None:
input = node.args[0]
input_id = self.define_tensor(input, enn_graph, vals_to_ids)

output_id = self.define_tensor(node, enn_graph, vals_to_ids)

enn_graph.define_op(node.name, "LOG", [input_id], [output_id])
42 changes: 42 additions & 0 deletions backends/samsung/builders/op_pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.samsung.builders.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.samsung.builders.utils import get_tensor
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph


@register_node_visitor
class PowVisitor(NodeVisitor):
target = "aten.pow.Tensor_Tensor"

def define_node(
self,
node: torch.fx.Node,
enn_graph: EnnGraph,
vals_to_ids: Dict[torch.Tensor, int],
) -> None:
input1 = node.args[0]
input2 = node.args[1]
input_tensor_1 = get_tensor(self.exported_program, input1)
input_tensor_2 = get_tensor(self.exported_program, input2)
assert (
input_tensor_1.dtype == torch.float32
and input_tensor_2.dtype == torch.float32
), "Requires the two inputs are all float type"

input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids)
input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids)

output_id = self.define_tensor(node, enn_graph, vals_to_ids)

enn_graph.define_op(node.name, "POW", [input_id_1, input_id_2], [output_id])
Loading
Loading