Skip to content

Commit 8d23481

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][conv1d] Route conv1d to height-packed implementations in export pipeline
Pull Request resolved: #18334 Integrate the new height-packed conv1d_pw and conv1d_dw operators into the aten.convolution.default dispatch path so they are automatically used during model export. In op_registry.py, add a pick_conv_storage function that inspects the convolution node at partition time. For 1D convolutions where the op is pointwise (kernel_size=1) or depthwise (groups=C_in) and channels are 4-aligned, it selects HEIGHT_PACKED_TEXTURE for input/output instead of the default CHANNELS_PACKED_TEXTURE. All other cases (conv2d, grouped conv1d with K>1, unaligned channels) retain channels-packed behavior. In Convolution.cpp, add a height-packed routing block at the top of the conv1d path. When the input tensor is height-packed, it dispatches to et_vk.conv1d_pw.default or et_vk.conv1d_dw.default via VK_GET_OP_FN. Falls through to the existing channels-packed add_conv1d_node path otherwise. ghstack-source-id: 358903217 @exported-using-ghexport Differential Revision: [D97344090](https://our.internmc.facebook.com/intern/diff/D97344090/)
1 parent 8843c48 commit 8d23481

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,47 @@ def check_conv_node(node: torch.fx.Node) -> bool:
802802

803803
return True
804804

805+
def pick_conv_storage(
806+
node: torch.fx.Node,
807+
) -> Tuple[List[utils.TensorRepSet], utils.TensorRepSet]:
808+
x = node.args[0]
809+
assert isinstance(x, torch.fx.Node)
810+
x_shape = x.meta["val"].size()
811+
812+
# Default: channels-packed texture (conv2d and fallback conv1d)
813+
input_storage = utils.CHANNELS_PACKED_TEXTURE
814+
output_storage = utils.CHANNELS_PACKED_TEXTURE
815+
816+
if len(x_shape) == 3:
817+
# Conv1d: check if we can use height-packed
818+
weight = node.args[1]
819+
assert isinstance(weight, torch.fx.Node)
820+
w_shape = weight.meta["val"].size()
821+
groups = node.args[8]
822+
823+
c_in = x_shape[1]
824+
c_out = w_shape[0]
825+
kernel_size = w_shape[2]
826+
827+
is_pointwise = kernel_size == 1
828+
is_depthwise = (
829+
isinstance(groups, int)
830+
and groups == c_in
831+
and c_out == c_in
832+
and w_shape[1] == 1
833+
)
834+
if is_pointwise or is_depthwise:
835+
input_storage = utils.HEIGHT_PACKED_TEXTURE
836+
output_storage = utils.HEIGHT_PACKED_TEXTURE
837+
838+
# Build per-input storage list. The convolution op has variable args:
839+
# aten.convolution.default: input, weight, bias, stride, padding,
840+
# dilation, transposed, output_padding, groups
841+
# et_vk.conv_with_clamp.default: + output_min, output_max
842+
# All args after input are NO_STORAGE (prepacked or non-tensor)
843+
inputs = [input_storage] + [utils.NO_STORAGE] * 10
844+
return inputs, output_storage
845+
805846
return OpFeatures(
806847
inputs_storage=[
807848
utils.CHANNELS_PACKED_TEXTURE, # input
@@ -820,6 +861,7 @@ def check_conv_node(node: torch.fx.Node) -> bool:
820861
supports_resize=True,
821862
supports_prepacking=True,
822863
are_node_inputs_supported_fn=check_conv_node,
864+
pick_io_storage_fn=pick_conv_storage,
823865
)
824866

825867

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,56 @@ void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
686686
true);
687687
}
688688
} else {
689+
// Conv1d path
690+
if (graph.packed_dim_of(args[0]) == WHCN::kHeightDim) {
691+
// Height-packed: route to optimized conv1d implementations
692+
const auto weight_sizes = graph.sizes_of(args[1]);
693+
const int64_t groups_val = graph.get_int(args[8]);
694+
const bool is_pointwise = weight_sizes.at(2) == 1;
695+
const bool is_depthwise =
696+
groups_val == weight_sizes.at(0) && weight_sizes.at(1) == 1;
697+
698+
// Build unified 10-arg vector:
699+
// in, weight, bias, stride, padding, dilation, groups,
700+
// output_min, output_max, out
701+
// For non-clamp (args.size() == 10): output_min/max = kDummyValueRef
702+
// For clamp (args.size() == 12): output_min/max from args[9]/args[10]
703+
ValueRef output_min = kDummyValueRef;
704+
ValueRef output_max = kDummyValueRef;
705+
ValueRef out;
706+
if (args.size() == 10) {
707+
out = args[9];
708+
} else {
709+
output_min = args[9];
710+
output_max = args[10];
711+
out = args[11];
712+
}
713+
714+
std::vector<ValueRef> conv1d_args = {
715+
args[0],
716+
args[1],
717+
args[2],
718+
args[3],
719+
args[4],
720+
args[5],
721+
args[8],
722+
output_min,
723+
output_max,
724+
out};
725+
726+
if (is_pointwise) {
727+
VK_GET_OP_FN("et_vk.conv1d_pw.default")(graph, conv1d_args);
728+
} else if (is_depthwise) {
729+
VK_GET_OP_FN("et_vk.conv1d_dw.default")(graph, conv1d_args);
730+
} else {
731+
VK_THROW(
732+
"Height-packed conv1d only supports pointwise (K=1) or "
733+
"depthwise (groups=C)");
734+
}
735+
return;
736+
}
737+
738+
// Existing channels-packed fallback
689739
if (args.size() == 10) {
690740
// ordinary conv1d
691741
return add_conv1d_node(

0 commit comments

Comments
 (0)