@@ -664,6 +664,7 @@ struct vk_device_struct {
664664 vk_pipeline pipeline_hardsigmoid[2];
665665 vk_pipeline pipeline_hardswish[2];
666666 vk_pipeline pipeline_abs[2];
667+ vk_pipeline pipeline_softplus[2];
667668
668669 vk_pipeline pipeline_add1_f16_f16;
669670 vk_pipeline pipeline_add1_f16_f32;
@@ -3834,6 +3835,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38343835 CREATE_UNARY(hardsigmoid)
38353836 CREATE_UNARY(hardswish)
38363837 CREATE_UNARY(abs)
3838+ CREATE_UNARY(softplus)
38373839#undef CREATE_UNARY
38383840
38393841#define CREATE_UNARY_RTE(name) \
@@ -8258,6 +8260,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82588260 return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
82598261 case GGML_UNARY_OP_ABS:
82608262 return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
8263+ case GGML_UNARY_OP_SOFTPLUS:
8264+ return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
82618265 default:
82628266 break;
82638267 }
@@ -11279,6 +11283,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1127911283 case GGML_UNARY_OP_HARDSIGMOID:
1128011284 case GGML_UNARY_OP_HARDSWISH:
1128111285 case GGML_UNARY_OP_ABS:
11286+ case GGML_UNARY_OP_SOFTPLUS:
1128211287 break;
1128311288 default:
1128411289 return false;
@@ -11631,6 +11636,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1163111636 case GGML_UNARY_OP_HARDSIGMOID:
1163211637 case GGML_UNARY_OP_HARDSWISH:
1163311638 case GGML_UNARY_OP_ABS:
11639+ case GGML_UNARY_OP_SOFTPLUS:
1163411640 ggml_vk_unary(ctx, compute_ctx, src0, node);
1163511641 break;
1163611642 default:
@@ -11907,6 +11913,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1190711913 case GGML_UNARY_OP_HARDSIGMOID:
1190811914 case GGML_UNARY_OP_HARDSWISH:
1190911915 case GGML_UNARY_OP_ABS:
11916+ case GGML_UNARY_OP_SOFTPLUS:
1191011917 buf = tensor->buffer;
1191111918 break;
1191211919 default:
@@ -13509,6 +13516,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1350913516 case GGML_UNARY_OP_HARDSIGMOID:
1351013517 case GGML_UNARY_OP_HARDSWISH:
1351113518 case GGML_UNARY_OP_ABS:
13519+ case GGML_UNARY_OP_SOFTPLUS:
1351213520 return ggml_is_contiguous(op->src[0]) &&
1351313521 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1351413522 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14422,6 +14430,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1442214430 case GGML_UNARY_OP_ABS:
1442314431 tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
1442414432 break;
14433+ case GGML_UNARY_OP_SOFTPLUS:
14434+ tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
14435+ break;
1442514436 default:
1442614437 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1442714438 GGML_ABORT("fatal error");
0 commit comments