@@ -669,6 +669,8 @@ struct vk_device_struct {
669669 vk_pipeline pipeline_add1_f16_f32;
670670 vk_pipeline pipeline_add1_f32_f32;
671671
672+ vk_pipeline pipeline_arange_f32;
673+
672674 vk_pipeline pipeline_geglu[2];
673675 vk_pipeline pipeline_reglu[2];
674676 vk_pipeline pipeline_swiglu[2];
@@ -3847,6 +3849,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
38473849 ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
38483850 ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
38493851
3852+ ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3853+
38503854#define CREATE_GLU(name) \
38513855 if (device->float_controls_rte_fp16) { \
38523856 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -8546,6 +8550,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85468550 return ctx->device->pipeline_add1_f32_f32;
85478551 }
85488552 return nullptr;
8553+ case GGML_OP_ARANGE:
8554+ if (dst->type == GGML_TYPE_F32) {
8555+ return ctx->device->pipeline_arange_f32;
8556+ }
8557+ return nullptr;
85498558 default:
85508559 return nullptr;
85518560 }
@@ -8837,6 +8846,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88378846 case GGML_OP_DIV:
88388847 case GGML_OP_MUL:
88398848 case GGML_OP_ADD1:
8849+ case GGML_OP_ARANGE:
88408850 case GGML_OP_SCALE:
88418851 case GGML_OP_SQR:
88428852 case GGML_OP_SQRT:
@@ -9458,6 +9468,27 @@ static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, cons
94589468 });
94599469}
94609470
9471+ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9472+ VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
9473+
9474+ vk_op_push_constants pc = {
9475+ (uint32_t)ggml_nelements(dst),
9476+ 1,
9477+ ggml_get_op_params_f32(dst, 0),
9478+ ggml_get_op_params_f32(dst, 2),
9479+ };
9480+
9481+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
9482+ GGML_ASSERT(pipeline != nullptr);
9483+
9484+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9485+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
9486+
9487+ std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
9488+
9489+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
9490+ }
9491+
94619492static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94629493 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
94639494}
@@ -11259,6 +11290,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1125911290 case GGML_OP_MUL:
1126011291 case GGML_OP_DIV:
1126111292 case GGML_OP_ADD1:
11293+ case GGML_OP_ARANGE:
1126211294 case GGML_OP_CONCAT:
1126311295 case GGML_OP_UPSCALE:
1126411296 case GGML_OP_SCALE:
@@ -11476,6 +11508,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1147611508 ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
1147711509
1147811510 break;
11511+ case GGML_OP_ARANGE:
11512+ ggml_vk_arange(ctx, compute_ctx, node);
11513+
1147911514 break;
1148011515 case GGML_OP_SCALE:
1148111516 ggml_vk_scale(ctx, compute_ctx, src0, node);
@@ -11763,6 +11798,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1176311798 case GGML_OP_MUL:
1176411799 case GGML_OP_DIV:
1176511800 case GGML_OP_ADD1:
11801+ case GGML_OP_ARANGE:
1176611802 case GGML_OP_ADD_ID:
1176711803 case GGML_OP_CONCAT:
1176811804 case GGML_OP_UPSCALE:
@@ -13742,6 +13778,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1374213778 case GGML_OP_ACC:
1374313779 case GGML_OP_CONCAT:
1374413780 case GGML_OP_ADD1:
13781+ case GGML_OP_ARANGE:
1374513782 case GGML_OP_SCALE:
1374613783 case GGML_OP_PAD:
1374713784 case GGML_OP_ROLL:
@@ -14226,6 +14263,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1422614263 tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
1422714264 } else if (tensor->op == GGML_OP_ADD1) {
1422814265 tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
14266+ } else if (tensor->op == GGML_OP_ARANGE) {
14267+ const float start = ggml_get_op_params_f32(tensor, 0);
14268+ const float stop = ggml_get_op_params_f32(tensor, 1);
14269+ const float step = ggml_get_op_params_f32(tensor, 2);
14270+ tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
1422914271 } else if (tensor->op == GGML_OP_SQR) {
1423014272 tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
1423114273 } else if (tensor->op == GGML_OP_SQRT) {
0 commit comments