Skip to content

Commit 703a171

Browse files
committed
softmax shader
1 parent 191343c commit 703a171

File tree

9 files changed

+268
-55
lines changed

9 files changed

+268
-55
lines changed

include/luisa/tensor/expression.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace luisa::compute {
1212
LCGRandExpr, \
1313
GEMMExpr, \
1414
ConvExpr, \
15-
TestExpr
15+
SoftmaxExpr
1616
// clang-format on
1717

1818
#define LUISA_MAKE_TENSOR_EXPR_DECL(CMD) class CMD;
@@ -173,19 +173,14 @@ class LC_TENSOR_API LUISA_TENSOR_EXPR_CLASS_INHERIT(ConvExpr) {
173173
static luisa::variant<TensorData *, luisa::string> get_output_tensor(TensorData *input_tensor, TensorData *filter_tensor) noexcept;
174174
};
175175

176-
class LC_TENSOR_API LUISA_TENSOR_EXPR_CLASS_INHERIT(TestExpr) {
176+
class LC_TENSOR_API LUISA_TENSOR_EXPR_CLASS_INHERIT(SoftmaxExpr) {
177177
public:
178178
TensorData *input;
179-
TensorData *output;
180-
luisa::string_view name;
181-
TestExpr(
179+
SoftmaxExpr(
182180
uint64_t idx,
183-
TensorData *input,
184-
TensorData *output,
185-
luisa::string_view name) noexcept;
181+
TensorData *input) noexcept;
186182
void get_tensors(vstd::FuncRef<void(TensorData *, Usage usage)> callback) noexcept override {
187-
callback(input, Usage::READ);
188-
callback(output, Usage::WRITE);
183+
callback(input, Usage::WRITE);
189184
}
190185
};
191186

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
#include "i_tensor_expr_executor.h"
3+
#include <luisa/core/stl/variant.h>
4+
namespace luisa::compute {
5+
struct SoftmaxImpl : ITensorExprExecutor {
6+
SoftmaxExpr *expr;
7+
struct LargeBatchShader {
8+
ShaderManager::ShaderDispatch sum;
9+
ShaderManager::ShaderDispatch final;
10+
};
11+
luisa::variant<
12+
LargeBatchShader,
13+
ShaderManager::ShaderDispatch>
14+
shaders;
15+
16+
SoftmaxImpl(
17+
DeviceInterface *device,
18+
ShaderManager *shader_manager,
19+
SoftmaxExpr *expr);
20+
~SoftmaxImpl();
21+
void execute(FallbackTensorCallback *callback, CommandList &cmdlist) const override;
22+
};
23+
};// namespace luisa::compute

include/luisa/tensor/pass/shader_manager.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <luisa/tensor/expression.h>
44
#include <luisa/core/stl/unordered_map.h>
55
#include <luisa/vstl/md5.h>
6+
#include <luisa/vstl/hash_map.h>
67
namespace luisa::compute {
78
class LC_TENSOR_API ShaderManager {
89
public:
@@ -11,11 +12,11 @@ class LC_TENSOR_API ShaderManager {
1112
std::array<uint, 4> user_ids;
1213
};
1314
struct KeyHashEqual {
14-
size_t operator()(Key const& k) const noexcept {
15+
size_t operator()(Key const &k) const noexcept {
1516
return luisa::hash64(&k, sizeof(Key), luisa::hash64_default_seed);
1617
}
17-
bool operator()(Key const& a, Key const& b) const noexcept {
18-
return std::memcmp(&a, &b, sizeof(Key)) == 0;
18+
int operator()(Key const &a, Key const &b) const noexcept {
19+
return std::memcmp(&a, &b, sizeof(Key));
1920
}
2021
};
2122
struct ShaderDispatch {
@@ -24,25 +25,30 @@ class LC_TENSOR_API ShaderManager {
2425
size_t uniform_size;
2526
};
2627
private:
27-
luisa::unordered_map<Key, ShaderDispatch, KeyHashEqual, KeyHashEqual> _shaders;
28+
vstd::HashMap<Key, std::pair<luisa::spin_mutex, ShaderDispatch>, KeyHashEqual, KeyHashEqual> _shaders;
29+
luisa::spin_mutex global_mtx;
2830
DeviceInterface *_device;
2931
public:
3032
ShaderManager(DeviceInterface *device) noexcept;
3133
template<typename Lambda>
3234
requires(std::is_invocable_r_v<ShaderDispatch, Lambda>)
33-
ShaderDispatch add_shader(
35+
ShaderDispatch const &add_shader(
3436
TensorExpr::Tag tag,
3537
vstd::MD5 hash,
3638
Lambda &&lambda) noexcept {
3739
std::array<uint, 4> arr;
3840
static_assert(sizeof(arr) == sizeof(vstd::MD5));
3941
std::memcpy(arr.data(), &hash, sizeof(vstd::MD5));
40-
auto iter = _shaders.try_emplace(Key{tag, arr}, luisa::lazy_construct([&]() -> ShaderDispatch {
41-
auto handle = lambda();
42-
LUISA_ASSERT(handle.shader_handle != invalid_resource_handle);
43-
return handle;
44-
}));
45-
return iter.first->second;
42+
global_mtx.lock();
43+
auto iter = _shaders.try_emplace(Key{tag, arr});
44+
auto &v = iter.first;
45+
global_mtx.unlock();
46+
std::lock_guard lck{v.value().first};
47+
if (iter.second) {
48+
v.value().second = lambda();
49+
LUISA_ASSERT(v.value().second.shader_handle != invalid_resource_handle);
50+
}
51+
return v.value().second;
4652
}
4753
~ShaderManager() noexcept;
4854
};

include/luisa/tensor/tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <luisa/tensor/fused_activation.h>
88
namespace luisa::compute {
99

10-
enum struct TensorElementType : uint8_t {
10+
enum struct TensorElementType : uint32_t {
1111
Float16,
1212
Float32,
1313
Float64,

src/tensor/expression.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,11 @@ ConvExpr::ConvExpr(
119119
out_tensor = TensorBuilder::get_thd_local()->allocate_tensor(out_tensor_sizes, out_type);
120120
}
121121

122-
TestExpr::TestExpr(
122+
SoftmaxExpr::SoftmaxExpr(
123123
uint64_t idx,
124-
TensorData *input,
125-
TensorData *output,
126-
luisa::string_view name) noexcept
124+
TensorData *input) noexcept
127125
: BaseClass(idx),
128-
input(input),
129-
output(output),
130-
name(name) {
126+
input(input) {
131127
}
132128

133129
Tensor Tensor::matmul(

src/tensor/fallback/matmul_impl.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,36 +108,28 @@ MatMulImpl::MatMulImpl(
108108
.min_batch_size = min_batch_size,
109109
.batch = uint(lhs_batch ? 1 : 0) | uint(rhs_batch ? 2 : 0),
110110
.activation = expr->fused_activation};
111+
auto bind_shader = [&]<typename T>() {
112+
auto disp_pack = shader_manager->add_shader(
113+
TensorExpr::Tag::EGEMMExpr,
114+
vstd::MD5{{reinterpret_cast<uint8_t const *>(&key), sizeof(key)}},
115+
[&]() {
116+
auto disp_pack = gemm_detail::gemm_kernel<T>(lhs_matrix_size, rhs_matrix_size, min_batch_size, lhs_batch, rhs_batch, expr->fused_activation);
117+
auto create_info = device->create_shader(ShaderOption{}, Function{disp_pack.kernel.function().get()});
118+
return ShaderManager::ShaderDispatch{
119+
create_info.handle,
120+
disp_pack.dispatch_size,
121+
ShaderDispatchCmdEncoder::compute_uniform_size(disp_pack.kernel.function()->unbound_arguments())};
122+
});
123+
set_disp_pack(disp_pack);
124+
};
111125
switch (expr->lhs_tensor->element_type()) {
112126
case TensorElementType::Float16: {
113127
key.type = 0;
114-
auto disp_pack = shader_manager->add_shader(
115-
TensorExpr::Tag::EGEMMExpr,
116-
vstd::MD5{{reinterpret_cast<uint8_t const *>(&key), sizeof(key)}},
117-
[&]() {
118-
auto disp_pack = gemm_detail::gemm_kernel<half>(lhs_matrix_size, rhs_matrix_size, min_batch_size, lhs_batch, rhs_batch, expr->fused_activation);
119-
auto create_info = device->create_shader(ShaderOption{}, Function{disp_pack.kernel.function().get()});
120-
return ShaderManager::ShaderDispatch{
121-
create_info.handle,
122-
disp_pack.dispatch_size,
123-
ShaderDispatchCmdEncoder::compute_uniform_size(disp_pack.kernel.function()->unbound_arguments())};
124-
});
125-
set_disp_pack(disp_pack);
128+
bind_shader.template operator()<half>();
126129
} break;
127130
case TensorElementType::Float32: {
128131
key.type = 1;
129-
auto disp_pack = shader_manager->add_shader(
130-
TensorExpr::Tag::EGEMMExpr,
131-
vstd::MD5{{reinterpret_cast<uint8_t const *>(&key), sizeof(key)}},
132-
[&]() {
133-
auto disp_pack = gemm_detail::gemm_kernel<float>(lhs_matrix_size, rhs_matrix_size, min_batch_size, lhs_batch, rhs_batch, expr->fused_activation);
134-
auto create_info = device->create_shader(ShaderOption{}, Function{disp_pack.kernel.function().get()});
135-
return ShaderManager::ShaderDispatch{
136-
create_info.handle,
137-
disp_pack.dispatch_size,
138-
ShaderDispatchCmdEncoder::compute_uniform_size(disp_pack.kernel.function()->unbound_arguments())};
139-
});
140-
set_disp_pack(disp_pack);
132+
bind_shader.template operator()<float>();
141133
} break;
142134
default: {
143135
LUISA_ERROR("Only float 16 and float 32 supported.");

0 commit comments

Comments
 (0)