Skip to content

Commit a6f38e1

Browse files
committed
vulkan: support larger argsort
This is an extension of the original bitonic sorting shader that puts the temporary values in global memory and when more than 1024 threads are needed it runs multiple workgroups and synchronizes through a pipelinebarrier. To improve the memory access pattern, a copy of the float value is kept with the index value. I've applied this same change to the original shared memory version of the shader, which is still used when ncols <= 1024.
1 parent 2376b77 commit a6f38e1

File tree

5 files changed

+240
-46
lines changed

5 files changed

+240
-46
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ enum shader_reduction_mode {
406406
SHADER_REDUCTION_MODE_COUNT,
407407
};
408408

409-
static constexpr uint32_t num_argsort_pipelines = 11;
409+
// Arbitrary limit for argsort size (about a million columns).
410+
static constexpr uint32_t num_argsort_pipelines = 21;
410411
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
411412
static constexpr uint32_t num_topk_moe_pipelines = 10;
412413

@@ -526,6 +527,7 @@ struct vk_device_struct {
526527
bool multi_add;
527528
bool shader_int64;
528529
bool buffer_device_address;
530+
bool vulkan_memory_model;
529531

530532
bool add_rms_fusion;
531533
uint32_t partials_binding_alignment;
@@ -539,6 +541,9 @@ struct vk_device_struct {
539541
uint32_t subgroup_max_size;
540542
bool subgroup_require_full_support;
541543

544+
// floor(log2(maxComputeWorkGroupInvocations))
545+
uint32_t max_workgroup_size_log2 {};
546+
542547
bool coopmat_support;
543548
bool coopmat_acc_f32_support {};
544549
bool coopmat_acc_f16_support {};
@@ -683,6 +688,7 @@ struct vk_device_struct {
683688
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
684689
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
685690
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
691+
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
686692
vk_pipeline pipeline_sum_rows_f32;
687693
vk_pipeline pipeline_argmax_f32;
688694
vk_pipeline pipeline_count_equal_i32;
@@ -1174,7 +1180,11 @@ struct vk_op_soft_max_push_constants {
11741180
struct vk_op_argsort_push_constants {
11751181
uint32_t ncols;
11761182
uint32_t nrows;
1177-
int32_t order;
1183+
uint32_t order;
1184+
uint32_t outer_start;
1185+
uint32_t outer_end;
1186+
uint32_t inner_start;
1187+
uint32_t inner_end;
11781188
};
11791189

11801190
struct vk_op_im2col_push_constants {
@@ -3885,7 +3895,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
38853895
}
38863896

38873897
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3888-
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3898+
const uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
3899+
const uint32_t NCOLS_PADDED = 1u << i;
3900+
const uint32_t NCOLS_PADDED_LOG2 = i;
3901+
if (i <= device->max_workgroup_size_log2 &&
3902+
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3903+
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true);
3904+
}
3905+
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true);
38893906
}
38903907

38913908
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4286,6 +4303,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
42864303

42874304
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
42884305

4306+
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
4307+
42894308
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
42904309

42914310
// Try to find a non-graphics compute queue and transfer-focused queues
@@ -4425,6 +4444,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
44254444

44264445
device->shader_int64 = device_features2.features.shaderInt64;
44274446
device->buffer_device_address = vk12_features.bufferDeviceAddress;
4447+
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
44284448

44294449
if (device->subgroup_size_control) {
44304450
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -8339,17 +8359,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
83398359
return nullptr;
83408360
}
83418361
case GGML_OP_ARGSORT:
8342-
if (ctx->num_additional_fused_ops) {
8343-
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8344-
GGML_ASSERT(idx < num_topk_moe_pipelines);
8345-
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8346-
return ctx->device->pipeline_topk_moe[idx][mode];
8347-
}
8348-
8349-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
8350-
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8351-
return ctx->device->pipeline_argsort_f32[idx];
8352-
}
8362+
GGML_ASSERT(0);
83538363
return nullptr;
83548364
case GGML_OP_SUM:
83558365
case GGML_OP_SUM_ROWS:
@@ -8742,8 +8752,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
87428752
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
87438753
break;
87448754
case GGML_OP_ARGSORT:
8745-
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
8746-
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
8755+
GGML_ASSERT(0);
87478756
break;
87488757
case GGML_OP_IM2COL:
87498758
{
@@ -9859,16 +9868,81 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
98599868
}
98609869

98619870
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9862-
int32_t * op_params = (int32_t *)dst->op_params;
9871+
const uint32_t * op_params = (const uint32_t *)dst->op_params;
98639872

98649873
uint32_t ncols = src0->ne[0];
98659874
uint32_t nrows = ggml_nrows(src0);
98669875

