From 237e551e734c274f7885c47f2d5ec231246aa86d Mon Sep 17 00:00:00 2001 From: ssjia Date: Sat, 23 Aug 2025 10:06:51 -0400 Subject: [PATCH] [ET-VK][ez] Allow high dimensional tensors (for buffer storage) Pull Request resolved: https://github.com/pytorch/executorch/pull/13596 As title; now that we can represent the metadata of higher dimensional tensors, we can remove checks that constrain Vulkan to only work with tensors of dim <= 4. @imported-using-ghimport Differential Revision: [D80800083](https://our.internmc.facebook.com/intern/diff/D80800083/) ghstack-source-id: 305143838 --- .../vulkan/runtime/api/containers/Tensor.cpp | 82 +++++++++++-------- .../vulkan/runtime/api/containers/Tensor.h | 1 + backends/vulkan/test/op_tests/cases.py | 22 +++-- backends/vulkan/utils.py | 4 +- 4 files changed, 68 insertions(+), 41 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index fedb0d7f173..433ae15db4e 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -189,10 +189,14 @@ utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, const std::vector& axis_map, const int32_t packed_dim) { - VK_CHECK_COND(padded_sizes.size() == 4); - VK_CHECK_COND(axis_map.size() == 4); - utils::uvec3 extents({1, 1, 1}); + + // For high dimensional tensors, buffer storage must be used. No need to + // compute image extents in this case. + if (padded_sizes.size() > 4) { + return extents; + } + // First three elements of axis_map indicate which (X,Y,Z) image axis the // width, height, and channels dim of the tensor maps to. for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) { @@ -577,12 +581,15 @@ vTensor::vTensor( sizes, dtype_, allocate_memory)) { - uniform_data_ = std::make_shared(UniformData{ - numel_, - sizes_, - dim_order_, - strides_, - calculate_logical_limits(storage_->image_extents_, axis_map_)}); + // uniform_data_ only valid for low dim tensors + if (sizes.size() <= 4) { + uniform_data_ = std::make_shared(UniformData{ + numel_, + sizes_, + dim_order_, + strides_, + calculate_logical_limits(storage_->image_extents_, axis_map_)}); + } VK_CHECK_COND( dim_order_is_valid(dim_order_), "computed dim order is invalid"); @@ -814,24 +821,29 @@ size_t vTensor::get_max_ubo_nbytes(const size_t nbytes_per_ubo) const { } const vkapi::BufferBindInfo vTensor::sizes_ubo() { + VK_CHECK_COND(sizes_.size() <= 4); return metadata_ubo_impl(&sizes_uniform_offset_, uniform_data_->sizes_v); } const vkapi::BufferBindInfo vTensor::dim_order_ubo() { + VK_CHECK_COND(sizes_.size() <= 4); return metadata_ubo_impl( &dim_order_uniform_offset_, uniform_data_->dim_order_v); } const vkapi::BufferBindInfo vTensor::strides_ubo() { + VK_CHECK_COND(sizes_.size() <= 4); return metadata_ubo_impl(&strides_uniform_offset, uniform_data_->strides_v); } const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { + VK_CHECK_COND(sizes_.size() <= 4); return metadata_ubo_impl( &logical_limits_uniform_offset_, uniform_data_->logical_limits); } const vkapi::BufferBindInfo vTensor::numel_ubo() { + VK_CHECK_COND(sizes_.size() <= 4); return metadata_ubo_impl(&numel_uniform_offset_, uniform_data_->numel); } @@ -894,31 +906,33 @@ void vTensor::update_metadata() { strides_ = calculate_strides(sizes_, dim_order_); // Update uniform data if it has been modified - uniform_data_->numel = utils::safe_downcast(numel_); - uniform_data_->sizes_v = - flip_and_unsqueeze_ivec4(sizes_, kTensorSizes, numel_); - uniform_data_->dim_order_v = - flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_); - uniform_data_->strides_v = - flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_); - uniform_data_->logical_limits.limits = - calculate_logical_limits(sizes_, axis_map_, packed_dim_); - - if (sizes_uniform_offset_ != kUniformOffsetUnset) { - uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_); - } - if (dim_order_uniform_offset_ != kUniformOffsetUnset) { - uniforms_.update(uniform_data_->dim_order_v, dim_order_uniform_offset_); - } - if (strides_uniform_offset != kUniformOffsetUnset) { - uniforms_.update(uniform_data_->strides_v, strides_uniform_offset); - } - if (numel_uniform_offset_ != kUniformOffsetUnset) { - uniforms_.update(numel_, numel_uniform_offset_); - } - if (logical_limits_uniform_offset_ != kUniformOffsetUnset) { - uniforms_.update( - uniform_data_->logical_limits.limits, logical_limits_uniform_offset_); + if (sizes_.size() <= 4) { + uniform_data_->numel = utils::safe_downcast(numel_); + uniform_data_->sizes_v = + flip_and_unsqueeze_ivec4(sizes_, kTensorSizes, numel_); + uniform_data_->dim_order_v = + flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_); + uniform_data_->strides_v = + flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_); + uniform_data_->logical_limits.limits = + calculate_logical_limits(sizes_, axis_map_, packed_dim_); + + if (sizes_uniform_offset_ != kUniformOffsetUnset) { + uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_); + } + if (dim_order_uniform_offset_ != kUniformOffsetUnset) { + uniforms_.update(uniform_data_->dim_order_v, dim_order_uniform_offset_); + } + if (strides_uniform_offset != kUniformOffsetUnset) { + uniforms_.update(uniform_data_->strides_v, strides_uniform_offset); + } + if (numel_uniform_offset_ != kUniformOffsetUnset) { + uniforms_.update(numel_, numel_uniform_offset_); + } + if (logical_limits_uniform_offset_ != kUniformOffsetUnset) { + uniforms_.update( + uniform_data_->logical_limits.limits, logical_limits_uniform_offset_); + } } if (buffer_meta_.buffer()) { diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index eb0e09dbd81..66c1fd1e4da 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -676,6 +676,7 @@ class vTensor final { } const std::shared_ptr& get_uniform_data() const { + VK_CHECK_COND(sizes_.size() <= 4); return uniform_data_; } }; diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 5aaf00fe8bc..f03b9a50737 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -55,16 +55,28 @@ def get_binary_elementwise_inputs(): ((3, 64, 1), (1, 64, 1)), ] ) - test_suite.layouts = [ - "utils::kWidthPacked", - "utils::kChannelsPacked", - ] test_suite.storage_types = [ "utils::kBuffer", "utils::kTexture3D", ] - return test_suite + highdim_test_suite = VkTestSuite( + [ + ((4, 5, 8, 1, 2, 1), (4, 5, 8, 1, 1, 1)), + ] + ) + highdim_test_suite.storage_types = [ + "utils::kBuffer", + ] + highdim_test_suite.test_name_suffix = "highdim" + + for suite in [test_suite, highdim_test_suite]: + suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + + return [test_suite, highdim_test_suite] # Eq requires a different test generator so it was split from the other test case. diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index bc03860ed3f..d1feeb0f5ce 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -599,9 +599,9 @@ def make_filtered_tensor_repset( if extents_are_valid(extents, texture_limits): valid_texture_layouts.add(memory_layout) - # High dimensional tensors are currently not supported + # High dimensional tensors require buffer storage if len(tensor_val.shape) > 4: - return NO_STORAGE + return TensorRepSet(tensor_repset.valid_buffer_layouts, set()) # Bool tensors are currently not supported if tensor_val.dtype == torch.bool: