@@ -666,6 +666,7 @@ struct vk_device_struct {
666666 vk_pipeline pipeline_abs[2];
667667 vk_pipeline pipeline_softplus[2];
668668 vk_pipeline pipeline_step[2];
669+ vk_pipeline pipeline_round[2];
669670
670671 vk_pipeline pipeline_add1_f16_f16;
671672 vk_pipeline pipeline_add1_f16_f32;
@@ -3832,6 +3833,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38323833 CREATE_UNARY(abs)
38333834 CREATE_UNARY(softplus)
38343835 CREATE_UNARY(step)
3836+ CREATE_UNARY(round)
38353837#undef CREATE_UNARY
38363838
38373839#define CREATE_UNARY_RTE(name) \
@@ -8260,6 +8262,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82608262 return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
82618263 case GGML_UNARY_OP_STEP:
82628264 return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
8265+ case GGML_UNARY_OP_ROUND:
8266+ return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
82638267 default:
82648268 break;
82658269 }
@@ -11283,6 +11287,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1128311287 case GGML_UNARY_OP_ABS:
1128411288 case GGML_UNARY_OP_SOFTPLUS:
1128511289 case GGML_UNARY_OP_STEP:
11290+ case GGML_UNARY_OP_ROUND:
1128611291 break;
1128711292 default:
1128811293 return false;
@@ -11637,6 +11642,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1163711642 case GGML_UNARY_OP_ABS:
1163811643 case GGML_UNARY_OP_SOFTPLUS:
1163911644 case GGML_UNARY_OP_STEP:
11645+ case GGML_UNARY_OP_ROUND:
1164011646 ggml_vk_unary(ctx, compute_ctx, src0, node);
1164111647 break;
1164211648 default:
@@ -11915,6 +11921,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1191511921 case GGML_UNARY_OP_ABS:
1191611922 case GGML_UNARY_OP_SOFTPLUS:
1191711923 case GGML_UNARY_OP_STEP:
11924+ case GGML_UNARY_OP_ROUND:
1191811925 buf = tensor->buffer;
1191911926 break;
1192011927 default:
@@ -13519,6 +13526,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1351913526 case GGML_UNARY_OP_ABS:
1352013527 case GGML_UNARY_OP_SOFTPLUS:
1352113528 case GGML_UNARY_OP_STEP:
13529+ case GGML_UNARY_OP_ROUND:
1352213530 return ggml_is_contiguous(op->src[0]) &&
1352313531 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1352413532 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14437,6 +14445,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1443714445 case GGML_UNARY_OP_STEP:
1443814446 tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
1443914447 break;
14448+ case GGML_UNARY_OP_ROUND:
14449+ tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
14450+ break;
1444014451 default:
1444114452 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1444214453 GGML_ABORT("fatal error");
0 commit comments