diff --git a/pufferlib/extensions/cuda/pufferlib.cu b/pufferlib/extensions/cuda/pufferlib.cu index 9426a7e5e..a57287f8f 100644 --- a/pufferlib/extensions/cuda/pufferlib.cu +++ b/pufferlib/extensions/cuda/pufferlib.cu @@ -5,15 +5,26 @@ namespace pufferlib { __host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards, float* dones, - float* importance, float* advantages, float gamma, float lambda, - float rho_clip, float c_clip, int horizon) { - float lastpufferlam = 0; - for (int t = horizon-2; t >= 0; t--) { + float* importance, float* advantages, float value_bs, float reward_bs, float done_bs, + float trunc_bs, float gamma, float lambda, float rho_clip, float c_clip, int horizon) { + float lastpufferlam = 0.0; + float next_value = 0.0; + float nextnonterminal = 1.0; + float next_reward = 0.0; + for (int t = horizon-1; t >= 0; t--) { int t_next = t + 1; - float nextnonterminal = 1.0 - dones[t_next]; + if ((t+1) == horizon) { + nextnonterminal = 1.0 - done_bs; + next_value = value_bs; + next_reward = reward_bs; + } else { + nextnonterminal = 1.0 - dones[t_next]; + next_value = values[t_next]; + next_reward = rewards[t_next]; + } float rho_t = fminf(importance[t], rho_clip); float c_t = fminf(importance[t], c_clip); - float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]); + float delta = rho_t*(next_reward + gamma*next_value*nextnonterminal - values[t]); lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal; advantages[t] = lastpufferlam; } @@ -21,11 +32,13 @@ __host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards, void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, - int num_steps, int horizon) { + torch::Tensor values_bs, torch::Tensor rewards_bs, torch::Tensor dones_bs, + torch::Tensor truncs_bs, int num_steps, int horizon) { // Validate input tensors torch::Device device = values.device(); for (const torch::Tensor& t : {values, rewards, dones, importance, advantages}) { + TORCH_CHECK(t.is_cuda(), "All tensors must be on GPU"); TORCH_CHECK(t.dim() == 2, "Tensor must be 2D"); TORCH_CHECK(t.device() == device, "All tensors must be on same device"); TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); @@ -35,28 +48,45 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards, t.contiguous(); } } + for (const torch::Tensor& t : {values_bs, rewards_bs, dones_bs, truncs_bs}) { + TORCH_CHECK(t.is_cuda(), "All tensors must be on GPU"); + TORCH_CHECK(t.dim() == 1, "Bootstrap Tensors must be 1D"); + TORCH_CHECK(t.device() == device, "All tensors must be on same device"); + TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); + TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); + if (!t.is_contiguous()) { + t.contiguous(); + } + } } - // [num_steps, horizon] +// [num_steps, horizon] __global__ void puff_advantage_kernel(float* values, float* rewards, - float* dones, float* importance, float* advantages, float gamma, + float* dones, float* importance, float* advantages, float* values_bs, + float* rewards_bs, float* dones_bs, float* truncs_bs, float gamma, float lambda, float rho_clip, float c_clip, int num_steps, int horizon) { int row = blockIdx.x*blockDim.x + threadIdx.x; if (row >= num_steps) { return; } int offset = row*horizon; + float value_bs = values_bs[row]; + float reward_bs = rewards_bs[row]; + float done_bs = dones_bs[row]; + float trunc_bs = truncs_bs[row]; puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset, - importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon); + importance + offset, advantages + offset, value_bs, reward_bs, done_bs, trunc_bs, + gamma, lambda, rho_clip, c_clip, horizon); } void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, - double gamma, double lambda, double rho_clip, double c_clip) { + torch::Tensor values_bs, torch::Tensor rewards_bs, torch::Tensor dones_bs, + torch::Tensor truncs_bs, double gamma, double lambda, double rho_clip, double c_clip) { int num_steps = values.size(0); int horizon = values.size(1); - vtrace_check_cuda(values, rewards, dones, importance, advantages, num_steps, horizon); - TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU"); + vtrace_check_cuda(values, rewards, dones, importance, advantages, values_bs, rewards_bs, + dones_bs, truncs_bs, num_steps, horizon); int threads_per_block = 256; int blocks = (num_steps + threads_per_block - 1) / threads_per_block; @@ -67,6 +97,10 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards, dones.data_ptr(), importance.data_ptr(), advantages.data_ptr(), + values_bs.data_ptr(), + rewards_bs.data_ptr(), + dones_bs.data_ptr(), + truncs_bs.data_ptr(), gamma, lambda, rho_clip, diff --git a/pufferlib/extensions/pufferlib.cpp b/pufferlib/extensions/pufferlib.cpp index a20d58bc2..d01f51b11 100644 --- a/pufferlib/extensions/pufferlib.cpp +++ b/pufferlib/extensions/pufferlib.cpp @@ -25,24 +25,36 @@ extern "C" { namespace pufferlib { -void puff_advantage_row(float* values, float* rewards, float* dones, - float* importance, float* advantages, float gamma, float lambda, - float rho_clip, float c_clip, int horizon) { - float lastpufferlam = 0; - for (int t = horizon-2; t >= 0; t--) { +void puff_advantage_row(float* values, float* rewards, float* dones, float* importance, + float* advantages, float value_bs, float reward_bs, float done_bs, float trunc_bs, + float gamma, float lambda, float rho_clip, float c_clip, int horizon) { + float lastpufferlam = 0.0; + float next_value = 0.0; + float nextnonterminal = 1.0; + float next_reward = 0.0; + for (int t = horizon-1; t >= 0; t--) { int t_next = t + 1; - float nextnonterminal = 1.0 - dones[t_next]; + if ((t+1) == horizon) { + nextnonterminal = 1.0 - done_bs; + next_value = value_bs; + next_reward = reward_bs; + } else { + nextnonterminal = 1.0 - dones[t_next]; + next_value = values[t_next]; + next_reward = rewards[t_next]; + } float rho_t = fminf(importance[t], rho_clip); float c_t = fminf(importance[t], c_clip); - float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]); + float delta = rho_t*(next_reward + gamma*next_value*nextnonterminal - values[t]); lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal; advantages[t] = lastpufferlam; } } -void vtrace_check(torch::Tensor values, torch::Tensor rewards, - torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, - int num_steps, int horizon) { +void vtrace_check(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, + torch::Tensor importance, torch::Tensor advantages, torch::Tensor values_bs, + torch::Tensor rewards_bs, torch::Tensor dones_bs, torch::Tensor truncs_bs, + int num_steps, int horizon) { // Validate input tensors torch::Device device = values.device(); @@ -56,37 +68,65 @@ void vtrace_check(torch::Tensor values, torch::Tensor rewards, t.contiguous(); } } + for (const torch::Tensor& t : {values_bs, rewards_bs, dones_bs, truncs_bs}) { + TORCH_CHECK(t.dim() == 1, "Bootstrap Tensors must be 1D"); + TORCH_CHECK(t.device() == device, "All tensors must be on same device"); + TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); + TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); + if (!t.is_contiguous()) { + t.contiguous(); + } + } } // [num_steps, horizon] void puff_advantage(float* values, float* rewards, float* dones, float* importance, - float* advantages, float gamma, float lambda, float rho_clip, float c_clip, - int num_steps, const int horizon){ + float* advantages, float* values_bs, float* rewards_bs, float* dones_bs, float* truncs_bs, + float gamma, float lambda, float rho_clip, float c_clip, int num_steps, const int horizon){ + int idx = 0; for (int offset = 0; offset < num_steps*horizon; offset+=horizon) { puff_advantage_row(values + offset, rewards + offset, dones + offset, importance + offset, advantages + offset, + values_bs[idx], rewards_bs[idx], dones_bs[idx], truncs_bs[idx], gamma, lambda, rho_clip, c_clip, horizon ); + idx++; } } void compute_puff_advantage_cpu(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, - double gamma, double lambda, double rho_clip, double c_clip) { + torch::Tensor values_bs, torch::Tensor rewards_bs, torch::Tensor dones_bs, + torch::Tensor truncs_bs, double gamma, double lambda, double rho_clip, double c_clip) { int num_steps = values.size(0); int horizon = values.size(1); - vtrace_check(values, rewards, dones, importance, advantages, num_steps, horizon); - puff_advantage(values.data_ptr(), rewards.data_ptr(), - dones.data_ptr(), importance.data_ptr(), advantages.data_ptr(), - gamma, lambda, rho_clip, c_clip, num_steps, horizon + vtrace_check(values, rewards, dones, importance, advantages, values_bs, rewards_bs, + dones_bs, truncs_bs, num_steps, horizon); + + puff_advantage( + values.data_ptr(), + rewards.data_ptr(), + dones.data_ptr(), + importance.data_ptr(), + advantages.data_ptr(), + values_bs.data_ptr(), + rewards_bs.data_ptr(), + dones_bs.data_ptr(), + truncs_bs.data_ptr(), + gamma, + lambda, + rho_clip, + c_clip, + num_steps, + horizon ); } TORCH_LIBRARY(pufferlib, m) { - m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, float gamma, float lambda, float rho_clip, float c_clip) -> ()"); - } + m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, Tensor(f!) values_bs, Tensor(g!) rewards_bs, Tensor(h!) dones_bs, Tensor(i!) truncs_bs, float gamma, float lambda, float rho_clip, float c_clip) -> ()"); +} TORCH_LIBRARY_IMPL(pufferlib, CPU, m) { m.impl("compute_puff_advantage", &compute_puff_advantage_cpu); diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 7132a19f9..ec06debcc 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -108,6 +108,10 @@ def __init__(self, config, vecenv, policy, logger=None): self.ep_lengths = torch.zeros(total_agents, device=device, dtype=torch.int32) self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32) self.free_idx = total_agents + self.values_bs = torch.zeros(segments, 1, device=device) + self.rewards_bs = torch.zeros(segments, 1, device=device) + self.terminals_bs = torch.zeros(segments, 1, device=device) + self.truncations_bs = torch.zeros(segments, 1, device=device) # LSTM if config['use_rnn']: @@ -256,6 +260,7 @@ def evaluate(self): o_device = o.to(device)#, non_blocking=True) r = torch.as_tensor(r).to(device)#, non_blocking=True) d = torch.as_tensor(d).to(device)#, non_blocking=True) + t = torch.as_tensor(t).to(device)#, non_blocking=True) profile('eval_forward', epoch) with torch.no_grad(), self.amp_context: @@ -284,20 +289,24 @@ def evaluate(self): l = self.ep_lengths[env_id.start].item() batch_rows = slice(self.ep_indices[env_id.start].item(), 1+self.ep_indices[env_id.stop - 1].item()) - if config['cpu_offload']: - self.observations[batch_rows, l] = o - else: - self.observations[batch_rows, l] = o_device - - self.actions[batch_rows, l] = action - self.logprobs[batch_rows, l] = logprob - self.rewards[batch_rows, l] = r - self.terminals[batch_rows, l] = d.float() - self.values[batch_rows, l] = value.flatten() - - # Note: We are not yet handling masks in this version - self.ep_lengths[env_id] += 1 - if l+1 >= config['bptt_horizon']: + if l < config['bptt_horizon']: + if config['cpu_offload']: + self.observations[batch_rows, l] = o + else: + self.observations[batch_rows, l] = o_device + + self.actions[batch_rows, l] = action + self.logprobs[batch_rows, l] = logprob + self.rewards[batch_rows, l] = r + self.terminals[batch_rows, l] = d.float() + self.truncations[batch_rows, l] = t.float() + self.values[batch_rows, l] = value.flatten() + self.ep_lengths[env_id] += 1 + elif l == config['bptt_horizon']: + self.values_bs[batch_rows, 0] = value.flatten() + self.rewards_bs[batch_rows, 0] = r + self.terminals_bs[batch_rows, 0] = d.float() + self.truncations_bs[batch_rows, 0] = t.float() num_full = env_id.stop - env_id.start self.ep_indices[env_id] = self.free_idx + torch.arange(num_full, device=config['device']).int() self.ep_lengths[env_id] = 0 @@ -352,8 +361,9 @@ def train(self): shape = self.values.shape advantages = torch.zeros(shape, device=device) advantages = compute_puff_advantage(self.values, self.rewards, - self.terminals, self.ratio, advantages, config['gamma'], - config['gae_lambda'], config['vtrace_rho_clip'], config['vtrace_c_clip']) + self.terminals, self.ratio, advantages, self.values_bs.squeeze(1), + self.rewards_bs.squeeze(1), self.terminals_bs.squeeze(1), self.truncations_bs.squeeze(1), + config['gamma'], config['gae_lambda'], config['vtrace_rho_clip'], config['vtrace_c_clip']) # Prioritize experience by advantage magnitude adv = advantages.abs().sum(axis=1) @@ -657,8 +667,9 @@ def print_dashboard(self, clear=False, idx=[0], print('\033[0;0H' + capture.get()) -def compute_puff_advantage(values, rewards, terminals, - ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip): +def compute_puff_advantage(values, rewards, terminals, ratio, advantages, + values_bs, rewards_bs, terminals_bs, truncations_bs, + gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip): '''CUDA kernel for puffer advantage with automatic CPU fallback. You need nvcc (in cuda-dev-tools or in a cuda-dev docker base) for PufferLib to compile the fast version.''' @@ -670,9 +681,14 @@ def compute_puff_advantage(values, rewards, terminals, terminals = terminals.cpu() ratio = ratio.cpu() advantages = advantages.cpu() + values_bs = values_bs.cpu() + rewards_bs = rewards_bs.cpu() + terminals_bs = terminals_bs.cpu() + truncations_bs = truncations_bs.cpu() torch.ops.pufferlib.compute_puff_advantage(values, rewards, terminals, - ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip) + ratio, advantages, values_bs, rewards_bs, terminals_bs, truncations_bs, + gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip) if not ADVANTAGE_CUDA: return advantages.to(device)