@@ -666,6 +666,7 @@ struct vk_device_struct {
666666 vk_pipeline pipeline_abs[2];
667667 vk_pipeline pipeline_softplus[2];
668668 vk_pipeline pipeline_step[2];
669+ vk_pipeline pipeline_round[2];
669670
670671 vk_pipeline pipeline_add1_f16_f16;
671672 vk_pipeline pipeline_add1_f16_f32;
@@ -3838,6 +3839,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38383839 CREATE_UNARY(abs)
38393840 CREATE_UNARY(softplus)
38403841 CREATE_UNARY(step)
3842+ CREATE_UNARY(round)
38413843#undef CREATE_UNARY
38423844
38433845#define CREATE_UNARY_RTE(name) \
@@ -8266,6 +8268,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82668268 return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
82678269 case GGML_UNARY_OP_STEP:
82688270 return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
8271+ case GGML_UNARY_OP_ROUND:
8272+ return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
82698273 default:
82708274 break;
82718275 }
@@ -11289,6 +11293,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1128911293 case GGML_UNARY_OP_ABS:
1129011294 case GGML_UNARY_OP_SOFTPLUS:
1129111295 case GGML_UNARY_OP_STEP:
11296+ case GGML_UNARY_OP_ROUND:
1129211297 break;
1129311298 default:
1129411299 return false;
@@ -11643,6 +11648,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1164311648 case GGML_UNARY_OP_ABS:
1164411649 case GGML_UNARY_OP_SOFTPLUS:
1164511650 case GGML_UNARY_OP_STEP:
11651+ case GGML_UNARY_OP_ROUND:
1164611652 ggml_vk_unary(ctx, compute_ctx, src0, node);
1164711653 break;
1164811654 default:
@@ -11921,6 +11927,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1192111927 case GGML_UNARY_OP_ABS:
1192211928 case GGML_UNARY_OP_SOFTPLUS:
1192311929 case GGML_UNARY_OP_STEP:
11930+ case GGML_UNARY_OP_ROUND:
1192411931 buf = tensor->buffer;
1192511932 break;
1192611933 default:
@@ -13525,6 +13532,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1352513532 case GGML_UNARY_OP_ABS:
1352613533 case GGML_UNARY_OP_SOFTPLUS:
1352713534 case GGML_UNARY_OP_STEP:
13535+ case GGML_UNARY_OP_ROUND:
1352813536 return ggml_is_contiguous(op->src[0]) &&
1352913537 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1353013538 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14444,6 +14452,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1444414452 case GGML_UNARY_OP_STEP:
1444514453 tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
1444614454 break;
14455+ case GGML_UNARY_OP_ROUND:
14456+ tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
14457+ break;
1444714458 default:
1444814459 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1444914460 GGML_ABORT("fatal error");
0 commit comments