9867-
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
9868-
ncols,
9869-
nrows,
9870-
op_params[0],
9871-
});
9876+
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
9877+
uint32_t ncolsp2 = 1 << ncols_pad_log2;
9878+
9879+
vk_op_argsort_push_constants pc { ncols, nrows, op_params[0], 0, 0, 0, 0, };
9880+
9881+
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
9882+
bool use_small = ctx->device->pipeline_argsort_f32[ncols_pad_log2] != nullptr;
9883+
9884+
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[ncols_pad_log2]
9885+
: ctx->device->pipeline_argsort_large_f32[ncols_pad_log2];
9886+
9887+
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
9888+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
9889+
vk_subbuffer subbuf1 = dst_buf;
9890+
9891+
// Reserve space for ivec2 per element, with rows padded to a power of two
9892+
if (!use_small) {
9893+
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
9894+
9895+
if (ctx->prealloc_size_x < x_sz) {
9896+
ctx->prealloc_size_x = x_sz;
9897+
ggml_vk_preallocate_buffers(ctx, subctx);
9898+
}
9899+
if (ctx->prealloc_x_need_sync) {
9900+
ggml_vk_sync_buffers(ctx, subctx);
9901+
}
9902+
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
9903+
}
9904+
9905+
std::array<uint32_t, 3> elements;
9906+
9907+
elements[0] = ncolsp2;
9908+
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9909+
elements[2] = 1;
9910+
9911+
// First dispatch initializes tmp_idx and does the first N passes where
9912+
// there is only communication between threads in the same workgroup.
9913+
{
9914+
vk_op_argsort_push_constants pc2 = pc;
9915+
pc2.outer_start = 0;
9916+
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
9917+
pc2.inner_start = 0;
9918+
pc2.inner_end = 100;
9919+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9920+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9921+
}
9922+
if (!use_small) {
9923+
ggml_vk_sync_buffers(ctx, subctx);
9924+
// Loop over outer/inner passes, synchronizing between each pass.
9925+
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
9926+
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
9927+
vk_op_argsort_push_constants pc2 = pc;
9928+
pc2.outer_start = outer;
9929+
pc2.outer_end = outer + 1;
9930+
pc2.inner_start = inner;
9931+
pc2.inner_end = inner + 1;
9932+
// When the inner idx is large enough, there's only communication
9933+
// within a workgroup. So the remaining inner iterations can all
9934+
// run in the same dispatch.
9935+
if (outer - inner < ctx->device->max_workgroup_size_log2) {
9936+
pc2.inner_end = 100;
9937+
inner = outer;
9938+
}
9939+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9940+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9941+
ggml_vk_sync_buffers(ctx, subctx);
9942+
}
9943+
}
9944+
ctx->prealloc_x_need_sync = true;
9945+
}
98729946
}
98739947

98749948
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13688,7 +13762,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1368813762
case GGML_OP_LOG:
1368913763
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1369013764
case GGML_OP_ARGSORT:
13691-
return op->ne[0] <= max_argsort_cols;
13765+
{
13766+
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
13767+
return false;
13768+
}
13769+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13770+
auto device = ggml_vk_get_device(ctx->device);
13771+
if (device->vulkan_memory_model) {
13772+
return op->ne[0] <= max_argsort_cols;
13773+
} else {
13774+
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
13775+
}
13776+
}
1369213777
case GGML_OP_UPSCALE:
1369313778
case GGML_OP_ACC:
1369413779
case GGML_OP_CONCAT:

ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,26 @@
44
#include "types.glsl"
55

66
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7-
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
7+
layout(constant_id = 1) const int NCOLS_PADDED = 1024;
8+
layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10;
89
#define ASC 0
910

1011
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1112

1213
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13-
layout (binding = 1) buffer D {int data_d[];};
14+
layout (binding = 2) writeonly buffer D {int data_d[];};
1415

1516
layout (push_constant) uniform parameter {
1617
uint ncols;
1718
uint nrows;
1819
uint order;
20+
uint outer_start;
21+
uint outer_end;
22+
uint inner_start;
23+
uint inner_end;
1924
} p;
2025

21-
shared int dst_row[BLOCK_SIZE];
22-
shared A_TYPE a_sh[BLOCK_SIZE];
23-
24-
void swap(uint idx0, uint idx1) {
25-
int tmp = dst_row[idx0];
26-
dst_row[idx0] = dst_row[idx1];
27-
dst_row[idx1] = tmp;
28-
}
26+
shared ivec2 dst_row[BLOCK_SIZE];
2927

3028
void argsort(bool needs_bounds_check, const uint row) {
3129
// bitonic sort
@@ -34,11 +32,10 @@ void argsort(bool needs_bounds_check, const uint row) {
3432
const uint row_offset = row * p.ncols;
3533

3634
// initialize indices
37-
dst_row[col] = col;
38-
a_sh[col] = data_a[row_offset + col];
35+
dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
3936
barrier();
4037

41-
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
38+
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
4239
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
4340
uint num_inner_loop_iters = outer_idx + 1;
4441
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
@@ -47,14 +44,15 @@ void argsort(bool needs_bounds_check, const uint row) {
4744
int idx_0 = (col & k) == 0 ? col : ixj;
4845
int idx_1 = (col & k) == 0 ? ixj : col;
4946

50-
int sh_idx_0 = dst_row[idx_0];
51-
int sh_idx_1 = dst_row[idx_1];
52-
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
53-
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
47+
ivec2 sh_idx_0 = dst_row[idx_0];
48+
ivec2 sh_idx_1 = dst_row[idx_1];
49+
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
50+
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
5451

5552
if ((idx_0_oob ||
56-
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
57-
swap(idx_0, idx_1);
53+
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
54+
dst_row[idx_0] = sh_idx_1;
55+
dst_row[idx_1] = sh_idx_0;
5856
}
5957

6058
barrier();
@@ -63,9 +61,9 @@ void argsort(bool needs_bounds_check, const uint row) {
6361

6462
if (col < p.ncols) {
6563
if (p.order == ASC) {
66-
data_d[row_offset + col] = dst_row[col];
64+
data_d[row_offset + col] = dst_row[col].x;
6765
} else {
68-
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
66+
data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
6967
}
7068
}
7169
}

0 commit comments

Comments
 (0)