Skip to content

Commit 637558c

Browse files
committed
vulkan: implement ROUND
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent fee74c9 commit 637558c

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,7 @@ struct vk_device_struct {
666666
vk_pipeline pipeline_abs[2];
667667
vk_pipeline pipeline_softplus[2];
668668
vk_pipeline pipeline_step[2];
669+
vk_pipeline pipeline_round[2];
669670

670671
vk_pipeline pipeline_add1_f16_f16;
671672
vk_pipeline pipeline_add1_f16_f32;
@@ -3838,6 +3839,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38383839
CREATE_UNARY(abs)
38393840
CREATE_UNARY(softplus)
38403841
CREATE_UNARY(step)
3842+
CREATE_UNARY(round)
38413843
#undef CREATE_UNARY
38423844

38433845
#define CREATE_UNARY_RTE(name) \
@@ -8266,6 +8268,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82668268
return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
82678269
case GGML_UNARY_OP_STEP:
82688270
return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
8271+
case GGML_UNARY_OP_ROUND:
8272+
return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
82698273
default:
82708274
break;
82718275
}
@@ -11289,6 +11293,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1128911293
case GGML_UNARY_OP_ABS:
1129011294
case GGML_UNARY_OP_SOFTPLUS:
1129111295
case GGML_UNARY_OP_STEP:
11296+
case GGML_UNARY_OP_ROUND:
1129211297
break;
1129311298
default:
1129411299
return false;
@@ -11643,6 +11648,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1164311648
case GGML_UNARY_OP_ABS:
1164411649
case GGML_UNARY_OP_SOFTPLUS:
1164511650
case GGML_UNARY_OP_STEP:
11651+
case GGML_UNARY_OP_ROUND:
1164611652
ggml_vk_unary(ctx, compute_ctx, src0, node);
1164711653
break;
1164811654
default:
@@ -11921,6 +11927,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1192111927
case GGML_UNARY_OP_ABS:
1192211928
case GGML_UNARY_OP_SOFTPLUS:
1192311929
case GGML_UNARY_OP_STEP:
11930+
case GGML_UNARY_OP_ROUND:
1192411931
buf = tensor->buffer;
1192511932
break;
1192611933
default:
@@ -13525,6 +13532,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1352513532
case GGML_UNARY_OP_ABS:
1352613533
case GGML_UNARY_OP_SOFTPLUS:
1352713534
case GGML_UNARY_OP_STEP:
13535+
case GGML_UNARY_OP_ROUND:
1352813536
return ggml_is_contiguous(op->src[0]) &&
1352913537
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1353013538
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14444,6 +14452,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1444414452
case GGML_UNARY_OP_STEP:
1444514453
tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
1444614454
break;
14455+
case GGML_UNARY_OP_ROUND:
14456+
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
14457+
break;
1444714458
default:
1444814459
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1444914460
GGML_ABORT("fatal error");
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
float result;
22+
// Round halfway cases away from zero as roundf does.
23+
if (x >= 0.0) {
24+
result = floor(x + 0.5);
25+
} else {
26+
result = ceil(x - 0.5);
27+
}
28+
data_d[i] = D_TYPE(result);
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,8 @@ void process_shaders() {
853853
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
854854
string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
855855
string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
856+
string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
857+
string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
856858

857859
for (auto rte : {false, true}) {
858860
std::string suffix = rte ? "_rte" : "";

0 commit comments

Comments
 (0)