diff --git a/pufferlib/extensions/bindings.cpp b/pufferlib/extensions/bindings.cpp index 843eb4aa1..546e05683 100644 --- a/pufferlib/extensions/bindings.cpp +++ b/pufferlib/extensions/bindings.cpp @@ -22,7 +22,7 @@ pybind11::dict log_environments(pybind11::object pufferl_obj) { Tensor initial_state(pybind11::object pufferl_obj, int64_t batch_size, torch::Device device) { auto& pufferl = pufferl_obj.cast(); - return pufferl.policy->initial_state(batch_size, device); + return pufferl.policy_bf16->initial_state(batch_size, device); } void python_vec_recv(pybind11::object pufferl_obj, int buf) { @@ -129,6 +129,7 @@ std::unique_ptr create_pufferl(pybind11::dict kwargs) { hypers.kernels = get_config(kwargs, "kernels"); hypers.profile = get_config(kwargs, "profile"); hypers.use_omp = get_config(kwargs, "use_omp"); + hypers.bf16 = get_config(kwargs, "bf16"); std::string env_name = kwargs["env_name"].cast(); Dict* vec_kwargs = py_dict_to_c_dict(kwargs["vec_kwargs"].cast()); @@ -228,7 +229,8 @@ PYBIND11_MODULE(_C, m) { m.def("create_pufferl", &create_pufferl); py::class_>(m, "PuffeRL") - .def_readwrite("policy", &PuffeRL::policy) + .def_readwrite("policy_bf16", &PuffeRL::policy_bf16) + .def_readwrite("policy_fp32", &PuffeRL::policy_fp32) .def_readwrite("muon", &PuffeRL::muon) .def_readwrite("hypers", &PuffeRL::hypers) .def_readwrite("rollouts", &PuffeRL::rollouts); diff --git a/pufferlib/extensions/cuda/advantage.cu b/pufferlib/extensions/cuda/advantage.cu index 9426a7e5e..6848e64cb 100644 --- a/pufferlib/extensions/cuda/advantage.cu +++ b/pufferlib/extensions/cuda/advantage.cu @@ -1,21 +1,25 @@ #include #include #include +#include namespace pufferlib { -__host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards, float* dones, - float* importance, float* advantages, float gamma, float lambda, +// TIn = input type (bf16 or float), TOut = output type (always float for precision) +template +__host__ __device__ void puff_advantage_row_cuda(const TIn* values, const TIn* rewards, const TIn* dones, + const TIn* importance, TOut* advantages, float gamma, float lambda, float rho_clip, float c_clip, int horizon) { float lastpufferlam = 0; for (int t = horizon-2; t >= 0; t--) { int t_next = t + 1; - float nextnonterminal = 1.0 - dones[t_next]; - float rho_t = fminf(importance[t], rho_clip); - float c_t = fminf(importance[t], c_clip); - float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]); + float nextnonterminal = 1.0f - float(dones[t_next]); + float imp = float(importance[t]); + float rho_t = fminf(imp, rho_clip); + float c_t = fminf(imp, c_clip); + float delta = rho_t*(float(rewards[t_next]) + gamma*float(values[t_next])*nextnonterminal - float(values[t])); lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal; - advantages[t] = lastpufferlam; + advantages[t] = TOut(lastpufferlam); } } @@ -25,32 +29,42 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards, // Validate input tensors torch::Device device = values.device(); - for (const torch::Tensor& t : {values, rewards, dones, importance, advantages}) { + auto input_dtype = values.dtype(); + for (const torch::Tensor& t : {values, rewards, dones, importance}) { TORCH_CHECK(t.dim() == 2, "Tensor must be 2D"); TORCH_CHECK(t.device() == device, "All tensors must be on same device"); TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon"); - TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); + TORCH_CHECK(t.dtype() == input_dtype, "Input tensors must have matching dtype"); if (!t.is_contiguous()) { t.contiguous(); } } + // advantages can be different dtype (fp32 for precision) + TORCH_CHECK(advantages.dim() == 2, "Advantages must be 2D"); + TORCH_CHECK(advantages.device() == device, "Advantages must be on same device"); + TORCH_CHECK(advantages.size(0) == num_steps, "Advantages first dimension must match"); + TORCH_CHECK(advantages.size(1) == horizon, "Advantages second dimension must match"); + if (!advantages.is_contiguous()) { + advantages.contiguous(); + } } - // [num_steps, horizon] -__global__ void puff_advantage_kernel(float* values, float* rewards, - float* dones, float* importance, float* advantages, float gamma, +template +__global__ void puff_advantage_kernel(const TIn* values, const TIn* rewards, + const TIn* dones, const TIn* importance, TOut* advantages, float gamma, float lambda, float rho_clip, float c_clip, int num_steps, int horizon) { int row = blockIdx.x*blockDim.x + threadIdx.x; if (row >= num_steps) { return; } int offset = row*horizon; - puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset, + puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset, importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon); } -void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards, +template +void compute_puff_advantage_cuda_impl(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, double gamma, double lambda, double rho_clip, double c_clip) { int num_steps = values.size(0); @@ -61,16 +75,16 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards, int threads_per_block = 256; int blocks = (num_steps + threads_per_block - 1) / threads_per_block; - puff_advantage_kernel<<>>( - values.data_ptr(), - rewards.data_ptr(), - dones.data_ptr(), - importance.data_ptr(), - advantages.data_ptr(), - gamma, - lambda, - rho_clip, - c_clip, + puff_advantage_kernel<<>>( + values.data_ptr(), + rewards.data_ptr(), + dones.data_ptr(), + importance.data_ptr(), + advantages.data_ptr(), + static_cast(gamma), + static_cast(lambda), + static_cast(rho_clip), + static_cast(c_clip), num_steps, horizon ); @@ -81,6 +95,27 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards, } } +void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, + double gamma, double lambda, double rho_clip, double c_clip) { + auto input_dtype = values.dtype(); + auto output_dtype = advantages.dtype(); + + // Support bf16 inputs with fp32 output for precision + if (input_dtype == torch::kFloat32 && output_dtype == torch::kFloat32) { + compute_puff_advantage_cuda_impl(values, rewards, dones, importance, advantages, + gamma, lambda, rho_clip, c_clip); + } else if (input_dtype == torch::kBFloat16 && output_dtype == torch::kFloat32) { + compute_puff_advantage_cuda_impl(values, rewards, dones, importance, advantages, + gamma, lambda, rho_clip, c_clip); + } else if (input_dtype == torch::kBFloat16 && output_dtype == torch::kBFloat16) { + compute_puff_advantage_cuda_impl(values, rewards, dones, importance, advantages, + gamma, lambda, rho_clip, c_clip); + } else { + TORCH_CHECK(false, "Unsupported dtype combination: inputs must be float32 or bfloat16, advantages must be float32 or bfloat16"); + } +} + TORCH_LIBRARY_IMPL(pufferlib, CUDA, m) { m.impl("compute_puff_advantage", &compute_puff_advantage_cuda); } diff --git a/pufferlib/extensions/cuda/kernels.cu b/pufferlib/extensions/cuda/kernels.cu index f75d6c559..c006678e7 100644 --- a/pufferlib/extensions/cuda/kernels.cu +++ b/pufferlib/extensions/cuda/kernels.cu @@ -1474,7 +1474,7 @@ __global__ void ppo_loss_forward_kernel_optimized( const T* __restrict__ values_pred, const int64_t* __restrict__ actions, const T* __restrict__ old_logprobs, - const T* __restrict__ advantages, + const float* __restrict__ advantages, const T* __restrict__ prio, const T* __restrict__ values, const T* __restrict__ returns, @@ -1541,8 +1541,8 @@ __global__ void ppo_loss_forward_kernel_optimized( float old_logp = float(old_logprobs[nt]); float adv = float(advantages[nt]); float w = float(prio[n]); - float adv_std = sqrtf(adv_var[0]); - float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f); + float adv_std = sqrtf(float(adv_var[0])); + float adv_normalized = (adv - float(adv_mean[0])) / (adv_std + 1e-8f); float logratio = new_logp - old_logp; float ratio = __expf(logratio); @@ -1588,14 +1588,14 @@ __global__ void ppo_loss_forward_kernel_optimized( template __global__ void ppo_loss_backward_kernel_optimized( - T* __restrict__ grad_logits, - T* __restrict__ grad_values_pred, + float* __restrict__ grad_logits, + float* __restrict__ grad_values_pred, const float* __restrict__ grad_loss, const T* __restrict__ logits, const T* __restrict__ values_pred, const int64_t* __restrict__ actions, const T* __restrict__ old_logprobs, - const T* __restrict__ advantages, + const float* __restrict__ advantages, const T* __restrict__ prio, const T* __restrict__ values, const T* __restrict__ returns, @@ -1665,8 +1665,8 @@ __global__ void ppo_loss_backward_kernel_optimized( float v_clipped = val + fmaxf(-vf_clip_coef, fminf(vf_clip_coef, v_error)); // normalize advantage - float adv_std = sqrtf(adv_var[0]); - float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f); + float adv_std = sqrtf(float(adv_var[0])); + float adv_normalized = (adv - float(adv_mean[0])) / (adv_std + 1e-8f); // loss gradient scaling float dL = grad_loss[0] * inv_NT; @@ -1686,7 +1686,7 @@ __global__ void ppo_loss_backward_kernel_optimized( } else { d_val_pred = val_pred - ret; } - grad_values_pred[values_idx] = T(dL * vf_coef * d_val_pred); + grad_values_pred[values_idx] = dL * vf_coef * d_val_pred; // policy loss gradient float ratio_clipped = fmaxf(1.0f - clip_coef, fminf(1.0f + clip_coef, ratio)); @@ -1710,7 +1710,7 @@ __global__ void ppo_loss_backward_kernel_optimized( d_logit -= p * d_new_logp; d_logit += d_entropy_term * p * (-entropy - logp); - grad_logits[logits_base + a * logits_stride_a] = T(d_logit); + grad_logits[logits_base + a * logits_stride_a] = d_logit; } } @@ -1724,12 +1724,12 @@ inline void launch_ppo_loss_forward_optimized( const T* values_pred, const int64_t* actions, const T* old_logprobs, - const T* advantages, + const float* advantages, // always fp32 for precision const T* prio, const T* values, const T* returns, - const float* adv_mean, - const float* adv_var, + const float* adv_mean, // keep fp32 + const float* adv_var, // keep fp32 float clip_coef, float vf_clip_coef, float vf_coef, @@ -1784,19 +1784,19 @@ inline void launch_ppo_loss_forward_optimized( template void launch_ppo_loss_backward_optimized( - T* grad_logits, - T* grad_values_pred, + float* grad_logits, + float* grad_values_pred, const float* grad_loss, const T* logits, const T* values_pred, // added: need to read val_pred directly const int64_t* actions, const T* old_logprobs, - const T* advantages, + const float* advantages, const T* prio, const T* values, const T* returns, const float* adv_mean, - const float* adv_var, // variance, not std + const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, @@ -2011,7 +2011,7 @@ __global__ void ppo_loss_backward_kernel( // === Retrieve saved values from forward pass === const double* saved = saved_for_backward + idx * 5; - double new_logp = saved[0]; // new log prob of selected action + // double new_logp = saved[0]; // new log prob of selected action double ratio = saved[1]; // exp(new_logp - old_logp) double val_pred = saved[2]; // value prediction double v_clipped = saved[3]; // clipped value target @@ -2481,10 +2481,10 @@ void launch_ppo_loss_backward_optimized_float(float* grad_logits, float* grad_va launch_ppo_loss_backward_optimized(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream); } -void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) { +void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) { launch_ppo_loss_forward_optimized(loss_output, saved_for_backward, ratio_out, newvalue_out, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream); } -void launch_ppo_loss_backward_optimized_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) { +void launch_ppo_loss_backward_optimized_bf16(float* grad_logits, float* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) { launch_ppo_loss_backward_optimized(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream); } diff --git a/pufferlib/extensions/cuda/kernels.h b/pufferlib/extensions/cuda/kernels.h index 3faca5186..d90fe962d 100644 --- a/pufferlib/extensions/cuda/kernels.h +++ b/pufferlib/extensions/cuda/kernels.h @@ -40,8 +40,8 @@ void launch_logcumsumexp_forward_bf16(at::BFloat16* out, double* s_buf, const at void launch_logcumsumexp_backward_bf16(at::BFloat16* grad_x, const at::BFloat16* grad_out, const at::BFloat16* x, const double* s_buf, int T_total, int H, int B, cudaStream_t stream); void launch_ppo_loss_forward_bf16(float* loss_output, double* saved_for_backward, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_std, double clip_coef, double vf_clip_coef, double vf_coef, double ent_coef, int T_seq, int A, int N, cudaStream_t stream); void launch_ppo_loss_backward_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const double* saved_for_backward, const float* adv_mean, const float* adv_std, double clip_coef, double vf_clip_coef, double vf_coef, double ent_coef, int T_seq, int A, int N, cudaStream_t stream); -void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream); -void launch_ppo_loss_backward_optimized_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream); +void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream); +void launch_ppo_loss_backward_optimized_bf16(float* grad_logits, float* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream); void launch_sample_logits_bf16(double* actions, at::BFloat16* logprobs, at::BFloat16* value_out, const at::BFloat16* logits, const at::BFloat16* value, const int* act_sizes, uint64_t seed, const int64_t* offset_ptr, int num_atns, int B, int logits_stride, int value_stride, cudaStream_t stream); #endif // PUFFERLIB_KERNELS_H diff --git a/pufferlib/extensions/models.cpp b/pufferlib/extensions/models.cpp index 35aae40dd..a5f146f4e 100644 --- a/pufferlib/extensions/models.cpp +++ b/pufferlib/extensions/models.cpp @@ -201,7 +201,7 @@ class DefaultEncoder : public Encoder { } Tensor forward(Tensor x) override { - return linear->forward(x); + return linear->forward(x.to(linear->weight.dtype())); } }; @@ -223,8 +223,9 @@ class SnakeEncoder : public Encoder { Tensor forward(Tensor x) override { // x is [B, input_size] with values 0-7 int64_t B = x.size(0); + auto target_dtype = linear->weight.dtype(); // One-hot encode: [B, input_size] -> [B, input_size, num_classes] - Tensor onehot = torch::one_hot(x.to(torch::kLong), num_classes).to(torch::kFloat32); + Tensor onehot = torch::one_hot(x.to(torch::kLong), num_classes).to(target_dtype); // Flatten: [B, input_size * num_classes] onehot = onehot.view({B, -1}); return linear->forward(onehot); @@ -291,12 +292,13 @@ class G2048Encoder : public Encoder { Tensor forward(Tensor x) override { // x is (B, 16) uint8 tile values auto B = x.size(0); + auto target_dtype = linear1->weight.dtype(); // value_embed(obs) -> (B, 16, embed_dim) - auto value_obs = value_embed->forward(x.to(torch::kLong)); + auto value_obs = value_embed->forward(x.to(torch::kLong)).to(target_dtype); // pos_embed.weight expanded to (B, 16, embed_dim) - auto pos_obs = pos_embed->weight.unsqueeze(0).expand({B, num_grid_cells, embed_dim}); + auto pos_obs = pos_embed->weight.unsqueeze(0).expand({B, num_grid_cells, embed_dim}).to(target_dtype); // grid_obs = (value_obs + pos_obs).flatten(1) -> (B, 48) auto grid_obs = (value_obs + pos_obs).flatten(1); @@ -368,7 +370,7 @@ class NMMO3Encoder : public Encoder { Tensor forward(Tensor x) override { int64_t B = x.size(0); auto device = x.device(); - auto dtype = x.dtype(); + auto target_dtype = conv1->weight.dtype(); // Split observations: map (1650), player (47), reward (10) Tensor ob_map = x.narrow(1, 0, 11*15*10).view({B, 11, 15, 10}); @@ -382,7 +384,7 @@ class NMMO3Encoder : public Encoder { Tensor codes = map_perm + offsets.to(device); // Create multi-hot buffer and scatter - Tensor map_buf = torch::zeros({B, 59, 11, 15}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + Tensor map_buf = torch::zeros({B, 59, 11, 15}, torch::TensorOptions().dtype(target_dtype).device(device)); map_buf.scatter_(1, codes.to(torch::kInt32), 1.0f); // Conv layers @@ -391,11 +393,11 @@ class NMMO3Encoder : public Encoder { map_out = map_out.flatten(1); // (B, 256) // Player discrete embedding - Tensor player_discrete = player_embed->forward(ob_player.to(torch::kInt64)); + Tensor player_discrete = player_embed->forward(ob_player.to(torch::kInt64)).to(target_dtype); player_discrete = player_discrete.flatten(1); // (B, 1504) // Concatenate: map_out + player_discrete + player_continuous + reward - Tensor obs = torch::cat({map_out, player_discrete, ob_player.to(torch::kFloat32), ob_reward.to(torch::kFloat32)}, 1); + Tensor obs = torch::cat({map_out, player_discrete, ob_player.to(target_dtype), ob_reward.to(target_dtype)}, 1); // Projection with ReLU obs = torch::relu(proj->forward(obs)); @@ -498,7 +500,8 @@ class DriveEncoder : public Encoder { Tensor forward(Tensor x) override { int64_t B = x.size(0); - x = x.to(torch::kFloat32); + auto target_dtype = ego_linear1->weight.dtype(); + x = x.to(target_dtype); // Split observations: ego (7), partner (441), road (1400) Tensor ego_obs = x.narrow(1, 0, 7); @@ -517,7 +520,7 @@ class DriveEncoder : public Encoder { Tensor road_objects = road_obs.view({B, 200, 7}); Tensor road_continuous = road_objects.narrow(2, 0, 6); Tensor road_categorical = road_objects.narrow(2, 6, 1).squeeze(2); - Tensor road_onehot = torch::one_hot(road_categorical.to(torch::kInt64), 7).to(torch::kFloat32); + Tensor road_onehot = torch::one_hot(road_categorical.to(torch::kInt64), 7).to(target_dtype); Tensor road_combined = torch::cat({road_continuous, road_onehot}, 2); // (B, 200, 13) Tensor road_enc = road_linear2->forward(road_norm->forward(road_linear1->forward(road_combined))); Tensor road_features = std::get<0>(road_enc.max(1)); // max pool over 200 objects @@ -608,9 +611,10 @@ class PolicyMinGRU : public torch::nn::Module { Tensor initial_state(int64_t batch_size, torch::Device device) { // Layout: {num_layers, batch_size, hidden} - select(0, i) gives contiguous slice + auto dtype = this->parameters().empty() ? torch::kFloat32 : this->parameters()[0].scalar_type(); return torch::zeros( {num_layers, batch_size, (int64_t)(hidden_size*expansion_factor)}, - torch::dtype(torch::kFloat32).device(device) + torch::dtype(dtype).device(device) ); } @@ -862,3 +866,23 @@ void sync_fp16_fp32(PolicyLSTM* policy_16, PolicyLSTM* policy_32) { params_16[i].copy_(params_32[i].to(torch::kFloat32)); } } + +// Sync bf16 working weights from fp32 master weights (for mixed-precision training) +void sync_policy_weights(PolicyMinGRU* policy_bf16, PolicyMinGRU* policy_fp32) { + auto params_fp32 = policy_fp32->parameters(); + auto params_bf16 = policy_bf16->parameters(); + for (size_t i = 0; i < params_fp32.size(); ++i) { + params_bf16[i].data().copy_(params_fp32[i].data().to(torch::kBFloat16)); + } +} + +// Copy gradients from bf16 policy to fp32 policy (for optimizer step) +void copy_gradients_to_fp32(PolicyMinGRU* policy_bf16, PolicyMinGRU* policy_fp32) { + auto params_fp32 = policy_fp32->parameters(); + auto params_bf16 = policy_bf16->parameters(); + for (size_t i = 0; i < params_fp32.size(); ++i) { + if (params_bf16[i].grad().defined()) { + params_fp32[i].mutable_grad() = params_bf16[i].grad().to(torch::kFloat32); + } + } +} diff --git a/pufferlib/extensions/modules.cpp b/pufferlib/extensions/modules.cpp index 0b2f6c4e7..fdd11a44f 100644 --- a/pufferlib/extensions/modules.cpp +++ b/pufferlib/extensions/modules.cpp @@ -1043,12 +1043,12 @@ class PPOFusedLossOptimizedFunction : public torch::autograd::Function(), actions.data_ptr(), old_logprobs.data_ptr(), - advantages.data_ptr(), + advantages.data_ptr(), // keep in fp32 for precision and training stability prio.data_ptr(), values.data_ptr(), returns.data_ptr(), - adv_mean.data_ptr(), - adv_var.data_ptr(), + adv_mean.data_ptr(), // fp32 training and precision + adv_var.data_ptr(), // fp32 training and precision static_cast(clip_coef), static_cast(vf_clip_coef), static_cast(vf_coef), @@ -1098,8 +1098,9 @@ class PPOFusedLossOptimizedFunction : public torch::autograd::Function(), - grad_values_pred.data_ptr(), + grad_logits.data_ptr(), + grad_values_pred.data_ptr(), grad_loss.data_ptr(), logits.data_ptr(), values_pred.data_ptr(), actions.data_ptr(), old_logprobs.data_ptr(), - advantages.data_ptr(), + advantages.data_ptr(), // keep in fp32 prio.data_ptr(), values.data_ptr(), returns.data_ptr(), diff --git a/pufferlib/extensions/pufferlib.cpp b/pufferlib/extensions/pufferlib.cpp index 99ff2343a..8f58f34dc 100644 --- a/pufferlib/extensions/pufferlib.cpp +++ b/pufferlib/extensions/pufferlib.cpp @@ -33,7 +33,10 @@ typedef torch::Tensor Tensor; // CUDA kernel wrappers #include "modules.cpp" -auto DTYPE = torch::kFloat32; +// get dtype based on bf16 flag +inline torch::ScalarType get_dtype(bool bf16) { + return bf16 ? torch::kBFloat16 : torch::kFloat32; +} namespace pufferlib { @@ -185,16 +188,17 @@ typedef struct { } TrainGraph; TrainGraph create_train_graph(int minibatch_segments, int horizon, int input_size, - int num_layers, int hidden_size, int expansion_factor, int num_atns) { + int num_layers, int hidden_size, int expansion_factor, int num_atns, bool bf16) { TrainGraph g; - auto options = torch::TensorOptions().dtype(DTYPE).device(torch::kCUDA); + auto dtype = get_dtype(bf16); + auto options = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); g.mb_obs = torch::zeros({minibatch_segments, horizon, input_size}, options); g.mb_state = torch::zeros({num_layers, minibatch_segments, 1, hidden_size * expansion_factor}, options); g.mb_newvalue = torch::zeros({minibatch_segments, horizon, 1}, options); g.mb_ratio = torch::zeros({minibatch_segments, horizon}, options); g.mb_actions = torch::zeros({minibatch_segments, horizon, num_atns}, options).to(torch::kInt64); g.mb_logprobs = torch::zeros({minibatch_segments, horizon}, options); - g.mb_advantages = torch::zeros({minibatch_segments, horizon}, options); + g.mb_advantages = torch::zeros({minibatch_segments, horizon}, options.dtype(torch::kFloat32)); // always fp32 for precision g.mb_prio = torch::zeros({minibatch_segments, 1}, options); g.mb_values = torch::zeros({minibatch_segments, horizon}, options); g.mb_returns = torch::zeros({minibatch_segments, horizon}, options); @@ -212,16 +216,17 @@ typedef struct { Tensor importance; } RolloutBuf; -RolloutBuf create_rollouts(int horizon, int segments, int input_size, int num_atns) { +RolloutBuf create_rollouts(int horizon, int segments, int input_size, int num_atns, bool bf16) { RolloutBuf r; - r.observations = torch::zeros({horizon, segments, input_size}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + auto dtype = get_dtype(bf16); + r.observations = torch::zeros({horizon, segments, input_size}, torch::dtype(dtype).device(torch::kCUDA)); r.actions = torch::zeros({horizon, segments, num_atns}, torch::dtype(torch::kFloat64).device(torch::kCUDA)); - r.values = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA)); - r.logprobs = torch::zeros({horizon, segments}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - r.rewards = torch::zeros({horizon, segments}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - r.terminals = torch::zeros({horizon, segments}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - r.ratio = torch::zeros({horizon, segments}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - r.importance = torch::zeros({horizon, segments}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + r.values = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA)); + r.logprobs = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA)); + r.rewards = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA)); + r.terminals = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA)); + r.ratio = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA)); + r.importance = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA)); return r; } @@ -270,10 +275,12 @@ typedef struct { bool kernels; bool profile; bool use_omp; + bool bf16; // bfloat16 mixed precision training } HypersT; typedef struct { - PolicyMinGRU* policy; + PolicyMinGRU* policy_bf16; // Working weights (bf16) - used for forward/backward + PolicyMinGRU* policy_fp32; // Master weights (fp32) - used for optimizer VecEnv* vec; torch::optim::Muon* muon; EnvExports* env_exports; @@ -313,8 +320,8 @@ void fused_rollout_step(PuffeRL& pufferl, int h, int buf) { Tensor obs_slice = pufferl.env.obs.narrow(0, buf*block_size, block_size); Tensor& state = pufferl.buffer_states[buf]; - // Run policy forward - auto [logits, value, state_out] = pufferl.policy->forward(obs_slice, state); + // Run policy forward using bf16 working weights + auto [logits, value, state_out] = pufferl.policy_bf16->forward(obs_slice, state); // Get output slices in rollouts storage Tensor actions_out = pufferl.rollouts.actions.select(0, h).narrow(0, buf*block_size, block_size); @@ -362,9 +369,9 @@ void fused_rollout_step(PuffeRL& pufferl, int h, int buf) { pufferl.env.actions.narrow(0, buf*block_size, block_size).copy_(actions_out, true); } -void train_forward_call(TrainGraph& graph, PolicyMinGRU* policy, +void train_forward_call(TrainGraph& graph, PolicyMinGRU* policy_bf16, PolicyMinGRU* policy_fp32, torch::optim::Muon* muon, HypersT& hypers, Tensor& adv_mean, Tensor& adv_std, Tensor& act_sizes_cpu, bool kernels) { - auto [logits, newvalue] = policy->forward_train(graph.mb_obs.to(DTYPE), graph.mb_state); + auto [logits, newvalue] = policy_bf16->forward_train(graph.mb_obs, graph.mb_state); Tensor loss; if (kernels) { @@ -373,11 +380,11 @@ void train_forward_call(TrainGraph& graph, PolicyMinGRU* policy, logits, newvalue, graph.mb_actions, - graph.mb_logprobs.to(logits.dtype()), - graph.mb_advantages.to(logits.dtype()), - graph.mb_prio.to(logits.dtype()), - graph.mb_values.to(logits.dtype()), - graph.mb_returns.to(logits.dtype()), + graph.mb_logprobs, + graph.mb_advantages, + graph.mb_prio, + graph.mb_values, + graph.mb_returns, mb_adv_mean, mb_adv_var, // variance, not std - kernel does sqrtf to avoid second kernel launch here graph.mb_ratio, @@ -438,34 +445,23 @@ void train_forward_call(TrainGraph& graph, PolicyMinGRU* policy, // Total loss loss = pg_loss + hypers.vf_coef*v_loss - hypers.ent_coef*entropy; - /* - { - torch::NoGradGuard no_grad; - - // Accumulate stats - pg_sum += pg_loss.detach(); - v_sum += v_loss.detach(); - ent_sum += entropy.detach(); - total_sum += loss.detach(); - - // KL and clipping diagnostics (matches Python) - auto old_kl = (-logratio).mean(); - auto kl = ((ratio_new - 1) - logratio).mean(); - auto cf = (ratio_new - 1.0).abs().gt(hypers.clip_coef).to(torch::kFloat32).mean(); - auto imp = ratio_new.mean(); - - old_approx_kl_sum += old_kl.detach(); - approx_kl_sum += kl.detach(); - clipfrac_sum += cf.detach(); - importance_sum += imp.detach(); - } - */ } + // computes gradients on bf16 weights (or fp32 if not using bf16) loss.backward(); - clip_grad_norm_(policy->parameters(), hypers.max_grad_norm); + + // copy gradients from bf16 to fp32, then optimizer step on fp32 master weights + if (hypers.bf16) { + copy_gradients_to_fp32(policy_bf16, policy_fp32); + } + clip_grad_norm_(policy_fp32->parameters(), hypers.max_grad_norm); muon->step(); muon->zero_grad(); + if (hypers.bf16) { + policy_bf16->zero_grad(); // also need to clear bf16 gradients + // sync updated fp32 weights back to bf16 for next forward pass + sync_policy_weights(policy_bf16, policy_fp32); + } } // Capture @@ -594,36 +590,57 @@ std::unique_ptr create_pufferl_impl(HypersT& hypers, const s // Create encoder/decoder based on env_name // Decoder output size is act_n (sum of all action space sizes) - std::shared_ptr enc; - std::shared_ptr dec; - if (env_name == "puffer_snake") { - enc = std::make_shared(input_size, hidden_size, 8); - dec = std::make_shared(hidden_size, act_n); - } else if (env_name == "puffer_g2048") { - enc = std::make_shared(input_size, hidden_size); - dec = std::make_shared(hidden_size, act_n); - } else if (env_name == "puffer_nmmo3") { - enc = std::make_shared(input_size, hidden_size); - dec = std::make_shared(hidden_size, act_n); - } else if (env_name == "puffer_drive") { - enc = std::make_shared(input_size, hidden_size); - dec = std::make_shared(hidden_size, act_n); + // We need two sets for mixed-precision: fp32 (master) and bf16 (working) + auto create_encoder_decoder = [&]() -> std::pair, std::shared_ptr> { + std::shared_ptr enc; + std::shared_ptr dec; + if (env_name == "puffer_snake") { + enc = std::make_shared(input_size, hidden_size, 8); + dec = std::make_shared(hidden_size, act_n); + } else if (env_name == "puffer_g2048") { + enc = std::make_shared(input_size, hidden_size); + dec = std::make_shared(hidden_size, act_n); + } else if (env_name == "puffer_nmmo3") { + enc = std::make_shared(input_size, hidden_size); + dec = std::make_shared(hidden_size, act_n); + } else if (env_name == "puffer_drive") { + enc = std::make_shared(input_size, hidden_size); + dec = std::make_shared(hidden_size, act_n); + } else { + enc = std::make_shared(input_size, hidden_size); + dec = std::make_shared(hidden_size, act_n); + } + return {enc, dec}; + }; + + // Create fp32 master policy (for optimizer - precise gradient accumulation) + auto [enc_fp32, dec_fp32] = create_encoder_decoder(); + PolicyMinGRU* policy_fp32 = new PolicyMinGRU(enc_fp32, dec_fp32, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels); + policy_fp32->to(torch::kCUDA); + policy_fp32->to(torch::kFloat32); + pufferl->policy_fp32 = policy_fp32; + + if (hypers.bf16) { + // create bf16 working policy (for fwd/bwd) + auto [enc_bf16, dec_bf16] = create_encoder_decoder(); + PolicyMinGRU* policy_bf16 = new PolicyMinGRU(enc_bf16, dec_bf16, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels); + policy_bf16->to(torch::kCUDA); + policy_bf16->to(torch::kBFloat16); + pufferl->policy_bf16 = policy_bf16; + sync_policy_weights(policy_bf16, policy_fp32); // initial sync } else { - enc = std::make_shared(input_size, hidden_size); - dec = std::make_shared(hidden_size, act_n); + // just use same policy for both + pufferl->policy_bf16 = policy_fp32; } - PolicyMinGRU* policy = new PolicyMinGRU(enc, dec, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels); - policy->to(torch::kCUDA); - policy->to(DTYPE); - pufferl->policy = policy; - + // Optimizer uses fp32 master weights for precise gradient accumulation float lr = hypers.lr; float beta1 = hypers.beta1; float eps = hypers.eps; - pufferl->muon = new torch::optim::Muon(policy->parameters(), + pufferl->muon = new torch::optim::Muon(policy_fp32->parameters(), torch::optim::MuonOptions(lr).momentum(beta1).eps(eps)); + // Allocate buffers int segments = hypers.segments; int horizon = hypers.horizon; @@ -635,17 +652,18 @@ std::unique_ptr create_pufferl_impl(HypersT& hypers, const s printf("DEBUG: num_envs=%d, total_agents=%d, segments=%d, batch=%d, num_buffers=%d\n", vec->size, total_agents, segments, batch, num_buffers); - pufferl->rollouts = create_rollouts(horizon, total_agents, input_size, num_action_heads); + pufferl->rollouts = create_rollouts(horizon, total_agents, input_size, num_action_heads, hypers.bf16); pufferl->train_buf = create_train_graph(minibatch_segments, horizon, input_size, - policy->num_layers, policy->hidden_size, policy->expansion_factor, num_action_heads); + policy_fp32->num_layers, policy_fp32->hidden_size, policy_fp32->expansion_factor, num_action_heads, hypers.bf16); - pufferl->adv_mean = torch::zeros({1}, torch::dtype(DTYPE).device(torch::kCUDA)); - pufferl->adv_std = torch::ones({1}, torch::dtype(DTYPE).device(torch::kCUDA)); + // always fp32 since advantages are computed in fp32 + pufferl->adv_mean = torch::zeros({1}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + pufferl->adv_std = torch::ones({1}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); // Per-buffer states: each is {num_layers, block_size, hidden} for contiguous access pufferl->buffer_states.resize(num_buffers); for (int i = 0; i < num_buffers; i++) { - pufferl->buffer_states[i] = policy->initial_state(batch, torch::kCUDA); + pufferl->buffer_states[i] = pufferl->policy_bf16->initial_state(batch, torch::kCUDA); } if (hypers.cudagraphs) { @@ -653,7 +671,7 @@ std::unique_ptr create_pufferl_impl(HypersT& hypers, const s auto* p = pufferl.get(); capture_graph(&pufferl->train_cudagraph, [p]() { - train_forward_call(p->train_buf, p->policy, p->muon, + train_forward_call(p->train_buf, p->policy_bf16, p->policy_fp32, p->muon, p->hypers, p->adv_mean, p->adv_std, p->act_sizes_cpu, p->hypers.kernels); }); @@ -798,7 +816,8 @@ void train_impl(PuffeRL& pufferl) { Tensor clipfrac_sum = torch::zeros({}, scalar_opts); Tensor importance_sum = torch::zeros({}, scalar_opts); - PolicyMinGRU* policy = pufferl.policy; + PolicyMinGRU* policy_bf16 = pufferl.policy_bf16; + // PolicyMinGRU* policy_fp32 = pufferl.policy_fp32; torch::optim::Muon* muon = pufferl.muon; if (anneal_lr) { @@ -813,10 +832,9 @@ void train_impl(PuffeRL& pufferl) { // Zero out ratio at start of epoch (matches Python: self.ratio[:] = 1) rollouts.ratio.fill_(1.0); - Tensor advantages = torch::zeros_like(rollouts.values); + Tensor advantages = torch::zeros_like(rollouts.values, torch::kFloat32); // fp32 precision compute_advantage(rollouts, advantages, hypers); - pufferl.adv_mean.copy_(advantages.mean().detach()); pufferl.adv_std.copy_(advantages.std().detach()); @@ -824,9 +842,10 @@ void train_impl(PuffeRL& pufferl) { cudaEventCreate(&start); cudaEventCreate(&stop); + auto dtype = get_dtype(hypers.bf16); Tensor mb_state = torch::zeros( - {policy->num_layers, minibatch_segments, 1, (int64_t)(policy->hidden_size*policy->expansion_factor)}, - torch::dtype(DTYPE).device(rollouts.values.device()) + {policy_bf16->num_layers, minibatch_segments, 1, (int64_t)(policy_bf16->hidden_size*policy_bf16->expansion_factor)}, + torch::dtype(dtype).device(rollouts.values.device()) ); // Temporary: random indices and uniform weights @@ -856,7 +875,7 @@ void train_impl(PuffeRL& pufferl) { if (hypers.cudagraphs) { pufferl.train_cudagraph.replay(); } else { - train_forward_call(graph, pufferl.policy, pufferl.muon, + train_forward_call(graph, pufferl.policy_bf16, pufferl.policy_fp32, pufferl.muon, hypers, pufferl.adv_mean, pufferl.adv_std, pufferl.act_sizes_cpu, hypers.kernels); } profile_end(hypers.profile); @@ -864,14 +883,8 @@ void train_impl(PuffeRL& pufferl) { // Update global ratio and values in-place (matches Python) // Buffers are {horizon, segments}, so index_copy_ along dim 1 (segments) // Source is {minibatch_segments, horizon}, need to transpose to {horizon, minibatch_segments} - // Temporary: use slice instead of index_copy_ for contiguous test - /* - pufferl.rollouts.ratio.slice(1, 0, minibatch_segments).copy_(graph.ratio.detach().squeeze(-1).to(torch::kFloat32).transpose(0, 1)); - pufferl.rollouts.values.slice(1, 0, minibatch_segments).copy_(graph.newvalue.detach().squeeze(-1).to(torch::kFloat32).transpose(0, 1)); - */ - // Original index_copy_ version: - pufferl.rollouts.ratio.index_copy_(1, idx, graph.mb_ratio.detach().squeeze(-1).to(torch::kFloat32).transpose(0, 1)); - pufferl.rollouts.values.index_copy_(1, idx, graph.mb_newvalue.detach().squeeze(-1).to(torch::kFloat32).transpose(0, 1)); + pufferl.rollouts.ratio.index_copy_(1, idx, graph.mb_ratio.detach().squeeze(-1).to(dtype).transpose(0, 1)); + pufferl.rollouts.values.index_copy_(1, idx, graph.mb_newvalue.detach().squeeze(-1).to(dtype).transpose(0, 1)); } pufferl.epoch += 1; diff --git a/pufferlib/ocean/drone/binding.c b/pufferlib/ocean/drone/binding.c index ed5475ea9..849eef7be 100644 --- a/pufferlib/ocean/drone/binding.c +++ b/pufferlib/ocean/drone/binding.c @@ -1,5 +1,5 @@ #include "drone.h" -#include "render.h" +// #include "render.h" #define Env DroneEnv #include "../env_binding.h" diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 83a4fdac3..2e48e9b46 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -156,6 +156,7 @@ def __init__(self, config, logger=None, verbose=True): config['kernels'] = True config['use_omp'] = True config['num_buffers'] = 2 + config['bf16'] = config.get('bf16', True) # bfloat16 mixed precision training self.pufferl_cpp = _C.create_pufferl(config) self.observations = self.pufferl_cpp.rollouts.observations self.actions = self.pufferl_cpp.rollouts.actions