From 8394aa47bdf9a03935147faeaad90f09a9fafbbb Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 11 Feb 2026 12:15:36 -0800 Subject: [PATCH] [ET-VK][qconv] Dynamically select between im2col path and general path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds a dispatch layer to `q8ta_conv2d` that dynamically selects between the im2col-based and general convolution implementations at graph build time. The existing `q8ta_conv2d` function is renamed to `q8ta_conv2d_general`, and a new `q8ta_conv2d` dispatcher is introduced that chooses the im2col path when the convolution is non-grouped, has input channels divisible by 4, and kernel size ≤ 3x3. All other cases fall through to the general path. A separate `q8ta_conv2d_general` operator is also registered so tests can directly invoke the general path for comparison. The test suite is updated to exercise both the general and im2col implementations explicitly, and the default impl_selector is changed from "general" to empty (which triggers the new dispatcher). FP buffer storage types are removed from the test matrix since they are not needed. Differential Revision: [D93000162](https://our.internmc.facebook.com/intern/diff/D93000162/) [ghstack-poisoned] --- .../runtime/graph/ops/impl/Q8taConv2d.cpp | 26 ++++++++++++++- .../runtime/graph/ops/impl/Q8taConv2d.h | 2 ++ .../test/custom_ops/impl/TestQ8taConv2d.cpp | 3 ++ .../test/custom_ops/test_q8ta_conv2d.cpp | 33 +++++++++++-------- .../test/custom_ops/test_q8ta_conv2d_dw.cpp | 7 ++-- .../test/custom_ops/test_q8ta_conv2d_pw.cpp | 4 +-- 6 files changed, 52 insertions(+), 23 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index aa4e1e47d27..dcc275fbb65 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -323,7 +323,9 @@ void add_q8ta_conv2d_node( // High level operator impl // -void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { +void q8ta_conv2d_general( + ComputeGraph& graph, + const std::vector& args) { int32_t idx = 0; const ValueRef packed_int8_input = args.at(idx++); const ValueRef input_scale = args.at(idx++); @@ -398,8 +400,30 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { packed_int8_output); } +void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { + // Index into args to extract values needed for dispatch decision + const ValueRef packed_int8_input = args.at(0); + const ValueRef kernel_size = args.at(9); + const ValueRef groups = args.at(13); + + const int32_t groups_val = graph.get_int(groups); + const int64_t IC = graph.size_at(-3, packed_int8_input); + + const int64_t K_h = graph.get_int_list(kernel_size)->at(0); + const int64_t K_w = graph.get_int_list(kernel_size)->at(1); + + // Use im2col path when: non-grouped, input channels multiple of 4, small + // kernel + if (groups_val == 1 && IC % 4 == 0 && K_h <= 3 && K_w <= 3) { + q8ta_conv2d_im2col(graph, args); + } else { + q8ta_conv2d_general(graph, args); + } +} + REGISTER_OPERATORS { VK_REGISTER_OP(etvk.q8ta_conv2d.default, q8ta_conv2d); + VK_REGISTER_OP(etvk.q8ta_conv2d_general.default, q8ta_conv2d_general); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 53a5aa15fe6..9686c873c1b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -113,4 +113,6 @@ void add_q8ta_conv2d_pw_node( const ValueRef packed_bias, const ValueRef packed_int8_output); +void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); + } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index acb5a3d03f5..8b4da0f2821 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -157,6 +157,9 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { } else if (impl_selector == "im2col") { // Use the im2col-based conv2d operator VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args); + } else if (impl_selector == "general") { + // Use the general q8ta_conv2d operator (no im2col dispatch) + VK_GET_OP_FN("etvk.q8ta_conv2d_general.default")(graph, conv_args); } else { // Use the new general q8ta_conv2d operator VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 0de4f546b0d..17dd7a0fc53 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -29,7 +29,7 @@ static TestCase create_test_case_from_config( vkapi::ScalarType input_dtype, utils::StorageType fp_storage_type, utils::GPUMemoryLayout int8_memory_layout, - const std::string& impl_selector = "general") { + const std::string& impl_selector = "") { TestCase test_case; // Calculate output dimensions @@ -53,7 +53,6 @@ static TestCase create_test_case_from_config( std::to_string(config.input_size.w) + " " + "g=" + std::to_string(config.groups) + " " + "k=" + std::to_string(config.kernel.h) + " " + - repr_str(fp_storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, int8_memory_layout); if (!impl_selector.empty()) { test_name += " [" + impl_selector + "]"; @@ -218,8 +217,7 @@ std::vector generate_quantized_conv2d_easy_cases() { }; config.op_name = "conv2d_q8ta_q8csw_q8to"; - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths @@ -379,8 +377,7 @@ static std::vector generate_quantized_conv2d_test_cases() { 4}}; // Test with different storage types and memory layouts - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths @@ -401,29 +398,37 @@ static std::vector generate_quantized_conv2d_test_cases() { int8_memory_layouts) { config.test_case_name = make_test_case_name( config, is_performance, fp_storage_type, utils::kBuffer); + test_cases.push_back(create_test_case_from_config( - config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"general")); - // For 4W4C layout, also test the legacy implementation - if (int8_memory_layout == utils::kPackedInt8_4W4C) { + // Test im2col implementation for non-grouped convolutions with input + // channels that are a multiple of 4 and stride_w == 1 + if (config.groups == 1 && config.channels.in % 4 == 0) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout, - /*impl_selector=*/"legacy_4w4c")); + /*impl_selector=*/"im2col")); } - // Test im2col implementation for non-grouped convolutions with input - // channels that are a multiple of 4 and stride_w == 1 - if (config.groups == 1 && config.channels.in % 4 == 0) { + // For 4W4C layout, also test the legacy implementation + if (int8_memory_layout == utils::kPackedInt8_4W4C) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout, - /*impl_selector=*/"im2col")); + /*impl_selector=*/"legacy_4w4c")); } + + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); } } } diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp index b4583071acd..7ef73d49802 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp @@ -54,7 +54,6 @@ TestCase create_test_case_from_config( std::to_string(config.input_size.w) + " " + "g=" + std::to_string(config.groups) + " " + "k=" + std::to_string(config.kernel.h) + " " + - repr_str(fp_storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, int8_memory_layout); if (!impl_selector.empty()) { test_name += " [" + impl_selector + "]"; @@ -228,8 +227,7 @@ std::vector generate_quantized_conv2d_dw_easy_cases() { }; config.op_name = "conv2d_q8ta_q8csw_q8to"; - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths @@ -351,8 +349,7 @@ std::vector generate_quantized_conv2d_dw_test_cases() { 32}}; // Test with different storage types and data types - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 6c7b2d94ecd..51095c649b6 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -53,7 +53,6 @@ static TestCase create_test_case_from_config( std::to_string(config.input_size.w) + " " + "g=" + std::to_string(config.groups) + " " + "k=" + std::to_string(config.kernel.h) + " " + - repr_str(fp_storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, int8_memory_layout); if (!impl_selector.empty()) { test_name += " [" + impl_selector + "]"; @@ -286,8 +285,7 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { }; // Test with different storage types and memory layouts - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths