Skip to content

Commit aeff922

Browse files
committed
vulkan: implement ARANGE
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent 1941f4d commit aeff922

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,8 @@ struct vk_device_struct {
669669
vk_pipeline pipeline_add1_f16_f32;
670670
vk_pipeline pipeline_add1_f32_f32;
671671

672+
vk_pipeline pipeline_arange_f32;
673+
672674
vk_pipeline pipeline_geglu[2];
673675
vk_pipeline pipeline_reglu[2];
674676
vk_pipeline pipeline_swiglu[2];
@@ -3847,6 +3849,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
38473849
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
38483850
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
38493851

3852+
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3853+
38503854
#define CREATE_GLU(name) \
38513855
if (device->float_controls_rte_fp16) { \
38523856
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -8546,6 +8550,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85468550
return ctx->device->pipeline_add1_f32_f32;
85478551
}
85488552
return nullptr;
8553+
case GGML_OP_ARANGE:
8554+
if (dst->type == GGML_TYPE_F32) {
8555+
return ctx->device->pipeline_arange_f32;
8556+
}
8557+
return nullptr;
85498558
default:
85508559
return nullptr;
85518560
}
@@ -8837,6 +8846,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88378846
case GGML_OP_DIV:
88388847
case GGML_OP_MUL:
88398848
case GGML_OP_ADD1:
8849+
case GGML_OP_ARANGE:
88408850
case GGML_OP_SCALE:
88418851
case GGML_OP_SQR:
88428852
case GGML_OP_SQRT:
@@ -9458,6 +9468,27 @@ static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, cons
94589468
});
94599469
}
94609470

9471+
static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9472+
VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
9473+
9474+
vk_op_push_constants pc = {
9475+
(uint32_t)ggml_nelements(dst),
9476+
1,
9477+
ggml_get_op_params_f32(dst, 0),
9478+
ggml_get_op_params_f32(dst, 2),
9479+
};
9480+
9481+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
9482+
GGML_ASSERT(pipeline != nullptr);
9483+
9484+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9485+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
9486+
9487+
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
9488+
9489+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
9490+
}
9491+
94619492
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94629493
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
94639494
}
@@ -11259,6 +11290,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1125911290
case GGML_OP_MUL:
1126011291
case GGML_OP_DIV:
1126111292
case GGML_OP_ADD1:
11293+
case GGML_OP_ARANGE:
1126211294
case GGML_OP_CONCAT:
1126311295
case GGML_OP_UPSCALE:
1126411296
case GGML_OP_SCALE:
@@ -11476,6 +11508,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1147611508
ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
1147711509

1147811510
break;
11511+
case GGML_OP_ARANGE:
11512+
ggml_vk_arange(ctx, compute_ctx, node);
11513+
1147911514
break;
1148011515
case GGML_OP_SCALE:
1148111516
ggml_vk_scale(ctx, compute_ctx, src0, node);
@@ -11763,6 +11798,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1176311798
case GGML_OP_MUL:
1176411799
case GGML_OP_DIV:
1176511800
case GGML_OP_ADD1:
11801+
case GGML_OP_ARANGE:
1176611802
case GGML_OP_ADD_ID:
1176711803
case GGML_OP_CONCAT:
1176811804
case GGML_OP_UPSCALE:
@@ -13742,6 +13778,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1374213778
case GGML_OP_ACC:
1374313779
case GGML_OP_CONCAT:
1374413780
case GGML_OP_ADD1:
13781+
case GGML_OP_ARANGE:
1374513782
case GGML_OP_SCALE:
1374613783
case GGML_OP_PAD:
1374713784
case GGML_OP_ROLL:
@@ -14226,6 +14263,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1422614263
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
1422714264
} else if (tensor->op == GGML_OP_ADD1) {
1422814265
tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
14266+
} else if (tensor->op == GGML_OP_ARANGE) {
14267+
const float start = ggml_get_op_params_f32(tensor, 0);
14268+
const float stop = ggml_get_op_params_f32(tensor, 1);
14269+
const float step = ggml_get_op_params_f32(tensor, 2);
14270+
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
1422914271
} else if (tensor->op == GGML_OP_SQR) {
1423014272
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
1423114273
} else if (tensor->op == GGML_OP_SQRT) {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
9+
10+
void main() {
11+
const uint i = gl_GlobalInvocationID.x;
12+
13+
if (i >= p.KX) {
14+
return;
15+
}
16+
17+
// p.param1 = start, p.param2 = step
18+
float value = p.param1 + p.param2 * float(i);
19+
data_d[i] = D_TYPE(value);
20+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ void process_shaders() {
845845
string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
846846
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
847847
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
848+
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
848849

849850
for (auto rte : {false, true}) {
850851
std::string suffix = rte ? "_rte" : "";

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7557,6 +7557,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75577557
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
75587558
test_cases.emplace_back(new test_roll());
75597559
test_cases.emplace_back(new test_arange());
7560+
test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f));
75607561
test_cases.emplace_back(new test_timestep_embedding());
75617562
test_cases.emplace_back(new test_leaky_relu());
75627563

0 commit comments

Comments
 (0)