From 599e5fb5c94ea07110f30eb58ab5d764da68a86a Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 27 Mar 2026 11:21:57 -0700 Subject: [PATCH 1/3] [ET-VK][conv1d] Implement height-packed pointwise conv1d operator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/18332 Implement a new conv1d pointwise (kernel_size=1) operator using height-packed layout where channels are the packed dimension (WHCN dim 1). This enables dot-product reduction over input channels: each vec4 load gives 4 consecutive channel values, yielding 4 MACs per dot() instruction. Uses tiled computation with the FP tile infrastructure from linear/matmul (FPInputTile, FPWeightTile, FPOutTile, fp_accumulate_with_fp_weight) and 4OC×4IC blocked weight packing via pack_fp_linear_weight.glsl for cache-friendly texture2d weight reads. Adaptive tile_m selection (4/2/1 rows) based on GPU occupancy. Thread mapping: X=OC4 tiles, Y=L tiles, Z=batch. Each thread computes TILE_M×TILE_N4×4 output elements. Inner loop loads input tiles and packed weight tiles, then calls fp_accumulate_with_fp_weight for tiled FMA. Supports both buffer and texture3d storage for input/output, texture2d or buffer for packed weights, fp32/fp16, and optional bias. Registered as et_vk.conv1d_pw.default (standalone custom op for testing/benchmarking). Performance on Adreno 750 (S24): - [1,256,1024]x[512,256,1] texture f16: 908 GFLOP/s - [1,512,2048]x[256,512,1] texture f16: 865 GFLOP/s - [1,128,4096]x[128,128,1] texture f16: 781 GFLOP/s - [1,256,1024]x[512,256,1] buffer f16: 491 GFLOP/s ghstack-source-id: 358903218 @exported-using-ghexport Differential Revision: [D97344092](https://our.internmc.facebook.com/intern/diff/D97344092/) --- .../runtime/graph/ops/glsl/conv1d_pw.glsl | 274 ++++++++++++++++++ .../runtime/graph/ops/glsl/conv1d_pw.yaml | 31 ++ .../runtime/graph/ops/impl/Conv1dPW.cpp | 238 +++++++++++++++ .../test/custom_ops/impl/TestConv1dPW.cpp | 46 +++ backends/vulkan/test/custom_ops/targets.bzl | 1 + .../vulkan/test/custom_ops/test_conv1d_pw.cpp | 248 ++++++++++++++++ 6 files changed, 838 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp create mode 100644 backends/vulkan/test/custom_ops/impl/TestConv1dPW.cpp create mode 100644 backends/vulkan/test/custom_ops/test_conv1d_pw.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl new file mode 100644 index 00000000000..4ebca126e03 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl @@ -0,0 +1,274 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +$if STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER + #define SCALAR_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER +$if HAS_BIAS: + #define HAS_BIAS +$if STORAGE == "buffer" and HAS_BIAS: + #define BIAS_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(STORAGE, DTYPE)} +$if WEIGHT_STORAGE != STORAGE: + ${define_required_extensions(WEIGHT_STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +$if STORAGE == "buffer": + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=True)} + ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=True)} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, WEIGHT_STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + $if STORAGE == "buffer": + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=True)} + $else: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)} + +// in_sizes: {L, C_in, N, 1} in WHCN order +${layout_declare_ubo(B, "ivec4", "in_sizes")} +// out_sizes: {L, C_out, N, 1} in WHCN order +${layout_declare_ubo(B, "ivec4", "out_sizes")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + +layout(push_constant) uniform restrict Block { + int weight_B; + float output_min; + float output_max; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_fp_input_tile.glslh" +#include "linear_fp_weight_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_packed_weight_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" + +// Conv1d pointwise is matrix multiplication with swapped texture coordinates. +// Linear: input ivec3(k4, m, b), output ivec3(n4, m, b) [width-packed] +// Conv1d: input ivec3(m, k4, b), output ivec3(m, n4, b) [height-packed] +// +// For buffer storage, height-packed tensors have packed_dim_block_size=1 (no +// vec4 grouping). Data is stored as contiguous scalars with strides based on +// logical sizes, so scalar indexing is required: (b * M + m) * C + c. +// For texture storage, 4 channels are packed per texel as usual. + +#ifndef SCALAR_BUFFER +VEC4_T load_input_x4( + const int k4, + const int m, + const int b, + const int K4, + const int M) { +#ifdef INPUT_BUFFER + return t_in[(b * M + m) * K4 + k4]; +#else + return texelFetch(t_in, ivec3(m, k4, b), 0); +#endif +} + +void load_input_tile_with_checks( + out FPInputTile tile, + const int k4_start, + const int m_start, + const int b, + const int K4, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (k4_start + k4 < K4 && m_start + m < M) { + tile.data[m][k4] = + load_input_x4(k4_start + k4, m_start + m, b, K4, M); + } else { + tile.data[m][k4] = VEC4_T(0.0); + } + } + } +} + +void store_output_x4( + const VEC4_T texel, + const int n4, + const int m, + const int b, + const int N4, + const int M) { +#ifdef OUTPUT_BUFFER + t_out[(b * M + m) * N4 + n4] = texel; +#else + imageStore(t_out, ivec3(m, n4, b), texel); +#endif +} + +void store_output_tile_with_checks( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int b, + const int N4, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + store_output_x4( + out_tile.data[m][n4], n4_start + n4, m_start + m, b, N4, M); + } + } + } +} +#endif // !SCALAR_BUFFER + +#ifdef SCALAR_BUFFER +void load_input_tile_scalar( + out FPInputTile tile, + const int k4_start, + const int m_start, + const int b, + const int K4, + const int K, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (k4_start + k4 < K4 && m_start + m < M) { + const int base = (b * M + m_start + m) * K + mul_4(k4_start + k4); + T s0 = t_in[base]; + T s1 = (mul_4(k4_start + k4) + 1 < K) ? t_in[base + 1] : T(0); + T s2 = (mul_4(k4_start + k4) + 2 < K) ? t_in[base + 2] : T(0); + T s3 = (mul_4(k4_start + k4) + 3 < K) ? t_in[base + 3] : T(0); + tile.data[m][k4] = VEC4_T(s0, s1, s2, s3); + } else { + tile.data[m][k4] = VEC4_T(0.0); + } + } + } +} + +void store_output_tile_scalar( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int b, + const int N4, + const int N, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + const int base = (b * M + m_start + m) * N + mul_4(n4_start + n4); + const VEC4_T val = out_tile.data[m][n4]; + t_out[base] = val.x; + if (mul_4(n4_start + n4) + 1 < N) t_out[base + 1] = val.y; + if (mul_4(n4_start + n4) + 2 < N) t_out[base + 2] = val.z; + if (mul_4(n4_start + n4) + 3 < N) t_out[base + 3] = val.w; + } + } + } +} +#endif // SCALAR_BUFFER + +void main() { + // Thread mapping: X=OC4 (N4), Y=L/tile_m (M tiles), Z=batch + const int tile_idx_n = int(gl_GlobalInvocationID.x); + const int tile_idx_m = int(gl_GlobalInvocationID.y); + + const int n4_start = tile_idx_n * TILE_N4; + const int m_start = tile_idx_m * TILE_M; + + // in_sizes: {L, C_in, N, 1} in WHCN + const int K = in_sizes.y; // C_in + const int M = in_sizes.x; // L + const int K4 = div_up_4(K); + // out_sizes: {L, C_out, N, 1} in WHCN + const int N_out = out_sizes.y; // C_out + const int N4 = div_up_4(N_out); + + if (n4_start >= N4 || m_start >= M) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + FPWeightTile w_tile; + + const int b = int(gl_GlobalInvocationID.z); + + for (int k4 = 0; k4 < K4; k4++) { +#ifdef SCALAR_BUFFER + load_input_tile_scalar(in_tile, k4, m_start, b, K4, K, M); +#else + load_input_tile_with_checks(in_tile, k4, m_start, b, K4, M); +#endif + load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4); + fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile); + } + +#ifdef HAS_BIAS + // Load bias (per output channel) and apply + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + VEC4_T bias_val = VEC4_T(0.0); + if (n4_start + n4 < N4) { +#ifdef BIAS_BUFFER + // Bias is a 1D tensor [C_out], width-packed. + // For buffer storage, width-packed has packed_dim_block_size=1, so data + // is stored as contiguous scalars. Read 4 with bounds checking. + const int bias_base = mul_4(n4_start + n4); + T b0 = t_bias[bias_base]; + T b1 = (bias_base + 1 < N_out) ? t_bias[bias_base + 1] : T(0); + T b2 = (bias_base + 2 < N_out) ? t_bias[bias_base + 2] : T(0); + T b3 = (bias_base + 3 < N_out) ? t_bias[bias_base + 3] : T(0); + bias_val = VEC4_T(b0, b1, b2, b3); +#else + bias_val = texelFetch(t_bias, ivec3(n4_start + n4, 0, 0), 0); +#endif + } + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][n4] = out_tile.data[m][n4] + bias_val; + } + } +#endif + + // Apply activation clamp + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = + clamp(out_tile.data[m][n4], VEC4_T(output_min), VEC4_T(output_max)); + } + } + +#ifdef SCALAR_BUFFER + store_output_tile_scalar(out_tile, n4_start, m_start, b, N4, N_out, M); +#else + store_output_tile_with_checks(out_tile, n4_start, m_start, b, N4, M); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml new file mode 100644 index 00000000000..c69f48deb46 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_pw: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + WEIGHT_STORAGE: texture2d + HAS_BIAS: false + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + TILE_M: 4 + generate_variant_forall: + combination: + parameter_names: [STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [texture3d, texture2d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, texture2d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: conv1d_pw + - NAME: conv1d_pw_bias + HAS_BIAS: true diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp new file mode 100644 index 00000000000..f6db56fc581 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include + +#include + +#include + +namespace vkcompute { + +// Must match TILE_M default in conv1d_pw.yaml. +static constexpr uint32_t kTileM = 4; + +// Prepack conv1d_pw weight [C_out, C_in, 1] into 4OC x 4IC blocked layout. +// This is equivalent to prepack_fp_linear_weight with N=C_out, K=C_in, +// is_transposed=true, but extracts dimensions from the conv weight shape. +static ValueRef prepack_conv1d_pw_weight( + ComputeGraph& graph, + const ValueRef weight_data) { + std::vector weight_sizes = graph.sizes_of(weight_data); + // weight is [C_out, C_in, 1] + int64_t N = weight_sizes.at(0); // C_out + int64_t K = weight_sizes.at(1); // C_in + + int64_t K4 = utils::div_up(K, int64_t(4)); + int64_t N4 = utils::div_up(N, int64_t(4)); + + // Packed tensor: K4 rows, N4*4 vec4 elements per row. + int64_t output_height = K4; + int64_t output_width = N4 * 4 * 4; + + utils::StorageType weight_storage = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width / 4 > max_extent || + static_cast(output_height) > max_extent) { + weight_storage = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + {output_height, output_width}, + graph.dtype_of(weight_data), + weight_storage, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(N4), + utils::safe_downcast(K4), + 1u}; + + struct PackParams { + int32_t N; + int32_t K; + int32_t B; + int32_t is_transposed; + }; + PackParams pack_params{ + utils::safe_downcast(N), utils::safe_downcast(K), 1, 1}; + + std::string kernel_name = "pack_fp_linear_weight"; + add_storage_type_suffix(kernel_name, weight_storage); + add_dtype_suffix(kernel_name, graph.dtype_of(weight_data)); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + weight_data, + packed_weight, + {}, + {}, + {PushConstantDataInfo(&pack_params, sizeof(PackParams))})); + + return packed_weight; +} + +void resize_conv1d_pw_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + + const int64_t C_out = graph->get_int(extra_args.at(0)); + + const std::vector in_sizes = graph->sizes_of(self); + const int64_t N_batch = in_sizes.at(0); + const int64_t L = in_sizes.at(2); + + graph->virtual_resize(out, {N_batch, C_out, L}); +} + +struct Conv1dPWParams final { + int32_t weight_B; + float output_min; + float output_max; +}; + +vkapi::ShaderInfo pick_conv1d_pw_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef packed_weight = args.at(1).refs.at(1); + bool has_bias = graph->get_bool(resize_args.at(1)); + + std::string kernel_name = has_bias ? "conv1d_pw_bias" : "conv1d_pw"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 pick_conv1d_pw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + // out is [N_batch, C_out, L]; in WHCN: {L, C_out, N_batch, 1} + uint32_t C_out = graph->size_at(-2, out); + uint32_t L = graph->size_at(-1, out); + uint32_t N_batch = + graph->dim_of(out) >= 3 ? graph->size_at(-3, out) : 1; + + // X=OC4 (div_up_4(C_out)), Y=L/tile_m, Z=N_batch + return {utils::div_up_4(C_out), utils::div_up(L, kTileM), N_batch}; +} + +void add_conv1d_pw_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef out, + const float output_min = std::numeric_limits::lowest(), + const float output_max = std::numeric_limits::max()) { + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kHeightDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kHeightDim); + + ValueRef packed_weight = prepack_conv1d_pw_weight(graph, weight_data); + + bool has_bias = graph.val_is_not_none(bias); + ValueRef packed_bias = kDummyValueRef; + if (has_bias) { + packed_bias = prepack_standard( + graph, bias, graph.storage_type_of(out), utils::kWidthPacked); + } + + std::vector out_sizes = graph.sizes_of(out); + int64_t C_out = out_sizes.at(1); + ValueRef C_out_ref = graph.add_scalar(C_out); + ValueRef has_bias_ref = graph.add_scalar(has_bias); + + Conv1dPWParams params{1, output_min, output_max}; + + std::vector read_inputs = {in, packed_weight}; + if (has_bias) { + read_inputs.push_back(packed_bias); + } + + std::vector push_constants; + push_constants.push_back( + PushConstantDataInfo(¶ms, sizeof(Conv1dPWParams))); + + vkapi::ParamsBindList shader_params = { + graph.sizes_ubo(in), graph.sizes_ubo(out)}; + if (has_bias) { + shader_params.append(graph.sizes_ubo(packed_bias)); + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_conv1d_pw_shader, + pick_conv1d_pw_global_wg_size, + pick_hw_square_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {read_inputs, vkapi::kRead}}, + // Shader params buffers + shader_params, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {C_out_ref, has_bias_ref}, + // Resizing Logic + resize_conv1d_pw_node)); +} + +// Args: in, weight, bias, stride, padding, dilation, groups, +// output_min, output_max, out +// output_min and output_max may be kDummyValueRef (no clamp). +void conv1d_pw(ComputeGraph& graph, const std::vector& args) { + ValueRef in = args[0]; + ValueRef weight = args[1]; + ValueRef bias = args[2]; + ValueRef out = args[9]; + + const std::vector weight_sizes = graph.sizes_of(weight); + VK_CHECK_COND( + weight_sizes.at(2) == 1, "conv1d_pw only supports kernel_size=1"); + VK_CHECK_COND( + graph.get_int(args[6]) == 1, "conv1d_pw only supports groups=1"); + + float output_min = std::numeric_limits::lowest(); + float output_max = std::numeric_limits::max(); + if (is_valid(args[7])) { + output_min = graph.extract_scalar(args[7]); + } + if (is_valid(args[8])) { + output_max = graph.extract_scalar(args[8]); + } + + add_conv1d_pw_node(graph, in, weight, bias, out, output_min, output_max); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.conv1d_pw.default, conv1d_pw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestConv1dPW.cpp b/backends/vulkan/test/custom_ops/impl/TestConv1dPW.cpp new file mode 100644 index 00000000000..99ccf082b1e --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestConv1dPW.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +void test_conv1d_pw(ComputeGraph& graph, const std::vector& args) { + // args: in, weight, bias, stride, padding, dilation, groups, out + const ValueRef input = args.at(0); + const ValueRef weight = args.at(1); + const ValueRef bias = args.at(2); + const ValueRef stride = args.at(3); + const ValueRef padding = args.at(4); + const ValueRef dilation = args.at(5); + const ValueRef groups = args.at(6); + const ValueRef out = args.at(7); + + // conv1d_pw expects: in, weight, bias, stride, padding, dilation, groups, + // output_min, output_max, out + VK_GET_OP_FN("et_vk.conv1d_pw.default") + (graph, + {input, + weight, + bias, + stride, + padding, + dilation, + groups, + kDummyValueRef, + kDummyValueRef, + out}); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.test_conv1d_pw.default, test_conv1d_pw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index e8cdb3a4bf8..d535ca2661c 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -103,3 +103,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_conv2d_pw") define_custom_op_test_binary("test_conv2d_dw") define_custom_op_test_binary("test_embedding_q4gsw") + define_custom_op_test_binary("test_conv1d_pw") diff --git a/backends/vulkan/test/custom_ops/test_conv1d_pw.cpp b/backends/vulkan/test/custom_ops/test_conv1d_pw.cpp new file mode 100644 index 00000000000..632224c478d --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_conv1d_pw.cpp @@ -0,0 +1,248 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 256; + +struct Conv1dPWConfig { + int64_t N; + int64_t C_in; + int64_t C_out; + int64_t L; + bool has_bias; +}; + +static TestCase create_conv1d_pw_test_case( + const Conv1dPWConfig& config, + vkapi::ScalarType dtype, + utils::StorageType storage_type) { + TestCase test_case; + + bool is_perf = config.C_in > kRefDimSizeLimit || + config.C_out > kRefDimSizeLimit || config.L > kRefDimSizeLimit; + + std::string prefix = is_perf ? "PERF" : "ACCU"; + std::string storage_str = storage_type_abbrev(storage_type); + std::string dtype_str = (dtype == vkapi::kHalf) ? "f16" : "f32"; + + std::string bias_str = config.has_bias ? "+bias" : ""; + + std::string name = prefix + " conv1d_pw" + bias_str + " [" + + std::to_string(config.N) + "," + std::to_string(config.C_in) + "," + + std::to_string(config.L) + "]x[" + std::to_string(config.C_out) + "," + + std::to_string(config.C_in) + ",1] " + storage_str + "(HP) " + dtype_str; + + test_case.set_name(name); + test_case.set_operator_name("test_etvk.test_conv1d_pw.default"); + + // Input: [N, C_in, L] height-packed + ValueSpec input( + {config.N, config.C_in, config.L}, + dtype, + storage_type, + utils::kHeightPacked, + DataGenType::RANDOM); + test_case.add_input_spec(input); + + // Weight: [C_out, C_in, 1] height-packed, constant + ValueSpec weight( + {config.C_out, config.C_in, 1}, + dtype, + storage_type, + utils::kHeightPacked, + DataGenType::RANDOM); + weight.set_constant(true); + test_case.add_input_spec(weight); + + // Bias: [C_out] or None + if (config.has_bias) { + ValueSpec bias( + {config.C_out}, + dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + test_case.add_input_spec(bias); + } else { + ValueSpec none_bias(static_cast(0)); + none_bias.set_none(true); + test_case.add_input_spec(none_bias); + } + + // stride = [1] + test_case.add_input_spec(ValueSpec(std::vector{1})); + // padding = [0] + test_case.add_input_spec(ValueSpec(std::vector{0})); + // dilation = [1] + test_case.add_input_spec(ValueSpec(std::vector{1})); + // groups = 1 + test_case.add_input_spec(ValueSpec(static_cast(1))); + + // Output: [N, C_out, L] height-packed + ValueSpec output( + {config.N, config.C_out, config.L}, + dtype, + storage_type, + utils::kHeightPacked, + DataGenType::ZEROS); + test_case.add_output_spec(output); + + if (dtype == vkapi::kHalf) { + test_case.set_abs_tolerance(1e-1f); + test_case.set_rel_tolerance(1e-2f); + } else { + test_case.set_abs_tolerance(1e-3f); + test_case.set_rel_tolerance(1e-3f); + } + + test_case.set_shader_filter({"nchw_to", "to_nchw", "view_copy"}); + + return test_case; +} + +static void conv1d_pw_reference_impl(TestCase& test_case) { + const auto& input_spec = test_case.inputs()[0]; + const auto& weight_spec = test_case.inputs()[1]; + const auto& bias_spec = test_case.inputs()[2]; + ValueSpec& output = test_case.outputs()[0]; + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Reference only supports float"); + } + + auto in_sizes = input_spec.get_tensor_sizes(); + auto w_sizes = weight_spec.get_tensor_sizes(); + + int64_t N = in_sizes[0]; + int64_t C_in = in_sizes[1]; + int64_t L = in_sizes[2]; + int64_t C_out = w_sizes[0]; + + const auto& in_data = input_spec.get_float_data(); + const auto& w_data = weight_spec.get_float_data(); + auto& ref_data = output.get_ref_float_data(); + ref_data.resize(N * C_out * L, 0.0f); + + // input is NCHW-contiguous: [N, C_in, L] + // weight is [C_out, C_in, 1] + for (int64_t n = 0; n < N; ++n) { + for (int64_t oc = 0; oc < C_out; ++oc) { + for (int64_t l = 0; l < L; ++l) { + float sum = 0.0f; + for (int64_t ic = 0; ic < C_in; ++ic) { + sum += in_data[n * C_in * L + ic * L + l] * w_data[oc * C_in + ic]; + } + ref_data[n * C_out * L + oc * L + l] = sum; + } + } + } + + if (!bias_spec.is_none()) { + const auto& bias_data = bias_spec.get_float_data(); + for (int64_t n = 0; n < N; ++n) { + for (int64_t oc = 0; oc < C_out; ++oc) { + for (int64_t l = 0; l < L; ++l) { + ref_data[n * C_out * L + oc * L + l] += bias_data[oc]; + } + } + } + } +} + +static std::vector generate_conv1d_pw_test_cases() { + std::vector test_cases; + + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Accuracy shapes (float, small) + std::vector accu_configs = { + {1, 16, 32, 64, false}, + {1, 16, 32, 64, true}, + {1, 32, 16, 128, false}, + {1, 32, 16, 128, true}, + {1, 64, 64, 32, false}, + {1, 128, 256, 16, true}, + {2, 16, 32, 64, false}, + {2, 16, 32, 64, true}, + // Non-aligned channel counts (not a multiple of 4) + {1, 5, 7, 64, false}, + {1, 5, 7, 64, true}, + {1, 13, 17, 48, false}, + {1, 13, 17, 48, true}, + {1, 7, 5, 32, false}, + {2, 5, 13, 64, true}, + }; + + for (const auto& cfg : accu_configs) { + for (auto st : storage_types) { + test_cases.push_back(create_conv1d_pw_test_case(cfg, vkapi::kFloat, st)); + } + } + + // Performance shapes (half + float) + std::vector perf_configs = { + {1, 256, 512, 1024, false}, + {1, 256, 512, 1024, true}, + {1, 512, 256, 2048, false}, + {1, 128, 128, 4096, true}, + }; + + for (const auto& cfg : perf_configs) { + for (auto st : storage_types) { + test_cases.push_back(create_conv1d_pw_test_case(cfg, vkapi::kFloat, st)); + test_cases.push_back(create_conv1d_pw_test_case(cfg, vkapi::kHalf, st)); + } + } + + return test_cases; +} + +static int64_t conv1d_pw_flop_calculator(const TestCase& test_case) { + auto in_sizes = test_case.inputs()[0].get_tensor_sizes(); + auto w_sizes = test_case.inputs()[1].get_tensor_sizes(); + + int64_t N = in_sizes[0]; + int64_t C_in = in_sizes[1]; + int64_t L = in_sizes[2]; + int64_t C_out = w_sizes[0]; + + return 2 * N * C_in * C_out * L; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Conv1d Pointwise (Height-Packed) Benchmark" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = conv1d_pw_reference_impl; + + auto results = execute_test_cases( + generate_conv1d_pw_test_cases, + conv1d_pw_flop_calculator, + "Conv1dPW", + 3, + 10, + ref_fn); + + return 0; +} From 8843c4837f6e32d9598fd97f72ca8c0f72c218cb Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 27 Mar 2026 11:21:59 -0700 Subject: [PATCH 2/3] [ET-VK][conv1d] Implement height-packed depthwise conv1d operator Pull Request resolved: https://github.com/pytorch/executorch/pull/18333 Implement a depthwise conv1d operator using height-packed layout where channels are the packed dimension (WHCN dim 1). Depthwise conv applies a separate filter to each channel independently (groups=C), so 4 channels can be processed in parallel using element-wise vec4 FMA over kernel positions. Thread mapping: X=C/4, Y=L_out, Z=N. Each thread computes one output texel (4 channels at one spatial position). Inner loop iterates over kernel positions K with bounds-checked input access for padding. Weight [C,1,K] is prepacked as channels-packed so each vec4 load gives 4 channels' weights at one kernel position. Supports both buffer and texture3d storage, fp32/fp16, optional bias, and arbitrary stride/padding/dilation. Registered as et_vk.conv1d_dw.default (standalone custom op). Performance on Adreno 750 (S24): - [1,128,4096] K=31 buffer f16: 231 GFLOP/s - [1,128,4096] K=31 buffer f32: 155 GFLOP/s - [1,512,2048] K=5 buffer f32: 66 GFLOP/s ghstack-source-id: 358903219 @exported-using-ghexport Differential Revision: [D97344091](https://our.internmc.facebook.com/intern/diff/D97344091/) --- .../runtime/graph/ops/glsl/conv1d_dw.glsl | 127 +++++++++ .../runtime/graph/ops/glsl/conv1d_dw.yaml | 22 ++ .../runtime/graph/ops/impl/Conv1dDW.cpp | 188 ++++++++++++ .../test/custom_ops/impl/TestConv1dDW.cpp | 46 +++ backends/vulkan/test/custom_ops/targets.bzl | 1 + .../vulkan/test/custom_ops/test_conv1d_dw.cpp | 267 ++++++++++++++++++ 6 files changed, 651 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp create mode 100644 backends/vulkan/test/custom_ops/impl/TestConv1dDW.cpp create mode 100644 backends/vulkan/test/custom_ops/test_conv1d_dw.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl new file mode 100644 index 00000000000..7ea068af93c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl @@ -0,0 +1,127 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +$if STORAGE == "buffer": + #define BUFFER + #define SCALAR_BUFFER +$if HAS_BIAS: + #define HAS_BIAS + +${define_required_extensions(STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +$if STORAGE == "buffer": + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=True)} + ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=True)} + ${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE, is_scalar_array=True)} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + $if STORAGE == "buffer": + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=True)} + $else: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)} + +// in_sizes: {L_in, C, N, 1} in WHCN order +${layout_declare_ubo(B, "ivec4", "in_sizes")} +// out_sizes: {L_out, C, N, 1} in WHCN order +${layout_declare_ubo(B, "ivec4", "out_sizes")} + +layout(push_constant) uniform restrict Block { + int kernel_size; + int stride; + int padding; + int dilation; + float output_min; + float output_max; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Thread mapping: X = C/4, Y = L_out, Z = N +// Each thread computes 4 output channels at one spatial position. +// Depthwise: each channel has its own filter, so 4 channels can be computed +// independently with element-wise vec4 FMA. + +void main() { + const int c4 = int(gl_GlobalInvocationID.x); + const int l_out = int(gl_GlobalInvocationID.y); + const int n = int(gl_GlobalInvocationID.z); + + const int L_in = in_sizes.x; + const int C = in_sizes.y; + const int C4 = div_up_4(C); + const int L_out = out_sizes.x; + + if (c4 >= C4 || l_out >= L_out) { + return; + } + + VEC4_T sum = VEC4_T(0); + + for (int k = 0; k < kernel_size; k++) { + const int l_in = l_out * stride - padding + k * dilation; + if (l_in >= 0 && l_in < L_in) { +#ifdef BUFFER + const int in_base = (n * L_in + l_in) * C + c4 * 4; + T in_s0 = t_in[in_base]; + T in_s1 = (c4 * 4 + 1 < C) ? t_in[in_base + 1] : T(0); + T in_s2 = (c4 * 4 + 2 < C) ? t_in[in_base + 2] : T(0); + T in_s3 = (c4 * 4 + 3 < C) ? t_in[in_base + 3] : T(0); + const VEC4_T in_val = VEC4_T(in_s0, in_s1, in_s2, in_s3); + + const int w_base = k * C + c4 * 4; + T w_s0 = t_weight[w_base]; + T w_s1 = (c4 * 4 + 1 < C) ? t_weight[w_base + 1] : T(0); + T w_s2 = (c4 * 4 + 2 < C) ? t_weight[w_base + 2] : T(0); + T w_s3 = (c4 * 4 + 3 < C) ? t_weight[w_base + 3] : T(0); + const VEC4_T w_val = VEC4_T(w_s0, w_s1, w_s2, w_s3); +#else + const VEC4_T in_val = texelFetch(t_in, ivec3(l_in, c4, n), 0); + const VEC4_T w_val = texelFetch(t_weight, ivec3(k, 0, c4), 0); +#endif + sum = fma(w_val, in_val, sum); + } + } + +#ifdef HAS_BIAS +#ifdef BUFFER + const int bias_base = c4 * 4; + T b0 = t_bias[bias_base]; + T b1 = (bias_base + 1 < C) ? t_bias[bias_base + 1] : T(0); + T b2 = (bias_base + 2 < C) ? t_bias[bias_base + 2] : T(0); + T b3 = (bias_base + 3 < C) ? t_bias[bias_base + 3] : T(0); + sum += VEC4_T(b0, b1, b2, b3); +#else + sum += texelFetch(t_bias, ivec3(c4, 0, 0), 0); +#endif +#endif + + sum = clamp(sum, VEC4_T(output_min), VEC4_T(output_max)); + +#ifdef BUFFER + const int out_base = (n * L_out + l_out) * C + c4 * 4; + t_out[out_base] = sum.x; + if (c4 * 4 + 1 < C) t_out[out_base + 1] = sum.y; + if (c4 * 4 + 2 < C) t_out[out_base + 2] = sum.z; + if (c4 * 4 + 3 < C) t_out[out_base + 3] = sum.w; +#else + imageStore(t_out, ivec3(l_out, c4, n), sum); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml new file mode 100644 index 00000000000..883ad8899ea --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_dw: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + HAS_BIAS: false + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: conv1d_dw + - NAME: conv1d_dw_bias + HAS_BIAS: true diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp new file mode 100644 index 00000000000..88d421e6994 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +void resize_conv1d_dw_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + + TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0)); + + const int64_t stride = graph->get_int_list(extra_args.at(1))->at(0); + const int64_t padding = graph->get_int_list(extra_args.at(2))->at(0); + const int64_t dilation = graph->get_int_list(extra_args.at(3))->at(0); + + const std::vector in_sizes = graph->sizes_of(self); + const int64_t kernel_size = weight_ref->sizes.at(2); + const int64_t L_in = in_sizes.at(2); + + const int64_t L_out = + calc_out_size(L_in, kernel_size, stride, padding, dilation, false); + + graph->virtual_resize(out, {in_sizes.at(0), in_sizes.at(1), L_out}); +} + +struct Conv1dDWParams final { + int32_t kernel_size; + int32_t stride; + int32_t padding; + int32_t dilation; +}; + +struct Conv1dDWClampParams final { + float output_min; + float output_max; +}; + +utils::uvec3 pick_conv1d_dw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + // out is [N, C, L_out]; in WHCN: {L_out, C, N, 1} + const uint32_t C = graph->size_at(-2, out); + const uint32_t L_out = graph->size_at(-1, out); + const uint32_t N = + graph->dim_of(out) >= 3 ? graph->size_at(-3, out) : 1; + + return {utils::div_up_4(C), L_out, N}; +} + +void add_conv1d_dw_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride_ref, + const ValueRef padding_ref, + const ValueRef dilation_ref, + const ValueRef out, + const float output_min = std::numeric_limits::lowest(), + const float output_max = std::numeric_limits::max()) { + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kHeightDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kHeightDim); + + const utils::StorageType storage_type = graph.storage_type_of(out); + + // Weight [C, 1, K] prepacked as channels-packed so each vec4 load gives + // 4 channels at one kernel position. + ValueRef packed_weight = prepack_standard( + graph, weight_data, storage_type, utils::kChannelsPacked); + + bool has_bias = graph.val_is_not_none(bias); + ValueRef packed_bias = kDummyValueRef; + if (has_bias) { + packed_bias = + prepack_standard(graph, bias, storage_type, utils::kWidthPacked); + } + + const auto stride_val = graph.get_int_list(stride_ref)->at(0); + const auto padding_val = graph.get_int_list(padding_ref)->at(0); + const auto dilation_val = graph.get_int_list(dilation_ref)->at(0); + + Conv1dDWParams params{ + utils::safe_downcast(graph.get_tref(weight_data)->sizes.at(2)), + utils::safe_downcast(stride_val), + utils::safe_downcast(padding_val), + utils::safe_downcast(dilation_val), + }; + + Conv1dDWClampParams clamp_params{ + output_min, + output_max, + }; + + std::string kernel_name = has_bias ? "conv1d_dw_bias" : "conv1d_dw"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, storage_type); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + std::vector read_inputs = {in, packed_weight}; + if (has_bias) { + read_inputs.push_back(packed_bias); + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_conv1d_dw_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {read_inputs, vkapi::kRead}}, + // Shader params buffers + {graph.sizes_ubo(in), graph.sizes_ubo(out)}, + // Push Constants + {PushConstantDataInfo(¶ms, sizeof(Conv1dDWParams)), + PushConstantDataInfo(&clamp_params, sizeof(Conv1dDWClampParams))}, + // Specialization Constants + {}, + // Resize Args + {weight_data, stride_ref, padding_ref, dilation_ref}, + // Resizing Logic + resize_conv1d_dw_node)); +} + +// Args: in, weight, bias, stride, padding, dilation, groups, +// output_min, output_max, out +// output_min and output_max may be kDummyValueRef (no clamp). +void conv1d_dw(ComputeGraph& graph, const std::vector& args) { + ValueRef in = args[0]; + ValueRef weight = args[1]; + ValueRef bias = args[2]; + ValueRef stride = args[3]; + ValueRef padding = args[4]; + ValueRef dilation = args[5]; + ValueRef out = args[9]; + + float output_min = std::numeric_limits::lowest(); + float output_max = std::numeric_limits::max(); + if (is_valid(args[7])) { + output_min = graph.extract_scalar(args[7]); + } + if (is_valid(args[8])) { + output_max = graph.extract_scalar(args[8]); + } + + add_conv1d_dw_node( + graph, + in, + weight, + bias, + stride, + padding, + dilation, + out, + output_min, + output_max); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.conv1d_dw.default, conv1d_dw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestConv1dDW.cpp b/backends/vulkan/test/custom_ops/impl/TestConv1dDW.cpp new file mode 100644 index 00000000000..15923462a20 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestConv1dDW.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +void test_conv1d_dw(ComputeGraph& graph, const std::vector& args) { + // args: in, weight, bias, stride, padding, dilation, groups, out + const ValueRef input = args.at(0); + const ValueRef weight = args.at(1); + const ValueRef bias = args.at(2); + const ValueRef stride = args.at(3); + const ValueRef padding = args.at(4); + const ValueRef dilation = args.at(5); + const ValueRef groups = args.at(6); + const ValueRef out = args.at(7); + + // conv1d_dw expects: in, weight, bias, stride, padding, dilation, groups, + // output_min, output_max, out + VK_GET_OP_FN("et_vk.conv1d_dw.default") + (graph, + {input, + weight, + bias, + stride, + padding, + dilation, + groups, + kDummyValueRef, + kDummyValueRef, + out}); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.test_conv1d_dw.default, test_conv1d_dw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index d535ca2661c..5fb0f7f4cbf 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -104,3 +104,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_conv2d_dw") define_custom_op_test_binary("test_embedding_q4gsw") define_custom_op_test_binary("test_conv1d_pw") + define_custom_op_test_binary("test_conv1d_dw") diff --git a/backends/vulkan/test/custom_ops/test_conv1d_dw.cpp b/backends/vulkan/test/custom_ops/test_conv1d_dw.cpp new file mode 100644 index 00000000000..2438847036e --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_conv1d_dw.cpp @@ -0,0 +1,267 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 256; + +struct Conv1dDWConfig { + int64_t N; + int64_t C; + int64_t L; + int64_t K; + int64_t stride; + int64_t padding; + int64_t dilation; + bool has_bias; +}; + +static TestCase create_conv1d_dw_test_case( + const Conv1dDWConfig& config, + vkapi::ScalarType dtype, + utils::StorageType storage_type) { + TestCase test_case; + + bool is_perf = config.C > kRefDimSizeLimit || config.L > kRefDimSizeLimit; + + std::string prefix = is_perf ? "PERF" : "ACCU"; + std::string storage_str = storage_type_abbrev(storage_type); + std::string dtype_str = (dtype == vkapi::kHalf) ? "f16" : "f32"; + std::string bias_str = config.has_bias ? "+bias" : ""; + + int64_t L_out = + (config.L + 2 * config.padding - config.dilation * (config.K - 1) - 1) / + config.stride + + 1; + + std::string name = prefix + " conv1d_dw" + bias_str + " [" + + std::to_string(config.N) + "," + std::to_string(config.C) + "," + + std::to_string(config.L) + "] K=" + std::to_string(config.K) + + " s=" + std::to_string(config.stride) + + " p=" + std::to_string(config.padding) + + " d=" + std::to_string(config.dilation) + " " + storage_str + "(HP) " + + dtype_str; + + test_case.set_name(name); + test_case.set_operator_name("test_etvk.test_conv1d_dw.default"); + + // Input: [N, C, L] height-packed + ValueSpec input( + {config.N, config.C, config.L}, + dtype, + storage_type, + utils::kHeightPacked, + DataGenType::RANDOM); + test_case.add_input_spec(input); + + // Weight: [C, 1, K] height-packed, constant + ValueSpec weight( + {config.C, 1, config.K}, + dtype, + storage_type, + utils::kHeightPacked, + DataGenType::RANDOM); + weight.set_constant(true); + test_case.add_input_spec(weight); + + // Bias: [C] or None + if (config.has_bias) { + ValueSpec bias( + {config.C}, + dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + test_case.add_input_spec(bias); + } else { + ValueSpec none_bias(static_cast(0)); + none_bias.set_none(true); + test_case.add_input_spec(none_bias); + } + + // stride + test_case.add_input_spec( + ValueSpec(std::vector{static_cast(config.stride)})); + // padding + test_case.add_input_spec( + ValueSpec(std::vector{static_cast(config.padding)})); + // dilation + test_case.add_input_spec( + ValueSpec(std::vector{static_cast(config.dilation)})); + // groups = C (depthwise) + test_case.add_input_spec(ValueSpec(static_cast(config.C))); + + // Output: [N, C, L_out] height-packed + ValueSpec output( + {config.N, config.C, L_out}, + dtype, + storage_type, + utils::kHeightPacked, + DataGenType::ZEROS); + test_case.add_output_spec(output); + + if (dtype == vkapi::kHalf) { + test_case.set_abs_tolerance(1e-1f); + test_case.set_rel_tolerance(1e-2f); + } else { + test_case.set_abs_tolerance(1e-3f); + test_case.set_rel_tolerance(1e-3f); + } + + test_case.set_shader_filter({"nchw_to", "to_nchw", "view_copy"}); + + return test_case; +} + +static void conv1d_dw_reference_impl(TestCase& test_case) { + const auto& input_spec = test_case.inputs()[0]; + const auto& weight_spec = test_case.inputs()[1]; + const auto& bias_spec = test_case.inputs()[2]; + const auto& stride_spec = test_case.inputs()[3]; + const auto& padding_spec = test_case.inputs()[4]; + const auto& dilation_spec = test_case.inputs()[5]; + ValueSpec& output = test_case.outputs()[0]; + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Reference only supports float"); + } + + auto in_sizes = input_spec.get_tensor_sizes(); + auto w_sizes = weight_spec.get_tensor_sizes(); + auto out_sizes = output.get_tensor_sizes(); + + const int64_t N = in_sizes[0]; + const int64_t C = in_sizes[1]; + const int64_t L_in = in_sizes[2]; + const int64_t K = w_sizes[2]; + const int64_t L_out = out_sizes[2]; + + const int64_t stride = stride_spec.get_int_list()[0]; + const int64_t padding = padding_spec.get_int_list()[0]; + const int64_t dilation = dilation_spec.get_int_list()[0]; + + const auto& in_data = input_spec.get_float_data(); + const auto& w_data = weight_spec.get_float_data(); + auto& ref_data = output.get_ref_float_data(); + ref_data.resize(N * C * L_out, 0.0f); + + for (int64_t n = 0; n < N; ++n) { + for (int64_t c = 0; c < C; ++c) { + for (int64_t l = 0; l < L_out; ++l) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + const int64_t l_in = l * stride - padding + k * dilation; + if (l_in >= 0 && l_in < L_in) { + sum += in_data[n * C * L_in + c * L_in + l_in] * w_data[c * K + k]; + } + } + ref_data[n * C * L_out + c * L_out + l] = sum; + } + } + } + + if (!bias_spec.is_none()) { + const auto& bias_data = bias_spec.get_float_data(); + for (int64_t n = 0; n < N; ++n) { + for (int64_t c = 0; c < C; ++c) { + for (int64_t l = 0; l < L_out; ++l) { + ref_data[n * C * L_out + c * L_out + l] += bias_data[c]; + } + } + } + } +} + +static std::vector generate_conv1d_dw_test_cases() { + std::vector test_cases; + + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Accuracy shapes + std::vector accu_configs = { + // {N, C, L, K, stride, padding, dilation, has_bias} + {1, 16, 64, 3, 1, 1, 1, false}, + {1, 32, 128, 5, 1, 2, 1, true}, + {1, 64, 32, 3, 2, 1, 1, false}, + {2, 16, 64, 3, 1, 1, 1, true}, + {1, 16, 64, 7, 1, 3, 2, false}, + // Non-aligned channel counts (not a multiple of 4) + {1, 5, 64, 3, 1, 1, 1, false}, + {1, 5, 64, 3, 1, 1, 1, true}, + {1, 7, 32, 5, 1, 2, 1, false}, + {1, 13, 48, 3, 2, 1, 1, true}, + {2, 7, 64, 3, 1, 1, 1, false}, + }; + + for (const auto& cfg : accu_configs) { + for (auto st : storage_types) { + test_cases.push_back(create_conv1d_dw_test_case(cfg, vkapi::kFloat, st)); + } + } + + // Performance shapes (half + float) + std::vector perf_configs = { + {1, 256, 1024, 3, 1, 1, 1, false}, + {1, 512, 2048, 5, 1, 2, 1, true}, + {1, 128, 4096, 31, 1, 15, 1, false}, + }; + + for (const auto& cfg : perf_configs) { + for (auto st : storage_types) { + test_cases.push_back(create_conv1d_dw_test_case(cfg, vkapi::kFloat, st)); + test_cases.push_back(create_conv1d_dw_test_case(cfg, vkapi::kHalf, st)); + } + } + + return test_cases; +} + +static int64_t conv1d_dw_flop_calculator(const TestCase& test_case) { + auto in_sizes = test_case.inputs()[0].get_tensor_sizes(); + auto w_sizes = test_case.inputs()[1].get_tensor_sizes(); + auto out_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const int64_t N = in_sizes[0]; + const int64_t C = in_sizes[1]; + const int64_t K = w_sizes[2]; + const int64_t L_out = out_sizes[2]; + + return 2 * N * C * L_out * K; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Conv1d Depthwise (Height-Packed) Benchmark" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = conv1d_dw_reference_impl; + + auto results = execute_test_cases( + generate_conv1d_dw_test_cases, + conv1d_dw_flop_calculator, + "Conv1dDW", + 3, + 10, + ref_fn); + + return 0; +} From 8d23481ddaaa650d95eff7f69e2dd3bcb0a0b80c Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 27 Mar 2026 11:22:01 -0700 Subject: [PATCH 3/3] [ET-VK][conv1d] Route conv1d to height-packed implementations in export pipeline Pull Request resolved: https://github.com/pytorch/executorch/pull/18334 Integrate the new height-packed conv1d_pw and conv1d_dw operators into the aten.convolution.default dispatch path so they are automatically used during model export. In op_registry.py, add a pick_conv_storage function that inspects the convolution node at partition time. For 1D convolutions where the op is pointwise (kernel_size=1) or depthwise (groups=C_in) and channels are 4-aligned, it selects HEIGHT_PACKED_TEXTURE for input/output instead of the default CHANNELS_PACKED_TEXTURE. All other cases (conv2d, grouped conv1d with K>1, unaligned channels) retain channels-packed behavior. In Convolution.cpp, add a height-packed routing block at the top of the conv1d path. When the input tensor is height-packed, it dispatches to et_vk.conv1d_pw.default or et_vk.conv1d_dw.default via VK_GET_OP_FN. Falls through to the existing channels-packed add_conv1d_node path otherwise. ghstack-source-id: 358903217 @exported-using-ghexport Differential Revision: [D97344090](https://our.internmc.facebook.com/intern/diff/D97344090/) --- backends/vulkan/op_registry.py | 42 ++++++++++++++++ .../runtime/graph/ops/impl/Convolution.cpp | 50 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ddb843e2335..38215c2d827 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -802,6 +802,47 @@ def check_conv_node(node: torch.fx.Node) -> bool: return True + def pick_conv_storage( + node: torch.fx.Node, + ) -> Tuple[List[utils.TensorRepSet], utils.TensorRepSet]: + x = node.args[0] + assert isinstance(x, torch.fx.Node) + x_shape = x.meta["val"].size() + + # Default: channels-packed texture (conv2d and fallback conv1d) + input_storage = utils.CHANNELS_PACKED_TEXTURE + output_storage = utils.CHANNELS_PACKED_TEXTURE + + if len(x_shape) == 3: + # Conv1d: check if we can use height-packed + weight = node.args[1] + assert isinstance(weight, torch.fx.Node) + w_shape = weight.meta["val"].size() + groups = node.args[8] + + c_in = x_shape[1] + c_out = w_shape[0] + kernel_size = w_shape[2] + + is_pointwise = kernel_size == 1 + is_depthwise = ( + isinstance(groups, int) + and groups == c_in + and c_out == c_in + and w_shape[1] == 1 + ) + if is_pointwise or is_depthwise: + input_storage = utils.HEIGHT_PACKED_TEXTURE + output_storage = utils.HEIGHT_PACKED_TEXTURE + + # Build per-input storage list. The convolution op has variable args: + # aten.convolution.default: input, weight, bias, stride, padding, + # dilation, transposed, output_padding, groups + # et_vk.conv_with_clamp.default: + output_min, output_max + # All args after input are NO_STORAGE (prepacked or non-tensor) + inputs = [input_storage] + [utils.NO_STORAGE] * 10 + return inputs, output_storage + return OpFeatures( inputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, # input @@ -820,6 +861,7 @@ def check_conv_node(node: torch.fx.Node) -> bool: supports_resize=True, supports_prepacking=True, are_node_inputs_supported_fn=check_conv_node, + pick_io_storage_fn=pick_conv_storage, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 2da98926fad..9c518678502 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -686,6 +686,56 @@ void conv(ComputeGraph& graph, const std::vector& args) { true); } } else { + // Conv1d path + if (graph.packed_dim_of(args[0]) == WHCN::kHeightDim) { + // Height-packed: route to optimized conv1d implementations + const auto weight_sizes = graph.sizes_of(args[1]); + const int64_t groups_val = graph.get_int(args[8]); + const bool is_pointwise = weight_sizes.at(2) == 1; + const bool is_depthwise = + groups_val == weight_sizes.at(0) && weight_sizes.at(1) == 1; + + // Build unified 10-arg vector: + // in, weight, bias, stride, padding, dilation, groups, + // output_min, output_max, out + // For non-clamp (args.size() == 10): output_min/max = kDummyValueRef + // For clamp (args.size() == 12): output_min/max from args[9]/args[10] + ValueRef output_min = kDummyValueRef; + ValueRef output_max = kDummyValueRef; + ValueRef out; + if (args.size() == 10) { + out = args[9]; + } else { + output_min = args[9]; + output_max = args[10]; + out = args[11]; + } + + std::vector conv1d_args = { + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + output_min, + output_max, + out}; + + if (is_pointwise) { + VK_GET_OP_FN("et_vk.conv1d_pw.default")(graph, conv1d_args); + } else if (is_depthwise) { + VK_GET_OP_FN("et_vk.conv1d_dw.default")(graph, conv1d_args); + } else { + VK_THROW( + "Height-packed conv1d only supports pointwise (K=1) or " + "depthwise (groups=C)"); + } + return; + } + + // Existing channels-packed fallback if (args.size() == 10) { // ordinary conv1d return add_conv1d_node(