Skip to content

Commit 2491d54

Browse files
author
ssjia
committed
[ET-VK][ez] Allow high dimensional tensors (for buffer storage)
Pull Request resolved: #13596 As title; now that we can represent the metadata of higher dimensional tensors, we can remove checks that constrain Vulkan to only work with tensors of dim <= 4. @imported-using-ghimport Differential Revision: [D80800083](https://our.internmc.facebook.com/intern/diff/D80800083/) ghstack-source-id: 305143838
1 parent 1859f7d commit 2491d54

File tree

4 files changed

+68
-41
lines changed

4 files changed

+68
-41
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,14 @@ utils::uvec3 calculate_image_extents(
189189
const std::vector<int64_t>& padded_sizes,
190190
const std::vector<int64_t>& axis_map,
191191
const int32_t packed_dim) {
192-
VK_CHECK_COND(padded_sizes.size() == 4);
193-
VK_CHECK_COND(axis_map.size() == 4);
194-
195192
utils::uvec3 extents({1, 1, 1});
193+
194+
// For high dimensional tensors, buffer storage must be used. No need to
195+
// compute image extents in this case.
196+
if (padded_sizes.size() > 4) {
197+
return extents;
198+
}
199+
196200
// First three elements of axis_map indicate which (X,Y,Z) image axis the
197201
// width, height, and channels dim of the tensor maps to.
198202
for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
@@ -577,12 +581,15 @@ vTensor::vTensor(
577581
sizes,
578582
dtype_,
579583
allocate_memory)) {
580-
uniform_data_ = std::make_shared<UniformData>(UniformData{
581-
numel_,
582-
sizes_,
583-
dim_order_,
584-
strides_,
585-
calculate_logical_limits(storage_->image_extents_, axis_map_)});
584+
// uniform_data_ only valid for low dim tensors
585+
if (sizes.size() <= 4) {
586+
uniform_data_ = std::make_shared<UniformData>(UniformData{
587+
numel_,
588+
sizes_,
589+
dim_order_,
590+
strides_,
591+
calculate_logical_limits(storage_->image_extents_, axis_map_)});
592+
}
586593

587594
VK_CHECK_COND(
588595
dim_order_is_valid(dim_order_), "computed dim order is invalid");
@@ -814,24 +821,29 @@ size_t vTensor::get_max_ubo_nbytes(const size_t nbytes_per_ubo) const {
814821
}
815822

816823
const vkapi::BufferBindInfo vTensor::sizes_ubo() {
824+
VK_CHECK_COND(sizes_.size() <= 4);
817825
return metadata_ubo_impl(&sizes_uniform_offset_, uniform_data_->sizes_v);
818826
}
819827

820828
const vkapi::BufferBindInfo vTensor::dim_order_ubo() {
829+
VK_CHECK_COND(sizes_.size() <= 4);
821830
return metadata_ubo_impl(
822831
&dim_order_uniform_offset_, uniform_data_->dim_order_v);
823832
}
824833

825834
const vkapi::BufferBindInfo vTensor::strides_ubo() {
835+
VK_CHECK_COND(sizes_.size() <= 4);
826836
return metadata_ubo_impl(&strides_uniform_offset, uniform_data_->strides_v);
827837
}
828838

829839
const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {
840+
VK_CHECK_COND(sizes_.size() <= 4);
830841
return metadata_ubo_impl(
831842
&logical_limits_uniform_offset_, uniform_data_->logical_limits);
832843
}
833844

834845
const vkapi::BufferBindInfo vTensor::numel_ubo() {
846+
VK_CHECK_COND(sizes_.size() <= 4);
835847
return metadata_ubo_impl(&numel_uniform_offset_, uniform_data_->numel);
836848
}
837849

@@ -894,31 +906,33 @@ void vTensor::update_metadata() {
894906
strides_ = calculate_strides(sizes_, dim_order_);
895907

896908
// Update uniform data if it has been modified
897-
uniform_data_->numel = utils::safe_downcast<int32_t>(numel_);
898-
uniform_data_->sizes_v =
899-
flip_and_unsqueeze_ivec4(sizes_, kTensorSizes, numel_);
900-
uniform_data_->dim_order_v =
901-
flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_);
902-
uniform_data_->strides_v =
903-
flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_);
904-
uniform_data_->logical_limits.limits =
905-
calculate_logical_limits(sizes_, axis_map_, packed_dim_);
906-
907-
if (sizes_uniform_offset_ != kUniformOffsetUnset) {
908-
uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_);
909-
}
910-
if (dim_order_uniform_offset_ != kUniformOffsetUnset) {
911-
uniforms_.update(uniform_data_->dim_order_v, dim_order_uniform_offset_);
912-
}
913-
if (strides_uniform_offset != kUniformOffsetUnset) {
914-
uniforms_.update(uniform_data_->strides_v, strides_uniform_offset);
915-
}
916-
if (numel_uniform_offset_ != kUniformOffsetUnset) {
917-
uniforms_.update(numel_, numel_uniform_offset_);
918-
}
919-
if (logical_limits_uniform_offset_ != kUniformOffsetUnset) {
920-
uniforms_.update(
921-
uniform_data_->logical_limits.limits, logical_limits_uniform_offset_);
909+
if (sizes_.size() <= 4) {
910+
uniform_data_->numel = utils::safe_downcast<int32_t>(numel_);
911+
uniform_data_->sizes_v =
912+
flip_and_unsqueeze_ivec4(sizes_, kTensorSizes, numel_);
913+
uniform_data_->dim_order_v =
914+
flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_);
915+
uniform_data_->strides_v =
916+
flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_);
917+
uniform_data_->logical_limits.limits =
918+
calculate_logical_limits(sizes_, axis_map_, packed_dim_);
919+
920+
if (sizes_uniform_offset_ != kUniformOffsetUnset) {
921+
uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_);
922+
}
923+
if (dim_order_uniform_offset_ != kUniformOffsetUnset) {
924+
uniforms_.update(uniform_data_->dim_order_v, dim_order_uniform_offset_);
925+
}
926+
if (strides_uniform_offset != kUniformOffsetUnset) {
927+
uniforms_.update(uniform_data_->strides_v, strides_uniform_offset);
928+
}
929+
if (numel_uniform_offset_ != kUniformOffsetUnset) {
930+
uniforms_.update(numel_, numel_uniform_offset_);
931+
}
932+
if (logical_limits_uniform_offset_ != kUniformOffsetUnset) {
933+
uniforms_.update(
934+
uniform_data_->logical_limits.limits, logical_limits_uniform_offset_);
935+
}
922936
}
923937

