Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class OpFeatures:
# bool indicating if the operator has a resize function, which allows it to
# support models with dynamic shape
"supports_resize",
# bool indicating if the operator supports tensors with more than 4 dimensions
"supports_highdim",
# bool indicating if the operator handles its own prepacking. If this is True,
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
Expand All @@ -60,6 +62,7 @@ def __init__(
Union[utils.TensorRepSet, List[utils.TensorRepSet]]
] = None,
supports_resize: bool = False,
supports_highdim: bool = False,
supports_prepacking: bool = False,
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
pick_io_storage_fn: Optional[Callable] = None,
Expand All @@ -85,6 +88,7 @@ def __init__(
self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0])

self.supports_resize = supports_resize
self.supports_highdim = supports_highdim
self.supports_prepacking = supports_prepacking

self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
Expand Down Expand Up @@ -239,6 +243,7 @@ def register_binaryop_cpp_ops():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -253,6 +258,7 @@ def register_pow_tensor_scalar():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
supports_resize=True,
supports_highdim=True,
)


Expand Down Expand Up @@ -630,6 +636,7 @@ def register_reduce_cpp_ops():
inputs_storage=utils.ANY_TEXTURE,
inputs_dtypes=utils.FP_T,
supports_resize=True,
supports_highdim=True,
are_node_inputs_supported_fn=is_reduce_node_supported,
pick_io_storage_fn=pick_storage_for_reduce,
)
Expand All @@ -651,6 +658,7 @@ def register_argreduce_cpp_ops():
inputs_storage=utils.ANY_TEXTURE,
inputs_dtypes=utils.FP_T,
supports_resize=True,
supports_highdim=True,
are_node_inputs_supported_fn=is_reduce_node_supported,
pick_io_storage_fn=pick_storage_for_reduce,
)
Expand Down Expand Up @@ -811,6 +819,7 @@ def register_apply_rotary_emb():
inputs_storage=utils.CONTIGUOUS_ANY,
inputs_dtypes=utils.FP_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -834,6 +843,7 @@ def register_permute_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -848,6 +858,7 @@ def register_view_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -857,6 +868,7 @@ def register_to_dim_order_copy():
inputs_storage=utils.ANY_BUFFER,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -871,6 +883,7 @@ def register_squeeze_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -885,6 +898,7 @@ def register_unsqueeze_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -899,6 +913,7 @@ def register_clone():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -908,6 +923,7 @@ def register_clone_dim_order():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -922,6 +938,7 @@ def register_gather():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -936,6 +953,7 @@ def register_expand_copy():
inputs_storage=utils.ANY_BUFFER,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=False,
supports_highdim=True,
)


Expand Down Expand Up @@ -964,6 +982,7 @@ def register_select_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -978,6 +997,7 @@ def register_slice_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand All @@ -992,6 +1012,7 @@ def register_split_with_sizes_copy():
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
)


Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
self.log_skip(node, "op args not supported")
return False

if not features.supports_highdim and utils.op_contains_high_dim_tensor(node):
self.log_skip(node, "op does not support high dim tensors")
return False

if self.require_dynamic_shapes and not features.supports_resize:
self.log_skip(node, "no dynamic shape support")
return False
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/api/containers/StagingBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ void StagingBuffer::cast_half_to_float_and_copy_from(
for (size_t i = 0; i < numel; ++i) {
dst[i] = half_to_float(src[i]);
}
vmaFlushAllocation(
vulkan_buffer_.vma_allocator(),
vulkan_buffer_.allocation(),
0u,
VK_WHOLE_SIZE);
}

void StagingBuffer::cast_float_to_half_and_copy_to(
Expand Down
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/api/containers/StagingBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class StagingBuffer final {
for (size_t i = 0; i < numel; ++i) {
dst[i] = static_cast<DST_T>(src[i]);
}
vmaFlushAllocation(
vulkan_buffer_.vma_allocator(),
vulkan_buffer_.allocation(),
0u,
VK_WHOLE_SIZE);
}

void cast_half_to_float_and_copy_from(
Expand All @@ -109,6 +114,11 @@ class StagingBuffer final {
template <typename SRC_T, typename DST_T>
void cast_and_copy_to(DST_T* dst, const size_t numel) {
VK_CHECK_COND(numel <= this->numel());
vmaInvalidateAllocation(
vulkan_buffer_.vma_allocator(),
vulkan_buffer_.allocation(),
0u,
VK_WHOLE_SIZE);
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
for (size_t i = 0; i < numel; ++i) {
dst[i] = static_cast<DST_T>(src[i]);
Expand Down
16 changes: 16 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,22 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
return False


def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a tensor
with more than 4 dimensions
"""
if is_tensor_node(node) and tensor_node_is_high_dim(node):
return True

for arg_node in node.args:
# pyre-ignore[6]
if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node):
return True

return False


def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
primary_arg_idx: Optional[int] = None
for i, arg_node in enumerate(node.args):
Expand Down
Loading