@@ -669,6 +669,7 @@ struct vk_device_struct {
669669 vk_pipeline pipeline_round[2];
670670 vk_pipeline pipeline_ceil[2];
671671 vk_pipeline pipeline_floor[2];
672+ vk_pipeline pipeline_trunc[2];
672673
673674 vk_pipeline pipeline_add1_f16_f16;
674675 vk_pipeline pipeline_add1_f16_f32;
@@ -3844,6 +3845,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38443845 CREATE_UNARY(round)
38453846 CREATE_UNARY(ceil)
38463847 CREATE_UNARY(floor)
3848+ CREATE_UNARY(trunc)
38473849#undef CREATE_UNARY
38483850
38493851#define CREATE_UNARY_RTE(name) \
@@ -8278,6 +8280,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82788280 return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
82798281 case GGML_UNARY_OP_FLOOR:
82808282 return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
8283+ case GGML_UNARY_OP_TRUNC:
8284+ return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
82818285 default:
82828286 break;
82838287 }
@@ -11304,6 +11308,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1130411308 case GGML_UNARY_OP_ROUND:
1130511309 case GGML_UNARY_OP_CEIL:
1130611310 case GGML_UNARY_OP_FLOOR:
11311+ case GGML_UNARY_OP_TRUNC:
1130711312 break;
1130811313 default:
1130911314 return false;
@@ -11661,6 +11666,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1166111666 case GGML_UNARY_OP_ROUND:
1166211667 case GGML_UNARY_OP_CEIL:
1166311668 case GGML_UNARY_OP_FLOOR:
11669+ case GGML_UNARY_OP_TRUNC:
1166411670 ggml_vk_unary(ctx, compute_ctx, src0, node);
1166511671 break;
1166611672 default:
@@ -11942,6 +11948,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1194211948 case GGML_UNARY_OP_ROUND:
1194311949 case GGML_UNARY_OP_CEIL:
1194411950 case GGML_UNARY_OP_FLOOR:
11951+ case GGML_UNARY_OP_TRUNC:
1194511952 buf = tensor->buffer;
1194611953 break;
1194711954 default:
@@ -13549,6 +13556,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1354913556 case GGML_UNARY_OP_ROUND:
1355013557 case GGML_UNARY_OP_CEIL:
1355113558 case GGML_UNARY_OP_FLOOR:
13559+ case GGML_UNARY_OP_TRUNC:
1355213560 return ggml_is_contiguous(op->src[0]) &&
1355313561 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1355413562 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14477,6 +14485,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1447714485 case GGML_UNARY_OP_FLOOR:
1447814486 tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
1447914487 break;
14488+ case GGML_UNARY_OP_TRUNC:
14489+ tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
14490+ break;
1448014491 default:
1448114492 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1448214493 GGML_ABORT("fatal error");
0 commit comments