Skip to content

Commit 89bf53a

Browse files
committed
Make captured tensor
1 parent 897ff06 commit 89bf53a

File tree

9 files changed

+142
-79
lines changed

9 files changed

+142
-79
lines changed

include/luisa/tensor/expression.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,10 @@ class LUISA_TENSOR_EXPR_CLASS_INHERIT(SetValueExpr) {
9999
SetValueExpr(
100100
uint64_t idx,
101101
TensorData *tensor_data,
102-
uint64_t value) noexcept
102+
ValueType const& value) noexcept
103103
: BaseClass(idx),
104104
tensor_data(tensor_data),
105105
value(value) {}
106-
SetValueExpr(
107-
uint64_t idx,
108-
TensorData *tensor_data,
109-
void *ptr,
110-
vstd::func_ptr_t<void(void *)> disposer) noexcept
111-
: BaseClass(idx),
112-
tensor_data(tensor_data),
113-
value(BinaryBlob{ptr, disposer}) {}
114106
~SetValueExpr() noexcept;
115107
void get_tensors(vstd::FuncRef<void(TensorData *, Usage usage)> callback) noexcept override {
116108
callback(tensor_data, Usage::WRITE);

include/luisa/tensor/kernel.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,16 @@ struct member_func_meta<Object, ExtraTypes, Ret (Class::*)(Args...) noexcept> :
3333
}// namespace detail
3434

