From ed64f3f933a5cd904d9d2b7ab014382a84fd9247 Mon Sep 17 00:00:00 2001 From: jonah Date: Fri, 23 Jan 2026 17:59:50 -0800 Subject: [PATCH] optimized logcumsumexp --- profile_kernels.cu | 90 +++++++++++++++++-- pufferlib/extensions/cuda/kernels.cu | 124 +++++++++++++++++++++++++++ pufferlib/extensions/vecenv.h | 6 ++ 3 files changed, 211 insertions(+), 9 deletions(-) diff --git a/profile_kernels.cu b/profile_kernels.cu index 506538740..494a9340d 100644 --- a/profile_kernels.cu +++ b/profile_kernels.cu @@ -18,7 +18,7 @@ #ifdef USE_TORCH #include "pufferlib/extensions/pufferlib.cpp" -#include "pufferlib/extensions/cuda/modules.cu" +#include "pufferlib/extensions/cuda/kernels.cu" using namespace pufferlib; #else #include "pufferlib/extensions/cuda/kernels.cu" @@ -448,8 +448,11 @@ void profile_logcoeffsandvalues(int batch, int seq, int hidden) { typedef struct { float* x; float* out; - double* s_buf; + double* s_buf; // original (double) + float* out_opt; + float* s_buf_opt; // optimized (float) float* grad_x; + float* grad_x_opt; float* grad_out; int B; int T; @@ -467,7 +470,10 @@ LogcumsumexpArgs* create_logcumsumexpargs(int batch, int seq, int hidden) { cudaMalloc(&args->x, args->N * sizeof(float)); cudaMalloc(&args->out, args->N * sizeof(float)); cudaMalloc(&args->s_buf, args->N * sizeof(double)); + cudaMalloc(&args->out_opt, args->N * sizeof(float)); + cudaMalloc(&args->s_buf_opt, args->N * sizeof(float)); cudaMalloc(&args->grad_x, args->N * sizeof(float)); + cudaMalloc(&args->grad_x_opt, args->N * sizeof(float)); cudaMalloc(&args->grad_out, args->N * sizeof(float)); float* buf = (float*)malloc(args->N * sizeof(float) * 2); @@ -489,7 +495,10 @@ void free_logcumsumexpargs(LogcumsumexpArgs* args) { cudaFree(args->x); cudaFree(args->out); cudaFree(args->s_buf); + cudaFree(args->out_opt); + cudaFree(args->s_buf_opt); cudaFree(args->grad_x); + cudaFree(args->grad_x_opt); cudaFree(args->grad_out); free(args); } @@ -504,6 +513,16 @@ void run_logcumsumexp_backward(LogcumsumexpArgs* args) { args->grad_x, args->grad_out, args->x, args->s_buf, args->T, args->H, args->B, 0); } +void run_logcumsumexp_forward_opt(LogcumsumexpArgs* args) { + launch_logcumsumexp_forward_opt( + args->out_opt, args->s_buf_opt, args->x, args->T, args->H, args->B, 0); +} + +void run_logcumsumexp_backward_opt(LogcumsumexpArgs* args) { + launch_logcumsumexp_backward_opt( + args->grad_x_opt, args->grad_out, args->x, args->s_buf_opt, args->T, args->H, args->B, 0); +} + #ifdef USE_TORCH typedef struct { @@ -539,6 +558,44 @@ void run_logcumsumexp_forward_cpp(LogcumsumexpArgsTorch* args) { logcumsumexp_cpp(args->x); } +void test_logcumsumexp_opt_correct(LogcumsumexpArgs* args) { + // Run original forward + run_logcumsumexp_forward(args); + + // Run optimized forward + run_logcumsumexp_forward_opt(args); + + cudaDeviceSynchronize(); + + // Compare outputs using torch + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto out_orig = torch::from_blob(args->out, {args->B, args->T, args->H}, opts); + auto out_opt = torch::from_blob(args->out_opt, {args->B, args->T, args->H}, opts); + + float rtol = 1e-4f, atol = 1e-5f; + bool fwd_match = torch::allclose(out_orig, out_opt, rtol, atol); + float fwd_max_diff = (out_orig - out_opt).abs().max().item(); + + // Run backward passes (need to run forward first for s_buf) + run_logcumsumexp_forward(args); + run_logcumsumexp_backward(args); + + run_logcumsumexp_forward_opt(args); + run_logcumsumexp_backward_opt(args); + + cudaDeviceSynchronize(); + + auto grad_x_orig = torch::from_blob(args->grad_x, {args->B, args->T, args->H}, opts); + auto grad_x_opt = torch::from_blob(args->grad_x_opt, {args->B, args->T, args->H}, opts); + + bool bwd_match = torch::allclose(grad_x_orig, grad_x_opt, rtol, atol); + float bwd_max_diff = (grad_x_orig - grad_x_opt).abs().max().item(); + + printf(" optimized correctness: forward=%s(%.2e), backward=%s(%.2e)\n", + fwd_match ? "\033[32mok\033[0m" : "\033[31mFAIL\033[0m", fwd_max_diff, + bwd_match ? "\033[32mok\033[0m" : "\033[31mFAIL\033[0m", bwd_max_diff); +} + #endif void profile_logcumsumexp(int batch, int seq, int hidden) { @@ -546,33 +603,48 @@ void profile_logcumsumexp(int batch, int seq, int hidden) { printf("logcumsumexp (N=%d, %dx%dx%d)\n", args->N, batch, seq, hidden); +#ifdef USE_TORCH + // Test correctness first + test_logcumsumexp_opt_correct(args); +#endif + + // Profile original (double precision) float fwd_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward, args); - print_timing("\tforward", fwd_ms, batch*seq); + print_timing(" forward", fwd_ms, batch*seq); float bwd_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward, args); - print_timing("\tbackward", bwd_ms, batch*seq); + print_timing(" backward", bwd_ms, batch*seq); + + // Profile optimized (float32, branch-free) + float fwd_opt_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_opt, args); + printf(" forward (opt) %6.1f us %6.2f M elem/s \033[32m%.2fx speedup\033[0m\n", + fwd_opt_ms * 1000, (batch*seq) / fwd_opt_ms / 1e3, fwd_ms / fwd_opt_ms); + + float bwd_opt_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_opt, args); + printf(" backward (opt) %6.1f us %6.2f M elem/s \033[32m%.2fx speedup\033[0m\n", + bwd_opt_ms * 1000, (batch*seq) / bwd_opt_ms / 1e3, bwd_ms / bwd_opt_ms); #ifdef USE_TORCH LogcumsumexpArgsTorch* args_torch = create_logcumsumexpargs_torch(args); float fwd_torch_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_torch, args_torch); - print_timing("\tforward (torch)", fwd_torch_ms, batch*seq); + print_timing(" forward (torch)", fwd_torch_ms, batch*seq); args_torch->out = logcumsumexp_cuda(args_torch->x); float bwd_torch_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_torch, args_torch); - print_timing("\tbackward (torch)", bwd_torch_ms, batch*seq); + print_timing(" backward (torch)", bwd_torch_ms, batch*seq); float fwd_cpp_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_cpp, args_torch); - print_timing("\tforward (cpp)", fwd_cpp_ms, batch*seq); + print_timing(" forward (cpp)", fwd_cpp_ms, batch*seq); args_torch->out = logcumsumexp_cpp(args_torch->x); float bwd_cpp_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_torch, args_torch); - print_timing("\tbackward (cpp)", bwd_cpp_ms, batch*seq); + print_timing(" backward (cpp)", bwd_cpp_ms, batch*seq); float fwd_graph_ms = profile_graph((kernel_fn)run_logcumsumexp_forward_cpp, args_torch); - print_timing("\tforward (graph)", fwd_graph_ms, batch*seq); + print_timing(" forward (graph)", fwd_graph_ms, batch*seq); delete args_torch; #endif diff --git a/pufferlib/extensions/cuda/kernels.cu b/pufferlib/extensions/cuda/kernels.cu index cfa596c8b..78b7ff9b8 100644 --- a/pufferlib/extensions/cuda/kernels.cu +++ b/pufferlib/extensions/cuda/kernels.cu @@ -433,6 +433,19 @@ __device__ __forceinline__ double logcumsumexp_backward(double x, double* acc, d return *acc * exp(x - s); } +// float32 + branch free +__device__ __forceinline__ float logcumsumexp_forward_opt(float x, float acc) { + float min_val = fminf(acc, x); + float max_val = fmaxf(acc, x); + return max_val + log1pf(__expf(min_val - max_val)); +} + +__device__ __forceinline__ float logcumsumexp_backward_opt(float x, float* acc, float grad, float s, float* s_nxt) { + *acc = fmaf(*acc, __expf(s - *s_nxt), grad); + *s_nxt = s; + return *acc * __expf(x - s); +} + // Fully fused forward: chunk + log_coeffs_and_values + scan + sigmoid(proj)*out // Takes combined (B, T, 3*H) = [hidden, gate, proj] and outputs gated result template @@ -1142,6 +1155,111 @@ void launch_logcumsumexp_backward( fprintf(stderr, "Backward kernel error: %s\n", cudaGetErrorString(err)); } +// logcumsumexp in float32 +template +__global__ void logcumsumexp_forward_kernel_opt( + T* __restrict__ out, + float* __restrict__ s_buf, // FLOAT32 + const T* __restrict__ x, + int T_total, + int H, + int B +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * H) return; + + int b = idx / H; + int h = idx % H; + + int base = b * T_total * H + h; + + float s = -INFINITY; + + for (int t = 0; t < T_total; t++) { + int curr = base + t * H; + float x_val = float(x[curr]); + s = logcumsumexp_forward_opt(x_val, s); + out[curr] = T(s); + s_buf[curr] = s; + } +} + +template +__global__ void logcumsumexp_backward_kernel_opt( + T* __restrict__ grad_x, + const T* __restrict__ grad_out, + const T* __restrict__ x, + const float* __restrict__ s_buf, + int T_total, + int H, + int B +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * H) return; + + int b = idx / H; + int h = idx % H; + + int base = b * T_total * H + h; + + float acc = 0.0f; + float s_val_next = 0.0f; + + for (int t = T_total - 1; t >= 0; --t) { + int curr = base + t * H; + + float x_val = float(x[curr]); + float s_val = s_buf[curr]; + float g_val = float(grad_out[curr]); + grad_x[curr] = T(logcumsumexp_backward_opt(x_val, &acc, g_val, s_val, &s_val_next)); + } +} + +template +void launch_logcumsumexp_forward_opt( + T* out, + float* s_buf, + const T* x, + int T_total, + int H, + int B, + cudaStream_t stream +) { + int total = B * H; + int grid = grid_size(total); + + logcumsumexp_forward_kernel_opt<<>>( + out, s_buf, x, T_total, H, B + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + fprintf(stderr, "Forward opt kernel error: %s\n", cudaGetErrorString(err)); +} + +template +void launch_logcumsumexp_backward_opt( + T* grad_x, + const T* grad_out, + const T* x, + const float* s_buf, + int T_total, + int H, + int B, + cudaStream_t stream +) { + int total = B * H; + int grid = grid_size(total); + + logcumsumexp_backward_kernel_opt<<>>( + grad_x, grad_out, x, s_buf, T_total, H, B + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + fprintf(stderr, "Backward opt kernel error: %s\n", cudaGetErrorString(err)); +} + template __global__ void ppo_loss_forward_kernel( float* __restrict__ loss, @@ -1682,6 +1800,12 @@ void launch_logcumsumexp_forward_float(float* out, double* s_buf, const float* x void launch_logcumsumexp_backward_float(float* grad_x, const float* grad_out, const float* x, const double* s_buf, int T_total, int H, int B, cudaStream_t stream) { launch_logcumsumexp_backward(grad_x, grad_out, x, s_buf, T_total, H, B, stream); } +void launch_logcumsumexp_forward_opt_float(float* out, float* s_buf, const float* x, int T_total, int H, int B, cudaStream_t stream) { + launch_logcumsumexp_forward_opt(out, s_buf, x, T_total, H, B, stream); +} +void launch_logcumsumexp_backward_opt_float(float* grad_x, const float* grad_out, const float* x, const float* s_buf, int T_total, int H, int B, cudaStream_t stream) { + launch_logcumsumexp_backward_opt(grad_x, grad_out, x, s_buf, T_total, H, B, stream); +} void launch_ppo_loss_forward_float(float* loss_output, double* saved_for_backward, const float* logits, const float* values_pred, const int64_t* actions, const float* old_logprobs, const float* advantages, const float* prio, const float* values, const float* returns, const float* adv_mean, const float* adv_std, double clip_coef, double vf_clip_coef, double vf_coef, double ent_coef, int T_seq, int A, int N, cudaStream_t stream) { launch_ppo_loss_forward(loss_output, saved_for_backward, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_std, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, stream); } diff --git a/pufferlib/extensions/vecenv.h b/pufferlib/extensions/vecenv.h index b80cb3e30..b4789a055 100644 --- a/pufferlib/extensions/vecenv.h +++ b/pufferlib/extensions/vecenv.h @@ -5,7 +5,9 @@ #include #include #include +#ifndef __cplusplus #include +#endif #include #define FLOAT 1 @@ -80,6 +82,10 @@ void dict_set(Dict* dict, const char* key, double value) { dict->size++; } +void dict_set_int(Dict* dict, const char* key, int value) { + dict_set(dict, key, (double)value); +} + void dict_set_ptr(Dict* dict, const char* key, void* ptr) { assert(dict->size < dict->capacity); DictItem* item = dict_get_unsafe(dict, key);