Skip to content

Commit ddd992f

Browse files
Jiseong-ohChen03ZhaoSamsungSangsoo.ko
committed
Support Quantized MobileBert
- update annotator - Support quantized mobilebert - update Quantization strategy Co-authored-by: chen.zhao <chen03.zhao@samsung.com> Co-authored-by: Sangsoo.ko <sangsoo.ko@samsung.com> Signed-off-by: jiseong.oh <jiseong.oh@samsung.com>
1 parent b3345fb commit ddd992f

File tree

18 files changed

+158
-246
lines changed

18 files changed

+158
-246
lines changed

backends/samsung/_passes/annotate_qparams.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch._export.utils import get_buffer
1515
from torch.export import ExportedProgram
1616
from torch.fx import GraphModule, Node
17+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1718

1819

1920
class AnnotateQparamsPass(ExportPass):
@@ -148,13 +149,34 @@ def _check_same(requant_obj, ori_obj) -> bool:
148149
_check_same(ori_quant_attrs[key], requantize_attrs[key])
149150
for key in key_map.values()
150151
):
151-
requantize_map[idx] = requantize_attrs
152+
if (
153+
ori_quant_attrs[QuantConstants.QUANT_KEY.quant_dtype]
154+
!= requantize_attrs[QuantConstants.QUANT_KEY.quant_dtype]
155+
):
156+
# For Q-DQ who will change quant dtype, we will insert requantization node
157+
requantize_map[idx] = requantize_attrs
158+
else:
159+
node.meta["quantize_attrs"] = requantize_attrs
152160

153161
def _annotate(self, graph_module: GraphModule):
154162
for node in graph_module.graph.nodes:
163+
if key_map := QuantConstants.DEQUANT_OPS_KEY_MAP.get(node.target, None):
164+
# We will fold node with constant output in the future pass as a constant node
165+
# example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one
166+
# We need to store the q-params from last DQ params for quantizing constant value
167+
quant_attrs = self.get_quant_attrs(node, key_map)
168+
if node.args[0].target in QuantConstants.QUANT_OPS_KEY_MAP:
169+
node.meta["quantize_attrs"] = quant_attrs
170+
else:
171+
node.args[0].meta["quantize_attrs"] = quant_attrs
172+
continue
155173
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
156174
if not key_map:
157175
continue
176+
quant_attrs = self.get_quant_attrs(node, key_map)
177+
if node.args[0].target in QuantConstants.QUANT_OPS_KEY_MAP:
178+
node.meta["quantize_attrs"] = quant_attrs
179+
continue
158180
source_node = node.args[0]
159181
if source_node.target in (
160182
*QuantConstants.QUANT_OPS_KEY_MAP,
@@ -164,13 +186,26 @@ def _annotate(self, graph_module: GraphModule):
164186
continue
165187
elif source_node.target == operator.getitem:
166188
source_node = source_node.args[0]
167-
quant_attrs = self.get_quant_attrs(node, key_map)
189+
168190
source_node.meta["quantize_attrs"] = quant_attrs
169191
self._annotate_requantize(source_node)
170192
self._propagate_quant_params(source_node)
171193

194+
def _annotate_decomposed_mm(self, graph_module: GraphModule):
195+
for source_list in get_source_partitions(graph_module.graph, ["matmul"]).get(
196+
"matmul", {}
197+
):
198+
final_view = source_list.output_nodes[0]
199+
if not (quantize_attrs := final_view.meta.get("quantize_attrs")):
200+
continue
201+
for node in source_list.nodes:
202+
if node.target == exir_ops.edge.aten.bmm.default:
203+
node.meta["quantize_attrs"] = quantize_attrs
204+
break
205+
172206
def call(self, graph_module: GraphModule):
173207
self._annotate(graph_module)
208+
self._annotate_decomposed_mm(graph_module)
174209
graph_module.recompile()
175210
return PassResult(graph_module, True)
176211

backends/samsung/_passes/annotate_scalar_parameters.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from executorch.backends.samsung.quantizer.quantizer import global_quant_info
98
from executorch.backends.samsung.utils.constants import QuantConstants
109
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
1110
from executorch.exir.dialects._ops import ops as exir_ops
@@ -25,6 +24,7 @@ class AnnotateScalarParametersPass(ExportPass):
2524
exir_ops.edge.aten.mul.Tensor,
2625
exir_ops.edge.aten.add.Tensor,
2726
exir_ops.edge.aten.div.Tensor,
27+
exir_ops.edge.aten.sub.Tensor,
2828
}
2929