924938
if (buffer_meta_.buffer()) {

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ class vTensor final {
676676
}
677677

678678
const std::shared_ptr<UniformData>& get_uniform_data() const {
679+
VK_CHECK_COND(sizes_.size() <= 4);
679680
return uniform_data_;
680681
}
681682
};

backends/vulkan/test/op_tests/cases.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,28 @@ def get_binary_elementwise_inputs():
5555
((3, 64, 1), (1, 64, 1)),
5656
]
5757
)
58-
test_suite.layouts = [
59-
"utils::kWidthPacked",
60-
"utils::kChannelsPacked",
61-
]
6258
test_suite.storage_types = [
6359
"utils::kBuffer",
6460
"utils::kTexture3D",
6561
]
6662

67-
return test_suite
63+
highdim_test_suite = VkTestSuite(
64+
[
65+
((4, 5, 8, 1, 2, 1), (4, 5, 8, 1, 1, 1)),
66+
]
67+
)
68+
highdim_test_suite.storage_types = [
69+
"utils::kBuffer",
70+
]
71+
highdim_test_suite.test_name_suffix = "highdim"
72+
73+
for suite in [test_suite, highdim_test_suite]:
74+
suite.layouts = [
75+
"utils::kWidthPacked",
76+
"utils::kChannelsPacked",
77+
]
78+
79+
return [test_suite, highdim_test_suite]
6880

6981

7082
# Eq requires a different test generator so it was split from the other test case.

backends/vulkan/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,9 @@ def make_filtered_tensor_repset(
599599
if extents_are_valid(extents, texture_limits):
600600
valid_texture_layouts.add(memory_layout)
601601

602-
# High dimensional tensors are currently not supported
602+
# High dimensional tensors require buffer storage
603603
if len(tensor_val.shape) > 4:
604-
return NO_STORAGE
604+
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())
605605

606606
# Bool tensors are currently not supported
607607
if tensor_val.dtype == torch.bool:

0 commit comments

Comments
 (0)