Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 81 additions & 9 deletions profile_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#ifdef USE_TORCH
#include "pufferlib/extensions/pufferlib.cpp"
#include "pufferlib/extensions/cuda/modules.cu"
#include "pufferlib/extensions/cuda/kernels.cu"
using namespace pufferlib;
#else
#include "pufferlib/extensions/cuda/kernels.cu"
Expand Down Expand Up @@ -448,8 +448,11 @@ void profile_logcoeffsandvalues(int batch, int seq, int hidden) {
typedef struct {
float* x;
float* out;
double* s_buf;
double* s_buf; // original (double)
float* out_opt;
float* s_buf_opt; // optimized (float)
float* grad_x;
float* grad_x_opt;
float* grad_out;
int B;
int T;
Expand All @@ -467,7 +470,10 @@ LogcumsumexpArgs* create_logcumsumexpargs(int batch, int seq, int hidden) {
cudaMalloc(&args->x, args->N * sizeof(float));
cudaMalloc(&args->out, args->N * sizeof(float));
cudaMalloc(&args->s_buf, args->N * sizeof(double));
cudaMalloc(&args->out_opt, args->N * sizeof(float));
cudaMalloc(&args->s_buf_opt, args->N * sizeof(float));
cudaMalloc(&args->grad_x, args->N * sizeof(float));
cudaMalloc(&args->grad_x_opt, args->N * sizeof(float));
cudaMalloc(&args->grad_out, args->N * sizeof(float));

float* buf = (float*)malloc(args->N * sizeof(float) * 2);
Expand All @@ -489,7 +495,10 @@ void free_logcumsumexpargs(LogcumsumexpArgs* args) {
cudaFree(args->x);
cudaFree(args->out);
cudaFree(args->s_buf);
cudaFree(args->out_opt);
cudaFree(args->s_buf_opt);
cudaFree(args->grad_x);
cudaFree(args->grad_x_opt);
cudaFree(args->grad_out);
free(args);
}
Expand All @@ -504,6 +513,16 @@ void run_logcumsumexp_backward(LogcumsumexpArgs* args) {
args->grad_x, args->grad_out, args->x, args->s_buf, args->T, args->H, args->B, 0);
}

void run_logcumsumexp_forward_opt(LogcumsumexpArgs* args) {
launch_logcumsumexp_forward_opt<float>(
args->out_opt, args->s_buf_opt, args->x, args->T, args->H, args->B, 0);
}

void run_logcumsumexp_backward_opt(LogcumsumexpArgs* args) {
launch_logcumsumexp_backward_opt<float>(
args->grad_x_opt, args->grad_out, args->x, args->s_buf_opt, args->T, args->H, args->B, 0);
}

#ifdef USE_TORCH

typedef struct {
Expand Down Expand Up @@ -539,40 +558,93 @@ void run_logcumsumexp_forward_cpp(LogcumsumexpArgsTorch* args) {
logcumsumexp_cpp(args->x);
}

void test_logcumsumexp_opt_correct(LogcumsumexpArgs* args) {
// Run original forward
run_logcumsumexp_forward(args);

// Run optimized forward
run_logcumsumexp_forward_opt(args);

cudaDeviceSynchronize();

// Compare outputs using torch
auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto out_orig = torch::from_blob(args->out, {args->B, args->T, args->H}, opts);
auto out_opt = torch::from_blob(args->out_opt, {args->B, args->T, args->H}, opts);

float rtol = 1e-4f, atol = 1e-5f;
bool fwd_match = torch::allclose(out_orig, out_opt, rtol, atol);
float fwd_max_diff = (out_orig - out_opt).abs().max().item<float>();

// Run backward passes (need to run forward first for s_buf)
run_logcumsumexp_forward(args);
run_logcumsumexp_backward(args);

run_logcumsumexp_forward_opt(args);
run_logcumsumexp_backward_opt(args);

cudaDeviceSynchronize();

auto grad_x_orig = torch::from_blob(args->grad_x, {args->B, args->T, args->H}, opts);
auto grad_x_opt = torch::from_blob(args->grad_x_opt, {args->B, args->T, args->H}, opts);

bool bwd_match = torch::allclose(grad_x_orig, grad_x_opt, rtol, atol);
float bwd_max_diff = (grad_x_orig - grad_x_opt).abs().max().item<float>();

printf(" optimized correctness: forward=%s(%.2e), backward=%s(%.2e)\n",
fwd_match ? "\033[32mok\033[0m" : "\033[31mFAIL\033[0m", fwd_max_diff,
bwd_match ? "\033[32mok\033[0m" : "\033[31mFAIL\033[0m", bwd_max_diff);
}

#endif

void profile_logcumsumexp(int batch, int seq, int hidden) {
LogcumsumexpArgs* args = create_logcumsumexpargs(batch, seq, hidden);

printf("logcumsumexp (N=%d, %dx%dx%d)\n", args->N, batch, seq, hidden);

#ifdef USE_TORCH
// Test correctness first
test_logcumsumexp_opt_correct(args);
#endif

// Profile original (double precision)
float fwd_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward, args);
print_timing("\tforward", fwd_ms, batch*seq);
print_timing(" forward", fwd_ms, batch*seq);

float bwd_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward, args);
print_timing("\tbackward", bwd_ms, batch*seq);
print_timing(" backward", bwd_ms, batch*seq);

// Profile optimized (float32, branch-free)
float fwd_opt_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_opt, args);
printf(" forward (opt) %6.1f us %6.2f M elem/s \033[32m%.2fx speedup\033[0m\n",
fwd_opt_ms * 1000, (batch*seq) / fwd_opt_ms / 1e3, fwd_ms / fwd_opt_ms);

float bwd_opt_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_opt, args);
printf(" backward (opt) %6.1f us %6.2f M elem/s \033[32m%.2fx speedup\033[0m\n",
bwd_opt_ms * 1000, (batch*seq) / bwd_opt_ms / 1e3, bwd_ms / bwd_opt_ms);

#ifdef USE_TORCH
LogcumsumexpArgsTorch* args_torch = create_logcumsumexpargs_torch(args);

float fwd_torch_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_torch, args_torch);
print_timing("\tforward (torch)", fwd_torch_ms, batch*seq);
print_timing(" forward (torch)", fwd_torch_ms, batch*seq);

args_torch->out = logcumsumexp_cuda(args_torch->x);

float bwd_torch_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_torch, args_torch);
print_timing("\tbackward (torch)", bwd_torch_ms, batch*seq);
print_timing(" backward (torch)", bwd_torch_ms, batch*seq);

float fwd_cpp_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_cpp, args_torch);
print_timing("\tforward (cpp)", fwd_cpp_ms, batch*seq);
print_timing(" forward (cpp)", fwd_cpp_ms, batch*seq);

args_torch->out = logcumsumexp_cpp(args_torch->x);

float bwd_cpp_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_torch, args_torch);
print_timing("\tbackward (cpp)", bwd_cpp_ms, batch*seq);
print_timing(" backward (cpp)", bwd_cpp_ms, batch*seq);

float fwd_graph_ms = profile_graph((kernel_fn)run_logcumsumexp_forward_cpp, args_torch);
print_timing("\tforward (graph)", fwd_graph_ms, batch*seq);
print_timing(" forward (graph)", fwd_graph_ms, batch*seq);

delete args_torch;
#endif
Expand Down
124 changes: 124 additions & 0 deletions pufferlib/extensions/cuda/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,19 @@ __device__ __forceinline__ double logcumsumexp_backward(double x, double* acc, d
return *acc * exp(x - s);
}

// float32 + branch free
__device__ __forceinline__ float logcumsumexp_forward_opt(float x, float acc) {
float min_val = fminf(acc, x);
float max_val = fmaxf(acc, x);
return max_val + log1pf(__expf(min_val - max_val));
}

__device__ __forceinline__ float logcumsumexp_backward_opt(float x, float* acc, float grad, float s, float* s_nxt) {
*acc = fmaf(*acc, __expf(s - *s_nxt), grad);
*s_nxt = s;
return *acc * __expf(x - s);
}

// Fully fused forward: chunk + log_coeffs_and_values + scan + sigmoid(proj)*out
// Takes combined (B, T, 3*H) = [hidden, gate, proj] and outputs gated result
template<typename T>
Expand Down Expand Up @@ -1142,6 +1155,111 @@ void launch_logcumsumexp_backward(
fprintf(stderr, "Backward kernel error: %s\n", cudaGetErrorString(err));
}

