Skip to content

Commit 4547daa

Browse files
author
ssjia
committed
[ET-VK] High dim tensor support for view, unsqueeze, squeeze, clone
Pull Request resolved: #13597 Add support for high dimensional tensors for view, unsqueeze, squeeze and clone. @imported-using-ghimport Differential Revision: [D80800084](https://our.internmc.facebook.com/intern/diff/D80800084/) ghstack-source-id: 305143842
1 parent 2491d54 commit 4547daa

File tree

10 files changed

+332
-34
lines changed

10 files changed

+332
-34
lines changed

backends/vulkan/op_registry.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,8 @@ def register_rotary_emb_op():
489489

490490
@update_features(
491491
[
492-
exir_ops.edge.aten.clone.default,
493492
exir_ops.edge.aten.permute.default,
494493
exir_ops.edge.aten.permute_copy.default,
495-
exir_ops.edge.aten.view_copy.default,
496494
]
497495
)
498496
def register_view_ops():
@@ -502,6 +500,21 @@ def register_view_ops():
502500
)
503501

504502

503+
@update_features(
504+
[
505+
exir_ops.edge.aten.view_copy.default,
506+
exir_ops.edge.aten.squeeze_copy.dims,
507+
exir_ops.edge.aten.unsqueeze_copy.default,
508+
exir_ops.edge.aten.clone.default,
509+
]
510+
)
511+
def register_view_ops_with_buffer_meta():
512+
return OpFeatures(
513+
inputs_storage=utils.ANY_STORAGE,
514+
supports_resize=True,
515+
)
516+
517+
505518
# Fully featured transfer operators (i.e. operators that copy data from the input
506519
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
507520
# for both texture and buffer storage types.
@@ -562,9 +575,6 @@ def register_ported_op():
562575
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions
563576
@update_features(
564577
[
565-
# Shape Manipulation
566-
exir_ops.edge.aten.squeeze_copy.dims,
567-
exir_ops.edge.aten.unsqueeze_copy.default,
568578
# Tensor combination
569579
exir_ops.edge.aten.repeat.default,
570580
exir_ops.edge.aten.split_with_sizes_copy.default,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#version 450 core
2+
3+
#define PRECISION ${PRECISION}
4+
5+
#define T ${buffer_scalar_type(DTYPE)}
6+
7+
${define_required_extensions(DTYPE)}
8+
9+
layout(std430) buffer;
10+
11+
#include "indexing.glslh"
12+
13+
${layout_declare_tensor(B, "w", "t_outp", DTYPE, STORAGE)}
14+
${layout_declare_tensor(B, "r", "t_inp", DTYPE, STORAGE)}
15+
16+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
17+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
18+
19+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
20+
21+
/*
22+
* The insight behind the view operation is that the contiguous index of each
23+
* tensor element in the input and output tensors are the same.
24+
*/
25+
void main() {
26+
const uint outp_bufi = gl_GlobalInvocationID.x;
27+
if (outp_bufi >= numel(outp)) {
28+
return;
29+
}
30+
31+
TensorIndex outp_tidx;
32+
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
33+
34+
// To map the output to the input, find the input element that has the same
35+
// contiguous index as the output element.
36+
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
37+
38+
TensorIndex inp_tidx;
39+
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
40+
41+
const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
42+
43+
t_outp[outp_bufi] = t_inp[inp_bufi];
44+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
view_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
- VALUE: float
15+
- VALUE: double
16+
- VALUE: int8
17+
- VALUE: uint8
18+
- VALUE: int32
19+
shader_variants:
20+
- NAME: view_buffer

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,11 @@ void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
143143
if (src_storage == utils::kBuffer && dst_storage == utils::kTexture3D) {
144144
return add_buffer_to_image_node(graph, src, dst);
145145
}
146-
VK_THROW("Buffer to buffer memory layout transition not supported yet!");
146+
147+
std::vector<ValueRef> extra_args = {};
148+
// Buffer to buffer copy
149+
return add_view_copy_buffer_node(
150+
graph, src, dst, extra_args, resize_clone_node);
147151
}
148152

149153
// Clone node is not the most efficient implementation for the aten.clone

backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Clone.h>
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1516

@@ -55,8 +56,49 @@ void add_squeeze_copy_dims_node(
5556
}
5657
}
5758

59+
void resize_squeeze_node(
60+
ComputeGraph* graph,
61+
const std::vector<ArgGroup>& args,
62+
const std::vector<ValueRef>& extra_args) {
63+
const ValueRef out = args.at(0).refs.at(0);
64+
const ValueRef in = args.at(1).refs.at(0);
65+
const ValueRef dims_ref = extra_args.at(0);
66+
67+
const IntListPtr dims = graph->get_int_list(dims_ref);
68+
69+
std::vector<int64_t> out_sizes = graph->sizes_of(in);
70+
71+
// Remove the dimensions specified in dims if their size is 1
72+
for (int64_t dim : *dims) {
73+
if (dim >= 0 && dim < static_cast<int64_t>(out_sizes.size()) &&
74+
out_sizes[dim] == 1) {
75+
out_sizes.erase(out_sizes.begin() + dim);
76+
// After erasing, all subsequent dims shift left by one
77+
// So we need to decrement all subsequent dims in dims
78+
for (auto& d : *dims) {
79+
if (d > dim) {
80+
--d;
81+
}
82+
}
83+
}
84+
}
85+
86+
graph->virtual_resize(out, out_sizes);
87+
}
88+
5889
void squeeze_copy_dims(ComputeGraph& graph, const std::vector<ValueRef>& args) {
59-
return add_squeeze_copy_dims_node(graph, args[0], args[1], args[2]);
90+
int idx = 0;
91+
const ValueRef in = args.at(idx++);
92+
const ValueRef dims = args.at(idx++);
93+
const ValueRef out = args.at(idx++);
94+
95+
std::vector<ValueRef> resize_args = {dims};
96+
97+
if (graph.is_buffer_storage(in)) {
98+
return add_view_copy_buffer_node(
99+
graph, in, out, resize_args, resize_squeeze_node);
100+
}
101+
return add_squeeze_copy_dims_node(graph, in, dims, out);
60102
}
61103

62104
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1415

@@ -45,8 +46,42 @@ void add_unsqueeze_node(
4546
add_permute_node(graph, in, permute_dims_ref, out);
4647
}
4748

49+
void resize_unsqueeze_node(
50+
ComputeGraph* graph,
51+
const std::vector<ArgGroup>& args,
52+
const std::vector<ValueRef>& extra_args) {
53+
const ValueRef out = args.at(0).refs.at(0);
54+
const ValueRef in = args.at(1).refs.at(0);
55+
const ValueRef dims_ref = extra_args.at(0);
56+
57+
const IntListPtr dims = graph->get_int_list(dims_ref);
58+
59+
std::vector<int64_t> out_sizes = graph->sizes_of(in);
60+
61+
// Insert singleton dimensions at the specified positions
62+
for (auto dim : *dims) {
63+
int64_t d = dim;
64+
if (d < 0) {
65+
d += static_cast<int64_t>(out_sizes.size()) + 1;
66+
}
67+
out_sizes.insert(out_sizes.begin() + d, 1);
68+
}
69+
70+
graph->virtual_resize(out, out_sizes);
71+
}
72+
4873
void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {
49-
return add_unsqueeze_node(graph, args[0], args[1], args[2]);
74+
int idx = 0;
75+
const ValueRef in = args.at(idx++);
76+
const ValueRef dims = args.at(idx++);
77+
const ValueRef out = args.at(idx++);
78+
79+
std::vector<ValueRef> resize_args = {dims};
80+
if (graph.is_buffer_storage(in)) {
81+
return add_view_copy_buffer_node(
82+
graph, in, out, resize_args, resize_unsqueeze_node);
83+
}
84+
return add_unsqueeze_node(graph, in, dims, out);
5085
}
5186

5287
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,47 @@ void add_view_node(
8989
resize_view_node));
9090
}
9191

