@@ -665,6 +665,7 @@ struct vk_device_struct {
665665 vk_pipeline pipeline_hardswish[2];
666666 vk_pipeline pipeline_abs[2];
667667 vk_pipeline pipeline_softplus[2];
668+ vk_pipeline pipeline_step[2];
668669
669670 vk_pipeline pipeline_add1_f16_f16;
670671 vk_pipeline pipeline_add1_f16_f32;
@@ -3836,6 +3837,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
38363837 CREATE_UNARY(hardswish)
38373838 CREATE_UNARY(abs)
38383839 CREATE_UNARY(softplus)
3840+ CREATE_UNARY(step)
38393841#undef CREATE_UNARY
38403842
38413843#define CREATE_UNARY_RTE(name) \
@@ -8262,6 +8264,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82628264 return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
82638265 case GGML_UNARY_OP_SOFTPLUS:
82648266 return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
8267+ case GGML_UNARY_OP_STEP:
8268+ return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
82658269 default:
82668270 break;
82678271 }
@@ -11284,6 +11288,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1128411288 case GGML_UNARY_OP_HARDSWISH:
1128511289 case GGML_UNARY_OP_ABS:
1128611290 case GGML_UNARY_OP_SOFTPLUS:
11291+ case GGML_UNARY_OP_STEP:
1128711292 break;
1128811293 default:
1128911294 return false;
@@ -11637,6 +11642,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1163711642 case GGML_UNARY_OP_HARDSWISH:
1163811643 case GGML_UNARY_OP_ABS:
1163911644 case GGML_UNARY_OP_SOFTPLUS:
11645+ case GGML_UNARY_OP_STEP:
1164011646 ggml_vk_unary(ctx, compute_ctx, src0, node);
1164111647 break;
1164211648 default:
@@ -11914,6 +11920,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1191411920 case GGML_UNARY_OP_HARDSWISH:
1191511921 case GGML_UNARY_OP_ABS:
1191611922 case GGML_UNARY_OP_SOFTPLUS:
11923+ case GGML_UNARY_OP_STEP:
1191711924 buf = tensor->buffer;
1191811925 break;
1191911926 default:
@@ -13517,6 +13524,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1351713524 case GGML_UNARY_OP_HARDSWISH:
1351813525 case GGML_UNARY_OP_ABS:
1351913526 case GGML_UNARY_OP_SOFTPLUS:
13527+ case GGML_UNARY_OP_STEP:
1352013528 return ggml_is_contiguous(op->src[0]) &&
1352113529 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1352213530 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14433,6 +14441,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1443314441 case GGML_UNARY_OP_SOFTPLUS:
1443414442 tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
1443514443 break;
14444+ case GGML_UNARY_OP_STEP:
14445+ tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
14446+ break;
1443614447 default:
1443714448 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1443814449 GGML_ABORT("fatal error");
0 commit comments