Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pufferlib/extensions/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pybind11::dict log_environments(pybind11::object pufferl_obj) {

Tensor initial_state(pybind11::object pufferl_obj, int64_t batch_size, torch::Device device) {
auto& pufferl = pufferl_obj.cast<PuffeRL&>();
return pufferl.policy->initial_state(batch_size, device);
return pufferl.policy_bf16->initial_state(batch_size, device);
}

void python_vec_recv(pybind11::object pufferl_obj, int buf) {
Expand Down Expand Up @@ -129,6 +129,7 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl(pybind11::dict kwargs) {
hypers.kernels = get_config(kwargs, "kernels");
hypers.profile = get_config(kwargs, "profile");
hypers.use_omp = get_config(kwargs, "use_omp");
hypers.bf16 = get_config(kwargs, "bf16");

std::string env_name = kwargs["env_name"].cast<std::string>();
Dict* vec_kwargs = py_dict_to_c_dict(kwargs["vec_kwargs"].cast<py::dict>());
Expand Down Expand Up @@ -228,7 +229,8 @@ PYBIND11_MODULE(_C, m) {

m.def("create_pufferl", &create_pufferl);
py::class_<PuffeRL, std::unique_ptr<PuffeRL>>(m, "PuffeRL")
.def_readwrite("policy", &PuffeRL::policy)
.def_readwrite("policy_bf16", &PuffeRL::policy_bf16)
.def_readwrite("policy_fp32", &PuffeRL::policy_fp32)
.def_readwrite("muon", &PuffeRL::muon)
.def_readwrite("hypers", &PuffeRL::hypers)
.def_readwrite("rollouts", &PuffeRL::rollouts);
Expand Down
83 changes: 59 additions & 24 deletions pufferlib/extensions/cuda/advantage.cu
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/util/BFloat16.h>

namespace pufferlib {

__host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards, float* dones,
float* importance, float* advantages, float gamma, float lambda,
// TIn = input type (bf16 or float), TOut = output type (always float for precision)
template<typename TIn, typename TOut>
__host__ __device__ void puff_advantage_row_cuda(const TIn* values, const TIn* rewards, const TIn* dones,
const TIn* importance, TOut* advantages, float gamma, float lambda,
float rho_clip, float c_clip, int horizon) {
float lastpufferlam = 0;
for (int t = horizon-2; t >= 0; t--) {
int t_next = t + 1;
float nextnonterminal = 1.0 - dones[t_next];
float rho_t = fminf(importance[t], rho_clip);
float c_t = fminf(importance[t], c_clip);
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
float nextnonterminal = 1.0f - float(dones[t_next]);
float imp = float(importance[t]);
float rho_t = fminf(imp, rho_clip);
float c_t = fminf(imp, c_clip);
float delta = rho_t*(float(rewards[t_next]) + gamma*float(values[t_next])*nextnonterminal - float(values[t]));
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
advantages[t] = lastpufferlam;
advantages[t] = TOut(lastpufferlam);
}
}

Expand All @@ -25,32 +29,42 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,

// Validate input tensors
torch::Device device = values.device();
for (const torch::Tensor& t : {values, rewards, dones, importance, advantages}) {
auto input_dtype = values.dtype();
for (const torch::Tensor& t : {values, rewards, dones, importance}) {
TORCH_CHECK(t.dim() == 2, "Tensor must be 2D");
TORCH_CHECK(t.device() == device, "All tensors must be on same device");
TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps");
TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon");
TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32");
TORCH_CHECK(t.dtype() == input_dtype, "Input tensors must have matching dtype");
if (!t.is_contiguous()) {
t.contiguous();
}
}
// advantages can be different dtype (fp32 for precision)
TORCH_CHECK(advantages.dim() == 2, "Advantages must be 2D");
TORCH_CHECK(advantages.device() == device, "Advantages must be on same device");
TORCH_CHECK(advantages.size(0) == num_steps, "Advantages first dimension must match");
TORCH_CHECK(advantages.size(1) == horizon, "Advantages second dimension must match");
if (!advantages.is_contiguous()) {
advantages.contiguous();
}
}

// [num_steps, horizon]
__global__ void puff_advantage_kernel(float* values, float* rewards,
float* dones, float* importance, float* advantages, float gamma,
template<typename TIn, typename TOut>
__global__ void puff_advantage_kernel(const TIn* values, const TIn* rewards,
const TIn* dones, const TIn* importance, TOut* advantages, float gamma,
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
int row = blockIdx.x*blockDim.x + threadIdx.x;
if (row >= num_steps) {
return;
}
int offset = row*horizon;
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
puff_advantage_row_cuda<TIn, TOut>(values + offset, rewards + offset, dones + offset,
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
}

void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
template<typename TIn, typename TOut>
void compute_puff_advantage_cuda_impl(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip) {
int num_steps = values.size(0);
Expand All @@ -61,16 +75,16 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
int threads_per_block = 256;
int blocks = (num_steps + threads_per_block - 1) / threads_per_block;

puff_advantage_kernel<<<blocks, threads_per_block>>>(
values.data_ptr<float>(),
rewards.data_ptr<float>(),
dones.data_ptr<float>(),
importance.data_ptr<float>(),
advantages.data_ptr<float>(),
gamma,
lambda,
rho_clip,
c_clip,
puff_advantage_kernel<TIn, TOut><<<blocks, threads_per_block>>>(
values.data_ptr<TIn>(),
rewards.data_ptr<TIn>(),
dones.data_ptr<TIn>(),
importance.data_ptr<TIn>(),
advantages.data_ptr<TOut>(),
static_cast<float>(gamma),
static_cast<float>(lambda),
static_cast<float>(rho_clip),
static_cast<float>(c_clip),
num_steps,
horizon
);
Expand All @@ -81,6 +95,27 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
}
}

void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip) {
auto input_dtype = values.dtype();
auto output_dtype = advantages.dtype();

// Support bf16 inputs with fp32 output for precision
if (input_dtype == torch::kFloat32 && output_dtype == torch::kFloat32) {
compute_puff_advantage_cuda_impl<float, float>(values, rewards, dones, importance, advantages,
gamma, lambda, rho_clip, c_clip);
} else if (input_dtype == torch::kBFloat16 && output_dtype == torch::kFloat32) {
compute_puff_advantage_cuda_impl<at::BFloat16, float>(values, rewards, dones, importance, advantages,
gamma, lambda, rho_clip, c_clip);
} else if (input_dtype == torch::kBFloat16 && output_dtype == torch::kBFloat16) {
compute_puff_advantage_cuda_impl<at::BFloat16, at::BFloat16>(values, rewards, dones, importance, advantages,
gamma, lambda, rho_clip, c_clip);
} else {
TORCH_CHECK(false, "Unsupported dtype combination: inputs must be float32 or bfloat16, advantages must be float32 or bfloat16");
}
}

TORCH_LIBRARY_IMPL(pufferlib, CUDA, m) {
m.impl("compute_puff_advantage", &compute_puff_advantage_cuda);
}
Expand Down
40 changes: 20 additions & 20 deletions pufferlib/extensions/cuda/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ __global__ void ppo_loss_forward_kernel_optimized(
const T* __restrict__ values_pred,
const int64_t* __restrict__ actions,
const T* __restrict__ old_logprobs,
const T* __restrict__ advantages,
const float* __restrict__ advantages,
const T* __restrict__ prio,
const T* __restrict__ values,
const T* __restrict__ returns,
Expand Down Expand Up @@ -1541,8 +1541,8 @@ __global__ void ppo_loss_forward_kernel_optimized(
float old_logp = float(old_logprobs[nt]);
float adv = float(advantages[nt]);
float w = float(prio[n]);
float adv_std = sqrtf(adv_var[0]);
float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f);
float adv_std = sqrtf(float(adv_var[0]));
float adv_normalized = (adv - float(adv_mean[0])) / (adv_std + 1e-8f);

float logratio = new_logp - old_logp;
float ratio = __expf(logratio);
Expand Down Expand Up @@ -1588,14 +1588,14 @@ __global__ void ppo_loss_forward_kernel_optimized(

template<typename T>
__global__ void ppo_loss_backward_kernel_optimized(
T* __restrict__ grad_logits,
T* __restrict__ grad_values_pred,
float* __restrict__ grad_logits,
float* __restrict__ grad_values_pred,
const float* __restrict__ grad_loss,
const T* __restrict__ logits,
const T* __restrict__ values_pred,
const int64_t* __restrict__ actions,
const T* __restrict__ old_logprobs,
const T* __restrict__ advantages,
const float* __restrict__ advantages,
const T* __restrict__ prio,
const T* __restrict__ values,
const T* __restrict__ returns,
Expand Down Expand Up @@ -1665,8 +1665,8 @@ __global__ void ppo_loss_backward_kernel_optimized(
float v_clipped = val + fmaxf(-vf_clip_coef, fminf(vf_clip_coef, v_error));

// normalize advantage
float adv_std = sqrtf(adv_var[0]);
float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f);
float adv_std = sqrtf(float(adv_var[0]));
float adv_normalized = (adv - float(adv_mean[0])) / (adv_std + 1e-8f);

// loss gradient scaling
float dL = grad_loss[0] * inv_NT;
Expand All @@ -1686,7 +1686,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
} else {
d_val_pred = val_pred - ret;
}
grad_values_pred[values_idx] = T(dL * vf_coef * d_val_pred);
grad_values_pred[values_idx] = dL * vf_coef * d_val_pred;

// policy loss gradient
float ratio_clipped = fmaxf(1.0f - clip_coef, fminf(1.0f + clip_coef, ratio));
Expand All @@ -1710,7 +1710,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
d_logit -= p * d_new_logp;

d_logit += d_entropy_term * p * (-entropy - logp);
grad_logits[logits_base + a * logits_stride_a] = T(d_logit);
grad_logits[logits_base + a * logits_stride_a] = d_logit;
}
}

Expand All @@ -1724,12 +1724,12 @@ inline void launch_ppo_loss_forward_optimized(
const T* values_pred,
const int64_t* actions,
const T* old_logprobs,
const T* advantages,
const float* advantages, // always fp32 for precision
const T* prio,
const T* values,
const T* returns,
const float* adv_mean,
const float* adv_var,
const float* adv_mean, // keep fp32
const float* adv_var, // keep fp32
float clip_coef,
float vf_clip_coef,
float vf_coef,
Expand Down Expand Up @@ -1784,19 +1784,19 @@ inline void launch_ppo_loss_forward_optimized(

template<typename T>
void launch_ppo_loss_backward_optimized(
T* grad_logits,
T* grad_values_pred,
float* grad_logits,
float* grad_values_pred,
const float* grad_loss,
const T* logits,
const T* values_pred, // added: need to read val_pred directly
const int64_t* actions,
const T* old_logprobs,
const T* advantages,
const float* advantages,
const T* prio,
const T* values,
const T* returns,
const float* adv_mean,
const float* adv_var, // variance, not std
const float* adv_var,
float clip_coef,
float vf_clip_coef,
float vf_coef,
Expand Down Expand Up @@ -2011,7 +2011,7 @@ __global__ void ppo_loss_backward_kernel(

// === Retrieve saved values from forward pass ===
const double* saved = saved_for_backward + idx * 5;
double new_logp = saved[0]; // new log prob of selected action
// double new_logp = saved[0]; // new log prob of selected action
double ratio = saved[1]; // exp(new_logp - old_logp)
double val_pred = saved[2]; // value prediction
double v_clipped = saved[3]; // clipped value target
Expand Down Expand Up @@ -2481,10 +2481,10 @@ void launch_ppo_loss_backward_optimized_float(float* grad_logits, float* grad_va
launch_ppo_loss_backward_optimized<float>(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream);
}

void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
launch_ppo_loss_forward_optimized<at::BFloat16>(loss_output, saved_for_backward, ratio_out, newvalue_out, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream);
}
void launch_ppo_loss_backward_optimized_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
void launch_ppo_loss_backward_optimized_bf16(float* grad_logits, float* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
launch_ppo_loss_backward_optimized<at::BFloat16>(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream);
}

Expand Down
4 changes: 2 additions & 2 deletions pufferlib/extensions/cuda/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void launch_logcumsumexp_forward_bf16(at::BFloat16* out, double* s_buf, const at
void launch_logcumsumexp_backward_bf16(at::BFloat16* grad_x, const at::BFloat16* grad_out, const at::BFloat16* x, const double* s_buf, int T_total, int H, int B, cudaStream_t stream);
void launch_ppo_loss_forward_bf16(float* loss_output, double* saved_for_backward, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_std, double clip_coef, double vf_clip_coef, double vf_coef, double ent_coef, int T_seq, int A, int N, cudaStream_t stream);
void launch_ppo_loss_backward_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const double* saved_for_backward, const float* adv_mean, const float* adv_std, double clip_coef, double vf_clip_coef, double vf_coef, double ent_coef, int T_seq, int A, int N, cudaStream_t stream);
void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream);
void launch_ppo_loss_backward_optimized_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream);
void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream);
void launch_ppo_loss_backward_optimized_bf16(float* grad_logits, float* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream);
void launch_sample_logits_bf16(double* actions, at::BFloat16* logprobs, at::BFloat16* value_out, const at::BFloat16* logits, const at::BFloat16* value, const int* act_sizes, uint64_t seed, const int64_t* offset_ptr, int num_atns, int B, int logits_stride, int value_stride, cudaStream_t stream);

#endif // PUFFERLIB_KERNELS_H
Loading
Loading