Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 108 additions & 23 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ enum shader_reduction_mode {
SHADER_REDUCTION_MODE_COUNT,
};

static constexpr uint32_t num_argsort_pipelines = 11;
// Arbitrary limit for argsort size (about a million columns).
static constexpr uint32_t num_argsort_pipelines = 21;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;

Expand Down Expand Up @@ -526,6 +527,7 @@ struct vk_device_struct {
bool multi_add;
bool shader_int64;
bool buffer_device_address;
bool vulkan_memory_model;

bool add_rms_fusion;
uint32_t partials_binding_alignment;
Expand All @@ -539,6 +541,9 @@ struct vk_device_struct {
uint32_t subgroup_max_size;
bool subgroup_require_full_support;

// floor(log2(maxComputeWorkGroupInvocations))
uint32_t max_workgroup_size_log2 {};

bool coopmat_support;
bool coopmat_acc_f32_support {};
bool coopmat_acc_f16_support {};
Expand Down Expand Up @@ -683,6 +688,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
Expand Down Expand Up @@ -1174,7 +1180,11 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
uint32_t nrows;
int32_t order;
uint32_t order;
uint32_t outer_start;
uint32_t outer_end;
uint32_t inner_start;
uint32_t inner_end;
};

struct vk_op_im2col_push_constants {
Expand Down Expand Up @@ -3885,7 +3895,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
}

for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
const uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
const uint32_t NCOLS_PADDED = 1u << i;
const uint32_t NCOLS_PADDED_LOG2 = i;
if (i <= device->max_workgroup_size_log2 &&
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true);
}
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED, NCOLS_PADDED_LOG2}, 1, true);
}

ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
Expand Down Expand Up @@ -4286,6 +4303,8 @@ static vk_device ggml_vk_get_device(size_t idx) {

device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;

device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));

std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();

// Try to find a non-graphics compute queue and transfer-focused queues
Expand Down Expand Up @@ -4425,6 +4444,7 @@ static vk_device ggml_vk_get_device(size_t idx) {

device->shader_int64 = device_features2.features.shaderInt64;
device->buffer_device_address = vk12_features.bufferDeviceAddress;
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;

if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
Expand Down Expand Up @@ -8339,17 +8359,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
}

if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
}
GGML_ASSERT(0);
return nullptr;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
Expand Down Expand Up @@ -8742,8 +8752,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
GGML_ASSERT(0);
break;
case GGML_OP_IM2COL:
{
Expand Down Expand Up @@ -9859,16 +9868,81 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
}

static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params;
const uint32_t * op_params = (const uint32_t *)dst->op_params;

uint32_t ncols = src0->ne[0];
uint32_t nrows = ggml_nrows(src0);

ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
nrows,
op_params[0],
});
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
uint32_t ncolsp2 = 1 << ncols_pad_log2;

vk_op_argsort_push_constants pc { ncols, nrows, op_params[0], 0, 0, 0, 0, };

// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
bool use_small = ctx->device->pipeline_argsort_f32[ncols_pad_log2] != nullptr;

vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[ncols_pad_log2]
: ctx->device->pipeline_argsort_large_f32[ncols_pad_log2];

vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer subbuf1 = dst_buf;

// Reserve space for ivec2 per element, with rows padded to a power of two
if (!use_small) {
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);

if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
}

std::array<uint32_t, 3> elements;

elements[0] = ncolsp2;
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = 1;

// First dispatch initializes tmp_idx and does the first N passes where
// there is only communication between threads in the same workgroup.
{
vk_op_argsort_push_constants pc2 = pc;
pc2.outer_start = 0;
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
pc2.inner_start = 0;
pc2.inner_end = 100;
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
}
if (!use_small) {
ggml_vk_sync_buffers(ctx, subctx);
// Loop over outer/inner passes, synchronizing between each pass.
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
vk_op_argsort_push_constants pc2 = pc;
pc2.outer_start = outer;
pc2.outer_end = outer + 1;
pc2.inner_start = inner;
pc2.inner_end = inner + 1;
// When the inner idx is large enough, there's only communication
// within a workgroup. So the remaining inner iterations can all
// run in the same dispatch.
if (outer - inner < ctx->device->max_workgroup_size_log2) {
pc2.inner_end = 100;
inner = outer;
}
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
ggml_vk_sync_buffers(ctx, subctx);
}
}
ctx->prealloc_x_need_sync = true;
}
}

static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
Expand Down Expand Up @@ -13688,7 +13762,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_ARGSORT:
return op->ne[0] <= max_argsort_cols;
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->vulkan_memory_model) {
return op->ne[0] <= max_argsort_cols;
} else {
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
}
}
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
case GGML_OP_CONCAT:
Expand Down
40 changes: 19 additions & 21 deletions ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@
#include "types.glsl"

layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
layout(constant_id = 1) const int NCOLS_PADDED = 1024;
layout(constant_id = 2) const int NCOLS_PADDED_LOG2 = 10;
#define ASC 0

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (binding = 2) writeonly buffer D {int data_d[];};

layout (push_constant) uniform parameter {
uint ncols;
uint nrows;
uint order;
uint outer_start;
uint outer_end;
uint inner_start;
uint inner_end;
} p;

shared int dst_row[BLOCK_SIZE];
shared A_TYPE a_sh[BLOCK_SIZE];

void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx0] = dst_row[idx1];
dst_row[idx1] = tmp;
}
shared ivec2 dst_row[BLOCK_SIZE];

void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
Expand All @@ -34,11 +32,10 @@ void argsort(bool needs_bounds_check, const uint row) {
const uint row_offset = row * p.ncols;

// initialize indices
dst_row[col] = col;
a_sh[col] = data_a[row_offset + col];
dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
barrier();

uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
Expand All @@ -47,14 +44,15 @@ void argsort(bool needs_bounds_check, const uint row) {
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;

int sh_idx_0 = dst_row[idx_0];
int sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
ivec2 sh_idx_0 = dst_row[idx_0];
ivec2 sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;

if ((idx_0_oob ||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
swap(idx_0, idx_1);
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
dst_row[idx_0] = sh_idx_1;
dst_row[idx_1] = sh_idx_0;
}

barrier();
Expand All @@ -63,9 +61,9 @@ void argsort(bool needs_bounds_check, const uint row) {

if (col < p.ncols) {
if (p.order == ASC) {
data_d[row_offset + col] = dst_row[col];
data_d[row_offset + col] = dst_row[col].x;
} else {
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
}
}
}
Expand Down
Loading
Loading