55# LICENSE file in the root directory of this source tree.
66
77import torch
8- from executorch .backends .samsung .quantizer .quantizer import global_quant_info
98from executorch .backends .samsung .utils .constants import QuantConstants
109from executorch .backends .transforms .utils import get_param_tensor , is_param_node
1110from executorch .exir .dialects ._ops import ops as exir_ops
@@ -25,6 +24,7 @@ class AnnotateScalarParametersPass(ExportPass):
2524 exir_ops .edge .aten .mul .Tensor ,
2625 exir_ops .edge .aten .add .Tensor ,
2726 exir_ops .edge .aten .div .Tensor ,
27+ exir_ops .edge .aten .sub .Tensor ,
2828 }
2929
3030 def __init__ (self , edge_program : ExportedProgram ):
@@ -35,27 +35,37 @@ def annotate(self, graph_module: torch.fx.GraphModule):
3535 for node in graph_module .graph .nodes :
3636 if node .target not in self .TARGET_OPS or "quantize_attrs" not in node .meta :
3737 continue
38- torch_quant_dtype = global_quant_info .weight_precison .torch_dtype
39- for input_arg in node .all_input_nodes :
40- if input_arg .op not in ("placeholder" , "get_attr" ) or not is_param_node (
41- self .edge_program , input_arg
38+ input0 , input1 = node .all_input_nodes [0 ], node .all_input_nodes [1 ]
39+ if input0 .op not in ("placeholder" , "get_attr" ) or not is_param_node (
40+ self .edge_program , input0
41+ ):
42+ if input1 .op not in ("placeholder" , "get_attr" ) or not is_param_node (
43+ self .edge_program , input1
4244 ):
4345 continue
44- else :
45- tensor = get_param_tensor (self .edge_program , input_arg )
46- if not tensor .shape :
47- qparams = {
48- QuantConstants .QUANT_KEY .scale : float (tensor ),
49- QuantConstants .QUANT_KEY .quant_dtype : torch_quant_dtype ,
50- QuantConstants .QUANT_KEY .quant_max : torch .iinfo (
51- torch_quant_dtype
52- ).max ,
53- QuantConstants .QUANT_KEY .quant_min : torch .iinfo (
54- torch_quant_dtype
55- ).min ,
56- QuantConstants .QUANT_KEY .zero_point : 0 ,
57- }
58- input_arg .meta ["quantize_attrs" ] = qparams
46+ ifm_node , param_tensor_node = input0 , input1
47+ else :
48+ ifm_node , param_tensor_node = input1 , input0
49+ if not (quantize_attrs := ifm_node .meta .get ("quantize_attrs" )):
50+ continue
51+ param_tensor = get_param_tensor (self .edge_program , param_tensor_node )
52+ if not param_tensor .shape :
53+ scale = (
54+ float (param_tensor ) if param_tensor > 0 else - float (param_tensor )
55+ )
56+ else :
57+ continue
58+ q_dtype = quantize_attrs [QuantConstants .QUANT_KEY .quant_dtype ]
59+ if scale == 0 :
60+ scale = 1.0
61+ qparams = {
62+ QuantConstants .QUANT_KEY .scale : scale ,
63+ QuantConstants .QUANT_KEY .quant_dtype : q_dtype ,
64+ QuantConstants .QUANT_KEY .quant_max : torch .iinfo (q_dtype ).max ,
65+ QuantConstants .QUANT_KEY .quant_min : torch .iinfo (q_dtype ).min ,
66+ QuantConstants .QUANT_KEY .zero_point : 0 ,
67+ }
68+ param_tensor_node .meta ["quantize_attrs" ] = qparams
5969
6070 def call (self , graph_module : torch .fx .GraphModule ):
6171 graph = graph_module .graph
0 commit comments