From 48859d37c520c484c31b813c16d5176e2fa9945d Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 11 Feb 2026 12:16:02 -0800 Subject: [PATCH] [ET-VK] Add nchw_to_int8x4_buffer shader for prepacking int8 staging data This adds a GLSL compute shader and supporting C++ dispatch logic to transfer int8 tensor data from a staging buffer (in NCHW contiguous order) to a GPU buffer in any PackedInt8 layout (4W, 4C, 4W4C, 4H4W, 4C1W). Previously there was no prepack path for kInt8x4 tensors, so constant int8 tensors (TensorRef inputs) could not be transferred to GPU buffers. This is needed to support constant quantized weights in q8ta operators. The shader uses texel-level dispatch where each thread writes one texel (one int32 = 4 packed int8 values). It decomposes the texel index into block-space coordinates using BufferMetadata strides and hashed_layout dim_order, then reads the corresponding int8 bytes from the staging buffer (interpreted as int32 for device compatibility, avoiding the need for 8-bit buffer support). New files: - nchw_to_int8x4_buffer.glsl: Compute shader handling all PackedInt8 layouts - nchw_to_int8x4_buffer.yaml: Shader variant config - Q8taStaging.h/cpp: C++ dispatch function creating the PrepackNode Modified files: - Staging.cpp: Routes kInt8x4 tensors in prepack_op() to the new function - TestQ8taBinary.cpp: Prepacks TensorRef inputs before quantization - test_q8ta_binary.cpp: Adds const_b test cases for constant tensor B inputs This diff was authored with Claude. Differential Revision: [D93000169](https://our.internmc.facebook.com/intern/diff/D93000169/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 1 - .../graph/ops/glsl/nchw_to_int8x4_buffer.glsl | 101 ++++++++++++++++++ .../graph/ops/glsl/nchw_to_int8x4_buffer.yaml | 11 ++ .../runtime/graph/ops/impl/Q8taStaging.cpp | 47 ++++++++ .../runtime/graph/ops/impl/Q8taStaging.h | 20 ++++ .../vulkan/runtime/graph/ops/impl/Staging.cpp | 4 + .../graph/ops/utils/ShaderNameUtils.cpp | 3 + .../test/custom_ops/impl/TestQ8taBinary.cpp | 21 +++- .../test/custom_ops/test_q8ta_binary.cpp | 23 +++- 9 files changed, 227 insertions(+), 4 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 0725a39f547..b1b22d01b72 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -504,7 +504,6 @@ def register_q8ta_add(): return OpFeatures( inputs_storage=utils.PACKED_INT8_BUFFER, supports_resize=False, - supports_prepacking=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl new file mode 100644 index 00000000000..c9610de8794 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl @@ -0,0 +1,101 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output buffer: packed int8x4 values (each int32 contains 4 packed int8) +${layout_declare_tensor(B, "w", "t_outp", "int", "buffer")} +// Input staging buffer: raw int8 data interpreted as int32 for device compat +${layout_declare_tensor(B, "r", "nchw_in", "int", "buffer")} + +// Metadata for output tensor +${layout_declare_ubo(B, "BufferMetadata", "outp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +void main() { + const uint texel_idx = gl_GlobalInvocationID.x; + const uint num_texels = numel(outp) / 4; + if (texel_idx >= num_texels) { + return; + } + + const int inner_dim = get_packed_dim(outp_layout); + const int outer_dim = get_outer_packed_dim(outp_layout); + const int inner_block_size = get_packed_dim_block_size(outp_layout); + const int outer_block_size = get_outer_packed_dim_block_size(outp_layout); + const uint texels_per_block = uint(inner_block_size * outer_block_size) >> 2; + + // Decompose texel_idx into block_idx and intra-block texel position + const uint block_idx = texel_idx / texels_per_block; + const uint intra_texel = texel_idx % texels_per_block; + + // Decompose block_idx into block-space tensor coordinates using strides + TensorIndex4D tidx; + uint remaining = block_idx; + [[unroll]] for (int d = 3; d >= 0; d--) { + const int dim = extract_4b(outp_layout, d); + const uint dim_stride = outp.strides[0][dim]; + tidx.data[dim] = int(remaining / dim_stride); + remaining %= dim_stride; + } + + // Convert from block-space to element-space + tidx.data[inner_dim] *= inner_block_size; + tidx.data[outer_dim] *= outer_block_size; + + // Add intra-block offset for outer dimension (block-packed layouts) + tidx.data[outer_dim] += int(intra_texel); + + // Bounds check on outer dimension + if (tidx.data[outer_dim] >= int(outp.sizes[0][outer_dim])) { + return; + } + + // Tensor sizes in WHCN order for NCHW contiguous index computation + const uint W = outp.sizes[0][0]; + const uint H = outp.sizes[0][1]; + const uint C = outp.sizes[0][2]; + + // Pack 4 int8 values along inner dimension into one int32 + int packed = 0; + [[unroll]] for (int i = 0; i < 4; ++i) { + const int elem_inner = tidx.data[inner_dim] + i; + if (elem_inner >= int(outp.sizes[0][inner_dim])) { + break; + } + + // Build element coordinates + ivec4 elem = tidx.data; + elem[inner_dim] = elem_inner; + + // Compute NCHW contiguous index: w + h*W + c*H*W + n*C*H*W + const uint nchw_idx = uint(elem[0]) + uint(elem[1]) * W + + uint(elem[2]) * H * W + uint(elem[3]) * C * H * W; + + // Read int8 from staging buffer (each int32 contains 4 bytes) + const uint int_idx = nchw_idx >> 2; + const uint byte_pos = nchw_idx & 3; + const int staging_val = nchw_in[int_idx]; + const int byte_val = (staging_val >> (byte_pos * 8)) & 0xFF; + + packed |= (byte_val << (i * 8)); + } + + t_outp[texel_idx] = packed; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml new file mode 100644 index 00000000000..514ada71f63 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +nchw_to_int8x4_buffer: + parameter_names_with_default_values: + DTYPE: int + shader_variants: + - NAME: nchw_to_int8x4_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp new file mode 100644 index 00000000000..028f55997d3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void add_staging_to_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef tensor_data, + const ValueRef tensor) { + VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4); + + std::string kernel_name = "nchw_to_int8x4_buffer"; + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(tensor)); + + // One thread per texel (each texel = one int32 = 4 packed int8) + uint32_t num_texels = + utils::safe_downcast(graph.numel_of(tensor) / 4); + utils::uvec3 global_wg_size = {num_texels, 1, 1}; + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Input and Output + tensor_data, + tensor, + // Parameter Buffers + param_buffers, + // Specialization Constants + {graph.hashed_layout_of(tensor)})); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h new file mode 100644 index 00000000000..40386551e36 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +void add_staging_to_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef tensor_data, + const ValueRef tensor); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 9dc4d0a58f8..adcad9f9817 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -327,6 +328,9 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( } void prepack_op(ComputeGraph& graph, const std::vector& args) { + if (graph.dtype_of(args[1]) == vkapi::kInt8x4) { + return add_staging_to_int8x4_buffer_node(graph, args[0], args[1]); + } return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 231e6d0c7f6..59a9d79a6e3 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -64,6 +64,9 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { case vkapi::kUInt64: kernel_name += "_uint64"; break; + case vkapi::kInt8x4: + kernel_name += "_int32"; + break; default: break; } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp index 53f8859b581..99a61f598e7 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp @@ -10,13 +10,14 @@ #include #include +#include namespace vkcompute { void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { int32_t idx = 0; - const ValueRef fp_input_a = args.at(idx++); - const ValueRef fp_input_b = args.at(idx++); + ValueRef fp_input_a = args.at(idx++); + ValueRef fp_input_b = args.at(idx++); const ValueRef input_a_scale = args.at(idx++); const ValueRef input_a_zp = args.at(idx++); const ValueRef input_b_scale = args.at(idx++); @@ -27,6 +28,22 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { const ValueRef quant_layout_int = args.at(idx++); const ValueRef fp_output = args.at(idx++); + // Prepack any TensorRef inputs to GPU tensors + if (graph.val_is_tref(fp_input_a)) { + fp_input_a = prepack_standard( + graph, + fp_input_a, + graph.storage_type_of(fp_output), + graph.estimate_memory_layout_of(fp_output)); + } + if (graph.val_is_tref(fp_input_b)) { + fp_input_b = prepack_standard( + graph, + fp_input_b, + graph.storage_type_of(fp_output), + graph.estimate_memory_layout_of(fp_output)); + } + // Extract the layout parameter and cast to GPUMemoryLayout int32_t layout_value = graph.extract_scalar(quant_layout_int); utils::GPUMemoryLayout quant_layout = diff --git a/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp b/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp index 1cb364c6f8d..d373c06f28b 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp @@ -29,7 +29,8 @@ TestCase create_test_case_from_config( utils::StorageType storage_type, vkapi::ScalarType input_dtype, utils::GPUMemoryLayout fp_memory_layout, - utils::GPUMemoryLayout quant_layout) { + utils::GPUMemoryLayout quant_layout, + bool const_b = false) { TestCase test_case; // Create a descriptive name for the test case @@ -37,6 +38,9 @@ TestCase create_test_case_from_config( std::string test_name = config.test_case_name + " I=" + shape_str + " " + repr_str(storage_type, fp_memory_layout) + "->" + repr_str(utils::kBuffer, quant_layout); + if (const_b) { + test_name += " const_b"; + } test_case.set_name(test_name); // Set the operator name for the test case @@ -58,6 +62,9 @@ TestCase create_test_case_from_config( storage_type, fp_memory_layout, DataGenType::RANDOM); + if (const_b) { + input_b.set_constant(true); + } // Quantization parameters for input A float input_a_scale_val = 0.007843; // 2/255 approximately @@ -158,6 +165,13 @@ std::vector generate_q8ta_add_easy_cases() { for (const auto& input_dtype : float_types) { test_cases.push_back(create_test_case_from_config( config, storage_type, input_dtype, fp_layout, quant_layout)); + test_cases.push_back(create_test_case_from_config( + config, + storage_type, + input_dtype, + fp_layout, + quant_layout, + /*const_b=*/true)); } } } @@ -236,6 +250,13 @@ std::vector generate_q8ta_add_test_cases() { test_cases.push_back(create_test_case_from_config( config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); + test_cases.push_back(create_test_case_from_config( + config, + storage_type, + vkapi::kFloat, + fp_layout, + quant_layout, + /*const_b=*/true)); } } }