Skip to content

Commit 1291e69

Browse files
committed
vulkan: implement TRUNC
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent 7942c39 commit 1291e69

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ struct vk_device_struct {
669669
vk_pipeline pipeline_round[2];
670670
vk_pipeline pipeline_ceil[2];
671671
vk_pipeline pipeline_floor[2];
672+
vk_pipeline pipeline_trunc[2];
672673

673674
vk_pipeline pipeline_add1_f16_f16;
674675
vk_pipeline pipeline_add1_f16_f32;
@@ -3838,6 +3839,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38383839
CREATE_UNARY(round)
38393840
CREATE_UNARY(ceil)
38403841
CREATE_UNARY(floor)
3842+
CREATE_UNARY(trunc)
38413843
#undef CREATE_UNARY
38423844

38433845
#define CREATE_UNARY_RTE(name) \
@@ -8272,6 +8274,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82728274
return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
82738275
case GGML_UNARY_OP_FLOOR:
82748276
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
8277+
case GGML_UNARY_OP_TRUNC:
8278+
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
82758279
default:
82768280
break;
82778281
}
@@ -11298,6 +11302,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1129811302
case GGML_UNARY_OP_ROUND:
1129911303
case GGML_UNARY_OP_CEIL:
1130011304
case GGML_UNARY_OP_FLOOR:
11305+
case GGML_UNARY_OP_TRUNC:
1130111306
break;
1130211307
default:
1130311308
return false;
@@ -11655,6 +11660,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1165511660
case GGML_UNARY_OP_ROUND:
1165611661
case GGML_UNARY_OP_CEIL:
1165711662
case GGML_UNARY_OP_FLOOR:
11663+
case GGML_UNARY_OP_TRUNC:
1165811664
ggml_vk_unary(ctx, compute_ctx, src0, node);
1165911665
break;
1166011666
default:
@@ -11936,6 +11942,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1193611942
case GGML_UNARY_OP_ROUND:
1193711943
case GGML_UNARY_OP_CEIL:
1193811944
case GGML_UNARY_OP_FLOOR:
11945+
case GGML_UNARY_OP_TRUNC:
1193911946
buf = tensor->buffer;
1194011947
break;
1194111948
default:
@@ -13543,6 +13550,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1354313550
case GGML_UNARY_OP_ROUND:
1354413551
case GGML_UNARY_OP_CEIL:
1354513552
case GGML_UNARY_OP_FLOOR:
13553+
case GGML_UNARY_OP_TRUNC:
1354613554
return ggml_is_contiguous(op->src[0]) &&
1354713555
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1354813556
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14470,6 +14478,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1447014478
case GGML_UNARY_OP_FLOOR:
1447114479
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
1447214480
break;
14481+
case GGML_UNARY_OP_TRUNC:
14482+
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
14483+
break;
1447314484
default:
1447414485
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1447514486
GGML_ABORT("fatal error");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
data_d[i] = D_TYPE(trunc(x));
22+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,8 @@ void process_shaders() {
859859
string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
860860
string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
861861
string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
862+
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
863+
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
862864

863865
for (auto rte : {false, true}) {
864866
std::string suffix = rte ? "_rte" : "";

0 commit comments

Comments
 (0)