diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index aed8b591fea..1f4c962fdb3 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -356,12 +356,12 @@ def linear_q8ta_q8csw( lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) -############################ -## conv2d_q8ta_q8csw_q8to ## -############################ +################### +## q8ta_conv2d_* ## +################### -def conv2d_q8ta_q8csw_q8to( +def q8ta_conv2d( x: torch.Tensor, input_scale: float, input_zero_point: int, @@ -425,7 +425,7 @@ def conv2d_q8ta_q8csw_q8to( return out -name = "conv2d_q8ta_q8csw_q8to" +name = "q8ta_conv2d" lib.define( f""" {name}( @@ -445,11 +445,35 @@ def conv2d_q8ta_q8csw_q8to( SymInt groups) -> Tensor """ ) -lib.impl(name, conv2d_q8ta_q8csw_q8to, "CompositeExplicitAutograd") -conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) +lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") +q8ta_conv2d_op = getattr(getattr(torch.ops, namespace), name) -def conv2d_q8ta_q8csw_q8to_dw( +name = "q8ta_conv2d_pw" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] dilation, + SymInt groups) -> Tensor + """ +) +lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") +q8ta_conv2d_pw_op = getattr(getattr(torch.ops, namespace), name) + + +def q8ta_conv2d_dw( x: torch.Tensor, input_scale: float, input_zero_point: int, @@ -497,7 +521,7 @@ def conv2d_q8ta_q8csw_q8to_dw( return out -name = "conv2d_q8ta_q8csw_q8to_dw" +name = "q8ta_conv2d_dw" lib.define( f""" {name}( @@ -517,7 +541,7 @@ def conv2d_q8ta_q8csw_q8to_dw( SymInt groups) -> Tensor """ ) -lib.impl(name, conv2d_q8ta_q8csw_q8to_dw, "CompositeExplicitAutograd") +lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index e2f305ca0e0..64d7261f576 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -745,17 +745,46 @@ def check_conv_node(node: torch.fx.Node) -> bool: # ============================================================================= -# QuantizedConvolution.cpp +# Q8taConv2d*.cpp # ============================================================================= @update_features( [ - exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, - exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, + exir_ops.edge.et_vk.q8ta_conv2d_pw.default, ] ) -def register_quantizedconvolution_cpp_ops(): +def register_q8ta_conv_pw_op(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4W4C_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # kernel_size (non tensor) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # groups (non tensor) + utils.NO_STORAGE, # original OC count (non tensor) + ], + supports_resize=False, + supports_prepacking=True, + ) + + +@update_features( + [ + exir_ops.edge.et_vk.q8ta_conv2d.default, + exir_ops.edge.et_vk.q8ta_conv2d_dw.default, + ] +) +def register_q8ta_conv2d_ops(): return OpFeatures( inputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, # input diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 522a19c58d6..b89dfe9aaab 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -156,7 +156,7 @@ def find_quantized_convolution_patterns( @register_pattern_replacement("quantized_convolution") -def make_conv2d_q8ta_q8csw_custom_op( +def make_q8ta_conv2d_custom_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, match: QuantizedConvolutionMatch, @@ -230,10 +230,20 @@ def make_conv2d_q8ta_q8csw_custom_op( data=sum_per_output_channel, ) + is_pointwise_conv = ( + H == 1 + and W == 1 + and list(match.stride) == [1, 1] + and list(match.dilation) == [1, 1] + and list(match.padding) == [0, 0] + ) + with graph_module.graph.inserting_before(match.output_node): - op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default + op_target = exir_ops.edge.et_vk.q8ta_conv2d.default if is_depthwise_conv: - op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default + op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default + elif is_pointwise_conv: + op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default qconv_node = graph_module.graph.create_node( "call_function", diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index dcc275fbb65..d3fe1afd906 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -422,8 +422,8 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { } REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.q8ta_conv2d.default, q8ta_conv2d); - VK_REGISTER_OP(etvk.q8ta_conv2d_general.default, q8ta_conv2d_general); + VK_REGISTER_OP(et_vk.q8ta_conv2d.default, q8ta_conv2d); + VK_REGISTER_OP(et_vk.q8ta_conv2d_general.default, q8ta_conv2d_general); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp index 121a577555f..d12bbc0574a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp @@ -436,7 +436,7 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { } REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.q8ta_conv2d_dw.default, q8ta_conv2d_dw); + VK_REGISTER_OP(et_vk.q8ta_conv2d_dw.default, q8ta_conv2d_dw); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index f634f8b1773..e89ebc92aba 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -269,7 +269,7 @@ void q8ta_conv2d_im2col( } REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col); + VK_REGISTER_OP(et_vk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index 5ff69dac63b..fc883eefeef 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -346,7 +346,7 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { } REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.q8ta_conv2d_pw.default, q8ta_conv2d_pw); + VK_REGISTER_OP(et_vk.q8ta_conv2d_pw.default, q8ta_conv2d_pw); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 8b4da0f2821..4fed7461ce6 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -82,7 +82,7 @@ void test_q8ta_conv2d_dw( VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { // Use the dedicated depthwise conv2d operator - VK_GET_OP_FN("etvk.q8ta_conv2d_dw.default")(graph, conv_args); + VK_GET_OP_FN("et_vk.q8ta_conv2d_dw.default")(graph, conv_args); } // Dequantize packed int8 output to floating point @@ -156,13 +156,13 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else if (impl_selector == "im2col") { // Use the im2col-based conv2d operator - VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args); + VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); } else if (impl_selector == "general") { // Use the general q8ta_conv2d operator (no im2col dispatch) - VK_GET_OP_FN("etvk.q8ta_conv2d_general.default")(graph, conv_args); + VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); } else { // Use the new general q8ta_conv2d operator - VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args); + VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); } // Dequantize packed int8 output to floating point @@ -240,7 +240,7 @@ void test_q8ta_conv2d_pw( if (impl_selector == "legacy_4w4c") { VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { - VK_GET_OP_FN("etvk.q8ta_conv2d_pw.default")(graph, conv_args); + VK_GET_OP_FN("et_vk.q8ta_conv2d_pw.default")(graph, conv_args); } // Dequantize packed int8 output to floating point