@@ -669,6 +669,7 @@ struct vk_device_struct {
669669 vk_pipeline pipeline_round[2];
670670 vk_pipeline pipeline_ceil[2];
671671 vk_pipeline pipeline_floor[2];
672+ vk_pipeline pipeline_trunc[2];
672673
673674 vk_pipeline pipeline_add1_f16_f16;
674675 vk_pipeline pipeline_add1_f16_f32;
@@ -3838,6 +3839,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38383839 CREATE_UNARY(round)
38393840 CREATE_UNARY(ceil)
38403841 CREATE_UNARY(floor)
3842+ CREATE_UNARY(trunc)
38413843#undef CREATE_UNARY
38423844
38433845#define CREATE_UNARY_RTE(name) \
@@ -8272,6 +8274,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82728274 return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
82738275 case GGML_UNARY_OP_FLOOR:
82748276 return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
8277+ case GGML_UNARY_OP_TRUNC:
8278+ return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
82758279 default:
82768280 break;
82778281 }
@@ -11298,6 +11302,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1129811302 case GGML_UNARY_OP_ROUND:
1129911303 case GGML_UNARY_OP_CEIL:
1130011304 case GGML_UNARY_OP_FLOOR:
11305+ case GGML_UNARY_OP_TRUNC:
1130111306 break;
1130211307 default:
1130311308 return false;
@@ -11655,6 +11660,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1165511660 case GGML_UNARY_OP_ROUND:
1165611661 case GGML_UNARY_OP_CEIL:
1165711662 case GGML_UNARY_OP_FLOOR:
11663+ case GGML_UNARY_OP_TRUNC:
1165811664 ggml_vk_unary(ctx, compute_ctx, src0, node);
1165911665 break;
1166011666 default:
@@ -11936,6 +11942,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1193611942 case GGML_UNARY_OP_ROUND:
1193711943 case GGML_UNARY_OP_CEIL:
1193811944 case GGML_UNARY_OP_FLOOR:
11945+ case GGML_UNARY_OP_TRUNC:
1193911946 buf = tensor->buffer;
1194011947 break;
1194111948 default:
@@ -13543,6 +13550,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1354313550 case GGML_UNARY_OP_ROUND:
1354413551 case GGML_UNARY_OP_CEIL:
1354513552 case GGML_UNARY_OP_FLOOR:
13553+ case GGML_UNARY_OP_TRUNC:
1354613554 return ggml_is_contiguous(op->src[0]) &&
1354713555 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1354813556 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14470,6 +14478,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1447014478 case GGML_UNARY_OP_FLOOR:
1447114479 tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
1447214480 break;
14481+ case GGML_UNARY_OP_TRUNC:
14482+ tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
14483+ break;
1447314484 default:
1447414485 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1447514486 GGML_ABORT("fatal error");
0 commit comments