Skip to content

Commit b7f5a07

Browse files
committed
vulkan: support larger argsort
This is an extension of the original bitonic sorting shader that puts the temporary values in global memory and when more than 1024 threads are needed it runs multiple workgroups and synchronizes through a pipelinebarrier. To improve the memory access pattern, a copy of the float value is kept with the index value. I've applied this same change to the original shared memory version of the shader, which is still used when ncols <= 1024.
1 parent 10e9780 commit b7f5a07

File tree

5 files changed

+240
-46
lines changed

5 files changed

+240
-46
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ enum shader_reduction_mode {
406406
SHADER_REDUCTION_MODE_COUNT,
407407
};
408408

409-
static constexpr uint32_t num_argsort_pipelines = 11;
409+
// Arbitrary limit for argsort size (about a million columns).
410+
static constexpr uint32_t num_argsort_pipelines = 21;
410411
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
411412
static constexpr uint32_t num_topk_moe_pipelines = 10;
412413

@@ -526,6 +527,7 @@ struct vk_device_struct {
526527
bool multi_add;
527528
bool shader_int64;
528529
bool buffer_device_address;
530+
bool vulkan_memory_model;
529531

530532
bool add_rms_fusion;
531533
uint32_t partials_binding_alignment;
@@ -539,6 +541,9 @@ struct vk_device_struct {
539541
uint32_t subgroup_max_size;
540542
bool subgroup_require_full_support;
541543

544+
// floor(log2(maxComputeWorkGroupInvocations))
545+
uint32_t max_workgroup_size_log2 {};
546+
542547
bool coopmat_support;
543548
bool coopmat_acc_f32_support {};
544549
bool coopmat_acc_f16_support {};
@@ -683,6 +688,7 @@ struct vk_device_struct {
683688
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
684689
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
685690
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
691+
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
686692
vk_pipeline pipeline_sum_rows_f32;
687693
vk_pipeline pipeline_argmax_f32;
688694
vk_pipeline pipeline_count_equal_i32;
@@ -1174,7 +1180,11 @@ struct vk_op_soft_max_push_constants {
11741180
struct vk_op_argsort_push_constants {
11751181
uint32_t ncols;
11761182
uint32_t nrows;
1177-
int32_t order;
1183+
uint32_t order;
1184+
uint32_t outer_start;
1185+
uint32_t outer_end;
1186+
uint32_t inner_start;
1187+
uint32_t inner_end;
11781188
};
11791189

11801190
struct vk_op_im2col_push_constants {
@@ -3891,7 +3901,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
38913901
}
38923902

38933903
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3894-
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3904+
const uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
3905+
const uint32_t NCOLS_PADDED = 1u << i;
3906+
const uint32_t NCOLS_PADDED_LOG2 = i;
3907+
if (i <= device->max_workgroup_size_log2 &&
3908+
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3909+
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true);
3910+
}
3911+
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true);
38953912
}
38963913

38973914
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4292,6 +4309,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
42924309

42934310
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
42944311

4312+
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
4313+
42954314
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
42964315

42974316
// Try to find a non-graphics compute queue and transfer-focused queues
@@ -4431,6 +4450,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
44314450

44324451
device->shader_int64 = device_features2.features.shaderInt64;
44334452
device->buffer_device_address = vk12_features.bufferDeviceAddress;
4453+
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
44344454

44354455
if (device->subgroup_size_control) {
44364456
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -8345,17 +8365,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
83458365
return nullptr;
83468366
}
83478367
case GGML_OP_ARGSORT:
8348-
if (ctx->num_additional_fused_ops) {
8349-
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8350-
GGML_ASSERT(idx < num_topk_moe_pipelines);
8351-
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8352-
return ctx->device->pipeline_topk_moe[idx][mode];
8353-
}
8354-
8355-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
8356-
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8357-
return ctx->device->pipeline_argsort_f32[idx];
8358-
}
8368+
GGML_ASSERT(0);
83598369
return nullptr;
83608370
case GGML_OP_SUM:
83618371
case GGML_OP_SUM_ROWS:
@@ -8748,8 +8758,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
87488758
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
87498759
break;
87508760
case GGML_OP_ARGSORT:
8751-
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
8752-
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
8761+
GGML_ASSERT(0);
87538762
break;
87548763
case GGML_OP_IM2COL:
87558764
{
@@ -9865,16 +9874,81 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
98659874
}
98669875

98679876
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9868-
int32_t * op_params = (int32_t *)dst->op_params;
9877+
const uint32_t * op_params = (const uint32_t *)dst->op_params;
98699878

98709879
uint32_t ncols = src0->ne[0];
98719880
uint32_t nrows = ggml_nrows(src0);
98729881

