@@ -671,6 +671,8 @@ struct vk_device_struct {
671671
672672 vk_pipeline pipeline_arange_f32;
673673
674+ vk_pipeline pipeline_fill_f32;
675+
674676 vk_pipeline pipeline_geglu[2];
675677 vk_pipeline pipeline_reglu[2];
676678 vk_pipeline pipeline_swiglu[2];
@@ -3851,6 +3853,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
38513853
38523854 ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38533855
3856+ ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3857+
38543858#define CREATE_GLU(name) \
38553859 if (device->float_controls_rte_fp16) { \
38563860 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -8555,6 +8559,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85558559 return ctx->device->pipeline_arange_f32;
85568560 }
85578561 return nullptr;
8562+ case GGML_OP_FILL:
8563+ if (dst->type == GGML_TYPE_F32) {
8564+ return ctx->device->pipeline_fill_f32;
8565+ }
8566+ return nullptr;
85588567 default:
85598568 return nullptr;
85608569 }
@@ -8847,6 +8856,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88478856 case GGML_OP_MUL:
88488857 case GGML_OP_ADD1:
88498858 case GGML_OP_ARANGE:
8859+ case GGML_OP_FILL:
88508860 case GGML_OP_SCALE:
88518861 case GGML_OP_SQR:
88528862 case GGML_OP_SQRT:
@@ -9489,6 +9499,27 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
94899499 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
94909500}
94919501
9502+ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9503+ VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
9504+
9505+ vk_op_push_constants pc = {
9506+ (uint32_t)ggml_nelements(dst),
9507+ 1,
9508+ ggml_get_op_params_f32(dst, 0),
9509+ 0.0f,
9510+ };
9511+
9512+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
9513+ GGML_ASSERT(pipeline != nullptr);
9514+
9515+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9516+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
9517+
9518+ std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
9519+
9520+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
9521+ }
9522+
94929523static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94939524 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
94949525}
@@ -11291,6 +11322,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1129111322 case GGML_OP_DIV:
1129211323 case GGML_OP_ADD1:
1129311324 case GGML_OP_ARANGE:
11325+ case GGML_OP_FILL:
1129411326 case GGML_OP_CONCAT:
1129511327 case GGML_OP_UPSCALE:
1129611328 case GGML_OP_SCALE:
@@ -11511,6 +11543,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1151111543 case GGML_OP_ARANGE:
1151211544 ggml_vk_arange(ctx, compute_ctx, node);
1151311545
11546+ break;
11547+ case GGML_OP_FILL:
11548+ ggml_vk_fill(ctx, compute_ctx, node);
11549+
1151411550 break;
1151511551 case GGML_OP_SCALE:
1151611552 ggml_vk_scale(ctx, compute_ctx, src0, node);
@@ -11799,6 +11835,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1179911835 case GGML_OP_DIV:
1180011836 case GGML_OP_ADD1:
1180111837 case GGML_OP_ARANGE:
11838+ case GGML_OP_FILL:
1180211839 case GGML_OP_ADD_ID:
1180311840 case GGML_OP_CONCAT:
1180411841 case GGML_OP_UPSCALE:
@@ -13779,6 +13816,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1377913816 case GGML_OP_CONCAT:
1378013817 case GGML_OP_ADD1:
1378113818 case GGML_OP_ARANGE:
13819+ case GGML_OP_FILL:
1378213820 case GGML_OP_SCALE:
1378313821 case GGML_OP_PAD:
1378413822 case GGML_OP_ROLL:
@@ -14268,6 +14306,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1426814306 const float stop = ggml_get_op_params_f32(tensor, 1);
1426914307 const float step = ggml_get_op_params_f32(tensor, 2);
1427014308 tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
14309+ } else if (tensor->op == GGML_OP_FILL) {
14310+ const float value = ggml_get_op_params_f32(tensor, 0);
14311+ tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
1427114312 } else if (tensor->op == GGML_OP_SQR) {
1427214313 tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
1427314314 } else if (tensor->op == GGML_OP_SQRT) {
0 commit comments