From 69a530692ada76873133421c6be88b231a2397af Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 14 Nov 2025 21:51:34 -0600 Subject: [PATCH 1/4] vulkan: support larger argsort This is an extension of the original bitonic sorting shader that puts the temporary values in global memory and when more than 1024 threads are needed it runs multiple workgroups and synchronizes through a pipelinebarrier. To improve the memory access pattern, a copy of the float value is kept with the index value. I've applied this same change to the original shared memory version of the shader, which is still used when ncols <= 1024. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 132 ++++++++++++++---- .../ggml-vulkan/vulkan-shaders/argsort.comp | 40 +++--- .../vulkan-shaders/argsort_large.comp | 108 ++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + tests/test-backend-ops.cpp | 6 +- 5 files changed, 239 insertions(+), 48 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 11262c19894c8..0dd922fed3102 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -406,7 +406,8 @@ enum shader_reduction_mode { SHADER_REDUCTION_MODE_COUNT, }; -static constexpr uint32_t num_argsort_pipelines = 11; +// Arbitrary limit for argsort size (about a million columns). +static constexpr uint32_t num_argsort_pipelines = 21; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; @@ -526,6 +527,7 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + bool vulkan_memory_model; bool add_rms_fusion; uint32_t partials_binding_alignment; @@ -539,6 +541,9 @@ struct vk_device_struct { uint32_t subgroup_max_size; bool subgroup_require_full_support; + // floor(log2(maxComputeWorkGroupInvocations)) + uint32_t max_workgroup_size_log2 {}; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -683,6 +688,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; @@ -1174,7 +1180,11 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; uint32_t nrows; - int32_t order; + uint32_t order; + uint32_t outer_start; + uint32_t outer_end; + uint32_t inner_start; + uint32_t inner_end; }; struct vk_op_im2col_push_constants { @@ -3891,7 +3901,14 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<max_workgroup_size_log2); + const uint32_t NCOLS_PADDED = 1u << i; + const uint32_t NCOLS_PADDED_LOG2 = i; + if (i <= device->max_workgroup_size_log2 && + 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true); + } + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -4292,6 +4309,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -4431,6 +4450,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_int64 = device_features2.features.shaderInt64; device->buffer_device_address = vk12_features.bufferDeviceAddress; + device->vulkan_memory_model = vk12_features.vulkanMemoryModel; if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; @@ -8344,19 +8364,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } - case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; - } - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - return ctx->device->pipeline_argsort_f32[idx]; - } - return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -8748,8 +8755,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: - elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; - elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + GGML_ASSERT(0); break; case GGML_OP_IM2COL: { @@ -9865,16 +9871,81 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - int32_t * op_params = (int32_t *)dst->op_params; + const uint32_t * op_params = (const uint32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; uint32_t nrows = ggml_nrows(src0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { - ncols, - nrows, - op_params[0], - }); + uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols))); + uint32_t ncolsp2 = 1 << ncols_pad_log2; + + vk_op_argsort_push_constants pc { ncols, nrows, op_params[0], 0, 0, 0, 0, }; + + // Use the "small" argsort shader if the whole sort can be done by a single workgroup. + bool use_small = ctx->device->pipeline_argsort_f32[ncols_pad_log2] != nullptr; + + vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[ncols_pad_log2] + : ctx->device->pipeline_argsort_large_f32[ncols_pad_log2]; + + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer subbuf1 = dst_buf; + + // Reserve space for ivec2 per element, with rows padded to a power of two + if (!use_small) { + const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int); + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size }; + } + + std::array elements; + + elements[0] = ncolsp2; + elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + // First dispatch initializes tmp_idx and does the first N passes where + // there is only communication between threads in the same workgroup. + { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = 0; + pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2); + pc2.inner_start = 0; + pc2.inner_end = 100; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + } + if (!use_small) { + ggml_vk_sync_buffers(ctx, subctx); + // Loop over outer/inner passes, synchronizing between each pass. + for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) { + for (uint32_t inner = 0; inner < outer + 1; ++inner) { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = outer; + pc2.outer_end = outer + 1; + pc2.inner_start = inner; + pc2.inner_end = inner + 1; + // When the inner idx is large enough, there's only communication + // within a workgroup. So the remaining inner iterations can all + // run in the same dispatch. + if (outer - inner < ctx->device->max_workgroup_size_log2) { + pc2.inner_end = 100; + inner = outer; + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; + } } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -13695,7 +13766,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_LOG: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_OP_ARGSORT: - return op->ne[0] <= max_argsort_cols; + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + if (device->vulkan_memory_model) { + return op->ne[0] <= max_argsort_cols; + } else { + return op->ne[0] <= (1 << device->max_workgroup_size_log2); + } + } case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c4e68bc02370a..6db2f2b0addd0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -4,28 +4,26 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +layout(constant_id = 1) const int NCOLS_PADDED = 1024; +layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) buffer D {int data_d[];}; +layout (binding = 2) writeonly buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; uint nrows; uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; } p; -shared int dst_row[BLOCK_SIZE]; -shared A_TYPE a_sh[BLOCK_SIZE]; - -void swap(uint idx0, uint idx1) { - int tmp = dst_row[idx0]; - dst_row[idx0] = dst_row[idx1]; - dst_row[idx1] = tmp; -} +shared ivec2 dst_row[BLOCK_SIZE]; void argsort(bool needs_bounds_check, const uint row) { // bitonic sort @@ -34,11 +32,10 @@ void argsort(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; - a_sh[col] = data_a[row_offset + col]; + dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col])); barrier(); - uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { uint num_inner_loop_iters = outer_idx + 1; [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { @@ -47,14 +44,15 @@ void argsort(bool needs_bounds_check, const uint row) { int idx_0 = (col & k) == 0 ? col : ixj; int idx_1 = (col & k) == 0 ? ixj : col; - int sh_idx_0 = dst_row[idx_0]; - int sh_idx_1 = dst_row[idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; if ((idx_0_oob || - (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { - swap(idx_0, idx_1); + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; } barrier(); @@ -63,9 +61,9 @@ void argsort(bool needs_bounds_check, const uint row) { if (col < p.ncols) { if (p.order == ASC) { - data_d[row_offset + col] = dst_row[col]; + data_d[row_offset + col] = dst_row[col].x; } else { - data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + data_d[row_offset + p.ncols - col - 1] = dst_row[col].x; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp new file mode 100644 index 0000000000000..ea136ddc5b45d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -0,0 +1,108 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_memory_scope_semantics : enable +#pragma use_vulkan_memory_model + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int NCOLS_PADDED = 1024; +layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];}; +layout (binding = 2) workgroupcoherent buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint nrows; + uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; +} p; + +void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + const int col = int(gl_GlobalInvocationID.x); + + const uint row_offset = row * p.ncols; + uint idx_offset = row * NCOLS_PADDED; + + bool need_barrier = false; + + // initialize indices + if (p.outer_start == 0 && p.inner_start == 0) { + if (col < NCOLS_PADDED) { + ivec2 v = ivec2(col, floatBitsToInt(data_a[row_offset + col])); + tmp_idx[idx_offset + col] = v; + } + need_barrier = true; + } + + [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + if (inner_idx >= p.inner_start && inner_idx < p.inner_end) { + + if (need_barrier) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + } + need_barrier = true; + const int c = col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; + } + } + } + } + + if (p.outer_end == NCOLS_PADDED_LOG2 && + p.inner_end >= NCOLS_PADDED_LOG2+1) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + if (col < p.ncols) { + int c = col; + if (p.order == ASC) { + data_d[row_offset + c] = tmp_idx[idx_offset + c].x; + } else { + data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + } + } + } +} + +void main() { + if (p.ncols == NCOLS_PADDED) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 9c207f1e46cff..4210d736c5990 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -889,6 +889,7 @@ void process_shaders() { string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 267bead8c4ab7..ecd85976f0af3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7501,13 +7501,15 @@ static std::vector> make_test_cases_eval() { } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); + for (uint32_t i = 4; i <= 1024*1024; i *= 2) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1})); + } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order)); From 4cf884b22279ca892cbd8769b8b1c3911e5ccf10 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 18 Nov 2025 21:12:11 -0600 Subject: [PATCH 2/4] Reduce the number of shader variants. Use smaller workgroups when doing a single pass, for a modest perf boost --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 35 ++++++++++++------- .../ggml-vulkan/vulkan-shaders/argsort.comp | 5 +-- .../vulkan-shaders/argsort_large.comp | 14 ++++---- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0dd922fed3102..ee58838b9e7e3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -406,9 +406,8 @@ enum shader_reduction_mode { SHADER_REDUCTION_MODE_COUNT, }; -// Arbitrary limit for argsort size (about a million columns). -static constexpr uint32_t num_argsort_pipelines = 21; -static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); +// argsort pipelines for up to 1<<10 invocations per workgroup +static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_topk_moe_pipelines = 10; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, @@ -1179,6 +1178,8 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t ncols_padded; + uint32_t ncols_padded_log2; uint32_t nrows; uint32_t order; uint32_t outer_start; @@ -3902,13 +3903,12 @@ static void ggml_vk_load_shaders(vk_device& device) { for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { const uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); - const uint32_t NCOLS_PADDED = 1u << i; - const uint32_t NCOLS_PADDED_LOG2 = i; if (i <= device->max_workgroup_size_log2 && 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { - ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true); + const uint32_t NCOLS_PADDED_LOG2 = i; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); } - ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true); + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -9879,13 +9879,17 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols))); uint32_t ncolsp2 = 1 << ncols_pad_log2; - vk_op_argsort_push_constants pc { ncols, nrows, op_params[0], 0, 0, 0, 0, }; + vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, }; + + // Pick the largest workgroup size <= ncolsp2 + uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1); // Use the "small" argsort shader if the whole sort can be done by a single workgroup. - bool use_small = ctx->device->pipeline_argsort_f32[ncols_pad_log2] != nullptr; + bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 && + ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr; - vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[ncols_pad_log2] - : ctx->device->pipeline_argsort_large_f32[ncols_pad_log2]; + vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx] + : ctx->device->pipeline_argsort_large_f32[pipeline_idx]; vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0); vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); @@ -9935,9 +9939,13 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c // When the inner idx is large enough, there's only communication // within a workgroup. So the remaining inner iterations can all // run in the same dispatch. - if (outer - inner < ctx->device->max_workgroup_size_log2) { + if (outer - inner < pipeline_idx) { pc2.inner_end = 100; inner = outer; + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + } else { + // Smaller workgroup empirically seems to perform better + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2]; } ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); @@ -13772,8 +13780,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); + // pipeline_argsort_large_f32 requires vulkan memory model. if (device->vulkan_memory_model) { - return op->ne[0] <= max_argsort_cols; + return true; } else { return op->ne[0] <= (1 << device->max_workgroup_size_log2); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index 6db2f2b0addd0..0fc2b9b725350 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -4,8 +4,7 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int NCOLS_PADDED = 1024; -layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -15,6 +14,8 @@ layout (binding = 2) writeonly buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint ncols_padded; + uint ncols_padded_log2; uint nrows; uint order; uint outer_start; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp index ea136ddc5b45d..4ad4a0c23414e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -6,8 +6,6 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int NCOLS_PADDED = 1024; -layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -18,6 +16,8 @@ layout (binding = 2) workgroupcoherent buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint ncols_padded; + uint ncols_padded_log2; uint nrows; uint order; uint outer_start; @@ -31,13 +31,13 @@ void argsort(bool needs_bounds_check, const uint row) { const int col = int(gl_GlobalInvocationID.x); const uint row_offset = row * p.ncols; - uint idx_offset = row * NCOLS_PADDED; + uint idx_offset = row * p.ncols_padded; bool need_barrier = false; // initialize indices if (p.outer_start == 0 && p.inner_start == 0) { - if (col < NCOLS_PADDED) { + if (col < p.ncols_padded) { ivec2 v = ivec2(col, floatBitsToInt(data_a[row_offset + col])); tmp_idx[idx_offset + col] = v; } @@ -77,8 +77,8 @@ void argsort(bool needs_bounds_check, const uint row) { } } - if (p.outer_end == NCOLS_PADDED_LOG2 && - p.inner_end >= NCOLS_PADDED_LOG2+1) { + if (p.outer_end == p.ncols_padded_log2 && + p.inner_end >= p.ncols_padded_log2 + 1) { controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); if (col < p.ncols) { int c = col; @@ -92,7 +92,7 @@ void argsort(bool needs_bounds_check, const uint row) { } void main() { - if (p.ncols == NCOLS_PADDED) { + if (p.ncols == p.ncols_padded) { uint row = gl_WorkGroupID.y; while (row < p.nrows) { argsort(false, row); From 68bd013ab8b5b171af853b35be582069e36f8b40 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 18 Nov 2025 21:29:04 -0600 Subject: [PATCH 3/4] reduce loop overhead --- .../vulkan-shaders/argsort_large.comp | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp index 4ad4a0c23414e..a357b6a4a1da1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -45,34 +45,31 @@ void argsort(bool needs_bounds_check, const uint row) { } [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) { - uint num_inner_loop_iters = outer_idx + 1; - [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { - if (inner_idx >= p.inner_start && inner_idx < p.inner_end) { - - if (need_barrier) { - controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); - } - need_barrier = true; - const int c = col; - const int ixj = int(c ^ j); - - if (ixj < c) { - continue; - } - - int idx_0 = (c & k) == 0 ? c : ixj; - int idx_1 = (c & k) == 0 ? ixj : c; - - ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; - ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; - - if ((idx_0_oob || - (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { - tmp_idx[idx_offset + idx_0] = sh_idx_1; - tmp_idx[idx_offset + idx_1] = sh_idx_0; - } + uint inner_end = min(p.inner_end, outer_idx + 1); + for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) { + if (need_barrier) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + } + need_barrier = true; + const int c = col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; } } } From a19c81b0c2d4cf43fafce9fa7b99d8c7f7609778 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 18 Nov 2025 23:06:30 -0600 Subject: [PATCH 4/4] run multiple cols per invocation, to reduce barrier overhead --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +- .../vulkan-shaders/argsort_large.comp | 67 +++++++++++-------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ee58838b9e7e3..53d212cf79b6f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3902,13 +3902,15 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - const uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); + uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); if (i <= device->max_workgroup_size_log2 && 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { const uint32_t NCOLS_PADDED_LOG2 = i; ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); } - ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE}, 1, true); + const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1; + BLOCK_SIZE /= WG_UNROLL_FACTOR; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp index a357b6a4a1da1..920bac6bb8996 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -6,6 +6,7 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -28,7 +29,8 @@ layout (push_constant) uniform parameter { void argsort(bool needs_bounds_check, const uint row) { // bitonic sort - const int col = int(gl_GlobalInvocationID.x); + int col = int(gl_GlobalInvocationID.x); + col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR; const uint row_offset = row * p.ncols; uint idx_offset = row * p.ncols_padded; @@ -37,9 +39,12 @@ void argsort(bool needs_bounds_check, const uint row) { // initialize indices if (p.outer_start == 0 && p.inner_start == 0) { - if (col < p.ncols_padded) { - ivec2 v = ivec2(col, floatBitsToInt(data_a[row_offset + col])); - tmp_idx[idx_offset + col] = v; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols_padded) { + ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c])); + tmp_idx[idx_offset + c] = v; + } } need_barrier = true; } @@ -51,25 +56,27 @@ void argsort(bool needs_bounds_check, const uint row) { controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); } need_barrier = true; - const int c = col; - const int ixj = int(c ^ j); - - if (ixj < c) { - continue; - } - - int idx_0 = (c & k) == 0 ? c : ixj; - int idx_1 = (c & k) == 0 ? ixj : c; - - ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; - ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; - - if ((idx_0_oob || - (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { - tmp_idx[idx_offset + idx_0] = sh_idx_1; - tmp_idx[idx_offset + idx_1] = sh_idx_0; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + int c = u*BLOCK_SIZE + col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; + } } } } @@ -77,12 +84,14 @@ void argsort(bool needs_bounds_check, const uint row) { if (p.outer_end == p.ncols_padded_log2 && p.inner_end >= p.ncols_padded_log2 + 1) { controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); - if (col < p.ncols) { - int c = col; - if (p.order == ASC) { - data_d[row_offset + c] = tmp_idx[idx_offset + c].x; - } else { - data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + c] = tmp_idx[idx_offset + c].x; + } else { + data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + } } } }