diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b58d6407308..55a92335bc7 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -510,7 +510,6 @@ def register_q8ta_add(): return OpFeatures( inputs_storage=utils.PACKED_INT8_BUFFER, supports_resize=False, - supports_prepacking=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 37c47795214..51cda9a3d1d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -334,19 +334,30 @@ TensorIndex linear_idx_to_tensor_idx( /* * Convert a linear texel index to a TensorIndex4D. * - * This function is used for texel-based dispatch where each thread handles - * one packed texel (4 elements along the packed dimension). The texel index - * is decomposed using the dim_order and strides from the tensor's layout. + * This is the inverse of tensor4d_idx_to_texel_idx. It handles both + * single-packed layouts (outer_block_size == 1) and block-packed layouts + * (e.g., 4W4C where outer_block_size > 1). * - * The strides in BufferMetadata should already be in texel space (with packed - * dimension size divided by 4). + * The approach mirrors tensor4d_idx_to_texel_idx by decomposing the problem + * into two levels: + * 1. Decompose texel_idx into block_idx and intra-block texel offset + * 2. Decompose block_idx into block-space tensor coordinates using strides + * 3. Convert block-space coordinates to element-space by multiplying by + * block sizes + * 4. Add the intra-block outer-dimension offset + * + * For single-packed layouts (outer_block_size == 1, inner_dim == outer_dim), + * texels_per_block == 1, so block_idx == texel_idx and intra_block_texel == 0. + * The only effective multiplication is tidx[inner_dim] *= inner_block_size + * (i.e., *= 4), matching the previous single-packed behavior. * * Parameters: - * meta: BufferMetadata with tensor sizes and texel-space strides + * meta: BufferMetadata with block-space strides * texel_idx: Linear index into packed texels (0 to num_texels-1) * hashed_layout: Packed layout info containing dim_order and packed_dim * - * Returns: TensorIndex4D with logical tensor coordinates (packed dim is base of 4-element block) + * Returns: TensorIndex4D with logical tensor coordinates (packed dims are + * base of their respective blocks) */ TensorIndex4D texel_idx_to_tensor4d_idx( const BufferMetadata meta, @@ -354,25 +365,35 @@ TensorIndex4D texel_idx_to_tensor4d_idx( const int hashed_layout) { TensorIndex4D tidx; - const int packed_dim = get_packed_dim(hashed_layout); + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); - // Decompose texel_idx using dim_order from hashed_layout and strides from meta - // Iterate from slowest-varying dimension (d=3) to fastest (d=0) - // This follows the pattern of linear_idx_to_tensor_idx in indexing.glslh + // Number of texels per block: each block has inner_block_size * + // outer_block_size elements, and each texel holds 4 elements + const int texels_per_block = (inner_block_size * outer_block_size) / 4; + + // Decompose texel_idx into block_idx and intra-block texel offset + const uint block_idx = texel_idx / texels_per_block; + const int intra_block_texel = int(texel_idx % texels_per_block); + + // Decompose block_idx into block-space tensor coordinates using dim_order + // and strides. Iterate from slowest-varying (d=3) to fastest (d=0). + uint remaining = block_idx; [[unroll]] for (int d = 3; d >= 0; d--) { - // Get dim index from hashed_layout's dim_order (bits 0-15) int dim_idx = extract_4b(hashed_layout, d); - - // Get stride for this dimension from BufferMetadata uint dim_stride = meta.strides[0][dim_idx]; - - // Compute coordinate for this dimension - tidx.data[dim_idx] = int(texel_idx / dim_stride); - texel_idx = texel_idx % dim_stride; + tidx.data[dim_idx] = int(remaining / dim_stride); + remaining = remaining % dim_stride; } - // Convert packed dimension from texel index to element index - tidx.data[packed_dim] *= 4; + // Convert block-space coordinates to element-space + tidx.data[inner_dim] *= inner_block_size; + tidx.data[outer_dim] *= outer_block_size; + + // Add intra-block outer-dimension offset + tidx.data[outer_dim] += intra_block_texel; return tidx; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl new file mode 100644 index 00000000000..b3e2f25c2bf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.glsl @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output buffer: packed int8x4 values (each int32 contains 4 packed int8) +${layout_declare_tensor(B, "w", "t_outp", "int", "buffer")} +// Input staging buffer: raw int8 data interpreted as int32 for device compat +${layout_declare_tensor(B, "r", "nchw_in", "int", "buffer")} + +// Metadata for output tensor +${layout_declare_ubo(B, "BufferMetadata", "outp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +void main() { + const uint texel_idx = gl_GlobalInvocationID.x; + const uint num_texels = numel(outp) / 4; + if (texel_idx >= num_texels) { + return; + } + + const int inner_dim = get_packed_dim(outp_layout); + const int outer_dim = get_outer_packed_dim(outp_layout); + + const TensorIndex4D tidx = + texel_idx_to_tensor4d_idx(outp, texel_idx, outp_layout); + + // Bounds check on outer dimension + if (tidx.data[outer_dim] >= int(outp.sizes[0][outer_dim])) { + return; + } + + // Tensor sizes in WHCN order for NCHW contiguous index computation + const uint W = outp.sizes[0][0]; + const uint H = outp.sizes[0][1]; + const uint C = outp.sizes[0][2]; + + // Pack 4 int8 values along inner dimension into one int32 + int packed = 0; + [[unroll]] for (int i = 0; i < 4; ++i) { + const int elem_inner = tidx.data[inner_dim] + i; + if (elem_inner >= int(outp.sizes[0][inner_dim])) { + break; + } + + // Build element coordinates + ivec4 elem = tidx.data; + elem[inner_dim] = elem_inner; + + // Compute NCHW contiguous index: w + h*W + c*H*W + n*C*H*W + const uint nchw_idx = uint(elem[0]) + uint(elem[1]) * W + + uint(elem[2]) * H * W + uint(elem[3]) * C * H * W; + + // Read int8 from staging buffer (each int32 contains 4 bytes) + const uint int_idx = nchw_idx >> 2; + const uint byte_pos = nchw_idx & 3; + const int staging_val = nchw_in[int_idx]; + const int byte_val = (staging_val >> (byte_pos * 8)) & 0xFF; + + packed |= (byte_val << (i * 8)); + } + + t_outp[texel_idx] = packed; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml new file mode 100644 index 00000000000..514ada71f63 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8x4_buffer.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +nchw_to_int8x4_buffer: + parameter_names_with_default_values: + DTYPE: int + shader_variants: + - NAME: nchw_to_int8x4_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp new file mode 100644 index 00000000000..8dc3f8156f8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void add_staging_to_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef tensor_data, + const ValueRef tensor) { + VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4); + + std::string kernel_name = "nchw_to_int8x4_buffer"; + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(tensor)); + + // One thread per texel (each texel = one int32 = 4 packed int8). + // Use padded_numel to account for dimension padding in packed int8 layouts + // (e.g., kPackedInt8_4C with C=3 pads to C=4). + uint32_t num_texels = + utils::safe_downcast(graph.padded_numel_of(tensor) / 4); + utils::uvec3 global_wg_size = {num_texels, 1, 1}; + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Input and Output + tensor_data, + tensor, + // Parameter Buffers + param_buffers, + // Specialization Constants + {graph.hashed_layout_of(tensor)})); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h new file mode 100644 index 00000000000..40386551e36 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +void add_staging_to_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef tensor_data, + const ValueRef tensor); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 9dc4d0a58f8..adcad9f9817 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -327,6 +328,9 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( } void prepack_op(ComputeGraph& graph, const std::vector& args) { + if (graph.dtype_of(args[1]) == vkapi::kInt8x4) { + return add_staging_to_int8x4_buffer_node(graph, args[0], args[1]); + } return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 231e6d0c7f6..59a9d79a6e3 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -64,6 +64,9 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { case vkapi::kUInt64: kernel_name += "_uint64"; break; + case vkapi::kInt8x4: + kernel_name += "_int32"; + break; default: break; } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp index 53f8859b581..f5214221359 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp @@ -10,13 +10,14 @@ #include #include +#include namespace vkcompute { void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { int32_t idx = 0; - const ValueRef fp_input_a = args.at(idx++); - const ValueRef fp_input_b = args.at(idx++); + ValueRef fp_input_a = args.at(idx++); + ValueRef input_b = args.at(idx++); const ValueRef input_a_scale = args.at(idx++); const ValueRef input_a_zp = args.at(idx++); const ValueRef input_b_scale = args.at(idx++); @@ -32,6 +33,10 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { utils::GPUMemoryLayout quant_layout = static_cast(layout_value); + // Check if input_b is a pre-quantized int8 TensorRef + bool input_b_is_int8 = + graph.val_is_tref(input_b) && graph.dtype_of(input_b) == vkapi::kChar; + // Create temporary tensors for quantized data with the specified layout TmpTensor packed_int8_input_a( &graph, @@ -40,12 +45,8 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { utils::kBuffer, quant_layout); - TmpTensor packed_int8_input_b( - &graph, - graph.sizes_of(fp_input_b), - vkapi::kInt8x4, - utils::kBuffer, - quant_layout); + ValueRef packed_int8_input_b = graph.add_tensor( + graph.sizes_of(input_b), vkapi::kInt8x4, utils::kBuffer, quant_layout); TmpTensor packed_int8_output( &graph, @@ -54,12 +55,19 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { utils::kBuffer, quant_layout); - // Quantize: FP -> int8x4 with specified layout + // Quantize input A: FP -> int8x4 add_q8ta_quantize_node( graph, fp_input_a, input_a_scale, input_a_zp, packed_int8_input_a); - add_q8ta_quantize_node( - graph, fp_input_b, input_b_scale, input_b_zp, packed_int8_input_b); + if (input_b_is_int8) { + // Input B is a pre-quantized int8 TensorRef; prepack directly into packed + // int8x4 format + add_staging_to_int8x4_buffer_node(graph, input_b, packed_int8_input_b); + } else { + // Input B is a float tensor; quantize at runtime + add_q8ta_quantize_node( + graph, input_b, input_b_scale, input_b_zp, packed_int8_input_b); + } // Binary add: int8x4 -> int8x4 (same layout for all tensors) add_q8ta_binary_node( diff --git a/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp b/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp index 1100eb4d5f0..86725ca8fb8 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_binary.cpp @@ -29,13 +29,17 @@ TestCase create_test_case_from_config( utils::StorageType storage_type, vkapi::ScalarType input_dtype, utils::GPUMemoryLayout fp_memory_layout, - utils::GPUMemoryLayout quant_layout) { + utils::GPUMemoryLayout quant_layout, + bool const_b = false) { TestCase test_case; // Create a descriptive name for the test case std::string shape_str = shape_string(config.shape); std::string test_name = config.test_case_name + " I=" + shape_str + " " + repr_str(utils::kBuffer, quant_layout); + if (const_b) { + test_name += " const_b"; + } test_case.set_name(test_name); // Set the operator name for the test case @@ -50,13 +54,16 @@ TestCase create_test_case_from_config( fp_memory_layout, DataGenType::RANDOM); - // Input tensor B (float/half) + // Input tensor B (float/half, or pre-quantized int8 for const_b) ValueSpec input_b( config.shape, - input_dtype, + const_b ? vkapi::kChar : input_dtype, storage_type, fp_memory_layout, - DataGenType::RANDOM); + const_b ? DataGenType::RANDINT8 : DataGenType::RANDOM); + if (const_b) { + input_b.set_constant(true); + } // Quantization parameters for input A float input_a_scale_val = 0.007843; // 2/255 approximately @@ -148,6 +155,13 @@ std::vector generate_q8ta_add_easy_cases() { /*input_dtype=*/vkapi::kFloat, /*fp_memory_layout=*/utils::kWidthPacked, quant_layout)); + test_cases.push_back(create_test_case_from_config( + config, + /*fp_storage_type=*/utils::kBuffer, + /*input_dtype=*/vkapi::kFloat, + /*fp_layout=*/utils::kWidthPacked, + quant_layout, + /*const_b=*/true)); } return test_cases; @@ -215,6 +229,13 @@ std::vector generate_q8ta_add_test_cases() { /*input_dtype=*/vkapi::kFloat, /*fp_memory_layout=*/utils::kWidthPacked, quant_layout)); + test_cases.push_back(create_test_case_from_config( + config, + /*fp_storage_type=*/utils::kBuffer, + /*fp_input_dtype=*/vkapi::kFloat, + /*fp_layout=*/utils::kWidthPacked, + quant_layout, + /*const_b=*/true)); } } @@ -261,9 +282,10 @@ void q8ta_add_reference_impl(TestCase& test_case) { throw std::invalid_argument("Unsupported dtype"); } + bool input_b_is_int8 = (input_b_spec.dtype == vkapi::kChar); + // Get raw data pointers auto& input_a_data = input_a_spec.get_float_data(); - auto& input_b_data = input_b_spec.get_float_data(); const float input_a_scale = input_a_scale_spec.get_float_value(); const int32_t input_a_zero_point = input_a_zero_point_spec.get_int_value(); @@ -284,11 +306,17 @@ void q8ta_add_reference_impl(TestCase& test_case) { quant_a_f = std::min(std::max(quant_a_f, -128.0f), 127.0f); int8_t quantized_a = static_cast(quant_a_f); - // Quantize input B to int8 - float quant_b_f = - std::round(input_b_data[i] / input_b_scale) + input_b_zero_point; - quant_b_f = std::min(std::max(quant_b_f, -128.0f), 127.0f); - int8_t quantized_b = static_cast(quant_b_f); + // Get quantized input B (either from pre-quantized int8 or by quantizing) + int8_t quantized_b; + if (input_b_is_int8) { + quantized_b = input_b_spec.get_int8_data()[i]; + } else { + float quant_b_f = + std::round(input_b_spec.get_float_data()[i] / input_b_scale) + + input_b_zero_point; + quant_b_f = std::min(std::max(quant_b_f, -128.0f), 127.0f); + quantized_b = static_cast(quant_b_f); + } // Dequantize both inputs to a common scale for addition float dequant_a =