Skip to content

Commit 71f9ff7

Browse files
committed
vulkan: implement FILL
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent aeff922 commit 71f9ff7

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,8 @@ struct vk_device_struct {
671671

672672
vk_pipeline pipeline_arange_f32;
673673

674+
vk_pipeline pipeline_fill_f32;
675+
674676
vk_pipeline pipeline_geglu[2];
675677
vk_pipeline pipeline_reglu[2];
676678
vk_pipeline pipeline_swiglu[2];
@@ -3851,6 +3853,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
38513853

38523854
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38533855

3856+
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3857+
38543858
#define CREATE_GLU(name) \
38553859
if (device->float_controls_rte_fp16) { \
38563860
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -8555,6 +8559,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85558559
return ctx->device->pipeline_arange_f32;
85568560
}
85578561
return nullptr;
8562+
case GGML_OP_FILL:
8563+
if (dst->type == GGML_TYPE_F32) {
8564+
return ctx->device->pipeline_fill_f32;
8565+
}
8566+
return nullptr;
85588567
default:
85598568
return nullptr;
85608569
}
@@ -8847,6 +8856,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88478856
case GGML_OP_MUL:
88488857
case GGML_OP_ADD1:
88498858
case GGML_OP_ARANGE:
8859+
case GGML_OP_FILL:
88508860
case GGML_OP_SCALE:
88518861
case GGML_OP_SQR:
88528862
case GGML_OP_SQRT:
@@ -9489,6 +9499,27 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
94899499
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
94909500
}
94919501

9502+
static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9503+
VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
9504+
9505+
vk_op_push_constants pc = {
9506+
(uint32_t)ggml_nelements(dst),
9507+
1,
9508+
ggml_get_op_params_f32(dst, 0),
9509+
0.0f,
9510+
};
9511+
9512+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
9513+
GGML_ASSERT(pipeline != nullptr);
9514+
9515+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9516+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
9517+
9518+
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
9519+
9520+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
9521+
}
9522+
94929523
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94939524
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
94949525
}
@@ -11291,6 +11322,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1129111322
case GGML_OP_DIV:
1129211323
case GGML_OP_ADD1:
1129311324
case GGML_OP_ARANGE:
11325+
case GGML_OP_FILL:
1129411326
case GGML_OP_CONCAT:
1129511327
case GGML_OP_UPSCALE:
1129611328
case GGML_OP_SCALE:
@@ -11511,6 +11543,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1151111543
case GGML_OP_ARANGE:
1151211544
ggml_vk_arange(ctx, compute_ctx, node);
1151311545

11546+
break;
11547+
case GGML_OP_FILL:
11548+
ggml_vk_fill(ctx, compute_ctx, node);
11549+
1151411550
break;
1151511551
case GGML_OP_SCALE:
1151611552
ggml_vk_scale(ctx, compute_ctx, src0, node);
@@ -11799,6 +11835,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1179911835
case GGML_OP_DIV:
1180011836
case GGML_OP_ADD1:
1180111837
case GGML_OP_ARANGE:
11838+
case GGML_OP_FILL:
1180211839
case GGML_OP_ADD_ID:
1180311840
case GGML_OP_CONCAT:
1180411841
case GGML_OP_UPSCALE:
@@ -13779,6 +13816,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1377913816
case GGML_OP_CONCAT:
1378013817
case GGML_OP_ADD1:
1378113818
case GGML_OP_ARANGE:
13819+
case GGML_OP_FILL:
1378213820
case GGML_OP_SCALE:
1378313821
case GGML_OP_PAD:
1378413822
case GGML_OP_ROLL:
@@ -14268,6 +14306,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1426814306
const float stop = ggml_get_op_params_f32(tensor, 1);
1426914307
const float step = ggml_get_op_params_f32(tensor, 2);
1427014308
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
14309+
} else if (tensor->op == GGML_OP_FILL) {
14310+
const float value = ggml_get_op_params_f32(tensor, 0);
14311+
tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
1427114312
} else if (tensor->op == GGML_OP_SQR) {
1427214313
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
1427314314
} else if (tensor->op == GGML_OP_SQRT) {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
9+
10+
void main() {
11+
const uint i = gl_GlobalInvocationID.x;
12+
13+
if (i >= p.KX) {
14+
return;
15+
}
16+
17+
// p.param1 = fill value
18+
data_d[i] = D_TYPE(p.param1);
19+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ void process_shaders() {
846846
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
847847
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
848848
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
849+
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
849850

850851
for (auto rte : {false, true}) {
851852
std::string suffix = rte ? "_rte" : "";

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7585,6 +7585,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75857585
test_cases.emplace_back(new test_fill(0.0f));
75867586
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
75877587
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
7588+
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
75887589

75897590
test_cases.emplace_back(new test_solve_tri());
75907591
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));

0 commit comments

Comments
 (0)