diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 64d7261f576..8e564bcf6b9 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -431,6 +431,11 @@ def register_torchao_quantize_dequantize(): ) +# ============================================================================= +# Q8taQuantizeDequantize.cpp +# ============================================================================= + + @update_features( [ exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -443,7 +448,7 @@ def register_quantize_per_tensor(): utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], outputs_storage=[ - utils.PACKED_INT8_4W4C_BUFFER, + utils.PACKED_INT8_BUFFER, ], ) @@ -457,7 +462,7 @@ def register_quantize_per_tensor(): def register_dequantize_per_tensor(): return OpFeatures( inputs_storage=[ - utils.PACKED_INT8_4W4C_BUFFER, + utils.PACKED_INT8_BUFFER, ], outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 677b042beb6..261585c381b 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -144,6 +144,8 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kPackedInt8_4W4C; case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: return utils::kPackedInt8_4H4W; + case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: + return utils::kPackedInt8_4C1W; default: break; } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 9d738bc386f..8218ee3387f 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -42,6 +42,7 @@ enum VkMemoryLayout : ubyte { TENSOR_CHANNELS_PACKED = 2, PACKED_INT8_4W4C = 3, PACKED_INT8_4H4W = 4, + PACKED_INT8_4C1W = 8, DEFAULT_LAYOUT = 255, } diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 236183ce42f..d14428d3b66 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -50,6 +50,7 @@ class VkMemoryLayout(IntEnum): TENSOR_CHANNELS_PACKED = 2 PACKED_INT8_4W4C = 3 PACKED_INT8_4H4W = 4 + PACKED_INT8_4C1W = 8 DEFAULT_LAYOUT = 255 def __str__(self) -> str: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index b9195265398..88d8bb00c6c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -594,6 +594,7 @@ def node_has_target(node: Any, target: str): all_quantized_memory_layouts: Set[VkMemoryLayout] = { VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4H4W, + VkMemoryLayout.PACKED_INT8_4C1W, } universal_memory_layout_set: Set[VkMemoryLayout] = { @@ -983,7 +984,14 @@ def make_filtered_tensor_repset( # Only includes memory layouts that can be used by quantized tensors +PACKED_INT8_BUFFER = TensorRepSet(all_quantized_memory_layouts, set()) PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) +PACKED_INT8_4C1W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4C1W}, set()) + +PACKED_INT8_CHANNELS_PACKED_BUFFER = TensorRepSet( + {VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4C1W}, set() +) + # Special use RepSets