// logcumsumexp in float32
template<typename T>
__global__ void logcumsumexp_forward_kernel_opt(
T* __restrict__ out,
float* __restrict__ s_buf, // FLOAT32
const T* __restrict__ x,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;

int b = idx / H;
int h = idx % H;

int base = b * T_total * H + h;

float s = -INFINITY;

for (int t = 0; t < T_total; t++) {
int curr = base + t * H;
float x_val = float(x[curr]);
s = logcumsumexp_forward_opt(x_val, s);
out[curr] = T(s);
s_buf[curr] = s;
}
}

template<typename T>
__global__ void logcumsumexp_backward_kernel_opt(
T* __restrict__ grad_x,
const T* __restrict__ grad_out,
const T* __restrict__ x,
const float* __restrict__ s_buf,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;

int b = idx / H;
int h = idx % H;

int base = b * T_total * H + h;

float acc = 0.0f;
float s_val_next = 0.0f;

for (int t = T_total - 1; t >= 0; --t) {
int curr = base + t * H;

float x_val = float(x[curr]);
float s_val = s_buf[curr];
float g_val = float(grad_out[curr]);
grad_x[curr] = T(logcumsumexp_backward_opt(x_val, &acc, g_val, s_val, &s_val_next));
}
}

template<typename T>
void launch_logcumsumexp_forward_opt(
T* out,
float* s_buf,
const T* x,
int T_total,
int H,
int B,
cudaStream_t stream
) {
int total = B * H;
int grid = grid_size(total);

logcumsumexp_forward_kernel_opt<T><<<grid, BLOCK_SIZE, 0, stream>>>(
out, s_buf, x, T_total, H, B
);

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
fprintf(stderr, "Forward opt kernel error: %s\n", cudaGetErrorString(err));
}

template<typename T>
void launch_logcumsumexp_backward_opt(
T* grad_x,
const T* grad_out,
const T* x,
const float* s_buf,
int T_total,
int H,
int B,
cudaStream_t stream
) {
int total = B * H;
int grid = grid_size(total);

logcumsumexp_backward_kernel_opt<T><<<grid, BLOCK_SIZE, 0, stream>>>(
grad_x, grad_out, x, s_buf, T_total, H, B
);

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
fprintf(stderr, "Backward opt kernel error: %s\n", cudaGetErrorString(err));
}

template<typename T>
__global__ void ppo_loss_forward_kernel(
float* __restrict__ loss,
Expand Down Expand Up @@ -1682,6 +1800,12 @@ void launch_logcumsumexp_forward_float(float* out, double* s_buf, const float* x
void launch_logcumsumexp_backward_float(float* grad_x, const float* grad_out, const float* x, const double* s_buf, int T_total, int H, int B, cudaStream_t stream) {
launch_logcumsumexp_backward<float>(grad_x, grad_out, x, s_buf, T_total, H, B, stream);
}
void launch_logcumsumexp_forward_opt_float(float* out, float* s_buf, const float* x, int T_total, int H, int B, cudaStream_t stream) {
launch_logcumsumexp_forward_opt<float>(out, s_buf, x, T_total, H, B, stream);
}
void launch_logcumsumexp_backward_opt_float(float* grad_x, const float* grad_out, const float* x, const float* s_buf, int T_total, int H, int B, cudaStream_t stream) {
launch_logcumsumexp_backward_opt<float>(grad_x, grad_out, x, s_buf, T_total, H, B, stream);
}
void launch_ppo_loss_forward_float(float* loss_output, double* saved_for_backward, const float* logits, const float* values_pred, const int64_t* actions, const float* old_logprobs, const float* advantages, const float* prio, const float* values, const float* returns, const float* adv_mean, const float* adv_std, double clip_coef, double vf_clip_coef, double vf_coef, double ent_coef, int T_seq, int A, int N, cudaStream_t stream) {
launch_ppo_loss_forward<float>(loss_output, saved_for_backward, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_std, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, stream);
}
Expand Down
6 changes: 6 additions & 0 deletions pufferlib/extensions/vecenv.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#ifndef __cplusplus
#include <stdatomic.h>
#endif
#include <cuda_runtime.h>

#define FLOAT 1
Expand Down Expand Up @@ -80,6 +82,10 @@ void dict_set(Dict* dict, const char* key, double value) {
dict->size++;
}

void dict_set_int(Dict* dict, const char* key, int value) {
dict_set(dict, key, (double)value);
}

void dict_set_ptr(Dict* dict, const char* key, void* ptr) {
assert(dict->size < dict->capacity);
DictItem* item = dict_get_unsafe(dict, key);
Expand Down