9873-
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
9874-
ncols,
9875-
nrows,
9876-
op_params[0],
9877-
});
9882+
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
9883+
uint32_t ncolsp2 = 1 << ncols_pad_log2;
9884+
9885+
vk_op_argsort_push_constants pc { ncols, nrows, op_params[0], 0, 0, 0, 0, };
9886+
9887+
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
9888+
bool use_small = ctx->device->pipeline_argsort_f32[ncols_pad_log2] != nullptr;
9889+
9890+
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[ncols_pad_log2]
9891+
: ctx->device->pipeline_argsort_large_f32[ncols_pad_log2];
9892+
9893+
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
9894+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
9895+
vk_subbuffer subbuf1 = dst_buf;
9896+
9897+
// Reserve space for ivec2 per element, with rows padded to a power of two
9898+
if (!use_small) {
9899+
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
9900+
9901+
if (ctx->prealloc_size_x < x_sz) {
9902+
ctx->prealloc_size_x = x_sz;
9903+
ggml_vk_preallocate_buffers(ctx, subctx);
9904+
}
9905+
if (ctx->prealloc_x_need_sync) {
9906+
ggml_vk_sync_buffers(ctx, subctx);
9907+
}
9908+
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
9909+
}
9910+
9911+
std::array<uint32_t, 3> elements;
9912+
9913+
elements[0] = ncolsp2;
9914+
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9915+
elements[2] = 1;
9916+
9917+
// First dispatch initializes tmp_idx and does the first N passes where
9918+
// there is only communication between threads in the same workgroup.
9919+
{
9920+
vk_op_argsort_push_constants pc2 = pc;
9921+
pc2.outer_start = 0;
9922+
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
9923+
pc2.inner_start = 0;
9924+
pc2.inner_end = 100;
9925+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9926+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9927+
}
9928+
if (!use_small) {
9929+
ggml_vk_sync_buffers(ctx, subctx);
9930+
// Loop over outer/inner passes, synchronizing between each pass.
9931+
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
9932+
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
9933+
vk_op_argsort_push_constants pc2 = pc;
9934+
pc2.outer_start = outer;
9935+
pc2.outer_end = outer + 1;
9936+
pc2.inner_start = inner;
9937+
pc2.inner_end = inner + 1;
9938+
// When the inner idx is large enough, there's only communication
9939+
// within a workgroup. So the remaining inner iterations can all
9940+
// run in the same dispatch.
9941+
if (outer - inner < ctx->device->max_workgroup_size_log2) {
9942+
pc2.inner_end = 100;
9943+
inner = outer;
9944+
}
9945+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9946+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9947+
ggml_vk_sync_buffers(ctx, subctx);
9948+
}
9949+
}
9950+
ctx->prealloc_x_need_sync = true;
9951+
}
98789952
}
98799953

98809954
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13695,7 +13769,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1369513769
case GGML_OP_LOG:
1369613770
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1369713771
case GGML_OP_ARGSORT:
13698-
return op->ne[0] <= max_argsort_cols;
13772+
{
13773+
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
13774+
return false;
13775+
}
13776+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13777+
auto device = ggml_vk_get_device(ctx->device);
13778+
if (device->vulkan_memory_model) {
13779+
return op->ne[0] <= max_argsort_cols;
13780+
} else {
13781+
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
13782+
}
13783+
}
1369913784
case GGML_OP_UPSCALE:
1370013785
case GGML_OP_ACC:
1370113786
case GGML_OP_CONCAT:

ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,26 @@
44
#include "types.glsl"
55

66
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7-
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
7+
layout(constant_id = 1) const int NCOLS_PADDED = 1024;
8+
layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10;
89
#define ASC 0
910

1011
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1112

1213
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13-
layout (binding = 1) buffer D {int data_d[];};
14+
layout (binding = 2) writeonly buffer D {int data_d[];};
1415

1516
layout (push_constant) uniform parameter {
1617
uint ncols;
1718
uint nrows;
1819
uint order;
20+
uint outer_start;
21+
uint outer_end;
22+
uint inner_start;
23+
uint inner_end;
1924
} p;
2025

21-
shared int dst_row[BLOCK_SIZE];
22-
shared A_TYPE a_sh[BLOCK_SIZE];
23-
24-
void swap(uint idx0, uint idx1) {
25-
int tmp = dst_row[idx0];
26-
dst_row[idx0] = dst_row[idx1];
27-
dst_row[idx1] = tmp;
28-
}
26+
shared ivec2 dst_row[BLOCK_SIZE];
2927

3028
void argsort(bool needs_bounds_check, const uint row) {
3129
// bitonic sort
@@ -34,11 +32,10 @@ void argsort(bool needs_bounds_check, const uint row) {
3432
const uint row_offset = row * p.ncols;
3533

3634
// initialize indices
37-
dst_row[col] = col;
38-
a_sh[col] = data_a[row_offset + col];
35+
dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
3936
barrier();
4037

41-
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
38+
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
4239
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
4340
uint num_inner_loop_iters = outer_idx + 1;
4441
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
@@ -47,14 +44,15 @@ void argsort(bool needs_bounds_check, const uint row) {
4744
int idx_0 = (col & k) == 0 ? col : ixj;
4845
int idx_1 = (col & k) == 0 ? ixj : col;
4946

50-
int sh_idx_0 = dst_row[idx_0];
51-
int sh_idx_1 = dst_row[idx_1];
52-
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
53-
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
47+
ivec2 sh_idx_0 = dst_row[idx_0];
48+
ivec2 sh_idx_1 = dst_row[idx_1];
49+
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
50+
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
5451

5552
if ((idx_0_oob ||
56-
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
57-
swap(idx_0, idx_1);
53+
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
54+
dst_row[idx_0] = sh_idx_1;
55+
dst_row[idx_1] = sh_idx_0;
5856
}
5957

6058
barrier();
@@ -63,9 +61,9 @@ void argsort(bool needs_bounds_check, const uint row) {
6361

6462
if (col < p.ncols) {
6563
if (p.order == ASC) {
66-
data_d[row_offset + col] = dst_row[col];
64+
data_d[row_offset + col] = dst_row[col].x;
6765
} else {
68-
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
66+
data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
6967
}
7068
}
7169
}

0 commit comments

Comments
 (0)