Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@ def linear_q8ta_q8csw(
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

############################
## conv2d_q8ta_q8csw_q8to ##
############################
###################
## q8ta_conv2d_* ##
###################


def conv2d_q8ta_q8csw_q8to(
def q8ta_conv2d(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
Expand Down Expand Up @@ -425,7 +425,7 @@ def conv2d_q8ta_q8csw_q8to(
return out


name = "conv2d_q8ta_q8csw_q8to"
name = "q8ta_conv2d"
lib.define(
f"""
{name}(
Expand All @@ -445,11 +445,35 @@ def conv2d_q8ta_q8csw_q8to(
SymInt groups) -> Tensor
"""
)
lib.impl(name, conv2d_q8ta_q8csw_q8to, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
q8ta_conv2d_op = getattr(getattr(torch.ops, namespace), name)


def conv2d_q8ta_q8csw_q8to_dw(
name = "q8ta_conv2d_pw"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias,
SymInt[] kernel_size,
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
q8ta_conv2d_pw_op = getattr(getattr(torch.ops, namespace), name)


def q8ta_conv2d_dw(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
Expand Down Expand Up @@ -497,7 +521,7 @@ def conv2d_q8ta_q8csw_q8to_dw(
return out


name = "conv2d_q8ta_q8csw_q8to_dw"
name = "q8ta_conv2d_dw"
lib.define(
f"""
{name}(
Expand All @@ -517,7 +541,7 @@ def conv2d_q8ta_q8csw_q8to_dw(
SymInt groups) -> Tensor
"""
)
lib.impl(name, conv2d_q8ta_q8csw_q8to_dw, "CompositeExplicitAutograd")
lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name)

######################
Expand Down
37 changes: 33 additions & 4 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,17 +745,46 @@ def check_conv_node(node: torch.fx.Node) -> bool:


# =============================================================================
# QuantizedConvolution.cpp
# Q8taConv2d*.cpp
# =============================================================================


@update_features(
[
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default,
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
exir_ops.edge.et_vk.q8ta_conv2d_pw.default,
]
)
def register_quantizedconvolution_cpp_ops():
def register_q8ta_conv_pw_op():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4W4C_BUFFER, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # output_scale (non tensor)
utils.NO_STORAGE, # output_zero_point (non tensor)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # kernel_size (non tensor)
utils.NO_STORAGE, # stride (non tensor)
utils.NO_STORAGE, # padding (non tensor)
utils.NO_STORAGE, # dilation (non tensor)
utils.NO_STORAGE, # groups (non tensor)
utils.NO_STORAGE, # original OC count (non tensor)
],
supports_resize=False,
supports_prepacking=True,
)


@update_features(
[
exir_ops.edge.et_vk.q8ta_conv2d.default,
exir_ops.edge.et_vk.q8ta_conv2d_dw.default,
]
)
def register_q8ta_conv2d_ops():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4W4C_BUFFER, # input
Expand Down
16 changes: 13 additions & 3 deletions backends/vulkan/patterns/quantized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def find_quantized_convolution_patterns(


@register_pattern_replacement("quantized_convolution")
def make_conv2d_q8ta_q8csw_custom_op(
def make_q8ta_conv2d_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedConvolutionMatch,
Expand Down Expand Up @@ -230,10 +230,20 @@ def make_conv2d_q8ta_q8csw_custom_op(
data=sum_per_output_channel,
)

is_pointwise_conv = (
H == 1
and W == 1
and list(match.stride) == [1, 1]
and list(match.dilation) == [1, 1]
and list(match.padding) == [0, 0]
)

with graph_module.graph.inserting_before(match.output_node):
op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default
op_target = exir_ops.edge.et_vk.q8ta_conv2d.default
if is_depthwise_conv:
op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default
op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default
elif is_pointwise_conv:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default

qconv_node = graph_module.graph.create_node(
"call_function",
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
}

REGISTER_OPERATORS {
VK_REGISTER_OP(etvk.q8ta_conv2d.default, q8ta_conv2d);
VK_REGISTER_OP(etvk.q8ta_conv2d_general.default, q8ta_conv2d_general);
VK_REGISTER_OP(et_vk.q8ta_conv2d.default, q8ta_conv2d);
VK_REGISTER_OP(et_vk.q8ta_conv2d_general.default, q8ta_conv2d_general);
}

} // namespace vkcompute
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
}

REGISTER_OPERATORS {
VK_REGISTER_OP(etvk.q8ta_conv2d_dw.default, q8ta_conv2d_dw);
VK_REGISTER_OP(et_vk.q8ta_conv2d_dw.default, q8ta_conv2d_dw);
}

} // namespace vkcompute
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void q8ta_conv2d_im2col(
}

REGISTER_OPERATORS {
VK_REGISTER_OP(etvk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col);
VK_REGISTER_OP(et_vk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col);
}

} // namespace vkcompute
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
}

REGISTER_OPERATORS {
VK_REGISTER_OP(etvk.q8ta_conv2d_pw.default, q8ta_conv2d_pw);
VK_REGISTER_OP(et_vk.q8ta_conv2d_pw.default, q8ta_conv2d_pw);
}

} // namespace vkcompute
10 changes: 5 additions & 5 deletions backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void test_q8ta_conv2d_dw(
VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args);
} else {
// Use the dedicated depthwise conv2d operator
VK_GET_OP_FN("etvk.q8ta_conv2d_dw.default")(graph, conv_args);
VK_GET_OP_FN("et_vk.q8ta_conv2d_dw.default")(graph, conv_args);
}

// Dequantize packed int8 output to floating point
Expand Down Expand Up @@ -156,13 +156,13 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args);
} else if (impl_selector == "im2col") {
// Use the im2col-based conv2d operator
VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args);
VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args);
} else if (impl_selector == "general") {
// Use the general q8ta_conv2d operator (no im2col dispatch)
VK_GET_OP_FN("etvk.q8ta_conv2d_general.default")(graph, conv_args);
VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args);
} else {
// Use the new general q8ta_conv2d operator
VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args);
VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args);
}

// Dequantize packed int8 output to floating point
Expand Down Expand Up @@ -240,7 +240,7 @@ void test_q8ta_conv2d_pw(
if (impl_selector == "legacy_4w4c") {
VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args);
} else {
VK_GET_OP_FN("etvk.q8ta_conv2d_pw.default")(graph, conv_args);
VK_GET_OP_FN("et_vk.q8ta_conv2d_pw.default")(graph, conv_args);
}

// Dequantize packed int8 output to floating point
Expand Down
Loading