diff --git a/backends/samsung/_passes/compose_rms_norm.py b/backends/samsung/_passes/compose_rms_norm.py new file mode 100644 index 00000000000..475d9f0bbb1 --- /dev/null +++ b/backends/samsung/_passes/compose_rms_norm.py @@ -0,0 +1,79 @@ +# Copyright (c) Qualcomm Innovation Center, Inc +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +class RecomposeRmsNorm(ExportPass): + """ + Merge decomposed operators back to one super node. + """ + + def __init__(self): + super().__init__() + + def _get_eps_node(self, nodes): + # eps: one of inputs of add node + add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0] + for a in add_node.args: + if isinstance(a, float) or a.op != "call_function": + return a + + def _get_gamma_node(self, output_node): + # gamma: one of inputs of output node + for a in output_node.args: + if a.op != "call_function": + return a + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + partitions = get_source_partitions(graph, [torch.nn.RMSNorm]) + + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + input_len = len(src_partition.input_nodes) + if input_len == 1: + input_node = src_partition.input_nodes[0] + elif input_len == 2: + inp_0, inp_1 = src_partition.input_nodes + input_node = inp_0 if len(inp_0.users) == 2 else inp_1 + else: + raise RuntimeError( + f"Found unsupported case of rms_node {src_partition}, " + f"which has {input_len} inputs" + ) + + output_node = src_partition.output_nodes[0] + eps_node = self._get_eps_node(src_partition.nodes) + gamma_node = self._get_gamma_node(output_node) + + with graph.inserting_before(output_node): + # args schema + # (Tensor input, int[] normalized_shape, Tensor? + # weight=None, float? eps=None) -> Tensor + rms_node = graph.create_node( + "call_function", + exir_ops.edge.aten.rms_norm.default, + ( + input_node, + list(gamma_node.meta["val"].shape), + gamma_node, + eps_node, + ), + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, rms_node) + # copy metadata + rms_node.meta = output_node.meta + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/samsung/aot/PyEnnWrapperAdaptor.h b/backends/samsung/aot/PyEnnWrapperAdaptor.h index 953ec050174..b1bbf3c0d63 100644 --- a/backends/samsung/aot/PyEnnWrapperAdaptor.h +++ b/backends/samsung/aot/PyEnnWrapperAdaptor.h @@ -53,6 +53,10 @@ class PyEnnWrapper { return py::array_t(); } + auto perf_mode = option_buf_->perf_mode(); + graphgen_set_perf_mode( + graphgen_instance_, static_cast(perf_mode)); + auto m_buf_info = model_buffer.request(); auto* model_buf_ptr = reinterpret_cast(m_buf_info.ptr); NNCBuffer* nnc_buffer = nullptr; diff --git a/backends/samsung/builders/__init__.py b/backends/samsung/builders/__init__.py index 978da82b370..655b4875e90 100644 --- a/backends/samsung/builders/__init__.py +++ b/backends/samsung/builders/__init__.py @@ -14,18 +14,22 @@ op_clamp, op_constant_pad_nd, op_conv2d, + op_cos, op_dequantize, op_div, op_embedding, op_expand_copy, op_gelu, op_getitem, + op_group_norm, op_hardsigmoid, op_hardswish, op_hardtanh, + op_index, op_layer_norm, op_leaky_relu, op_linear, + op_log, op_log_softmax, op_max_pool2d, op_maximum, @@ -34,17 +38,25 @@ op_mul, op_permute, op_pixel_shuffle, + op_pow, op_quantize, op_relu, op_reshape, + op_rms_norm, op_rsqrt, op_select, + op_sigmoid, + op_sin, op_slice_copy, op_softmax, + op_split_with_sizes_copy, op_sqrt, op_squeeze, op_sub, + op_sum_int_list, + op_tanh, op_to_copy, + op_topk, op_unsqueeze, op_upsample_bilinear2d, op_upsample_nearest2d, @@ -60,19 +72,23 @@ op_clamp, op_conv2d, op_constant_pad_nd, + op_cos, op_dequantize, op_div, op_embedding, op_expand_copy, op_gelu, op_getitem, + op_group_norm, op_hardswish, op_hardtanh, op_hardsigmoid, + op_index, op_layer_norm, op_leaky_relu, op_linear, op_log_softmax, + op_log, op_max_pool2d, op_maximum, op_mean_dim, @@ -80,17 +96,25 @@ op_mul, op_permute, op_pixel_shuffle, + op_pow, op_quantize, op_relu, op_reshape, + op_rms_norm, op_rsqrt, op_select, + op_sigmoid, + op_sin, op_slice_copy, op_softmax, + op_split_with_sizes_copy, op_sqrt, op_squeeze, op_sub, + op_sum_int_list, + op_tanh, op_to_copy, + op_topk, op_unsqueeze, op_upsample_bilinear2d, op_upsample_nearest2d, diff --git a/backends/samsung/builders/op_cos.py b/backends/samsung/builders/op_cos.py new file mode 100644 index 00000000000..bd746db91cd --- /dev/null +++ b/backends/samsung/builders/op_cos.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class CosVisitor(NodeVisitor): + target = "aten.cos.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op(node.name, "Cos", [input_id], [output_id]) diff --git a/backends/samsung/builders/op_group_norm.py b/backends/samsung/builders/op_group_norm.py new file mode 100644 index 00000000000..55c7bb6732a --- /dev/null +++ b/backends/samsung/builders/op_group_norm.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class GroupNormVisitor(NodeVisitor): + target = "aten.native_group_norm.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + all_input_tensors = [] + input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids) + all_input_tensors.append(input_id) + + weight_node = node.args[1] + weight_id = self.define_tensor(weight_node, enn_graph, vals_to_ids) + all_input_tensors.append(weight_id) + bias_node = node.args[2] + bias_id = self.define_tensor(bias_node, enn_graph, vals_to_ids) + all_input_tensors.append(bias_id) + + num_groups = cast(int, node.args[6]) + epsilon = node.args[7] + + params = {"num_groups": num_groups, "epsilon": epsilon} + + output_id = self.define_tensor(node, enn_graph, vals_to_ids, output_idx=0) + enn_graph.define_op( + node.name, "GROUPNORM", all_input_tensors, [output_id], params + ) diff --git a/backends/samsung/builders/op_index.py b/backends/samsung/builders/op_index.py new file mode 100644 index 00000000000..ab0cef58576 --- /dev/null +++ b/backends/samsung/builders/op_index.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class IndexVisitor(NodeVisitor): + target = "aten.index.Tensor" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + axis = 0 + valid_indices_node_count = 0 + target_indices_node = None + for indices_node in node.args[1]: + if indices_node is not None: + target_indices_node = indices_node + valid_indices_node_count += 1 + if valid_indices_node_count > 1: + raise NotImplementedError("Not support multi indices node.") + if target_indices_node is None: + axis += 1 + + indices_id = self.define_tensor(target_indices_node, enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + params = {"axis": axis, "input_type": "params"} + enn_graph.define_op( + node.name, "GATHER", [input_id, indices_id], [output_id], params + ) diff --git a/backends/samsung/builders/op_log.py b/backends/samsung/builders/op_log.py new file mode 100644 index 00000000000..97127dd94ba --- /dev/null +++ b/backends/samsung/builders/op_log.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class LogVisitor(NodeVisitor): + target = "aten.log.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op(node.name, "LOG", [input_id], [output_id]) diff --git a/backends/samsung/builders/op_pow.py b/backends/samsung/builders/op_pow.py new file mode 100644 index 00000000000..cd6ec7f81ef --- /dev/null +++ b/backends/samsung/builders/op_pow.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.builders.utils import get_tensor +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class PowVisitor(NodeVisitor): + target = "aten.pow.Tensor_Tensor" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input1 = node.args[0] + input2 = node.args[1] + input_tensor_1 = get_tensor(self.exported_program, input1) + input_tensor_2 = get_tensor(self.exported_program, input2) + assert ( + input_tensor_1.dtype == torch.float32 + and input_tensor_2.dtype == torch.float32 + ), "Requires the two inputs are all float type" + + input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op(node.name, "POW", [input_id_1, input_id_2], [output_id]) diff --git a/backends/samsung/builders/op_rms_norm.py b/backends/samsung/builders/op_rms_norm.py new file mode 100644 index 00000000000..6a58d62a5ce --- /dev/null +++ b/backends/samsung/builders/op_rms_norm.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.builders.utils import get_tensor +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class RmsNormVisitor(NodeVisitor): + target = "aten.rms_norm.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + # args of node : ['input', 'normalized_shape', 'weight', 'eps'] + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + # input2 + normalized_shape = cast(List[int], node.args[1]) + + gamma_node = node.args[2] + gamma_id = self.define_tensor(gamma_node, enn_graph, vals_to_ids) + + epsilon = node.args[3] + if isinstance(epsilon, torch.fx.Node): + epsilon = get_tensor(self.exported_program, epsilon) + epsilon = epsilon.item() + + params = {} + params["normalize_shape"] = normalized_shape + params["param_num"] = 2 + params["epsilon"] = epsilon + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op( + node.name, "RMSNORM", [input_id, gamma_id], [output_id], params + ) diff --git a/backends/samsung/builders/op_sigmoid.py b/backends/samsung/builders/op_sigmoid.py new file mode 100644 index 00000000000..e87973f9a85 --- /dev/null +++ b/backends/samsung/builders/op_sigmoid.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class SigmoidVisitor(NodeVisitor): + target = "aten.sigmoid.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op(node.name, "SIGMOID", [input_id], [output_id]) diff --git a/backends/samsung/builders/op_sin.py b/backends/samsung/builders/op_sin.py new file mode 100644 index 00000000000..5fd22e8275e --- /dev/null +++ b/backends/samsung/builders/op_sin.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class SinVisitor(NodeVisitor): + target = "aten.sin.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op(node.name, "Sin", [input_id], [output_id]) diff --git a/backends/samsung/builders/op_split_with_sizes_copy.py b/backends/samsung/builders/op_split_with_sizes_copy.py new file mode 100644 index 00000000000..9cdb34f03f7 --- /dev/null +++ b/backends/samsung/builders/op_split_with_sizes_copy.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class SplitVisitor(NodeVisitor): + target = "aten.split_with_sizes_copy.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ): + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + # output + all_output_tensors = [] + + for output_idx in range(len(node.args[1])): + output_id = self.define_tensor( + node, + enn_graph, + vals_to_ids, + output_idx=output_idx, + ) + all_output_tensors.append(output_id) + + axis = node.args[2] if len(node.args) > 2 else 0 + + params = {} + params["axis"] = axis + params["point"] = node.args[1] + + enn_graph.define_op(node.name, "SPLIT", [input_id], all_output_tensors, params) diff --git a/backends/samsung/builders/op_sum_int_list.py b/backends/samsung/builders/op_sum_int_list.py new file mode 100644 index 00000000000..7743e6632dd --- /dev/null +++ b/backends/samsung/builders/op_sum_int_list.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph +from executorch.backends.transforms import get_shape + + +@register_node_visitor +class SumDimIntListVisitor(NodeVisitor): + target = "aten.sum.dim_IntList" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + reduce_axes = cast(List[int], node.args[1]) + in_shape = get_shape(input) + reduce_axes = [axis % len(in_shape) for axis in reduce_axes] + + keep_dims = cast(bool, node.args[2]) if len(node.args) > 2 else False + params = {"keep_dims": keep_dims, "axis": reduce_axes} + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + enn_graph.define_op(node.name, "REDUCESUM", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_tanh.py b/backends/samsung/builders/op_tanh.py new file mode 100644 index 00000000000..5b002890075 --- /dev/null +++ b/backends/samsung/builders/op_tanh.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class TanhVisitor(NodeVisitor): + target = "aten.tanh.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + enn_graph.define_op(node.name, "TANH", [input_id], [output_id]) diff --git a/backends/samsung/builders/op_topk.py b/backends/samsung/builders/op_topk.py new file mode 100644 index 00000000000..0f7bac1be78 --- /dev/null +++ b/backends/samsung/builders/op_topk.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph +from executorch.backends.transforms import get_shape + + +@register_node_visitor +class TopKVisitor(NodeVisitor): + target = "aten.topk.default" + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + k = cast(int, node.args[1]) + params = {"k_dims": k} + in_shape_len = len(get_shape(input)) + dim = cast(int, node.args[2]) if len(node.args) > 2 else in_shape_len - 1 + if dim < 0: + dim = dim + in_shape_len + if dim != in_shape_len - 1: + raise AssertionError("Not supported dim not being last dimension!") + + all_output_tensors = [] + users = list(node.users.keys()) + output_val_idx = 0 + output_val_id = self.define_tensor( + node, + enn_graph, + vals_to_ids, + output_idx=output_val_idx, + ) + if len(users) > 0 and users[0].target.__name__ == "getitem": + vals_to_ids[users[0]] = output_val_id + all_output_tensors.append(output_val_id) + + output_indices_idx = 1 + output_indices_id = self.define_tensor( + node, + enn_graph, + vals_to_ids, + output_idx=output_indices_idx, + ) + if len(users) > 1 and users[1].target.__name__ == "getitem": + vals_to_ids[users[1]] = output_indices_id + all_output_tensors.append(output_indices_id) + + if len(node.args) > 3: + largest = cast(bool, node.args[3]) + if not largest: + raise AssertionError("Not supported largest = False.") + + if len(node.args) > 4: + sorted = cast(bool, node.args[4]) + if not sorted: + raise AssertionError("Not supported sorted = False.") + + enn_graph.define_op(node.name, "TopK", [input_id], all_output_tensors, params) diff --git a/backends/samsung/enn_preprocess.py b/backends/samsung/enn_preprocess.py index 0847ec0adeb..07114759211 100644 --- a/backends/samsung/enn_preprocess.py +++ b/backends/samsung/enn_preprocess.py @@ -13,6 +13,7 @@ from executorch.backends.samsung._passes.annotate_scalar_parameters import ( AnnotateScalarParametersPass, ) +from executorch.backends.samsung._passes.compose_rms_norm import RecomposeRmsNorm from executorch.backends.samsung._passes.conv1d_to_conv2d import Conv1dToConv2d from executorch.backends.samsung._passes.customized_constant_prop import ( ConstantPropPass, @@ -69,6 +70,7 @@ def preprocess( RemoveGetItemPass(), InsertQDQPass(edge_program), AnnotateScalarParametersPass(edge_program), + RecomposeRmsNorm(), ] ) pass_result = enn_preprocess_passes(edge_program.graph_module) diff --git a/backends/samsung/partition/enn_partitioner.py b/backends/samsung/partition/enn_partitioner.py index 368d069c380..dddba01aefa 100644 --- a/backends/samsung/partition/enn_partitioner.py +++ b/backends/samsung/partition/enn_partitioner.py @@ -38,6 +38,7 @@ exir_ops.edge.aten.sub.Scalar, exir_ops.edge.aten.mul.Scalar, exir_ops.edge.aten.div.Scalar, + exir_ops.edge.aten.pow.Tensor_Scalar, ] diff --git a/backends/samsung/serialization/compile_options.py b/backends/samsung/serialization/compile_options.py index 67d9f57a25d..88c18569a83 100644 --- a/backends/samsung/serialization/compile_options.py +++ b/backends/samsung/serialization/compile_options.py @@ -15,6 +15,7 @@ from executorch.exir._serialize._dataclass import _DataclassEncoder from executorch.exir._serialize._flatbuffer import _flatc_compile +from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import CompileSpec @@ -25,9 +26,20 @@ class SamsungChipset(IntEnum): E9965 = 9965 +@experimental( + "This API is experimental. If you use this mode, you should verify pte file on device farm which can be used on " + "exynos developer society site ( https://soc-developer.semiconductor.samsung.com/)" +) +@unique +class PerformanceMode(IntEnum): + DEFAULT = 0 + HIGH_PERFORMANCE = 1 + + @dataclass class EnnExecuTorchOptions: chipset: SamsungChipset = SamsungChipset.UNDEFINED_CHIP_V + perf_mode: PerformanceMode = PerformanceMode.DEFAULT ENN_COMPILE_OPTION_TITLE = "enn_compile_options" @@ -61,6 +73,7 @@ def gen_samsung_backend_compile_spec_core(options: EnnExecuTorchOptions) -> Comp def gen_samsung_backend_compile_spec( chipset: str, + perf_mode: PerformanceMode = None, ): """ A function to generate an ExecuTorch binary for Samsung Backend. @@ -71,8 +84,12 @@ def gen_samsung_backend_compile_spec( Returns: CompileSpec: key is COMPILE_OPTION_SCHEMA_NAME, value is serialization binary of fb schema """ + + perf_mode = PerformanceMode.DEFAULT if perf_mode is None else perf_mode + option = EnnExecuTorchOptions( getattr(SamsungChipset, chipset.upper()), + perf_mode, ) return gen_samsung_backend_compile_spec_core(option) diff --git a/backends/samsung/serialization/compile_options_def.fbs b/backends/samsung/serialization/compile_options_def.fbs index d38c2772715..ac7b2a6322d 100644 --- a/backends/samsung/serialization/compile_options_def.fbs +++ b/backends/samsung/serialization/compile_options_def.fbs @@ -14,10 +14,16 @@ file_identifier "EETO"; // Extension of written files. file_extension "eeto"; +enum PerformanceMode : byte { + DEFAULT = 0, + HIGH_PERFORMANCE = 1, +} table EnnExecuTorchOptions { // The version of chipset. Specify the soc to compile and execute model. chipset: int; } + perf_mode: PerformanceMode = DEFAULT; + root_type EnnExecuTorchOptions; diff --git a/backends/samsung/test/ops/test_cos.py b/backends/samsung/test/ops/test_cos.py new file mode 100644 index 00000000000..df77ec217da --- /dev/null +++ b/backends/samsung/test/ops/test_cos.py @@ -0,0 +1,46 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Cos(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cos(x) + + +class TestCos(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.cos.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_cos_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_cos(self): + inputs = (torch.randn(1, 3, 56, 56),) + self._test(Cos(), inputs) diff --git a/backends/samsung/test/ops/test_group_norm.py b/backends/samsung/test/ops/test_group_norm.py new file mode 100644 index 00000000000..b323b3de572 --- /dev/null +++ b/backends/samsung/test/ops/test_group_norm.py @@ -0,0 +1,52 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details.. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class GroupNorm(torch.nn.Module): + def __init__(self, groups, in_channels) -> None: + super().__init__() + self.in_channels = in_channels + self.module = torch.nn.GroupNorm(groups, self.in_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + + +class TestGroupNorm(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.group_norm.default": 1}) + .to_edge_transform_and_lower() + .check_not( + ["executorch_exir_dialects_edge__ops_aten_native_group_norm_default"] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_group_norm(self): + groups = 3 + in_channels = 12 + inputs = (torch.randn(1, in_channels, 8, 8),) + self._test(GroupNorm(groups, in_channels), inputs) diff --git a/backends/samsung/test/ops/test_index.py b/backends/samsung/test/ops/test_index.py new file mode 100644 index 00000000000..ba7ab02bb53 --- /dev/null +++ b/backends/samsung/test/ops/test_index.py @@ -0,0 +1,72 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Index(torch.nn.Module): + def __init__(self, indices: tuple[torch.Tensor]) -> None: + super().__init__() + self.indices = indices + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[self.indices] + + +class TestIndex(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.index.Tensor": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_index_Tensor"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_index_on_axis0(self): + indices = torch.tensor([2, 0], dtype=torch.int32) + inputs = (torch.randn(4, 16, 8, 8),) + self._test(Index(indices), inputs) + + def test_fp32_index_on_axis1(self): + indices = (slice(None), torch.tensor([0, 3, 2, 11, 8], dtype=torch.int32)) + inputs = (torch.randn(4, 16, 8, 8),) + self._test(Index(indices), inputs) + + def test_fp32_index_on_axis2(self): + indices = ( + slice(None), + slice(None), + torch.tensor([1, 2, 5, 6], dtype=torch.int32), + ) + inputs = (torch.randn(4, 16, 8, 8),) + self._test(Index(indices), inputs) + + def test_fp32_index_on_axis3(self): + indices = ( + slice(None), + slice(None), + slice(None), + torch.tensor([0, 3, 6, 4], dtype=torch.int32), + ) + inputs = (torch.randn(4, 16, 8, 8),) + self._test(Index(indices), inputs) diff --git a/backends/samsung/test/ops/test_log.py b/backends/samsung/test/ops/test_log.py new file mode 100644 index 00000000000..f7e4c8aa895 --- /dev/null +++ b/backends/samsung/test/ops/test_log.py @@ -0,0 +1,47 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Log(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.log(x) + + +class TestLog(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.log.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten__log_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + + def test_fp32_log(self): + inputs = (torch.randn(1, 16, 56, 56),) + self._test(Log(), inputs) diff --git a/backends/samsung/test/ops/test_pow.py b/backends/samsung/test/ops/test_pow.py new file mode 100644 index 00000000000..002e80fc974 --- /dev/null +++ b/backends/samsung/test/ops/test_pow.py @@ -0,0 +1,46 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Pow(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**2.0 + + +class TestPow(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.pow.Tensor_Scalar": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_pow_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_pow(self): + inputs = (torch.randn(1, 1, 16, 8),) + self._test(Pow(), inputs) diff --git a/backends/samsung/test/ops/test_rms_norm.py b/backends/samsung/test/ops/test_rms_norm.py new file mode 100644 index 00000000000..ddf0640760f --- /dev/null +++ b/backends/samsung/test/ops/test_rms_norm.py @@ -0,0 +1,52 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest +from typing import List + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalize_shape: List[int]) -> None: + super().__init__() + self.normalize_shape = normalize_shape + self.module = torch.nn.RMSNorm(self.normalize_shape, eps=1e-5) + self.module.weight = torch.nn.Parameter(torch.ones(self.normalize_shape) * 2.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + + +class TestRMSNorm(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + + ( + tester.export() + .check_count({"torch.ops.aten.rms_norm.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_rms_norm_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_rms_norm(self): + normalize_shape = [196] + inputs = (torch.randn(1, *normalize_shape),) + self._test(RMSNorm(normalize_shape), inputs) diff --git a/backends/samsung/test/ops/test_sigmoid.py b/backends/samsung/test/ops/test_sigmoid.py new file mode 100644 index 00000000000..f3836c7b277 --- /dev/null +++ b/backends/samsung/test/ops/test_sigmoid.py @@ -0,0 +1,47 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Sigmoid(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.module = torch.nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + + +class TestSigmoid(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.sigmoid.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_sigmoid(self): + inputs = (torch.randn(1, 4, 32, 32),) + self._test(Sigmoid(), inputs) diff --git a/backends/samsung/test/ops/test_sin.py b/backends/samsung/test/ops/test_sin.py new file mode 100644 index 00000000000..4634fe32117 --- /dev/null +++ b/backends/samsung/test/ops/test_sin.py @@ -0,0 +1,46 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Sin(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + + +class TestSin(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.sin.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_sin_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_sin(self): + inputs = (torch.randn(1, 3, 56, 56),) + self._test(Sin(), inputs) diff --git a/backends/samsung/test/ops/test_split_with_sizes_copy.py b/backends/samsung/test/ops/test_split_with_sizes_copy.py new file mode 100644 index 00000000000..ca566f1e831 --- /dev/null +++ b/backends/samsung/test/ops/test_split_with_sizes_copy.py @@ -0,0 +1,58 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Split(torch.nn.Module): + def __init__(self, split_sizes=1, dim=0) -> None: + super().__init__() + self.split_sizes = split_sizes + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.split(x, self.split_sizes, dim=self.dim) + + +class TestSplit(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + + ( + tester.export() + .to_edge_transform_and_lower() + .check_not( + ["executorch_exir_dialects_edge__ops_aten_split_with_sizes_default"] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_split_default(self): + inputs = (torch.randn(6, 6),) + self._test(Split(3), inputs) + + def test_fp32_split_chunk(self): + inputs = (torch.randn(6, 6),) + self._test(Split([3, 3]), inputs) + + def test_fp32_split_dim1(self): + inputs = (torch.randn(6, 6),) + self._test(Split([3, 2, 1], dim=1), inputs) diff --git a/backends/samsung/test/ops/test_sum_int_list.py b/backends/samsung/test/ops/test_sum_int_list.py new file mode 100644 index 00000000000..0c792298ea3 --- /dev/null +++ b/backends/samsung/test/ops/test_sum_int_list.py @@ -0,0 +1,51 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class SumDimIntList(torch.nn.Module): + def __init__(self, keep_dims=True) -> None: + super().__init__() + self.keep_dims = keep_dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sum(x, [2, 3], keepdim=self.keep_dims) + + +class TestSumDimIntList(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_sum_dim_IntList"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(atol=0.005, rtol=0.005) + ) + + def test_fp32_sum_dim_with_keep_dims(self): + inputs = (torch.randn(1, 16, 8, 8),) + self._test(SumDimIntList(), inputs) + + def test_fp32_sum_dim_without_keep_dims(self): + inputs = (torch.randn(1, 16, 8, 8),) + self._test(SumDimIntList(keep_dims=False), inputs) diff --git a/backends/samsung/test/ops/test_tanh.py b/backends/samsung/test/ops/test_tanh.py new file mode 100644 index 00000000000..edd6cff2800 --- /dev/null +++ b/backends/samsung/test/ops/test_tanh.py @@ -0,0 +1,47 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class Tanh(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.tanh = torch.nn.Tanh().to(torch.float) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.tanh(x) + + +class TestTanh(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.tanh.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(atol=0.008) + ) + + def test_fp32_tanh(self): + inputs = (torch.randn(1, 3, 8, 8),) + self._test(Tanh(), inputs) diff --git a/backends/samsung/test/ops/test_topk.py b/backends/samsung/test/ops/test_topk.py new file mode 100644 index 00000000000..4df439d1a2d --- /dev/null +++ b/backends/samsung/test/ops/test_topk.py @@ -0,0 +1,52 @@ +# Copyright (c) Samsung Electronics Co. LTD +# All rights reserved +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import unittest + +import torch + +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.test.tester import SamsungTester +from executorch.backends.samsung.test.utils.utils import TestConfig + + +class TopK(torch.nn.Module): + def __init__(self, k=1, dim=-1) -> None: + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return tuple(torch.topk(x, self.k, dim=self.dim)) + + +class TestTopK(unittest.TestCase): + def _test(self, module: torch.nn.Module, inputs): + tester = SamsungTester( + module, + inputs, + [gen_samsung_backend_compile_spec(TestConfig.chipset)], + ) + ( + tester.export() + .check_count({"torch.ops.aten.topk.default": 1}) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_topk_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs() + ) + + def test_fp32_topk_dim3(self): + inputs = (torch.randn(1, 16, 8, 8),) + self._test(TopK(k=5, dim=3), inputs) + + def test_fp32_topk_dim_negative1(self): + inputs = (torch.randn(1, 16, 8, 8),) + self._test(TopK(k=5, dim=-1), inputs) diff --git a/examples/samsung/README.md b/examples/samsung/README.md index e72530cbcae..8b21c48a34f 100644 --- a/examples/samsung/README.md +++ b/examples/samsung/README.md @@ -38,6 +38,10 @@ Take `EXECUTORCH_ROOT` as work directory and here is an example for ic3. python -m executorch.examples.samsung.aot_compiler --chipset E9955 (or E9965) -m ic3 --output_dir ic3_artifact ``` +Examples use "PerformanceMode.HIGH_PERFORMANCE" mode, this mode is experimental. +If you want to use this mode on your model, verify your model on devicefarm which can use samsung developer society site +firstly for checking stability. (https://soc-developer.semiconductor.samsung.com/) + ## Execution After lowering, we could get a pte model and then run it on mobile phone. diff --git a/examples/samsung/aot_compiler.py b/examples/samsung/aot_compiler.py index 30e60d0e503..1ea97566cc0 100644 --- a/examples/samsung/aot_compiler.py +++ b/examples/samsung/aot_compiler.py @@ -9,6 +9,7 @@ from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( to_edge_transform_and_lower_to_enn, @@ -74,7 +75,12 @@ model = model.eval() outputs = model(*example_inputs) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec( + args.chipset, + PerformanceMode.HIGH_PERFORMANCE, + ) + ] edge = to_edge_transform_and_lower_to_enn( model, example_inputs, compile_specs=compile_specs ) diff --git a/examples/samsung/scripts/deeplab_v3.py b/examples/samsung/scripts/deeplab_v3.py index 8538515db6d..8484eb17cd5 100644 --- a/examples/samsung/scripts/deeplab_v3.py +++ b/examples/samsung/scripts/deeplab_v3.py @@ -15,6 +15,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -143,7 +144,9 @@ def get_dataset( test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/edsr.py b/examples/samsung/scripts/edsr.py index e501f16f099..8237b737a2c 100644 --- a/examples/samsung/scripts/edsr.py +++ b/examples/samsung/scripts/edsr.py @@ -12,6 +12,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -156,7 +157,9 @@ def __call__(self, x): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/inception_v3.py b/examples/samsung/scripts/inception_v3.py index 17b745b239d..f562037ac7e 100644 --- a/examples/samsung/scripts/inception_v3.py +++ b/examples/samsung/scripts/inception_v3.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -144,7 +145,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/inception_v4.py b/examples/samsung/scripts/inception_v4.py index b47a853759f..acdd1da5d05 100644 --- a/examples/samsung/scripts/inception_v4.py +++ b/examples/samsung/scripts/inception_v4.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -142,7 +143,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/mobilebert_finetune.py b/examples/samsung/scripts/mobilebert_finetune.py index d90bfc3f8f1..25e54305a9d 100644 --- a/examples/samsung/scripts/mobilebert_finetune.py +++ b/examples/samsung/scripts/mobilebert_finetune.py @@ -12,6 +12,7 @@ from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( to_edge_transform_and_lower_to_enn, @@ -245,7 +246,9 @@ def validate(self, model, val_data_loader): example_inputs = mobilebert_finetune.get_example_inputs() output = model(*example_inputs) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] edge = to_edge_transform_and_lower_to_enn( model, example_inputs, compile_specs=compile_specs ) diff --git a/examples/samsung/scripts/mobilenet_v2.py b/examples/samsung/scripts/mobilenet_v2.py index e36fa6d9de9..aa594e0cd24 100644 --- a/examples/samsung/scripts/mobilenet_v2.py +++ b/examples/samsung/scripts/mobilenet_v2.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -144,7 +145,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/mobilenet_v3.py b/examples/samsung/scripts/mobilenet_v3.py index 83361522f4d..a35041b4b3e 100644 --- a/examples/samsung/scripts/mobilenet_v3.py +++ b/examples/samsung/scripts/mobilenet_v3.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -144,7 +145,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/resnet18.py b/examples/samsung/scripts/resnet18.py index 3fd263f61af..3835aa4a51a 100644 --- a/examples/samsung/scripts/resnet18.py +++ b/examples/samsung/scripts/resnet18.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -144,7 +145,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/resnet50.py b/examples/samsung/scripts/resnet50.py index 04dd05b0c1a..40c4dbfb0c0 100644 --- a/examples/samsung/scripts/resnet50.py +++ b/examples/samsung/scripts/resnet50.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -144,7 +145,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/vit.py b/examples/samsung/scripts/vit.py index 7393cd5c53e..96767cad8d8 100644 --- a/examples/samsung/scripts/vit.py +++ b/examples/samsung/scripts/vit.py @@ -13,6 +13,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -144,7 +145,9 @@ def get_data_loader(): test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module( diff --git a/examples/samsung/scripts/wav2letter.py b/examples/samsung/scripts/wav2letter.py index 79202d7ad63..9f95cdf5747 100644 --- a/examples/samsung/scripts/wav2letter.py +++ b/examples/samsung/scripts/wav2letter.py @@ -14,6 +14,7 @@ from executorch.backends.samsung.quantizer import Precision from executorch.backends.samsung.serialization.compile_options import ( gen_samsung_backend_compile_spec, + PerformanceMode, ) from executorch.backends.samsung.utils.export_utils import ( quantize_module, @@ -210,7 +211,9 @@ def get_dataset( test_in = inputs[0] float_out = model(*test_in) - compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + compile_specs = [ + gen_samsung_backend_compile_spec(args.chipset, PerformanceMode.HIGH_PERFORMANCE) + ] if args.precision: model = quantize_module(