3535
struct TensorDescriptor {
36-
luisa::fixed_vector<size_t, 4> _sizes;
36+
luisa::fixed_vector<uint64_t, 4> _sizes;
3737
TensorElementType _type;
3838
TensorDescriptor(
39-
std::initializer_list<size_t> dimensions,
39+
std::initializer_list<uint64_t> dimensions,
4040
TensorElementType type) noexcept : _type(type) {
4141
_sizes.resize_uninitialized(dimensions.size());
4242
std::memcpy(_sizes.data(), dimensions.begin(), _sizes.size_bytes());
4343
}
4444
TensorDescriptor(
45-
luisa::span<size_t const> dimensions,
45+
luisa::span<uint64_t const> dimensions,
4646
TensorElementType type) noexcept : _type(type) {
4747
_sizes.resize_uninitialized(dimensions.size());
4848
std::memcpy(_sizes.data(), dimensions.data(), _sizes.size_bytes());
@@ -117,7 +117,7 @@ struct TensorKernel {
117117
_kernel_ptr = _tensor_interface->compile_kernel(std::move(r));
118118
}
119119
template<typename... ExecArgs>
120-
requires(sizeof...(Args) == sizeof...(ExecArgs) && !(detail::AnyMap<luisa::compute::is_buffer_or_view, true>::template Run<std::remove_cvref_t<ExecArgs>...>()))
120+
requires(sizeof...(Args) == sizeof...(ExecArgs) && (sizeof...(Args) == 0 || !(detail::AnyMap<luisa::compute::is_buffer_or_view, true>::template Run<std::remove_cvref_t<ExecArgs>...>())))
121121
luisa::vector<BufferView<float4>> execute(CommandList &cmdlist, ExecArgs const &...args) noexcept {
122122
auto to_desc = []<typename T>(T const &arg) -> Argument::Buffer {
123123
if constexpr (is_buffer_view_v<T>) {
@@ -132,10 +132,14 @@ struct TensorKernel {
132132
arg.size_bytes()};
133133
}
134134
};
135-
auto tensors = {to_desc(args)...};
136135
luisa::vector<BufferCreationInfo> buffer_handles;
137136
luisa::vector<BufferView<float4>> buffers;
138-
_tensor_interface->execute(cmdlist, _kernel_ptr, luisa::span<Argument::Buffer const>{tensors.begin(), tensors.size()}, buffer_handles);
137+
if constexpr (sizeof...(Args) > 0) {
138+
auto tensors = {to_desc(args)...};
139+
_tensor_interface->execute(cmdlist, _kernel_ptr, luisa::span<Argument::Buffer const>{tensors.begin(), tensors.size()}, buffer_handles);
140+
} else {
141+
_tensor_interface->execute(cmdlist, _kernel_ptr, luisa::span<Argument::Buffer const>{}, buffer_handles);
142+
}
139143
buffers.reserve(buffer_handles.size());
140144
for (auto &i : buffer_handles) {
141145
buffers.emplace_back(

include/luisa/tensor/tensor.h

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,23 @@ enum struct TensorElementType : uint8_t {
1212
Float32,
1313
Float64,
1414
};
15-
constexpr size_t tensor_element_size(TensorElementType type) {
15+
template<typename T>
16+
constexpr TensorElementType get_tensor_elem_type() {
17+
if constexpr (std::is_same_v<T, half>) {
18+
return TensorElementType::Float16;
19+
} else if constexpr (std::is_same_v<T, float>) {
20+
return TensorElementType::Float32;
21+
} else if constexpr (std::is_same_v<T, double>) {
22+
return TensorElementType::Float64;
23+
} else {
24+
static_assert(luisa::always_false_v<T>, "Bad type.");
25+
}
26+
}
27+
template<typename T>
28+
concept valid_tensor_elem_type = requires() {
29+
get_tensor_elem_type<T>();
30+
};
31+
constexpr uint64_t tensor_element_size(TensorElementType type) {
1632
switch (type) {
1733
case TensorElementType::Float16:
1834
return 2;
@@ -24,7 +40,7 @@ constexpr size_t tensor_element_size(TensorElementType type) {
2440
return 0;
2541
}
2642
}
27-
constexpr size_t tensor_element_align(TensorElementType type) {
43+
constexpr uint64_t tensor_element_align(TensorElementType type) {
2844
switch (type) {
2945
case TensorElementType::Float16:
3046
return 2;
@@ -37,13 +53,13 @@ constexpr size_t tensor_element_align(TensorElementType type) {
3753
}
3854
}
3955
class LC_TENSOR_API TensorData {
40-
luisa::span<size_t const> _sizes;
56+
luisa::span<uint64_t const> _sizes;
4157
TensorElementType _type;
4258
uint64_t _idx;
43-
size_t _size_bytes;
59+
uint64_t _size_bytes;
4460

4561
public:
46-
TensorData(luisa::span<size_t const> sizes,
62+
TensorData(luisa::span<uint64_t const> sizes,
4763
TensorElementType element_type,
4864
uint64_t uid) noexcept;
4965
TensorData(TensorData &&rhs) noexcept;
@@ -53,14 +69,14 @@ class LC_TENSOR_API TensorData {
5369
[[nodiscard]] uint64_t idx() const noexcept {
5470
return _idx;
5571
}
56-
[[nodiscard]] size_t get_size(uint dimension) const noexcept {
72+
[[nodiscard]] uint64_t get_size(uint dimension) const noexcept {
5773
if (dimension >= _sizes.size()) return 1;
5874
return _sizes[dimension];
5975
}
60-
[[nodiscard]] size_t dimension() const noexcept {
76+
[[nodiscard]] uint64_t dimension() const noexcept {
6177
return _sizes.size();
6278
}
63-
[[nodiscard]] size_t size_bytes() const noexcept {
79+
[[nodiscard]] uint64_t size_bytes() const noexcept {
6480
return _size_bytes;
6581
}
6682
[[nodiscard]] TensorElementType element_type() const noexcept {
@@ -75,6 +91,7 @@ class LC_TENSOR_API Tensor {
7591

7692
TensorData *_data;
7793
bool _contained;
94+
[[nodiscard]] void _create(TensorElementType element_type, luisa::span<const uint64_t> sizes, Argument::Buffer buffer) noexcept;
7895

7996
public:
8097
explicit Tensor(TensorData *data,
@@ -87,12 +104,41 @@ class LC_TENSOR_API Tensor {
87104
std::destroy_at(this);
88105
new (this) Tensor(std::move(rhs));
89106
}
107+
template<typename T>
108+
requires(luisa::compute::is_buffer_or_view_v<T> && valid_tensor_elem_type<luisa::compute::buffer_element_t<T>>)
109+
Tensor(T const &t, luisa::span<const uint64_t> sizes) noexcept {
110+
Argument::Buffer bf{
111+
.handle = t.handle(),
112+
.size = t.size_bytes()};
113+
if constexpr (luisa::compute::is_buffer_v<T>) {
114+
bf.offset = 0;
115+
} else {
116+
bf.offset = t.offset_bytes();
117+
}
118+
_create(get_tensor_elem_type<luisa::compute::buffer_element_t<T>>(), sizes, bf);
119+
}
120+
template<typename T>
121+
requires(luisa::compute::is_buffer_or_view_v<T> && valid_tensor_elem_type<luisa::compute::buffer_element_t<T>>)
122+
Tensor(T const &t, std::initializer_list<const uint64_t> sizes) noexcept
123+
: Tensor(t, luisa::span{sizes.begin(), sizes.size()}) {
124+
}
125+
template<typename T>
126+
requires(luisa::compute::is_buffer_or_view_v<T> && valid_tensor_elem_type<luisa::compute::buffer_element_t<T>>)
127+
Tensor(T const &t) noexcept {
128+
Argument::Buffer bf{
129+
.handle = t.handle(),
130+
.size = t.size_bytes()};
131+
if constexpr (luisa::compute::is_buffer_v<T>) {
132+
bf.offset = 0;
133+
} else {
134+
bf.offset = t.offset_bytes();
135+
}
136+
uint64_t size = t.size_bytes();
137+
_create(get_tensor_elem_type<luisa::compute::buffer_element_t<T>>(), {&size, 1});
138+
}
90139
[[nodiscard]] auto data() const noexcept { return _data; }
91140
void dispose() noexcept;
92141

93-
[[nodiscard]] static Tensor one(TensorElementType element_type, luisa::span<const size_t> sizes) noexcept;
94-
[[nodiscard]] static Tensor zero(TensorElementType element_type, luisa::span<const size_t> sizes) noexcept;
95-
96142
[[nodiscard]] static Tensor matmul(
97143
Tensor const &lhs,
98144
Tensor const &rhs,
@@ -141,17 +187,17 @@ class LC_TENSOR_API Tensor {
141187
// bool _requires_grad = false;
142188
// bool _reserve_memory = false;
143189
// bool _dirty = false;
144-
// std::array<size_t, 3> _shape;
145-
// std::array<size_t, 3> _stride;
190+
// std::array<uint64_t, 3> _shape;
191+
// std::array<uint64_t, 3> _stride;
146192
// luisa::optional<Buffer<T>> _buffer;
147193
// luisa::optional<Var<T>> _var;
148194
// public:
149-
// using shape_type = std::array<size_t, 3>;
195+
// using shape_type = std::array<uint64_t, 3>;
150196
// using value_type = T;
151197

152198
// private:
153-
// static size_t compute_size(shape_type s) {
154-
// size_t size = 1;
199+
// static uint64_t compute_size(shape_type s) {
200+
// uint64_t size = 1;
155201
// for (auto i : s) {
156202
// size *= i;
157203
// }
@@ -186,14 +232,14 @@ class LC_TENSOR_API Tensor {
186232
// return DTensor<T>{device};
187233
// }
188234

189-
// // template<size_t... Is>
235+
// // template<uint64_t... Is>
190236
// // [[nodiscard]] Tensor<T, Dim + sizeof...(Is)> repeat(Is...) {
191237
// // // TODO: implement
192238
// // return Tensor<T, Dim>{device};
193239
// // }
194240
// };
195241

196-
// template<class R, size_t Dim, class... Ts>
242+
// template<class R, uint64_t Dim, class... Ts>
197243
// Tensor<R, Dim> map(const Tensor<Ts, Dim> &... ts) noexcept {
198244
// // TODO: implement
199245
// }

include/luisa/tensor/tensor_builder.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,22 @@ class LC_TENSOR_API TensorBuilder {
99
luisa::Pool<TensorData, false, false> _tensor_pool;
1010
luisa::vector<TensorData *> _allocated_tensor;
1111
luisa::vector<TensorData *> _arguments;
12+
luisa::vector<std::pair<TensorData *, Argument::Buffer>> _captured_arguments;
1213
luisa::vector<TensorData *> _outputs;
1314
luisa::vector<ScopeExpr *> _expr_stack;
1415
struct Stack {
1516
luisa::unique_ptr<std::byte> ptr;
16-
size_t size;
17-
size_t offset;
17+
uint64_t size;
18+
uint64_t offset;
1819
Stack(
1920
std::byte *ptr,
20-
size_t size,
21-
size_t offset) noexcept
21+
uint64_t size,
22+
uint64_t offset) noexcept
2223
: ptr(ptr), size(size), offset(offset) {}
2324
};
2425
luisa::vector<Stack> _stack_allocator;
2526
ScopeExpr _root_expr;
26-
size_t _allocated_absolute_size = 0;
27+
uint64_t _allocated_absolute_size = 0;
2728
void deallocate_tensor(TensorData *tensor) noexcept;
2829

2930
public:
@@ -35,24 +36,30 @@ class LC_TENSOR_API TensorBuilder {
3536
[[nodiscard]] luisa::span<TensorData *const> arguments() const noexcept {
3637
return {_arguments};
3738
}
39+
[[nodiscard]] luisa::span<std::pair<TensorData *, Argument::Buffer> const> captured_arguments() const noexcept {
40+
return {_captured_arguments};
41+
}
3842
[[nodiscard]] luisa::span<TensorData *const> outputs() const noexcept {
3943
return {_outputs};
4044
}
4145
[[nodiscard]] auto current_scope() noexcept { return _expr_stack.back(); }
4246
static void set_thd_local(TensorBuilder *builder);
4347
static TensorBuilder *get_thd_local();
44-
void *allocate_stack(size_t size_bytes, size_t alignment = 8) noexcept;
48+
void *allocate_stack(uint64_t size_bytes, uint64_t alignment = 8) noexcept;
4549
template<typename T>
4650
requires(std::is_trivially_destructible_v<T>)
47-
T *allocate_array(size_t size) noexcept {
51+
T *allocate_array(uint64_t size) noexcept {
4852
return std::launder(reinterpret_cast<T *>(allocate_stack(size * sizeof(T), alignof(T))));
4953
}
5054
TensorData *allocate_tensor(
51-
luisa::span<size_t const> sizes,
55+
luisa::span<uint64_t const> sizes,
5256
TensorElementType element_type) noexcept;
5357
void push_argument(TensorData *data) noexcept {
5458
_arguments.emplace_back(data);
5559
}
60+
void push_captured_arguments(TensorData *data, Argument::Buffer buffer) noexcept {
61+
_captured_arguments.emplace_back(data, buffer);
62+
}
5663
void push_output(TensorData *data) noexcept {
5764
_outputs.emplace_back(data);
5865
}

src/tensor/expression.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ GEMMExpr::GEMMExpr(
6565
if (lhs_tensor->element_type() != rhs_tensor->element_type()) [[unlikely]] {
6666
LUISA_ERROR("Element type mismatch.");
6767
}
68-
auto sizes = {(size_t)desire_out_size.x, (size_t)desire_out_size.y, (size_t)group_size};
68+
auto sizes = {(uint64_t)desire_out_size.x, (uint64_t)desire_out_size.y, (uint64_t)group_size};
6969
output_tensor = TensorBuilder::get_thd_local()->allocate_tensor(sizes, lhs_tensor->element_type());
7070
}
7171

@@ -93,12 +93,12 @@ ConvExpr::ConvExpr(
9393
vstd::push_back_all(this->dilation, dilation);
9494
vstd::push_back_all(this->start_paddings, start_paddings);
9595
vstd::push_back_all(this->end_paddings, end_paddings);
96-
size_t filter_len = 1;
96+
uint64_t filter_len = 1;
9797
for (auto &i : filter_size) {
9898
filter_len *= i;
9999
}
100100
LUISA_ERROR("Weight tensor width {} must be same as input filter size {}", weight_tensor->get_size(0), filter_len);
101-
luisa::fixed_vector<size_t, 4> out_tensor_sizes;
101+
luisa::fixed_vector<uint64_t, 4> out_tensor_sizes;
102102
for (auto i : vstd::range(dim)) {
103103
if (start_paddings[i] > filter_size[i] ||
104104
end_paddings[i] > filter_size[i]) [[unlikely]] {
@@ -219,8 +219,7 @@ void Tensor::init_tensor(
219219
void (*disposer)(void *)) noexcept {
220220
TensorBuilder::get_thd_local()->current_scope()->allocate_expr<SetValueExpr>(
221221
input.data(),
222-
ptr,
223-
disposer);
222+
SetValueExpr::BinaryBlob{ptr, disposer});
224223
}
225224

226225
SetValueExpr::~SetValueExpr() noexcept {

0 commit comments

Comments
 (0)