92+
void add_view_copy_buffer_node(
93+
ComputeGraph& graph,
94+
ValueRef in,
95+
ValueRef out,
96+
const std::vector<ValueRef>& resize_args,
97+
const ExecuteNode::ResizeFunction& resize_fn) {
98+
std::string kernel_name = "view_buffer";
99+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
100+
101+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
102+
graph,
103+
VK_KERNEL_FROM_STR(kernel_name),
104+
default_pick_global_wg_size,
105+
default_pick_local_wg_size,
106+
// Inputs and Outputs
107+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
108+
// Parameter Buffers
109+
{graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)},
110+
// Push Constants
111+
{},
112+
// Specialization Constants
113+
{},
114+
// Resize Args
115+
resize_args,
116+
// Resizing Logic
117+
resize_fn));
118+
}
119+
92120
void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
93-
return add_view_node(graph, args[0], args[1], args[2]);
121+
int idx = 0;
122+
const ValueRef in = args.at(idx++);
123+
const ValueRef sizes = args.at(idx++);
124+
const ValueRef out = args.at(idx++);
125+
126+
std::vector<ValueRef> resize_args = {sizes};
127+
128+
if (graph.is_buffer_storage(out)) {
129+
return add_view_copy_buffer_node(
130+
graph, in, out, resize_args, resize_view_node);
131+
}
132+
return add_view_node(graph, in, sizes, out);
94133
}
95134

96135
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/View.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212

1313
namespace vkcompute {
1414

15+
/*
16+
* Dispatches the view_copy compute shader. This can be used to implement ops
17+
* that preserve the "contiguous" indexes of elements between the input and
18+
* output such as view_copy, squeeze_copy, unsqueeze_copy, etc.
19+
*/
20+
void add_view_copy_buffer_node(
21+
ComputeGraph& graph,
22+
ValueRef in,
23+
ValueRef out,
24+
const std::vector<ValueRef>& resize_args,
25+
const ExecuteNode::ResizeFunction& resize_fn);
26+
1527
void add_view_node(
1628
ComputeGraph& graph,
1729
ValueRef in,

0 commit comments

Comments
 (0)