Skip to content

Commit 32ff590

Browse files
committed
vulkan: implement SOFTPLUS
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent 71f9ff7 commit 32ff590

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ struct vk_device_struct {
664664
vk_pipeline pipeline_hardsigmoid[2];
665665
vk_pipeline pipeline_hardswish[2];
666666
vk_pipeline pipeline_abs[2];
667+
vk_pipeline pipeline_softplus[2];
667668

668669
vk_pipeline pipeline_add1_f16_f16;
669670
vk_pipeline pipeline_add1_f16_f32;
@@ -3834,6 +3835,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38343835
CREATE_UNARY(hardsigmoid)
38353836
CREATE_UNARY(hardswish)
38363837
CREATE_UNARY(abs)
3838+
CREATE_UNARY(softplus)
38373839
#undef CREATE_UNARY
38383840

38393841
#define CREATE_UNARY_RTE(name) \
@@ -8258,6 +8260,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82588260
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
82598261
case GGML_UNARY_OP_ABS:
82608262
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
8263+
case GGML_UNARY_OP_SOFTPLUS:
8264+
return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
82618265
default:
82628266
break;
82638267
}
@@ -11279,6 +11283,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1127911283
case GGML_UNARY_OP_HARDSIGMOID:
1128011284
case GGML_UNARY_OP_HARDSWISH:
1128111285
case GGML_UNARY_OP_ABS:
11286+
case GGML_UNARY_OP_SOFTPLUS:
1128211287
break;
1128311288
default:
1128411289
return false;
@@ -11631,6 +11636,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1163111636
case GGML_UNARY_OP_HARDSIGMOID:
1163211637
case GGML_UNARY_OP_HARDSWISH:
1163311638
case GGML_UNARY_OP_ABS:
11639+
case GGML_UNARY_OP_SOFTPLUS:
1163411640
ggml_vk_unary(ctx, compute_ctx, src0, node);
1163511641
break;
1163611642
default:
@@ -11907,6 +11913,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1190711913
case GGML_UNARY_OP_HARDSIGMOID:
1190811914
case GGML_UNARY_OP_HARDSWISH:
1190911915
case GGML_UNARY_OP_ABS:
11916+
case GGML_UNARY_OP_SOFTPLUS:
1191011917
buf = tensor->buffer;
1191111918
break;
1191211919
default:
@@ -13509,6 +13516,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1350913516
case GGML_UNARY_OP_HARDSIGMOID:
1351013517
case GGML_UNARY_OP_HARDSWISH:
1351113518
case GGML_UNARY_OP_ABS:
13519+
case GGML_UNARY_OP_SOFTPLUS:
1351213520
return ggml_is_contiguous(op->src[0]) &&
1351313521
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1351413522
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14422,6 +14430,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1442214430
case GGML_UNARY_OP_ABS:
1442314431
tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
1442414432
break;
14433+
case GGML_UNARY_OP_SOFTPLUS:
14434+
tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
14435+
break;
1442514436
default:
1442614437
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1442714438
GGML_ABORT("fatal error");
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
const float result = (x > 20.0f) ? x : log(1.0f + exp(x));
22+
data_d[i] = D_TYPE(result);
23+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,10 @@ void process_shaders() {
842842
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
843843
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
844844
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
845+
846+
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
847+
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
848+
845849
string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
846850
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
847851
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

0 commit comments

Comments
 (0)