From d936ff326144b0fe1302fe29160443e8958e1a83 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 11 Feb 2026 12:15:47 -0800 Subject: [PATCH] [ET-VK][ez] Add AOT support for PackedInt8_4C1W dtype This adds end-to-end support for the PackedInt8_4C1W memory layout throughout the serialization and AOT pipeline. The 4C1W layout packs 4 channels into a single texel with width-major ordering, which is the natural output layout for convolutions that produce channel-packed results. - Adds PACKED_INT8_4C1W = 8 to the FlatBuffers schema and Python schema class - Adds deserialization mapping in VulkanBackend.cpp - Updates quantize/dequantize per-tensor op registrations to accept any PackedInt8 layout (not just 4W4C), enabling the layout propagation pass to choose the optimal layout - Adds new TensorRepSet constants: PACKED_INT8_BUFFER (all quantized layouts), PACKED_INT8_4C1W_BUFFER, and PACKED_INT8_CHANNELS_PACKED_BUFFER (4W4C + 4C1W) Differential Revision: [D93000167](https://our.internmc.facebook.com/intern/diff/D93000167/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 9 +++++++-- backends/vulkan/runtime/VulkanBackend.cpp | 2 ++ backends/vulkan/serialization/schema.fbs | 1 + backends/vulkan/serialization/vulkan_graph_schema.py | 1 + backends/vulkan/utils.py | 8 ++++++++ 5 files changed, 19 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index fe9271a7396..bc5b83b1b55 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -425,6 +425,11 @@ def register_torchao_quantize_dequantize(): ) +# ============================================================================= +# Q8taQuantizeDequantize.cpp +# ============================================================================= + + @update_features( [ exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -437,7 +442,7 @@ def register_quantize_per_tensor(): utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], outputs_storage=[ - utils.PACKED_INT8_4W4C_BUFFER, + utils.PACKED_INT8_BUFFER, ], ) @@ -451,7 +456,7 @@ def register_quantize_per_tensor(): def register_dequantize_per_tensor(): return OpFeatures( inputs_storage=[ - utils.PACKED_INT8_4W4C_BUFFER, + utils.PACKED_INT8_BUFFER, ], outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 677b042beb6..261585c381b 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -144,6 +144,8 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kPackedInt8_4W4C; case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: return utils::kPackedInt8_4H4W; + case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: + return utils::kPackedInt8_4C1W; default: break; } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 9d738bc386f..8218ee3387f 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -42,6 +42,7 @@ enum VkMemoryLayout : ubyte { TENSOR_CHANNELS_PACKED = 2, PACKED_INT8_4W4C = 3, PACKED_INT8_4H4W = 4, + PACKED_INT8_4C1W = 8, DEFAULT_LAYOUT = 255, } diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 236183ce42f..d14428d3b66 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -50,6 +50,7 @@ class VkMemoryLayout(IntEnum): TENSOR_CHANNELS_PACKED = 2 PACKED_INT8_4W4C = 3 PACKED_INT8_4H4W = 4 + PACKED_INT8_4C1W = 8 DEFAULT_LAYOUT = 255 def __str__(self) -> str: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index bdb7308a0e7..9e1f0ce2956 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -578,6 +578,7 @@ def node_has_target(node: Any, target: str): all_quantized_memory_layouts: Set[VkMemoryLayout] = { VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4H4W, + VkMemoryLayout.PACKED_INT8_4C1W, } universal_memory_layout_set: Set[VkMemoryLayout] = { @@ -967,7 +968,14 @@ def make_filtered_tensor_repset( # Only includes memory layouts that can be used by quantized tensors +PACKED_INT8_BUFFER = TensorRepSet(all_quantized_memory_layouts, set()) PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) +PACKED_INT8_4C1W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4C1W}, set()) + +PACKED_INT8_CHANNELS_PACKED_BUFFER = TensorRepSet( + {VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4C1W}, set() +) + # Special use RepSets