Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/setup-samsung-linux-deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ while [[ $# -gt 0 ]]; do
esac
done

LITECORE_VERSION="v1.0"
LITECORE_VERSION="v1.1"
LITECORE_FILE_NAME="ai-litecore-ubuntu2204-${LITECORE_VERSION}.tar.gz"
DEVICEFARM_CLI_VERSION="beta-v1.1.0"
DEVICEFARM_FILE_NAME="devicefarmcli-${DEVICEFARM_CLI_VERSION}.zip"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,7 @@ jobs:
fi

# Test models
#python -m executorch.backends.samsung.test.utils.run_tests --chipset E9955
python -m unittest discover -s backends/samsung/test/models -p "test_*.py"

test-vulkan-models-linux:
Expand Down
39 changes: 37 additions & 2 deletions backends/samsung/_passes/annotate_qparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch._export.utils import get_buffer
from torch.export import ExportedProgram
from torch.fx import GraphModule, Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateQparamsPass(ExportPass):
Expand Down Expand Up @@ -148,13 +149,34 @@ def _check_same(requant_obj, ori_obj) -> bool:
_check_same(ori_quant_attrs[key], requantize_attrs[key])
for key in key_map.values()
):
requantize_map[idx] = requantize_attrs
if (
ori_quant_attrs[QuantConstants.QUANT_KEY.quant_dtype]
!= requantize_attrs[QuantConstants.QUANT_KEY.quant_dtype]
):
# For Q-DQ who will change quant dtype, we will insert requantization node
requantize_map[idx] = requantize_attrs
else:
node.meta["quantize_attrs"] = requantize_attrs

def _annotate(self, graph_module: GraphModule):
for node in graph_module.graph.nodes:
if key_map := QuantConstants.DEQUANT_OPS_KEY_MAP.get(node.target, None):
# We will fold node with constant output in the future pass as a constant node
# example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one
# We need to store the q-params from last DQ params for quantizing constant value
quant_attrs = self.get_quant_attrs(node, key_map)
if node.args[0].target in QuantConstants.QUANT_OPS_KEY_MAP:
node.meta["quantize_attrs"] = quant_attrs
else:
node.args[0].meta["quantize_attrs"] = quant_attrs
continue
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
if not key_map:
continue
quant_attrs = self.get_quant_attrs(node, key_map)
if node.args[0].target in QuantConstants.QUANT_OPS_KEY_MAP:
node.meta["quantize_attrs"] = quant_attrs
continue
source_node = node.args[0]
if source_node.target in (
*QuantConstants.QUANT_OPS_KEY_MAP,
Expand All @@ -164,13 +186,26 @@ def _annotate(self, graph_module: GraphModule):
continue
elif source_node.target == operator.getitem:
source_node = source_node.args[0]
quant_attrs = self.get_quant_attrs(node, key_map)

source_node.meta["quantize_attrs"] = quant_attrs
self._annotate_requantize(source_node)
self._propagate_quant_params(source_node)

def _annotate_decomposed_mm(self, graph_module: GraphModule):
for source_list in get_source_partitions(graph_module.graph, ["matmul"]).get(
"matmul", {}
):
final_view = source_list.output_nodes[0]
if not (quantize_attrs := final_view.meta.get("quantize_attrs")):
continue
for node in source_list.nodes:
if node.target == exir_ops.edge.aten.bmm.default:
node.meta["quantize_attrs"] = quantize_attrs
break

def call(self, graph_module: GraphModule):
self._annotate(graph_module)
self._annotate_decomposed_mm(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)

Expand Down
50 changes: 30 additions & 20 deletions backends/samsung/_passes/annotate_scalar_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.samsung.quantizer.quantizer import global_quant_info
from executorch.backends.samsung.utils.constants import QuantConstants
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -25,6 +24,7 @@ class AnnotateScalarParametersPass(ExportPass):
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.sub.Tensor,
}

def __init__(self, edge_program: ExportedProgram):
Expand All @@ -35,27 +35,37 @@ def annotate(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta:
continue
torch_quant_dtype = global_quant_info.weight_precison.torch_dtype
for input_arg in node.all_input_nodes:
if input_arg.op not in ("placeholder", "get_attr") or not is_param_node(
self.edge_program, input_arg
input0, input1 = node.all_input_nodes[0], node.all_input_nodes[1]
if input0.op not in ("placeholder", "get_attr") or not is_param_node(
self.edge_program, input0
):
if input1.op not in ("placeholder", "get_attr") or not is_param_node(
self.edge_program, input1
):
continue
else:
tensor = get_param_tensor(self.edge_program, input_arg)
if not tensor.shape:
qparams = {
QuantConstants.QUANT_KEY.scale: float(tensor),
QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype,
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(
torch_quant_dtype
).max,
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(
torch_quant_dtype
).min,
QuantConstants.QUANT_KEY.zero_point: 0,
}
input_arg.meta["quantize_attrs"] = qparams
ifm_node, param_tensor_node = input0, input1
else:
ifm_node, param_tensor_node = input1, input0
if not (quantize_attrs := ifm_node.meta.get("quantize_attrs")):
continue
param_tensor = get_param_tensor(self.edge_program, param_tensor_node)
if not param_tensor.shape:
scale = (
float(param_tensor) if param_tensor > 0 else -float(param_tensor)
)
else:
continue
q_dtype = quantize_attrs[QuantConstants.QUANT_KEY.quant_dtype]
if scale == 0:
scale = 1.0
qparams = {
QuantConstants.QUANT_KEY.scale: scale,
QuantConstants.QUANT_KEY.quant_dtype: q_dtype,
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(q_dtype).max,
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(q_dtype).min,
QuantConstants.QUANT_KEY.zero_point: 0,
}
param_tensor_node.meta["quantize_attrs"] = qparams

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def map_hardtan_relux(tanhnode: torch.fx.node.Node) -> Optional[str]:
return None


class FuseConvActPass(ExportPass):
class FuseActivationPass(ExportPass):
TARGET_ACTS_MAP = {
exir_ops.edge.aten.relu.default: (lambda x: "RELU"),
exir_ops.edge.aten.relu_.default: (lambda x: "RELU"),
Expand All @@ -33,45 +33,45 @@ class FuseConvActPass(ExportPass):
exir_ops.edge.aten.hardtanh.default: map_hardtan_relux,
exir_ops.edge.aten.hardtanh_.default: map_hardtan_relux,
}
TARGET_SOURCE_NODES = {
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.linear.default,
}

def _fuse(
self,
graph_module: GraphModule,
):
for target_conv, target_act in self.get_target_conv_act(graph_module):
for target_src, target_act in self.get_target_src_act(graph_module):
assert (
act_name := self.TARGET_ACTS_MAP.get(target_act.target)(target_act)
), f"Not supported {target_act.name} now."
target_conv.meta["activation"] = act_name
target_src.meta["activation"] = act_name
if "quantize_attrs" in target_act.meta:
target_conv.meta["quantize_attrs"] = target_act.meta["quantize_attrs"]

# If we merge the real out activation to conv, the conv should be the real out
if "real_out" in target_act.meta:
target_conv.meta["real_out"] = target_act.meta["real_out"]
target_src.meta["quantize_attrs"] = target_act.meta["quantize_attrs"]
else:
continue
for user in [user for user in target_act.users.keys()]: # noqa: C416
user.replace_input_with(target_act, target_conv)
user.replace_input_with(target_act, target_src)
graph_module.graph.erase_node(target_act)

def get_target_conv_act(self, graph_module: GraphModule):
def get_target_src_act(self, graph_module: GraphModule):
for node in graph_module.graph.nodes:
if node.target != exir_ops.edge.aten.convolution.default:
if node.target not in self.TARGET_SOURCE_NODES:
continue
if len(node.users) != 1:
# Such cases couldn't be conv + act
# Such cases couldn't be src + act
continue
act_node = list(node.users.keys())[0]
if act_node.target not in self.TARGET_ACTS_MAP:
continue
if "quantize_attrs" in node.meta:
# If the conv's output is quantized
# We do not fuse them
# If we merge the real out activation to source, the source should be the real out
continue
yield node, act_node

def call(self, graph_module: GraphModule):
self._fuse(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
_ = super().call(graph_module).graph_module
return PassResult(graph_module, True)
9 changes: 9 additions & 0 deletions backends/samsung/_passes/insert_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,18 @@ def _add_qdq(self, graph_module: GraphModule):
elif is_graph_output(node):
self._add_dq_after(graph_module, node)

def _add_q_for_cast(self, graph_module: GraphModule):
for node in list(graph_module.graph.nodes):
if not node.target == exir_ops.edge.aten._to_copy.default:
continue
if "quantize_attrs" not in node.meta:
continue
self._add_q_after(graph_module, node)

def call(self, graph_module: GraphModule):
self._add_qdq(graph_module)
self._add_qdq_for_requantize(graph_module)
self._add_q_for_cast(graph_module)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
105 changes: 105 additions & 0 deletions backends/samsung/_passes/transform_quantized_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.samsung.utils.constants import QuantConstants
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass
from torch.export import ExportedProgram
from torch.fx import GraphModule


class TransformQuantizedMaskPass(ExportPass):
def __init__(self, edge_program: ExportedProgram):
super().__init__()
self.edge_program = edge_program

def get_mask_mul(self, graph_module: GraphModule):
"""
Iterator for each patterns in the graph.
The obj returned by iterator is the first node of the pattern.
"""
nodes_in_pattern = (
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.mul.Tensor,
)
mask_node = None
for node in graph_module.graph.nodes:
if node.target != "attention_mask":
continue
else:
mask_node = node
break
if mask_node is None:
return None
while node.target != exir_ops.edge.aten.mul.Tensor:
find_next = False
for successor in list(node.users.keys()):
if successor.target in nodes_in_pattern:
node = successor
find_next = True
break
if not find_next:
return None
return node

def transform(
self,
graph_module: GraphModule,
):
mask_mul = self.get_mask_mul(graph_module)
if mask_mul is None:
return
rsub_node = mask_mul.args[0]
manual_mul_idx = 0
for add in list(mask_mul.users.keys()):
custom_tensor_name = f"_custom_tensor_{manual_mul_idx}"
div_node = add.args[0]
if "quantize_attrs" not in div_node.meta:
return
div_quant_args = div_node.meta["quantize_attrs"]
custom_tensor = torch.tensor(
(
div_node.meta["quantize_attrs"][QuantConstants.QUANT_KEY.quant_min]
- div_node.meta["quantize_attrs"][
QuantConstants.QUANT_KEY.zero_point
]
)
* div_node.meta["quantize_attrs"][QuantConstants.QUANT_KEY.scale],
dtype=torch.float32,
)
graph_module.register_buffer(custom_tensor_name, custom_tensor)
add.meta["quantize_attrs"] = div_quant_args
with graph_module.graph.inserting_after(rsub_node):
custom_attr = graph_module.graph.get_attr(custom_tensor_name)
with graph_module.graph.inserting_after(custom_attr):
new_mul = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.mul.Tensor,
(mask_mul.args[0], custom_attr),
)
new_mul.meta["quantize_attrs"] = div_quant_args
add.replace_input_with(mask_mul, new_mul)

rsub_in = rsub_node.args[1]
with graph_module.graph.inserting_before(add):
new_mul = graph_module.graph.create_node(
"call_function", exir_ops.edge.aten.mul.Tensor, (div_node, rsub_in)
)
new_mul.meta["quantize_attrs"] = div_quant_args
add.replace_input_with(div_node, new_mul)
manual_mul_idx += 1

def call(self, graph_module: GraphModule):
self.transform(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/samsung/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
op_mul,
op_permute,
op_pixel_shuffle,
op_placeholder,
op_quantize,
op_relu,
op_reshape,
Expand Down Expand Up @@ -80,6 +81,7 @@
op_mul,
op_permute,
op_pixel_shuffle,
op_placeholder,
op_quantize,
op_relu,
op_reshape,
Expand Down
2 changes: 1 addition & 1 deletion backends/samsung/builders/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ def define_node(
"padding": "EXPLICIT",
"padding_type": "CONSTANT",
}

self._update_params_qdtype(node, params)
enn_graph.define_op(node.name, "PAD", [input_id], [output_id], params)
1 change: 1 addition & 0 deletions backends/samsung/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def define_node(
output_id = self.define_tensor(node, enn_graph, vals_to_ids)

params = {"axis": 0, "input_type": "indices"}
self._update_params_qdtype(node, params)
enn_graph.define_op(
node.name, "GATHER", [input_id, weight_id], [output_id], params
)
Loading
Loading