From 7c461f8fe6672701d1516dc1208575bf9923cdf5 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 11 Feb 2026 15:12:57 -0800 Subject: [PATCH] Back out "[Diff Train][pytorch/executorch] Apply fixup patch to fbsource" Pull Request resolved: https://github.com/pytorch/executorch/pull/17399 Revert D92897428 which accidentally erased from recent changes to the Executorch Vulkan backend when syncing open source and fbsource. ghstack-source-id: 340490031 @exported-using-ghexport Differential Revision: [D93012332](https://our.internmc.facebook.com/intern/diff/D93012332/) --- backends/vulkan/op_registry.py | 21 +++++++++++++++++++ .../vulkan/partitioner/vulkan_partitioner.py | 4 ++++ .../runtime/api/containers/StagingBuffer.cpp | 5 +++++ .../runtime/api/containers/StagingBuffer.h | 10 +++++++++ backends/vulkan/utils.py | 16 ++++++++++++++ 5 files changed, 56 insertions(+) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a802b48d9c..e2f305ca0e0 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -37,6 +37,8 @@ class OpFeatures: # bool indicating if the operator has a resize function, which allows it to # support models with dynamic shape "supports_resize", + # bool indicating if the operator supports tensors with more than 4 dimensions + "supports_highdim", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. @@ -60,6 +62,7 @@ def __init__( Union[utils.TensorRepSet, List[utils.TensorRepSet]] ] = None, supports_resize: bool = False, + supports_highdim: bool = False, supports_prepacking: bool = False, are_node_inputs_supported_fn: Optional[Callable] = allow_node, pick_io_storage_fn: Optional[Callable] = None, @@ -85,6 +88,7 @@ def __init__( self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0]) self.supports_resize = supports_resize + self.supports_highdim = supports_highdim self.supports_prepacking = supports_prepacking self.are_node_inputs_supported_fn = are_node_inputs_supported_fn @@ -239,6 +243,7 @@ def register_binaryop_cpp_ops(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_T, supports_resize=True, + supports_highdim=True, ) @@ -253,6 +258,7 @@ def register_pow_tensor_scalar(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_T, supports_resize=True, + supports_highdim=True, ) @@ -630,6 +636,7 @@ def register_reduce_cpp_ops(): inputs_storage=utils.ANY_TEXTURE, inputs_dtypes=utils.FP_T, supports_resize=True, + supports_highdim=True, are_node_inputs_supported_fn=is_reduce_node_supported, pick_io_storage_fn=pick_storage_for_reduce, ) @@ -651,6 +658,7 @@ def register_argreduce_cpp_ops(): inputs_storage=utils.ANY_TEXTURE, inputs_dtypes=utils.FP_T, supports_resize=True, + supports_highdim=True, are_node_inputs_supported_fn=is_reduce_node_supported, pick_io_storage_fn=pick_storage_for_reduce, ) @@ -811,6 +819,7 @@ def register_apply_rotary_emb(): inputs_storage=utils.CONTIGUOUS_ANY, inputs_dtypes=utils.FP_T, supports_resize=True, + supports_highdim=True, ) @@ -834,6 +843,7 @@ def register_permute_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -848,6 +858,7 @@ def register_view_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -857,6 +868,7 @@ def register_to_dim_order_copy(): inputs_storage=utils.ANY_BUFFER, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -871,6 +883,7 @@ def register_squeeze_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -885,6 +898,7 @@ def register_unsqueeze_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -899,6 +913,7 @@ def register_clone(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -908,6 +923,7 @@ def register_clone_dim_order(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -922,6 +938,7 @@ def register_gather(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -936,6 +953,7 @@ def register_expand_copy(): inputs_storage=utils.ANY_BUFFER, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=False, + supports_highdim=True, ) @@ -964,6 +982,7 @@ def register_select_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -978,6 +997,7 @@ def register_slice_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -992,6 +1012,7 @@ def register_split_with_sizes_copy(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 82e92a1bc81..75475643100 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -266,6 +266,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "op args not supported") return False + if not features.supports_highdim and utils.op_contains_high_dim_tensor(node): + self.log_skip(node, "op does not support high dim tensors") + return False + if self.require_dynamic_shapes and not features.supports_resize: self.log_skip(node, "no dynamic shape support") return False diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.cpp b/backends/vulkan/runtime/api/containers/StagingBuffer.cpp index ca020ab6572..499f0b43d05 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.cpp +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.cpp @@ -159,6 +159,11 @@ void StagingBuffer::cast_half_to_float_and_copy_from( for (size_t i = 0; i < numel; ++i) { dst[i] = half_to_float(src[i]); } + vmaFlushAllocation( + vulkan_buffer_.vma_allocator(), + vulkan_buffer_.allocation(), + 0u, + VK_WHOLE_SIZE); } void StagingBuffer::cast_float_to_half_and_copy_to( diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index d786d030b39..19060804693 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -88,6 +88,11 @@ class StagingBuffer final { for (size_t i = 0; i < numel; ++i) { dst[i] = static_cast(src[i]); } + vmaFlushAllocation( + vulkan_buffer_.vma_allocator(), + vulkan_buffer_.allocation(), + 0u, + VK_WHOLE_SIZE); } void cast_half_to_float_and_copy_from( @@ -109,6 +114,11 @@ class StagingBuffer final { template void cast_and_copy_to(DST_T* dst, const size_t numel) { VK_CHECK_COND(numel <= this->numel()); + vmaInvalidateAllocation( + vulkan_buffer_.vma_allocator(), + vulkan_buffer_.allocation(), + 0u, + VK_WHOLE_SIZE); const SRC_T* src = reinterpret_cast(data()); for (size_t i = 0; i < numel; ++i) { dst[i] = static_cast(src[i]); diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index bdb7308a0e7..b9195265398 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -468,6 +468,22 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool: return False +def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool: + """ + Returns true if the operator used to compute the given node contains a tensor + with more than 4 dimensions + """ + if is_tensor_node(node) and tensor_node_is_high_dim(node): + return True + + for arg_node in node.args: + # pyre-ignore[6] + if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node): + return True + + return False + + def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]: primary_arg_idx: Optional[int] = None for i, arg_node in enumerate(node.args):