3030
def __init__(self, edge_program: ExportedProgram):
@@ -35,27 +35,37 @@ def annotate(self, graph_module: torch.fx.GraphModule):
3535
for node in graph_module.graph.nodes:
3636
if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta:
3737
continue
38-
torch_quant_dtype = global_quant_info.weight_precison.torch_dtype
39-
for input_arg in node.all_input_nodes:
40-
if input_arg.op not in ("placeholder", "get_attr") or not is_param_node(
41-
self.edge_program, input_arg
38+
input0, input1 = node.all_input_nodes[0], node.all_input_nodes[1]
39+
if input0.op not in ("placeholder", "get_attr") or not is_param_node(
40+
self.edge_program, input0
41+
):
42+
if input1.op not in ("placeholder", "get_attr") or not is_param_node(
43+
self.edge_program, input1
4244
):
4345
continue
44-
else:
45-
tensor = get_param_tensor(self.edge_program, input_arg)
46-
if not tensor.shape:
47-
qparams = {
48-
QuantConstants.QUANT_KEY.scale: float(tensor),
49-
QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype,
50-
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(
51-
torch_quant_dtype
52-
).max,
53-
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(
54-
torch_quant_dtype
55-
).min,
56-
QuantConstants.QUANT_KEY.zero_point: 0,
57-
}
58-
input_arg.meta["quantize_attrs"] = qparams
46+
ifm_node, param_tensor_node = input0, input1
47+
else:
48+
ifm_node, param_tensor_node = input1, input0
49+
if not (quantize_attrs := ifm_node.meta.get("quantize_attrs")):
50+
continue
51+
param_tensor = get_param_tensor(self.edge_program, param_tensor_node)
52+
if not param_tensor.shape:
53+
scale = (
54+
float(param_tensor) if param_tensor > 0 else -float(param_tensor)
55+
)
56+
else:
57+
continue
58+
q_dtype = quantize_attrs[QuantConstants.QUANT_KEY.quant_dtype]
59+
if scale == 0:
60+
scale = 1.0
61+
qparams = {
62+
QuantConstants.QUANT_KEY.scale: scale,
63+
QuantConstants.QUANT_KEY.quant_dtype: q_dtype,
64+
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(q_dtype).max,
65+
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(q_dtype).min,
66+
QuantConstants.QUANT_KEY.zero_point: 0,
67+
}
68+
param_tensor_node.meta["quantize_attrs"] = qparams
5969

6070
def call(self, graph_module: torch.fx.GraphModule):
6171
graph = graph_module.graph

backends/samsung/_passes/fuse_conv_act.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

backends/samsung/_passes/insert_qdq.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,18 @@ def _add_qdq(self, graph_module: GraphModule):
156156
elif is_graph_output(node):
157157
self._add_dq_after(graph_module, node)
158158

159+
def _add_q_for_cast(self, graph_module: GraphModule):
160+
for node in list(graph_module.graph.nodes):
161+
if not node.target == exir_ops.edge.aten._to_copy.default:
162+
continue
163+
if "quantize_attrs" not in node.meta:
164+
continue
165+
self._add_q_after(graph_module, node)
166+
159167
def call(self, graph_module: GraphModule):
160168
self._add_qdq(graph_module)
161169
self._add_qdq_for_requantize(graph_module)
170+
self._add_q_for_cast(graph_module)
162171
graph_module.graph.eliminate_dead_code()
163172
graph_module.recompile()
164173
return PassResult(graph_module, True)

backends/samsung/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_mul,
3535
op_permute,
3636
op_pixel_shuffle,
37+
op_placeholder,
3738
op_quantize,
3839
op_relu,
3940
op_reshape,
@@ -80,6 +81,7 @@
8081
op_mul,
8182
op_permute,
8283
op_pixel_shuffle,
84+
op_placeholder,
8385
op_quantize,
8486
op_relu,
8587
op_reshape,

backends/samsung/builders/op_constant_pad_nd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ def define_node(
5252
"padding": "EXPLICIT",
5353
"padding_type": "CONSTANT",
5454
}
55-
55+
self._update_params_qdtype(node, params)
5656
enn_graph.define_op(node.name, "PAD", [input_id], [output_id], params)

backends/samsung/builders/op_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_node(
3636
output_id = self.define_tensor(node, enn_graph, vals_to_ids)
3737

3838
params = {"axis": 0, "input_type": "indices"}
39+
self._update_params_qdtype(node, params)
3940
enn_graph.define_op(
4041
node.name, "GATHER", [input_id, weight_id], [output_id], params
4142
)

backends/samsung/builders/op_slice_copy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ def define_node(
3838
dim = cast(int, node.args[1])
3939
if dim < 0:
4040
dim = dim + len(in_shape)
41-
start_val = cast(int, node.args[2])
41+
start_val = cast(int, node.args[2]) if node.args[2] else 0
4242
if start_val < 0:
4343
start_val = start_val + in_shape[dim]
44-
end_val = min(cast(int, node.args[3]), in_shape[dim])
44+
end_val = (
45+
in_shape[dim]
46+
if len(node.args) < 4
47+
else min(cast(int, node.args[3]), in_shape[dim])
48+
)
4549
if end_val < 0:
4650
end_val = end_val + in_shape[dim]
4751

backends/samsung/builders/op_sub.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,8 @@ def define_node(
3636
# output
3737
output_id = self.define_tensor(node, enn_graph, vals_to_ids)
3838

39-
enn_graph.define_op(node.name, "SUB", [input_id_1, input_id_2], [output_id])
39+
params = {}
40+
self._update_params_qdtype(node, params)
41+
enn_graph.define_op(
42+
node.name, "SUB", [input_id_1, input_id_2], [output_id], params
43+
)

backends/samsung/enn_preprocess.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@
1818
ConstantPropPass,
1919
)
2020
from executorch.backends.samsung._passes.fold_qdq import FoldQDQPass
21+
from executorch.backends.samsung._passes.fuse_activation import FuseActivationPass
2122
from executorch.backends.samsung._passes.insert_qdq import InsertQDQPass
23+
from executorch.backends.samsung._passes.remove_useless_ops import RemoveUselessOpPass
2224
from executorch.backends.samsung._passes.replace_scalar_ops import ReplaceOpsWithScalar
25+
from executorch.backends.samsung._passes.transform_quantized_mask import (
26+
TransformQuantizedMaskPass,
27+
)
2328
from executorch.backends.samsung.builders.node_visitor import get_node_visitors
2429
from executorch.backends.samsung.serialization.compile_options import (
2530
ENN_COMPILE_OPTION_TITLE,
@@ -30,6 +35,7 @@
3035
from executorch.backends.transforms.fuse_batch_norm_with_conv import (
3136
FuseBatchNormWithConvPass,
3237
)
38+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
3339

3440
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
3541

@@ -59,9 +65,13 @@ def preprocess(
5965

6066
enn_preprocess_passes = PassManager(
6167
passes=[
68+
RemoveUselessOpPass(),
69+
RemoveCloneOpsTransform(),
6270
AnnotateQparamsPass(edge_program),
71+
FuseActivationPass(),
6372
FoldQDQPass(),
6473
ConstantPropPass(edge_program),
74+
TransformQuantizedMaskPass(edge_program),
6575
Conv1dToConv2d(edge_program),
6676
FuseBatchNormWithConvPass(edge_program),
6777
AddmmToLinearTransform(),
@@ -79,6 +89,7 @@ def preprocess(
7989
node_visitors = get_node_visitors(edge_program)
8090

8191
vals_to_ids: Dict[torch.fx.Node, int] = {}
92+
placeholder_vistor = node_visitors["placeholder"]
8293
for node in pass_result.graph_module.graph.nodes:
8394
if node.op == "call_function":
8495
logging.info(f"Visiting: {node}, {node.target.__name__}")
@@ -90,9 +101,11 @@ def preprocess(
90101
raise RuntimeError(
91102
f"{node.target.__name__}" " is not supported in ENN Delegate"
92103
)
104+
elif node.op == "placeholder":
105+
logging.info(f"Visiting input of graph: {node}")
106+
placeholder_vistor.define_node(node, enn_graph, vals_to_ids)
93107
elif node.op in [
94108
"get_attr",
95-
"placeholder",
96109
"output",
97110
]:
98111
continue

0 commit comments

Comments
 (0)