@@ -189,10 +189,14 @@ utils::uvec3 calculate_image_extents(
189
189
const std::vector<int64_t >& padded_sizes,
190
190
const std::vector<int64_t >& axis_map,
191
191
const int32_t packed_dim) {
192
- VK_CHECK_COND (padded_sizes.size () == 4 );
193
- VK_CHECK_COND (axis_map.size () == 4 );
194
-
195
192
utils::uvec3 extents ({1 , 1 , 1 });
193
+
194
+ // For high dimensional tensors, buffer storage must be used. No need to
195
+ // compute image extents in this case.
196
+ if (padded_sizes.size () > 4 ) {
197
+ return extents;
198
+ }
199
+
196
200
// First three elements of axis_map indicate which (X,Y,Z) image axis the
197
201
// width, height, and channels dim of the tensor maps to.
198
202
for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
@@ -577,12 +581,15 @@ vTensor::vTensor(
577
581
sizes,
578
582
dtype_,
579
583
allocate_memory)) {
580
- uniform_data_ = std::make_shared<UniformData>(UniformData{
581
- numel_,
582
- sizes_,
583
- dim_order_,
584
- strides_,
585
- calculate_logical_limits (storage_->image_extents_ , axis_map_)});
584
+ // uniform_data_ only valid for low dim tensors
585
+ if (sizes.size () <= 4 ) {
586
+ uniform_data_ = std::make_shared<UniformData>(UniformData{
587
+ numel_,
588
+ sizes_,
589
+ dim_order_,
590
+ strides_,
591
+ calculate_logical_limits (storage_->image_extents_ , axis_map_)});
592
+ }
586
593
587
594
VK_CHECK_COND (
588
595
dim_order_is_valid (dim_order_), " computed dim order is invalid" );
@@ -814,24 +821,29 @@ size_t vTensor::get_max_ubo_nbytes(const size_t nbytes_per_ubo) const {
814
821
}
815
822
816
823
const vkapi::BufferBindInfo vTensor::sizes_ubo () {
824
+ VK_CHECK_COND (sizes_.size () <= 4 );
817
825
return metadata_ubo_impl (&sizes_uniform_offset_, uniform_data_->sizes_v );
818
826
}
819
827
820
828
const vkapi::BufferBindInfo vTensor::dim_order_ubo () {
829
+ VK_CHECK_COND (sizes_.size () <= 4 );
821
830
return metadata_ubo_impl (
822
831
&dim_order_uniform_offset_, uniform_data_->dim_order_v );
823
832
}
824
833
825
834
const vkapi::BufferBindInfo vTensor::strides_ubo () {
835
+ VK_CHECK_COND (sizes_.size () <= 4 );
826
836
return metadata_ubo_impl (&strides_uniform_offset, uniform_data_->strides_v );
827
837
}
828
838
829
839
const vkapi::BufferBindInfo vTensor::logical_limits_ubo () {
840
+ VK_CHECK_COND (sizes_.size () <= 4 );
830
841
return metadata_ubo_impl (
831
842
&logical_limits_uniform_offset_, uniform_data_->logical_limits );
832
843
}
833
844
834
845
const vkapi::BufferBindInfo vTensor::numel_ubo () {
846
+ VK_CHECK_COND (sizes_.size () <= 4 );
835
847
return metadata_ubo_impl (&numel_uniform_offset_, uniform_data_->numel );
836
848
}
837
849
@@ -894,31 +906,33 @@ void vTensor::update_metadata() {
894
906
strides_ = calculate_strides (sizes_, dim_order_);
895
907
896
908
// Update uniform data if it has been modified
897
- uniform_data_->numel = utils::safe_downcast<int32_t >(numel_);
898
- uniform_data_->sizes_v =
899
- flip_and_unsqueeze_ivec4 (sizes_, kTensorSizes , numel_);
900
- uniform_data_->dim_order_v =
901
- flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
902
- uniform_data_->strides_v =
903
- flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
904
- uniform_data_->logical_limits .limits =
905
- calculate_logical_limits (sizes_, axis_map_, packed_dim_);
906
-
907
- if (sizes_uniform_offset_ != kUniformOffsetUnset ) {
908
- uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
909
- }
910
- if (dim_order_uniform_offset_ != kUniformOffsetUnset ) {
911
- uniforms_.update (uniform_data_->dim_order_v , dim_order_uniform_offset_);
912
- }
913
- if (strides_uniform_offset != kUniformOffsetUnset ) {
914
- uniforms_.update (uniform_data_->strides_v , strides_uniform_offset);
915
- }
916
- if (numel_uniform_offset_ != kUniformOffsetUnset ) {
917
- uniforms_.update (numel_, numel_uniform_offset_);
918
- }
919
- if (logical_limits_uniform_offset_ != kUniformOffsetUnset ) {
920
- uniforms_.update (
921
- uniform_data_->logical_limits .limits , logical_limits_uniform_offset_);
909
+ if (sizes_.size () <= 4 ) {
910
+ uniform_data_->numel = utils::safe_downcast<int32_t >(numel_);
911
+ uniform_data_->sizes_v =
912
+ flip_and_unsqueeze_ivec4 (sizes_, kTensorSizes , numel_);
913
+ uniform_data_->dim_order_v =
914
+ flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
915
+ uniform_data_->strides_v =
916
+ flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
917
+ uniform_data_->logical_limits .limits =
918
+ calculate_logical_limits (sizes_, axis_map_, packed_dim_);
919
+
920
+ if (sizes_uniform_offset_ != kUniformOffsetUnset ) {
921
+ uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
922
+ }
923
+ if (dim_order_uniform_offset_ != kUniformOffsetUnset ) {
924
+ uniforms_.update (uniform_data_->dim_order_v , dim_order_uniform_offset_);
925
+ }
926
+ if (strides_uniform_offset != kUniformOffsetUnset ) {
927
+ uniforms_.update (uniform_data_->strides_v , strides_uniform_offset);
928
+ }
929
+ if (numel_uniform_offset_ != kUniformOffsetUnset ) {
930
+ uniforms_.update (numel_, numel_uniform_offset_);
931
+ }
932
+ if (logical_limits_uniform_offset_ != kUniformOffsetUnset ) {
933
+ uniforms_.update (
934
+ uniform_data_->logical_limits .limits , logical_limits_uniform_offset_);
935
+ }
922
936
}
923
937
924
938
if (buffer_meta_.buffer ()) {
0 commit comments