diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index aa4e1e47d27..dcc275fbb65 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -323,7 +323,9 @@ void add_q8ta_conv2d_node( // High level operator impl // -void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { +void q8ta_conv2d_general( + ComputeGraph& graph, + const std::vector& args) { int32_t idx = 0; const ValueRef packed_int8_input = args.at(idx++); const ValueRef input_scale = args.at(idx++); @@ -398,8 +400,30 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { packed_int8_output); } +void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { + // Index into args to extract values needed for dispatch decision + const ValueRef packed_int8_input = args.at(0); + const ValueRef kernel_size = args.at(9); + const ValueRef groups = args.at(13); + + const int32_t groups_val = graph.get_int(groups); + const int64_t IC = graph.size_at(-3, packed_int8_input); + + const int64_t K_h = graph.get_int_list(kernel_size)->at(0); + const int64_t K_w = graph.get_int_list(kernel_size)->at(1); + + // Use im2col path when: non-grouped, input channels multiple of 4, small + // kernel + if (groups_val == 1 && IC % 4 == 0 && K_h <= 3 && K_w <= 3) { + q8ta_conv2d_im2col(graph, args); + } else { + q8ta_conv2d_general(graph, args); + } +} + REGISTER_OPERATORS { VK_REGISTER_OP(etvk.q8ta_conv2d.default, q8ta_conv2d); + VK_REGISTER_OP(etvk.q8ta_conv2d_general.default, q8ta_conv2d_general); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 53a5aa15fe6..9686c873c1b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -113,4 +113,6 @@ void add_q8ta_conv2d_pw_node( const ValueRef packed_bias, const ValueRef packed_int8_output); +void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); + } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index acb5a3d03f5..8b4da0f2821 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -157,6 +157,9 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { } else if (impl_selector == "im2col") { // Use the im2col-based conv2d operator VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args); + } else if (impl_selector == "general") { + // Use the general q8ta_conv2d operator (no im2col dispatch) + VK_GET_OP_FN("etvk.q8ta_conv2d_general.default")(graph, conv_args); } else { // Use the new general q8ta_conv2d operator VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 0de4f546b0d..17dd7a0fc53 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -29,7 +29,7 @@ static TestCase create_test_case_from_config( vkapi::ScalarType input_dtype, utils::StorageType fp_storage_type, utils::GPUMemoryLayout int8_memory_layout, - const std::string& impl_selector = "general") { + const std::string& impl_selector = "") { TestCase test_case; // Calculate output dimensions @@ -53,7 +53,6 @@ static TestCase create_test_case_from_config( std::to_string(config.input_size.w) + " " + "g=" + std::to_string(config.groups) + " " + "k=" + std::to_string(config.kernel.h) + " " + - repr_str(fp_storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, int8_memory_layout); if (!impl_selector.empty()) { test_name += " [" + impl_selector + "]"; @@ -218,8 +217,7 @@ std::vector generate_quantized_conv2d_easy_cases() { }; config.op_name = "conv2d_q8ta_q8csw_q8to"; - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths @@ -379,8 +377,7 @@ static std::vector generate_quantized_conv2d_test_cases() { 4}}; // Test with different storage types and memory layouts - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths @@ -401,29 +398,37 @@ static std::vector generate_quantized_conv2d_test_cases() { int8_memory_layouts) { config.test_case_name = make_test_case_name( config, is_performance, fp_storage_type, utils::kBuffer); + test_cases.push_back(create_test_case_from_config( - config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"general")); - // For 4W4C layout, also test the legacy implementation - if (int8_memory_layout == utils::kPackedInt8_4W4C) { + // Test im2col implementation for non-grouped convolutions with input + // channels that are a multiple of 4 and stride_w == 1 + if (config.groups == 1 && config.channels.in % 4 == 0) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout, - /*impl_selector=*/"legacy_4w4c")); + /*impl_selector=*/"im2col")); } - // Test im2col implementation for non-grouped convolutions with input - // channels that are a multiple of 4 and stride_w == 1 - if (config.groups == 1 && config.channels.in % 4 == 0) { + // For 4W4C layout, also test the legacy implementation + if (int8_memory_layout == utils::kPackedInt8_4W4C) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout, - /*impl_selector=*/"im2col")); + /*impl_selector=*/"legacy_4w4c")); } + + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); } } } diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp index b4583071acd..7ef73d49802 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp @@ -54,7 +54,6 @@ TestCase create_test_case_from_config( std::to_string(config.input_size.w) + " " + "g=" + std::to_string(config.groups) + " " + "k=" + std::to_string(config.kernel.h) + " " + - repr_str(fp_storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, int8_memory_layout); if (!impl_selector.empty()) { test_name += " [" + impl_selector + "]"; @@ -228,8 +227,7 @@ std::vector generate_quantized_conv2d_dw_easy_cases() { }; config.op_name = "conv2d_q8ta_q8csw_q8to"; - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths @@ -351,8 +349,7 @@ std::vector generate_quantized_conv2d_dw_test_cases() { 32}}; // Test with different storage types and data types - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 6c7b2d94ecd..51095c649b6 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -53,7 +53,6 @@ static TestCase create_test_case_from_config( std::to_string(config.input_size.w) + " " + "g=" + std::to_string(config.groups) + " " + "k=" + std::to_string(config.kernel.h) + " " + - repr_str(fp_storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, int8_memory_layout); if (!impl_selector.empty()) { test_name += " [" + impl_selector + "]"; @@ -286,8 +285,7 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { }; // Test with different storage types and memory layouts - std::vector fp_storage_types = { - utils::kTexture3D, utils::kBuffer}; + std::vector fp_storage_types = {utils::kTexture3D}; // Memory layouts for int8 tensors - test both optimized (4W4C) and general // paths