Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,11 +564,11 @@ def apply_rotary_emb_impl(
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)

########################
## add_q8ta_q8ta_q8to ##
## q8ta_add ##
########################


def add_q8ta_q8ta_q8to_impl(
def q8ta_add_impl(
input_a: torch.Tensor,
input_b: torch.Tensor,
input_a_scale: float,
Expand Down Expand Up @@ -598,12 +598,12 @@ def add_q8ta_q8ta_q8to_impl(
return quantized_result


name = "add_q8ta_q8ta_q8to"
name = "q8ta_add"
lib.define(
f"{name}(Tensor input_a, Tensor input_b, float input_a_scale, int input_a_zero_point, float input_b_scale, int input_b_zero_point, float output_scale, int output_zero_point, float alpha) -> Tensor"
)
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)
lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd")
q8ta_add_op = getattr(getattr(torch.ops, namespace), name)

#############################
## select_as_symint ##
Expand Down
8 changes: 4 additions & 4 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,14 @@ def register_torchao_choose_qparams_affine():


# =============================================================================
# QuantizedBinary.cpp
# Q8taBinary.cpp
# =============================================================================


@update_features(exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default)
def register_add_q8ta_q8ta_q8to():
@update_features(exir_ops.edge.et_vk.q8ta_add.default)
def register_q8ta_add():
return OpFeatures(
inputs_storage=utils.PACKED_INT8_4W4C_BUFFER,
inputs_storage=utils.PACKED_INT8_BUFFER,
supports_resize=False,
supports_prepacking=True,
)
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/patterns/quantized_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def make_add_q8ta_q8ta_q8to_custom_op(
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.add_.Tensor,
}:
op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default
op_target = exir_ops.edge.et_vk.q8ta_add.default
else:
# For future binary operations, add more mappings here
raise NotImplementedError(
Expand Down

This file was deleted.

91 changes: 91 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

${define_active_storage_type("buffer")}

#define op(X, Y) ${OPERATOR}

layout(std430) buffer;

#include "indexing.glslh"
#include "common.glslh"
#include "block_indexing.glslh"
#include "block_int8x4_load.glslh"
#include "block_int8x4_store.glslh"

// Output buffer: packed int8x4 values
${layout_declare_tensor(B, "w", "t_out", "int", "buffer")}
// Input buffers: packed int8x4 values
${layout_declare_tensor(B, "r", "t_in_a", "int", "buffer")}
${layout_declare_tensor(B, "r", "t_in_b", "int", "buffer")}

// Metadata for output and input tensors
${layout_declare_ubo(B, "BufferMetadata", "out_meta")}
${layout_declare_ubo(B, "BufferMetadata", "in_a_meta")}
${layout_declare_ubo(B, "BufferMetadata", "in_b_meta")}

layout(push_constant) uniform restrict Block {
float input_a_scale;
int input_a_zp;
float input_b_scale;
int input_b_zp;
float output_inv_scale;
int output_zp;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "block_config", "0")}

// Generate loading functions for input buffers
define_load_int8x4_buffer_fns(t_in_a)
define_load_int8x4_buffer_fns(t_in_b)

// Generate storing functions for output buffer
define_store_int8x4_buffer_fns(t_out)

void main() {
// Buffer storage: use linear dispatch
const uint contig_block_idx = gl_GlobalInvocationID.x;
TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config(
out_meta, contig_block_idx, block_config);

if (out_of_bounds(tidx, out_meta)) {
return;
}

const int block_outer_dim = get_block_outer_dim(block_config);

// Load int8x4 blocks from both inputs
ivec4 in_block_a = load_int8x4_block_from_t_in_a(
in_a_meta, tidx, in_layout, block_outer_dim);
ivec4 in_block_b = load_int8x4_block_from_t_in_b(
in_b_meta, tidx, in_layout, block_outer_dim);

ivec4 out_block;

for (int row = 0; row < 4; row++) {
vec4 in_texel_a = unpack_and_dequantize(
in_block_a[row], input_a_scale, input_a_zp);
vec4 in_texel_b = unpack_and_dequantize(
in_block_b[row], input_b_scale, input_b_zp);

vec4 out_texel = op(in_texel_a, in_texel_b);
out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp);
}

// Store to output buffer
store_int8x4_block_to_t_out(
out_meta, tidx, out_layout, block_outer_dim, out_block);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

binary_q8ta_q8ta_q8to:
q8ta_binary:
parameter_names_with_default_values:
OPERATOR: X + Y
NDIM: 3
DTYPE: float
PACKING: C_packed
IO_STORAGE: buffer
generate_variant_forall:
IO_STORAGE:
- VALUE: buffer
shader_variants:
- NAME: add_q8ta_q8ta_q8to
- NAME: q8ta_add_buffer
OPERATOR: X + Y
Loading
Loading