Skip to content

vulkan: optimize rms_norm, and allow the work to spread across multiple SMs #15281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 163 additions & 31 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

42 changes: 41 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/add.comp
Original file line number Diff line number Diff line change
@@ -1,29 +1,69 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif

#include "types.comp"
#include "generic_binary_head.comp"

const uint num_threads = 256;

layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};

layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;

#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif

void main() {
uint idx = get_idx();
uint orig_idx = idx;

// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;

FLOAT_TYPE sum_sq = 0;

[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);

data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
sum_sq += sum*sum;

data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);

idx += num_threads;
}

#if ADD_RMS
if (p.param3 != 0) {
// reduce the sum within each subgroup, then across subgroups
const uint NumSubgroups = num_threads / gl_SubgroupSize;
sum_sq = subgroupAdd(sum_sq);
if (gl_SubgroupInvocationID == 0) {
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
sum_sq += sumsh[gl_SubgroupID + s];
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
}

if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
}
}
#endif
}
42 changes: 41 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_nonuniform_qualifier : enable
#extension GL_EXT_control_flow_attributes : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif

#include "rte.comp"
#include "types.comp"
Expand All @@ -14,12 +18,16 @@ layout (push_constant) uniform parameter2
uint ne20; uint ne21; uint ne22; uint ne23;

// strides for srcs+dst
uint nb[8][4];
uint nb[12][4];

uint rms_partials;
} p;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];

layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];

layout(constant_id = 0) const uint num_srcs = 2;

uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
Expand All @@ -42,14 +50,22 @@ const uint num_threads = 256;

layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;

#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif

void main() {
uint idx = get_idx();
uint orig_idx = idx;

uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;

// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;

FLOAT_TYPE sum_sq = 0;

[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= ne) {
continue;
Expand All @@ -61,8 +77,32 @@ void main() {
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
}
sum_sq += sum*sum;
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);

idx += num_threads;
}

#if ADD_RMS
if (p.rms_partials != 0) {
// reduce the sum within each subgroup, then across subgroups
const uint NumSubgroups = num_threads / gl_SubgroupSize;
sum_sq = subgroupAdd(sum_sq);
if (gl_SubgroupInvocationID == 0) {
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
sum_sq += sumsh[gl_SubgroupID + s];
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
}

if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
}
}
#endif
}
60 changes: 49 additions & 11 deletions ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE sum[BLOCK_SIZE];
shared FLOAT_TYPE sumsh[BLOCK_SIZE];

void main() {
void rms_norm(uint num_iters) {
const uint ncols = p.ne00;
const uint nrows = gl_NumWorkGroups.x;
const uint nchannels = gl_NumWorkGroups.y;
Expand All @@ -30,38 +30,76 @@ void main() {
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();

sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp

[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
sum[tid] += xi * xi;
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
FLOAT_TYPE xi = FLOAT_TYPE(0);
if (col < ncols) {
xi = FLOAT_TYPE(data_a[a_offset + col]);
}
sum += xi * xi;
}

sumsh[tid] = sum;
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum[tid] += sum[tid + s];
sum += sumsh[tid + s];
sumsh[tid] = sum;
}
barrier();
}
sum = sumsh[0];

const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));

if (do_multiply) {
if (ncols > p.ne10) {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
if (col >= ncols) {
continue;
}
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
}
} else {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
if (col >= ncols) {
continue;
}
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
}
}
} else {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
if (col >= ncols) {
continue;
}
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}
}

void main() {
// instantiate the rms_norm function for several different
// dimensions, to allow loop unrolling
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
if (num_blocks > 32) {
rms_norm(num_blocks);
} else if (num_blocks > 16) {
rms_norm(32);
} else if (num_blocks > 8) {
rms_norm(16);
} else if (num_blocks > 4) {
rms_norm(8);
} else if (num_blocks == 4) {
rms_norm(4);
} else if (num_blocks == 3) {
rms_norm(3);
} else if (num_blocks == 2) {
rms_norm(2);
} else if (num_blocks == 1) {
rms_norm(1);
}
}
65 changes: 65 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#version 450

#include "generic_binary_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable

#define BLOCK_SIZE 128

layout (constant_id = 1) const bool do_multiply = false;

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};

shared FLOAT_TYPE sumsh[BLOCK_SIZE];

void main() {
const uint ncols = p.ne00;
const uint nrows = gl_NumWorkGroups.x;
const uint nchannels = gl_NumWorkGroups.y;

const uint row = 0;
const uint channel = gl_WorkGroupID.y;
const uint samp = gl_WorkGroupID.z;
// The work is split across multiple workgroups in the x dimension. Each invocation
// processes one element
const uint tid = gl_GlobalInvocationID.x;

const uint stride_row = p.nb01;
const uint stride_channel = p.nb02;
const uint stride_sample = p.nb03;

uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();

FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp

uint32_t num_partials = p.param3;
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
sum += partial_sums[i];
}
sum = subgroupAdd(sum);

uint col = tid;
if (col >= ncols) {
return;
}

const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));

if (do_multiply) {
if (ncols > p.ne10) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
} else {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
}
} else {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}
12 changes: 8 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ void process_shaders() {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

Expand Down Expand Up @@ -534,13 +535,15 @@ void process_shaders() {
s += std::string(dst_f16 ? "_f16" : "_f32");
return s;
};
for (std::string op : {"add", "sub", "mul", "div"}) {
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
for (auto src0_f16 : {false, true}) {
for (auto src1_f16 : {false, true}) {
for (auto dst_f16 : {false, true}) {
for (auto rte : {false, true}) {
auto source = op == "add_rms" ? std::string("add") : op;
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
auto add_rms = op == "add_rms" ? "1" : "0";
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
}
}
}
Expand Down Expand Up @@ -677,7 +680,8 @@ void process_shaders() {

string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));

string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});

for (auto &c : compiles) {
c.wait();
Expand Down Expand Up @@ -735,7 +739,7 @@ void write_output_files() {
}

std::string suffixes[2] = {"_f32", "_f16"};
for (const char *op : {"add", "sub", "mul", "div"}) {
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
Expand Down
Loading
Loading