diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 11262c19894c8..53d212cf79b6f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -406,8 +406,8 @@ enum shader_reduction_mode { SHADER_REDUCTION_MODE_COUNT, }; +// argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; -static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, @@ -526,6 +526,7 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + bool vulkan_memory_model; bool add_rms_fusion; uint32_t partials_binding_alignment; @@ -539,6 +540,9 @@ struct vk_device_struct { uint32_t subgroup_max_size; bool subgroup_require_full_support; + // floor(log2(maxComputeWorkGroupInvocations)) + uint32_t max_workgroup_size_log2 {}; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -683,6 +687,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; @@ -1173,8 +1178,14 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t ncols_padded; + uint32_t ncols_padded_log2; uint32_t nrows; - int32_t order; + uint32_t order; + uint32_t outer_start; + uint32_t outer_end; + uint32_t inner_start; + uint32_t inner_end; }; struct vk_op_im2col_push_constants { @@ -3891,7 +3902,15 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<max_workgroup_size_log2); + if (i <= device->max_workgroup_size_log2 && + 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t NCOLS_PADDED_LOG2 = i; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1; + BLOCK_SIZE /= WG_UNROLL_FACTOR; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -4292,6 +4311,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -4431,6 +4452,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_int64 = device_features2.features.shaderInt64; device->buffer_device_address = vk12_features.bufferDeviceAddress; + device->vulkan_memory_model = vk12_features.vulkanMemoryModel; if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; @@ -8344,19 +8366,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } - case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; - } - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - return ctx->device->pipeline_argsort_f32[idx]; - } - return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -8748,8 +8757,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: - elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; - elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + GGML_ASSERT(0); break; case GGML_OP_IM2COL: { @@ -9865,16 +9873,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - int32_t * op_params = (int32_t *)dst->op_params; + const uint32_t * op_params = (const uint32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; uint32_t nrows = ggml_nrows(src0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { - ncols, - nrows, - op_params[0], - }); + uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols))); + uint32_t ncolsp2 = 1 << ncols_pad_log2; + + vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, }; + + // Pick the largest workgroup size <= ncolsp2 + uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1); + + // Use the "small" argsort shader if the whole sort can be done by a single workgroup. + bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 && + ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr; + + vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx] + : ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer subbuf1 = dst_buf; + + // Reserve space for ivec2 per element, with rows padded to a power of two + if (!use_small) { + const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int); + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size }; + } + + std::array elements; + + elements[0] = ncolsp2; + elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + // First dispatch initializes tmp_idx and does the first N passes where + // there is only communication between threads in the same workgroup. + { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = 0; + pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2); + pc2.inner_start = 0; + pc2.inner_end = 100; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + } + if (!use_small) { + ggml_vk_sync_buffers(ctx, subctx); + // Loop over outer/inner passes, synchronizing between each pass. + for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) { + for (uint32_t inner = 0; inner < outer + 1; ++inner) { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = outer; + pc2.outer_end = outer + 1; + pc2.inner_start = inner; + pc2.inner_end = inner + 1; + // When the inner idx is large enough, there's only communication + // within a workgroup. So the remaining inner iterations can all + // run in the same dispatch. + if (outer - inner < pipeline_idx) { + pc2.inner_end = 100; + inner = outer; + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + } else { + // Smaller workgroup empirically seems to perform better + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2]; + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; + } } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -13695,7 +13776,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_LOG: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_OP_ARGSORT: - return op->ne[0] <= max_argsort_cols; + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // pipeline_argsort_large_f32 requires vulkan memory model. + if (device->vulkan_memory_model) { + return true; + } else { + return op->ne[0] <= (1 << device->max_workgroup_size_log2); + } + } case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c4e68bc02370a..0fc2b9b725350 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -4,28 +4,27 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) buffer D {int data_d[];}; +layout (binding = 2) writeonly buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint ncols_padded; + uint ncols_padded_log2; uint nrows; uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; } p; -shared int dst_row[BLOCK_SIZE]; -shared A_TYPE a_sh[BLOCK_SIZE]; - -void swap(uint idx0, uint idx1) { - int tmp = dst_row[idx0]; - dst_row[idx0] = dst_row[idx1]; - dst_row[idx1] = tmp; -} +shared ivec2 dst_row[BLOCK_SIZE]; void argsort(bool needs_bounds_check, const uint row) { // bitonic sort @@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; - a_sh[col] = data_a[row_offset + col]; + dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col])); barrier(); - uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { uint num_inner_loop_iters = outer_idx + 1; [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { @@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) { int idx_0 = (col & k) == 0 ? col : ixj; int idx_1 = (col & k) == 0 ? ixj : col; - int sh_idx_0 = dst_row[idx_0]; - int sh_idx_1 = dst_row[idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; if ((idx_0_oob || - (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { - swap(idx_0, idx_1); + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; } barrier(); @@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) { if (col < p.ncols) { if (p.order == ASC) { - data_d[row_offset + col] = dst_row[col]; + data_d[row_offset + col] = dst_row[col].x; } else { - data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + data_d[row_offset + p.ncols - col - 1] = dst_row[col].x; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp new file mode 100644 index 0000000000000..920bac6bb8996 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -0,0 +1,114 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_memory_scope_semantics : enable +#pragma use_vulkan_memory_model + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];}; +layout (binding = 2) workgroupcoherent buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_padded; + uint ncols_padded_log2; + uint nrows; + uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; +} p; + +void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + int col = int(gl_GlobalInvocationID.x); + col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR; + + const uint row_offset = row * p.ncols; + uint idx_offset = row * p.ncols_padded; + + bool need_barrier = false; + + // initialize indices + if (p.outer_start == 0 && p.inner_start == 0) { + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols_padded) { + ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c])); + tmp_idx[idx_offset + c] = v; + } + } + need_barrier = true; + } + + [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) { + uint inner_end = min(p.inner_end, outer_idx + 1); + for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) { + if (need_barrier) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + } + need_barrier = true; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + int c = u*BLOCK_SIZE + col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; + } + } + } + } + + if (p.outer_end == p.ncols_padded_log2 && + p.inner_end >= p.ncols_padded_log2 + 1) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + c] = tmp_idx[idx_offset + c].x; + } else { + data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + } + } + } + } +} + +void main() { + if (p.ncols == p.ncols_padded) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 9c207f1e46cff..4210d736c5990 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -889,6 +889,7 @@ void process_shaders() { string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 267bead8c4ab7..ecd85976f0af3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7501,13 +7501,15 @@ static std::vector> make_test_cases_eval() { } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); + for (uint32_t i = 4; i <= 1024*1024; i *= 2) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1})); + } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));