Skip to content

Commit 979fef2

Browse files
ssjiarascani
authored andcommitted
[ET-VK][conv1d] Implement height-packed depthwise conv1d operator
Pull Request resolved: pytorch#18333 Implement a depthwise conv1d operator using height-packed layout where channels are the packed dimension (WHCN dim 1). Depthwise conv applies a separate filter to each channel independently (groups=C), so 4 channels can be processed in parallel using element-wise vec4 FMA over kernel positions. Thread mapping: X=C/4, Y=L_out, Z=N. Each thread computes one output texel (4 channels at one spatial position). Inner loop iterates over kernel positions K with bounds-checked input access for padding. Weight [C,1,K] is prepacked as channels-packed so each vec4 load gives 4 channels' weights at one kernel position. Supports both buffer and texture3d storage, fp32/fp16, optional bias, and arbitrary stride/padding/dilation. Registered as et_vk.conv1d_dw.default (standalone custom op). Performance on Adreno 750 (S24): - [1,128,4096] K=31 buffer f16: 231 GFLOP/s - [1,128,4096] K=31 buffer f32: 155 GFLOP/s - [1,512,2048] K=5 buffer f32: 66 GFLOP/s ghstack-source-id: 358903219 @exported-using-ghexport Differential Revision: [D97344091](https://our.internmc.facebook.com/intern/diff/D97344091/)
1 parent cf5ea70 commit 979fef2

File tree

6 files changed

+651
-0
lines changed

6 files changed

+651
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
#define T ${texel_load_component_type(DTYPE, STORAGE)}
14+
15+
$if STORAGE == "buffer":
16+
#define BUFFER
17+
#define SCALAR_BUFFER
18+
$if HAS_BIAS:
19+
#define HAS_BIAS
20+
21+
${define_required_extensions(STORAGE, DTYPE)}
22+
23+
layout(std430) buffer;
24+
25+
#include "common.glslh"
26+
27+
$if STORAGE == "buffer":
28+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=True)}
29+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=True)}
30+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE, is_scalar_array=True)}
31+
$else:
32+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
34+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE, is_scalar_array=False)}
35+
$if HAS_BIAS:
36+
$if STORAGE == "buffer":
37+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=True)}
38+
$else:
39+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)}
40+
41+
// in_sizes: {L_in, C, N, 1} in WHCN order
42+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
43+
// out_sizes: {L_out, C, N, 1} in WHCN order
44+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
45+
46+
layout(push_constant) uniform restrict Block {
47+
int kernel_size;
48+
int stride;
49+
int padding;
50+
int dilation;
51+
float output_min;
52+
float output_max;
53+
};
54+
55+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
56+
57+
// Thread mapping: X = C/4, Y = L_out, Z = N
58+
// Each thread computes 4 output channels at one spatial position.
59+
// Depthwise: each channel has its own filter, so 4 channels can be computed
60+
// independently with element-wise vec4 FMA.
61+
62+
void main() {
63+
const int c4 = int(gl_GlobalInvocationID.x);
64+
const int l_out = int(gl_GlobalInvocationID.y);
65+
const int n = int(gl_GlobalInvocationID.z);
66+
67+
const int L_in = in_sizes.x;
68+
const int C = in_sizes.y;
69+
const int C4 = div_up_4(C);
70+
const int L_out = out_sizes.x;
71+
72+
if (c4 >= C4 || l_out >= L_out) {
73+
return;
74+
}
75+
76+
VEC4_T sum = VEC4_T(0);
77+
78+
for (int k = 0; k < kernel_size; k++) {
79+
const int l_in = l_out * stride - padding + k * dilation;
80+
if (l_in >= 0 && l_in < L_in) {
81+
#ifdef BUFFER
82+
const int in_base = (n * L_in + l_in) * C + c4 * 4;
83+
T in_s0 = t_in[in_base];
84+
T in_s1 = (c4 * 4 + 1 < C) ? t_in[in_base + 1] : T(0);
85+
T in_s2 = (c4 * 4 + 2 < C) ? t_in[in_base + 2] : T(0);
86+
T in_s3 = (c4 * 4 + 3 < C) ? t_in[in_base + 3] : T(0);
87+
const VEC4_T in_val = VEC4_T(in_s0, in_s1, in_s2, in_s3);
88+
89+
const int w_base = k * C + c4 * 4;
90+
T w_s0 = t_weight[w_base];
91+
T w_s1 = (c4 * 4 + 1 < C) ? t_weight[w_base + 1] : T(0);
92+
T w_s2 = (c4 * 4 + 2 < C) ? t_weight[w_base + 2] : T(0);
93+
T w_s3 = (c4 * 4 + 3 < C) ? t_weight[w_base + 3] : T(0);
94+
const VEC4_T w_val = VEC4_T(w_s0, w_s1, w_s2, w_s3);
95+
#else
96+
const VEC4_T in_val = texelFetch(t_in, ivec3(l_in, c4, n), 0);
97+
const VEC4_T w_val = texelFetch(t_weight, ivec3(k, 0, c4), 0);
98+
#endif
99+
sum = fma(w_val, in_val, sum);
100+
}
101+
}
102+
103+
#ifdef HAS_BIAS
104+
#ifdef BUFFER
105+
const int bias_base = c4 * 4;
106+
T b0 = t_bias[bias_base];
107+
T b1 = (bias_base + 1 < C) ? t_bias[bias_base + 1] : T(0);
108+
T b2 = (bias_base + 2 < C) ? t_bias[bias_base + 2] : T(0);
109+
T b3 = (bias_base + 3 < C) ? t_bias[bias_base + 3] : T(0);
110+
sum += VEC4_T(b0, b1, b2, b3);
111+
#else
112+
sum += texelFetch(t_bias, ivec3(c4, 0, 0), 0);
113+
#endif
114+
#endif
115+
116+
sum = clamp(sum, VEC4_T(output_min), VEC4_T(output_max));
117+
118+
#ifdef BUFFER
119+
const int out_base = (n * L_out + l_out) * C + c4 * 4;
120+
t_out[out_base] = sum.x;
121+
if (c4 * 4 + 1 < C) t_out[out_base + 1] = sum.y;
122+
if (c4 * 4 + 2 < C) t_out[out_base + 2] = sum.z;
123+
if (c4 * 4 + 3 < C) t_out[out_base + 3] = sum.w;
124+
#else
125+
imageStore(t_out, ivec3(l_out, c4, n), sum);
126+
#endif
127+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv1d_dw:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
HAS_BIAS: false
12+
generate_variant_forall:
13+
STORAGE:
14+
- VALUE: texture3d
15+
- VALUE: buffer
16+
DTYPE:
17+
- VALUE: float
18+
- VALUE: half
19+
shader_variants:
20+
- NAME: conv1d_dw
21+
- NAME: conv1d_dw_bias
22+
HAS_BIAS: true
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18+
19+
#include <limits>
20+
21+
namespace vkcompute {
22+
23+
void resize_conv1d_dw_node(
24+
ComputeGraph* graph,
25+
const std::vector<ArgGroup>& args,
26+
const std::vector<ValueRef>& extra_args) {
27+
const ValueRef out = args.at(0).refs.at(0);
28+
const ValueRef self = args.at(1).refs.at(0);
29+
30+
TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0));
31+
32+
const int64_t stride = graph->get_int_list(extra_args.at(1))->at(0);
33+
const int64_t padding = graph->get_int_list(extra_args.at(2))->at(0);
34+
const int64_t dilation = graph->get_int_list(extra_args.at(3))->at(0);
35+
36+
const std::vector<int64_t> in_sizes = graph->sizes_of(self);
37+
const int64_t kernel_size = weight_ref->sizes.at(2);
38+
const int64_t L_in = in_sizes.at(2);
39+
40+
const int64_t L_out =
41+
calc_out_size(L_in, kernel_size, stride, padding, dilation, false);
42+
43+
graph->virtual_resize(out, {in_sizes.at(0), in_sizes.at(1), L_out});
44+
}
45+
46+
struct Conv1dDWParams final {
47+
int32_t kernel_size;
48+
int32_t stride;
49+
int32_t padding;
50+
int32_t dilation;
51+
};
52+
53+
struct Conv1dDWClampParams final {
54+
float output_min;
55+
float output_max;
56+
};
57+
58+
utils::uvec3 pick_conv1d_dw_global_wg_size(
59+
ComputeGraph* graph,
60+
const vkapi::ShaderInfo& shader,
61+
const std::vector<ArgGroup>& args,
62+
const std::vector<ValueRef>& resize_args) {
63+
(void)shader;
64+
(void)resize_args;
65+
const ValueRef out = args.at(0).refs.at(0);
66+
67+
// out is [N, C, L_out]; in WHCN: {L_out, C, N, 1}
68+
const uint32_t C = graph->size_at<uint32_t>(-2, out);
69+
const uint32_t L_out = graph->size_at<uint32_t>(-1, out);
70+
const uint32_t N =
71+
graph->dim_of(out) >= 3 ? graph->size_at<uint32_t>(-3, out) : 1;
72+
73+
return {utils::div_up_4(C), L_out, N};
74+
}
75+
76+
void add_conv1d_dw_node(
77+
ComputeGraph& graph,
78+
const ValueRef in,
79+
const ValueRef weight_data,
80+
const ValueRef bias,
81+
const ValueRef stride_ref,
82+
const ValueRef padding_ref,
83+
const ValueRef dilation_ref,
84+
const ValueRef out,
85+
const float output_min = std::numeric_limits<float>::lowest(),
86+
const float output_max = std::numeric_limits<float>::max()) {
87+
VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kHeightDim);
88+
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kHeightDim);
89+
90+
const utils::StorageType storage_type = graph.storage_type_of(out);
91+
92+
// Weight [C, 1, K] prepacked as channels-packed so each vec4 load gives
93+
// 4 channels at one kernel position.
94+
ValueRef packed_weight = prepack_standard(
95+
graph, weight_data, storage_type, utils::kChannelsPacked);
96+
97+
bool has_bias = graph.val_is_not_none(bias);
98+
ValueRef packed_bias = kDummyValueRef;
99+
if (has_bias) {
100+
packed_bias =
101+
prepack_standard(graph, bias, storage_type, utils::kWidthPacked);
102+
}
103+
104+
const auto stride_val = graph.get_int_list(stride_ref)->at(0);
105+
const auto padding_val = graph.get_int_list(padding_ref)->at(0);
106+
const auto dilation_val = graph.get_int_list(dilation_ref)->at(0);
107+
108+
Conv1dDWParams params{
109+
utils::safe_downcast<int32_t>(graph.get_tref(weight_data)->sizes.at(2)),
110+
utils::safe_downcast<int32_t>(stride_val),
111+
utils::safe_downcast<int32_t>(padding_val),
112+
utils::safe_downcast<int32_t>(dilation_val),
113+
};
114+
115+
Conv1dDWClampParams clamp_params{
116+
output_min,
117+
output_max,
118+
};
119+
120+
std::string kernel_name = has_bias ? "conv1d_dw_bias" : "conv1d_dw";
121+
kernel_name.reserve(kShaderNameReserve);
122+
add_storage_type_suffix(kernel_name, storage_type);
123+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
124+
125+
std::vector<ValueRef> read_inputs = {in, packed_weight};
126+
if (has_bias) {
127+
read_inputs.push_back(packed_bias);
128+
}
129+
130+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
131+
graph,
132+
VK_KERNEL_FROM_STR(kernel_name),
133+
pick_conv1d_dw_global_wg_size,
134+
default_pick_local_wg_size,
135+
// Inputs and Outputs
136+
{{out, vkapi::kWrite}, {read_inputs, vkapi::kRead}},
137+
// Shader params buffers
138+
{graph.sizes_ubo(in), graph.sizes_ubo(out)},
139+
// Push Constants
140+
{PushConstantDataInfo(&params, sizeof(Conv1dDWParams)),
141+
PushConstantDataInfo(&clamp_params, sizeof(Conv1dDWClampParams))},
142+
// Specialization Constants
143+
{},
144+
// Resize Args
145+
{weight_data, stride_ref, padding_ref, dilation_ref},
146+
// Resizing Logic
147+
resize_conv1d_dw_node));
148+
}
149+
150+
// Args: in, weight, bias, stride, padding, dilation, groups,
151+
// output_min, output_max, out
152+
// output_min and output_max may be kDummyValueRef (no clamp).
153+
void conv1d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
154+
ValueRef in = args[0];
155+
ValueRef weight = args[1];
156+
ValueRef bias = args[2];
157+
ValueRef stride = args[3];
158+
ValueRef padding = args[4];
159+
ValueRef dilation = args[5];
160+
ValueRef out = args[9];
161+
162+
float output_min = std::numeric_limits<float>::lowest();
163+
float output_max = std::numeric_limits<float>::max();
164+
if (is_valid(args[7])) {
165+
output_min = graph.extract_scalar<float>(args[7]);
166+
}
167+
if (is_valid(args[8])) {
168+
output_max = graph.extract_scalar<float>(args[8]);
169+
}
170+
171+
add_conv1d_dw_node(
172+
graph,
173+
in,
174+
weight,
175+
bias,
176+
stride,
177+
padding,
178+
dilation,
179+
out,
180+
output_min,
181+
output_max);
182+
}
183+
184+
REGISTER_OPERATORS {
185+
VK_REGISTER_OP(et_vk.conv1d_dw.default, conv1d_dw);
186+
}
187+
188+
} // namespace vkcompute
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
13+
namespace vkcompute {
14+
15+
void test_conv1d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
16+
// args: in, weight, bias, stride, padding, dilation, groups, out
17+
const ValueRef input = args.at(0);
18+
const ValueRef weight = args.at(1);
19+
const ValueRef bias = args.at(2);
20+
const ValueRef stride = args.at(3);
21+
const ValueRef padding = args.at(4);
22+
const ValueRef dilation = args.at(5);
23+
const ValueRef groups = args.at(6);
24+
const ValueRef out = args.at(7);
25+
26+
// conv1d_dw expects: in, weight, bias, stride, padding, dilation, groups,
27+
// output_min, output_max, out
28+
VK_GET_OP_FN("et_vk.conv1d_dw.default")
29+
(graph,
30+
{input,
31+
weight,
32+
bias,
33+
stride,
34+
padding,
35+
dilation,
36+
groups,
37+
kDummyValueRef,
38+
kDummyValueRef,
39+
out});
40+
}
41+
42+
REGISTER_OPERATORS {
43+
VK_REGISTER_OP(test_etvk.test_conv1d_dw.default, test_conv1d_dw);
44+
}
45+
46+
} // namespace vkcompute

backends/vulkan/test/custom_ops/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,4 @@ def define_common_targets(is_fbcode = False):
104104
define_custom_op_test_binary("test_conv2d_dw")
105105
define_custom_op_test_binary("test_embedding_q4gsw")
106106
define_custom_op_test_binary("test_conv1d_pw")
107+
define_custom_op_test_binary("test_conv1d_dw")

0 commit comments

Comments
 (0)