@@ -12,7 +12,23 @@ enum struct TensorElementType : uint8_t {
12
12
Float32 ,
13
13
Float64 ,
14
14
};
15
- constexpr size_t tensor_element_size (TensorElementType type) {
15
+ template <typename T>
16
+ constexpr TensorElementType get_tensor_elem_type () {
17
+ if constexpr (std::is_same_v<T, half>) {
18
+ return TensorElementType::Float16;
19
+ } else if constexpr (std::is_same_v<T, float >) {
20
+ return TensorElementType::Float32 ;
21
+ } else if constexpr (std::is_same_v<T, double >) {
22
+ return TensorElementType::Float64 ;
23
+ } else {
24
+ static_assert (luisa::always_false_v<T>, " Bad type." );
25
+ }
26
+ }
27
+ template <typename T>
28
+ concept valid_tensor_elem_type = requires() {
29
+ get_tensor_elem_type<T>();
30
+ };
31
+ constexpr uint64_t tensor_element_size (TensorElementType type) {
16
32
switch (type) {
17
33
case TensorElementType::Float16:
18
34
return 2 ;
@@ -24,7 +40,7 @@ constexpr size_t tensor_element_size(TensorElementType type) {
24
40
return 0 ;
25
41
}
26
42
}
27
- constexpr size_t tensor_element_align (TensorElementType type) {
43
+ constexpr uint64_t tensor_element_align (TensorElementType type) {
28
44
switch (type) {
29
45
case TensorElementType::Float16:
30
46
return 2 ;
@@ -37,13 +53,13 @@ constexpr size_t tensor_element_align(TensorElementType type) {
37
53
}
38
54
}
39
55
class LC_TENSOR_API TensorData {
40
- luisa::span<size_t const > _sizes;
56
+ luisa::span<uint64_t const > _sizes;
41
57
TensorElementType _type;
42
58
uint64_t _idx;
43
- size_t _size_bytes;
59
+ uint64_t _size_bytes;
44
60
45
61
public:
46
- TensorData (luisa::span<size_t const > sizes,
62
+ TensorData (luisa::span<uint64_t const > sizes,
47
63
TensorElementType element_type,
48
64
uint64_t uid) noexcept ;
49
65
TensorData (TensorData &&rhs) noexcept ;
@@ -53,14 +69,14 @@ class LC_TENSOR_API TensorData {
53
69
[[nodiscard]] uint64_t idx () const noexcept {
54
70
return _idx;
55
71
}
56
- [[nodiscard]] size_t get_size (uint dimension) const noexcept {
72
+ [[nodiscard]] uint64_t get_size (uint dimension) const noexcept {
57
73
if (dimension >= _sizes.size ()) return 1 ;
58
74
return _sizes[dimension];
59
75
}
60
- [[nodiscard]] size_t dimension () const noexcept {
76
+ [[nodiscard]] uint64_t dimension () const noexcept {
61
77
return _sizes.size ();
62
78
}
63
- [[nodiscard]] size_t size_bytes () const noexcept {
79
+ [[nodiscard]] uint64_t size_bytes () const noexcept {
64
80
return _size_bytes;
65
81
}
66
82
[[nodiscard]] TensorElementType element_type () const noexcept {
@@ -75,6 +91,7 @@ class LC_TENSOR_API Tensor {
75
91
76
92
TensorData *_data;
77
93
bool _contained;
94
+ [[nodiscard]] void _create (TensorElementType element_type, luisa::span<const uint64_t > sizes, Argument::Buffer buffer) noexcept ;
78
95
79
96
public:
80
97
explicit Tensor (TensorData *data,
@@ -87,12 +104,41 @@ class LC_TENSOR_API Tensor {
87
104
std::destroy_at (this );
88
105
new (this ) Tensor (std::move (rhs));
89
106
}
107
+ template <typename T>
108
+ requires (luisa::compute::is_buffer_or_view_v<T> && valid_tensor_elem_type<luisa::compute::buffer_element_t <T>>)
109
+ Tensor (T const &t, luisa::span<const uint64_t > sizes) noexcept {
110
+ Argument::Buffer bf{
111
+ .handle = t.handle (),
112
+ .size = t.size_bytes ()};
113
+ if constexpr (luisa::compute::is_buffer_v<T>) {
114
+ bf.offset = 0 ;
115
+ } else {
116
+ bf.offset = t.offset_bytes ();
117
+ }
118
+ _create (get_tensor_elem_type<luisa::compute::buffer_element_t <T>>(), sizes, bf);
119
+ }
120
+ template <typename T>
121
+ requires (luisa::compute::is_buffer_or_view_v<T> && valid_tensor_elem_type<luisa::compute::buffer_element_t <T>>)
122
+ Tensor (T const &t, std::initializer_list<const uint64_t > sizes) noexcept
123
+ : Tensor(t, luisa::span{sizes.begin (), sizes.size ()}) {
124
+ }
125
+ template <typename T>
126
+ requires (luisa::compute::is_buffer_or_view_v<T> && valid_tensor_elem_type<luisa::compute::buffer_element_t <T>>)
127
+ Tensor (T const &t) noexcept {
128
+ Argument::Buffer bf{
129
+ .handle = t.handle (),
130
+ .size = t.size_bytes ()};
131
+ if constexpr (luisa::compute::is_buffer_v<T>) {
132
+ bf.offset = 0 ;
133
+ } else {
134
+ bf.offset = t.offset_bytes ();
135
+ }
136
+ uint64_t size = t.size_bytes ();
137
+ _create (get_tensor_elem_type<luisa::compute::buffer_element_t <T>>(), {&size, 1 });
138
+ }
90
139
[[nodiscard]] auto data () const noexcept { return _data; }
91
140
void dispose () noexcept ;
92
141
93
- [[nodiscard]] static Tensor one (TensorElementType element_type, luisa::span<const size_t > sizes) noexcept ;
94
- [[nodiscard]] static Tensor zero (TensorElementType element_type, luisa::span<const size_t > sizes) noexcept ;
95
-
96
142
[[nodiscard]] static Tensor matmul (
97
143
Tensor const &lhs,
98
144
Tensor const &rhs,
@@ -141,17 +187,17 @@ class LC_TENSOR_API Tensor {
141
187
// bool _requires_grad = false;
142
188
// bool _reserve_memory = false;
143
189
// bool _dirty = false;
144
- // std::array<size_t , 3> _shape;
145
- // std::array<size_t , 3> _stride;
190
+ // std::array<uint64_t , 3> _shape;
191
+ // std::array<uint64_t , 3> _stride;
146
192
// luisa::optional<Buffer<T>> _buffer;
147
193
// luisa::optional<Var<T>> _var;
148
194
// public:
149
- // using shape_type = std::array<size_t , 3>;
195
+ // using shape_type = std::array<uint64_t , 3>;
150
196
// using value_type = T;
151
197
152
198
// private:
153
- // static size_t compute_size(shape_type s) {
154
- // size_t size = 1;
199
+ // static uint64_t compute_size(shape_type s) {
200
+ // uint64_t size = 1;
155
201
// for (auto i : s) {
156
202
// size *= i;
157
203
// }
@@ -186,14 +232,14 @@ class LC_TENSOR_API Tensor {
186
232
// return DTensor<T>{device};
187
233
// }
188
234
189
- // // template<size_t ... Is>
235
+ // // template<uint64_t ... Is>
190
236
// // [[nodiscard]] Tensor<T, Dim + sizeof...(Is)> repeat(Is...) {
191
237
// // // TODO: implement
192
238
// // return Tensor<T, Dim>{device};
193
239
// // }
194
240
// };
195
241
196
- // template<class R, size_t Dim, class... Ts>
242
+ // template<class R, uint64_t Dim, class... Ts>
197
243
// Tensor<R, Dim> map(const Tensor<Ts, Dim> &... ts) noexcept {
198
244
// // TODO: implement
199
245
// }
0 commit comments