Skip to content
180 changes: 73 additions & 107 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@

#define PACKED_INT8_OUTPUT_BUFFER

#define TILE_M4 1
#define TILE_N4 1
#define TILE_K4 1

#define TILE_M 4
#define TILE_N 4
#define TILE_K 4

layout(std430) buffer;

#include "indexing.glslh"
Expand All @@ -36,129 +28,103 @@ ${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
${layout_declare_spec_const(C, "int", "apply_bias", "1")}

// Layout specialization constants
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "im2col_outp_layout", "CONTIG_LAYOUT_INT")}

layout(push_constant) uniform restrict Block {
int zp;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "conv2d_int8_output_tile_store.glslh"

// Compute input tensor index from im2col coordinates
TensorIndex4D get_input_tidx(
const int im2col_w,
const int im2col_h,
const int k_in_group,
const int group_idx) {
TensorIndex4D tidx;
tidx.data.w = 0;
void main() {
const int out_buf_idx = int(gl_GlobalInvocationID.x);

const int c_in_group = k_in_group % conv2d_params.in_channels_per_group;
const int row = k_in_group / conv2d_params.in_channels_per_group;
const int kernel_x = row % conv2d_params.kernel_size.x;
const int kernel_y = row / conv2d_params.kernel_size.x;
// Extract sizes from BufferMetadata
const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]);
const ivec4 input_sizes = ivec4(inp.sizes[0]);

tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group;
// im2col block extents
const int im2col_W4 = div_up_4(im2col_sizes.x);
const int im2col_H = im2col_sizes.y;
const int im2col_Z4 = div_up_4(im2col_sizes.z);

tidx.data.x = (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x +
(kernel_x * conv2d_params.dilation.x);
tidx.data.y = (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y +
(kernel_y * conv2d_params.dilation.y);
// im2col block index from linear output buffer index
const int c4_idx = out_buf_idx % im2col_Z4;
const int row = out_buf_idx / im2col_Z4;
const int w4_idx = row % im2col_W4;
const int h_idx = row / im2col_W4;

return tidx;
}

// Load a single int8 value from the input tensor using layout-agnostic indexing
int load_input_element(const TensorIndex4D tidx, const int input_zp) {
// Bounds checking
if (any(lessThan(tidx.data, ivec4(0))) ||
any(greaterThanEqual(tidx.data, ivec4(inp.sizes[0])))) {
return input_zp;
// out of bounds check
if (w4_idx >= im2col_W4 || h_idx >= im2col_H || c4_idx >= im2col_Z4) {
return;
}

// Use layout-agnostic indexing to get buffer position
int texel_idx;
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
// For 4C or 4C1W layouts: use tensor4d_idx_to_texel_idx
texel_idx = tensor4d_idx_to_texel_idx(inp, tidx, inp_layout);
} else {
// For 4W4C layout: compute index directly
const int w4 = div_4(tidx.data[0]);
const int c4 = div_4(tidx.data[2]);
const int h_stride = int(inp.strides[0][1]);
const int w_stride = int(inp.strides[0][0]);
texel_idx = (tidx.data[1] * h_stride + w4 * w_stride + c4) * 4 + mod_4(tidx.data[0]);
}
const int im2col_w = mul_4(w4_idx);
const int im2col_h = h_idx;
const int im2col_k = mul_4(c4_idx);

// Load packed int32 containing 4 int8 values
const int packed_input = t_packed_int8_input[texel_idx];
const int group_idx = im2col_k / conv2d_params.K_per_group;
const int k_in_group = im2col_k % conv2d_params.K_per_group;

// Extract the appropriate int8 value based on channel offset within texel
const int c_offset = mod_4(tidx.data[2]);
return extract_8bit_from_packed_int_le(packed_input, c_offset);
}
const int c_in_group = k_in_group % conv2d_params.in_channels_per_group;
const int krow = k_in_group / conv2d_params.in_channels_per_group;
const int kernel_x = krow % conv2d_params.kernel_size.x;
const int kernel_y = krow / conv2d_params.kernel_size.x;

// Load a 4x4 im2col block (4 widths × 4 channels)
ivec4 load_im2col_block(
const int im2col_w_start,
const int im2col_h,
const int k_in_group_start,
const int group_idx) {
ivec4 im2col_block;
// Base input position
const int input_x_base =
(im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x +
(kernel_x * conv2d_params.dilation.x);
const int input_y =
(im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y +
(kernel_y * conv2d_params.dilation.y);
const int input_z =
group_idx * conv2d_params.in_channels_per_group + c_in_group;

for (int r = 0; r < 4; r++) {
const int im2col_w = im2col_w_start + r;
ivec4 row_values;
for (int c = 0; c < 4; c++) {
const int k_in_group = k_in_group_start + c;
// Input tensor extents
const int input_W = input_sizes.x;
const int input_H = input_sizes.y;
const int input_Z4 = div_up_4(input_sizes.z);

if (k_in_group >= conv2d_params.logical_K_per_group) {
row_values[c] = zp;
continue;
}
const int zp_packed = pack_into_int32(ivec4(zp));
const int z4 = div_4(input_z);

TensorIndex4D input_tidx =
get_input_tidx(im2col_w, im2col_h, k_in_group, group_idx);
// Check if y and z are in bounds (constant for all 4 width elements)
const bool y_z_in_bounds =
(input_y >= 0 && input_y < input_H && z4 >= 0 && z4 < input_Z4);

row_values[c] = load_input_element(input_tidx, zp);
// Load 4 elements from input, one for each output width position.
// Each loaded int contains 4 packed int8 channel values.
ivec4 im2col_block;
for (int i = 0; i < 4; i++) {
const int x = input_x_base + i * conv2d_params.stride.x;
if (!y_z_in_bounds || x < 0 || x >= input_W) {
im2col_block[i] = zp_packed;
} else {
const int x4 = div_4(x);
const int x_mod = mod_4(x);
int scalar_idx;
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
scalar_idx = input_y * int(inp.strides[0][1])
+ x * int(inp.strides[0][0])
+ z4 * int(inp.strides[0][2]);
} else {
scalar_idx = mul_4(
input_y * int(inp.strides[0][1])
+ x4 * int(inp.strides[0][0])
+ z4) + x_mod;
}
im2col_block[i] = t_packed_int8_input[scalar_idx];
}

im2col_block[r] = pack_into_int32(row_values);
}
return im2col_block;
}

void main() {
const int out_buf_idx = int(gl_GlobalInvocationID.x);

const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]);
Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes);

Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx(
out_buf_idx, im2col_block_extents);
// store_packed_int8_output_tile (with TILE_M4=1, TILE_N4=1)
const int buffer_idx = h_idx * int(im2col_outp.strides[0][1])
+ w4_idx * int(im2col_outp.strides[0][0])
+ c4_idx;

if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) {
return;
if (w4_idx < im2col_W4 && c4_idx < im2col_Z4) {
t_packed_int8_output[buffer_idx] = im2col_block;
}

// Convert block index to im2col coordinates
const int im2col_w = mul_4(im2col_block_idx.data.x);
const int im2col_h = im2col_block_idx.data.y;
const int im2col_k = mul_4(im2col_block_idx.data.z);

// Compute group and k offset within group
const int group_idx = im2col_k / conv2d_params.K_per_group;
const int k_in_group = im2col_k % conv2d_params.K_per_group;

// Load the im2col block using layout-agnostic input access
Int8OutTile int8_im2col_tile;
int8_im2col_tile.data[0][0] = load_im2col_block(
im2col_w, im2col_h, k_in_group, group_idx);

// Store to output (4W4C format)
store_packed_int8_output_tile(
int8_im2col_tile, im2col_block_idx, im2col_block_extents);
}
130 changes: 0 additions & 130 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl

This file was deleted.

11 changes: 0 additions & 11 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void add_q8ta_im2col_node(
// The implementation also requires that input channels is a multiple of 4
VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0);

std::string kernel_name = "q8ta_im2col_4w4c";
std::string kernel_name = "q8ta_im2col";

vkapi::ParamsBindList param_buffers = {
graph.buffer_meta_ubo(packed_int8_im2col),
Expand Down
6 changes: 2 additions & 4 deletions backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ std::vector<TestCase> generate_quantized_conv2d_easy_cases() {

// Test im2col implementation for non-grouped convolutions with input
// channels that are a multiple of 4 and stride_w == 1
if (config.groups == 1 && config.channels.in % 4 == 0 &&
config.stride.w == 1) {
if (config.groups == 1 && config.channels.in % 4 == 0) {
test_cases.push_back(create_test_case_from_config(
config,
vkapi::kFloat,
Expand Down Expand Up @@ -417,8 +416,7 @@ static std::vector<TestCase> generate_quantized_conv2d_test_cases() {

// Test im2col implementation for non-grouped convolutions with input
// channels that are a multiple of 4 and stride_w == 1
if (config.groups == 1 && config.channels.in % 4 == 0 &&
config.stride.w == 1) {
if (config.groups == 1 && config.channels.in % 4 == 0) {
test_cases.push_back(create_test_case_from_config(
config,
vkapi::kFloat,
Expand Down
Loading