Skip to content

Commit fee74c9

Browse files
committed
vulkan: implement STEP
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent 32ff590 commit fee74c9

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ struct vk_device_struct {
665665
vk_pipeline pipeline_hardswish[2];
666666
vk_pipeline pipeline_abs[2];
667667
vk_pipeline pipeline_softplus[2];
668+
vk_pipeline pipeline_step[2];
668669

669670
vk_pipeline pipeline_add1_f16_f16;
670671
vk_pipeline pipeline_add1_f16_f32;
@@ -3836,6 +3837,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38363837
CREATE_UNARY(hardswish)
38373838
CREATE_UNARY(abs)
38383839
CREATE_UNARY(softplus)
3840+
CREATE_UNARY(step)
38393841
#undef CREATE_UNARY
38403842

38413843
#define CREATE_UNARY_RTE(name) \
@@ -8262,6 +8264,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82628264
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
82638265
case GGML_UNARY_OP_SOFTPLUS:
82648266
return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
8267+
case GGML_UNARY_OP_STEP:
8268+
return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
82658269
default:
82668270
break;
82678271
}
@@ -11284,6 +11288,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1128411288
case GGML_UNARY_OP_HARDSWISH:
1128511289
case GGML_UNARY_OP_ABS:
1128611290
case GGML_UNARY_OP_SOFTPLUS:
11291+
case GGML_UNARY_OP_STEP:
1128711292
break;
1128811293
default:
1128911294
return false;
@@ -11637,6 +11642,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1163711642
case GGML_UNARY_OP_HARDSWISH:
1163811643
case GGML_UNARY_OP_ABS:
1163911644
case GGML_UNARY_OP_SOFTPLUS:
11645+
case GGML_UNARY_OP_STEP:
1164011646
ggml_vk_unary(ctx, compute_ctx, src0, node);
1164111647
break;
1164211648
default:
@@ -11914,6 +11920,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1191411920
case GGML_UNARY_OP_HARDSWISH:
1191511921
case GGML_UNARY_OP_ABS:
1191611922
case GGML_UNARY_OP_SOFTPLUS:
11923+
case GGML_UNARY_OP_STEP:
1191711924
buf = tensor->buffer;
1191811925
break;
1191911926
default:
@@ -13517,6 +13524,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1351713524
case GGML_UNARY_OP_HARDSWISH:
1351813525
case GGML_UNARY_OP_ABS:
1351913526
case GGML_UNARY_OP_SOFTPLUS:
13527+
case GGML_UNARY_OP_STEP:
1352013528
return ggml_is_contiguous(op->src[0]) &&
1352113529
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1352213530
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14433,6 +14441,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1443314441
case GGML_UNARY_OP_SOFTPLUS:
1443414442
tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
1443514443
break;
14444+
case GGML_UNARY_OP_STEP:
14445+
tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
14446+
break;
1443614447
default:
1443714448
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1443814449
GGML_ABORT("fatal error");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f);
22+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,8 @@ void process_shaders() {
851851
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
852852
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
853853
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
854+
string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
855+
string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
854856

855857
for (auto rte : {false, true}) {
856858
std::string suffix = rte ? "_rte" : "";

0 commit comments

Comments
 (0)