Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions pufferlib/extensions/cuda/pufferlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,40 @@
namespace pufferlib {

__host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards, float* dones,
float* importance, float* advantages, float gamma, float lambda,
float rho_clip, float c_clip, int horizon) {
float lastpufferlam = 0;
for (int t = horizon-2; t >= 0; t--) {
float* importance, float* advantages, float value_bs, float reward_bs, float done_bs,
float trunc_bs, float gamma, float lambda, float rho_clip, float c_clip, int horizon) {
float lastpufferlam = 0.0;
float next_value = 0.0;
float nextnonterminal = 1.0;
float next_reward = 0.0;
for (int t = horizon-1; t >= 0; t--) {
int t_next = t + 1;
float nextnonterminal = 1.0 - dones[t_next];
if ((t+1) == horizon) {
nextnonterminal = 1.0 - done_bs;
next_value = value_bs;
next_reward = reward_bs;
} else {
nextnonterminal = 1.0 - dones[t_next];
next_value = values[t_next];
next_reward = rewards[t_next];
}
float rho_t = fminf(importance[t], rho_clip);
float c_t = fminf(importance[t], c_clip);
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
float delta = rho_t*(next_reward + gamma*next_value*nextnonterminal - values[t]);
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
advantages[t] = lastpufferlam;
}
}

void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
int num_steps, int horizon) {
torch::Tensor values_bs, torch::Tensor rewards_bs, torch::Tensor dones_bs,
torch::Tensor truncs_bs, int num_steps, int horizon) {

// Validate input tensors
torch::Device device = values.device();
for (const torch::Tensor& t : {values, rewards, dones, importance, advantages}) {
TORCH_CHECK(t.is_cuda(), "All tensors must be on GPU");
TORCH_CHECK(t.dim() == 2, "Tensor must be 2D");
TORCH_CHECK(t.device() == device, "All tensors must be on same device");
TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps");
Expand All @@ -35,28 +48,45 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
t.contiguous();
}
}
for (const torch::Tensor& t : {values_bs, rewards_bs, dones_bs, truncs_bs}) {
TORCH_CHECK(t.is_cuda(), "All tensors must be on GPU");
TORCH_CHECK(t.dim() == 1, "Bootstrap Tensors must be 1D");
TORCH_CHECK(t.device() == device, "All tensors must be on same device");
TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps");
TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32");
if (!t.is_contiguous()) {
t.contiguous();
}
}
}

// [num_steps, horizon]
// [num_steps, horizon]
__global__ void puff_advantage_kernel(float* values, float* rewards,
float* dones, float* importance, float* advantages, float gamma,
float* dones, float* importance, float* advantages, float* values_bs,
float* rewards_bs, float* dones_bs, float* truncs_bs, float gamma,
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
int row = blockIdx.x*blockDim.x + threadIdx.x;
if (row >= num_steps) {
return;
}
int offset = row*horizon;
float value_bs = values_bs[row];
float reward_bs = rewards_bs[row];
float done_bs = dones_bs[row];
float trunc_bs = truncs_bs[row];
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
importance + offset, advantages + offset, value_bs, reward_bs, done_bs, trunc_bs,
gamma, lambda, rho_clip, c_clip, horizon);
}

void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip) {
torch::Tensor values_bs, torch::Tensor rewards_bs, torch::Tensor dones_bs,
torch::Tensor truncs_bs, double gamma, double lambda, double rho_clip, double c_clip) {
int num_steps = values.size(0);
int horizon = values.size(1);
vtrace_check_cuda(values, rewards, dones, importance, advantages, num_steps, horizon);
TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU");
vtrace_check_cuda(values, rewards, dones, importance, advantages, values_bs, rewards_bs,
dones_bs, truncs_bs, num_steps, horizon);

int threads_per_block = 256;
int blocks = (num_steps + threads_per_block - 1) / threads_per_block;
Expand All @@ -67,6 +97,10 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
dones.data_ptr<float>(),
importance.data_ptr<float>(),
advantages.data_ptr<float>(),
values_bs.data_ptr<float>(),
rewards_bs.data_ptr<float>(),
dones_bs.data_ptr<float>(),
truncs_bs.data_ptr<float>(),
gamma,
lambda,
rho_clip,
Expand Down
78 changes: 59 additions & 19 deletions pufferlib/extensions/pufferlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,36 @@ extern "C" {

namespace pufferlib {

void puff_advantage_row(float* values, float* rewards, float* dones,
float* importance, float* advantages, float gamma, float lambda,
float rho_clip, float c_clip, int horizon) {
float lastpufferlam = 0;
for (int t = horizon-2; t >= 0; t--) {
void puff_advantage_row(float* values, float* rewards, float* dones, float* importance,
float* advantages, float value_bs, float reward_bs, float done_bs, float trunc_bs,
float gamma, float lambda, float rho_clip, float c_clip, int horizon) {
float lastpufferlam = 0.0;
float next_value = 0.0;
float nextnonterminal = 1.0;
float next_reward = 0.0;
for (int t = horizon-1; t >= 0; t--) {
int t_next = t + 1;
float nextnonterminal = 1.0 - dones[t_next];
if ((t+1) == horizon) {
nextnonterminal = 1.0 - done_bs;
next_value = value_bs;
next_reward = reward_bs;
} else {
nextnonterminal = 1.0 - dones[t_next];
next_value = values[t_next];
next_reward = rewards[t_next];
}
float rho_t = fminf(importance[t], rho_clip);
float c_t = fminf(importance[t], c_clip);
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
float delta = rho_t*(next_reward + gamma*next_value*nextnonterminal - values[t]);
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
advantages[t] = lastpufferlam;
}
}

void vtrace_check(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
int num_steps, int horizon) {
void vtrace_check(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones,
torch::Tensor importance, torch::Tensor advantages, torch::Tensor values_bs,
torch::Tensor rewards_bs, torch::Tensor dones_bs, torch::Tensor truncs_bs,
int num_steps, int horizon) {

// Validate input tensors
torch::Device device = values.device();
Expand All @@ -56,37 +68,65 @@ void vtrace_check(torch::Tensor values, torch::Tensor rewards,
t.contiguous();
}
}
for (const torch::Tensor& t : {values_bs, rewards_bs, dones_bs, truncs_bs}) {
TORCH_CHECK(t.dim() == 1, "Bootstrap Tensors must be 1D");
TORCH_CHECK(t.device() == device, "All tensors must be on same device");
TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps");
TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32");
if (!t.is_contiguous()) {
t.contiguous();
}
}
}


// [num_steps, horizon]
void puff_advantage(float* values, float* rewards, float* dones, float* importance,
float* advantages, float gamma, float lambda, float rho_clip, float c_clip,
int num_steps, const int horizon){
float* advantages, float* values_bs, float* rewards_bs, float* dones_bs, float* truncs_bs,
float gamma, float lambda, float rho_clip, float c_clip, int num_steps, const int horizon){
int idx = 0;
for (int offset = 0; offset < num_steps*horizon; offset+=horizon) {
puff_advantage_row(values + offset, rewards + offset,
dones + offset, importance + offset, advantages + offset,
values_bs[idx], rewards_bs[idx], dones_bs[idx], truncs_bs[idx],
gamma, lambda, rho_clip, c_clip, horizon
);
idx++;
}
}


void compute_puff_advantage_cpu(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip) {
torch::Tensor values_bs, torch::Tensor rewards_bs, torch::Tensor dones_bs,
torch::Tensor truncs_bs, double gamma, double lambda, double rho_clip, double c_clip) {
int num_steps = values.size(0);
int horizon = values.size(1);
vtrace_check(values, rewards, dones, importance, advantages, num_steps, horizon);
puff_advantage(values.data_ptr<float>(), rewards.data_ptr<float>(),
dones.data_ptr<float>(), importance.data_ptr<float>(), advantages.data_ptr<float>(),
gamma, lambda, rho_clip, c_clip, num_steps, horizon
vtrace_check(values, rewards, dones, importance, advantages, values_bs, rewards_bs,
dones_bs, truncs_bs, num_steps, horizon);

puff_advantage(
values.data_ptr<float>(),
rewards.data_ptr<float>(),
dones.data_ptr<float>(),
importance.data_ptr<float>(),
advantages.data_ptr<float>(),
values_bs.data_ptr<float>(),
rewards_bs.data_ptr<float>(),
dones_bs.data_ptr<float>(),
truncs_bs.data_ptr<float>(),
gamma,
lambda,
rho_clip,
c_clip,
num_steps,
horizon
);
}

TORCH_LIBRARY(pufferlib, m) {
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, float gamma, float lambda, float rho_clip, float c_clip) -> ()");
}
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, Tensor(f!) values_bs, Tensor(g!) rewards_bs, Tensor(h!) dones_bs, Tensor(i!) truncs_bs, float gamma, float lambda, float rho_clip, float c_clip) -> ()");
}

TORCH_LIBRARY_IMPL(pufferlib, CPU, m) {
m.impl("compute_puff_advantage", &compute_puff_advantage_cpu);
Expand Down
54 changes: 35 additions & 19 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def __init__(self, config, vecenv, policy, logger=None):
self.ep_lengths = torch.zeros(total_agents, device=device, dtype=torch.int32)
self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32)
self.free_idx = total_agents
self.values_bs = torch.zeros(segments, 1, device=device)
self.rewards_bs = torch.zeros(segments, 1, device=device)
self.terminals_bs = torch.zeros(segments, 1, device=device)
self.truncations_bs = torch.zeros(segments, 1, device=device)

# LSTM
if config['use_rnn']:
Expand Down Expand Up @@ -256,6 +260,7 @@ def evaluate(self):
o_device = o.to(device)#, non_blocking=True)
r = torch.as_tensor(r).to(device)#, non_blocking=True)
d = torch.as_tensor(d).to(device)#, non_blocking=True)
t = torch.as_tensor(t).to(device)#, non_blocking=True)

profile('eval_forward', epoch)
with torch.no_grad(), self.amp_context:
Expand Down Expand Up @@ -284,20 +289,24 @@ def evaluate(self):
l = self.ep_lengths[env_id.start].item()
batch_rows = slice(self.ep_indices[env_id.start].item(), 1+self.ep_indices[env_id.stop - 1].item())

if config['cpu_offload']:
self.observations[batch_rows, l] = o
else:
self.observations[batch_rows, l] = o_device

self.actions[batch_rows, l] = action
self.logprobs[batch_rows, l] = logprob
self.rewards[batch_rows, l] = r
self.terminals[batch_rows, l] = d.float()
self.values[batch_rows, l] = value.flatten()

# Note: We are not yet handling masks in this version
self.ep_lengths[env_id] += 1
if l+1 >= config['bptt_horizon']:
if l < config['bptt_horizon']:
if config['cpu_offload']:
self.observations[batch_rows, l] = o
else:
self.observations[batch_rows, l] = o_device

self.actions[batch_rows, l] = action
self.logprobs[batch_rows, l] = logprob
self.rewards[batch_rows, l] = r
self.terminals[batch_rows, l] = d.float()
self.truncations[batch_rows, l] = t.float()
self.values[batch_rows, l] = value.flatten()
self.ep_lengths[env_id] += 1
elif l == config['bptt_horizon']:
self.values_bs[batch_rows, 0] = value.flatten()
self.rewards_bs[batch_rows, 0] = r
self.terminals_bs[batch_rows, 0] = d.float()
self.truncations_bs[batch_rows, 0] = t.float()
num_full = env_id.stop - env_id.start
self.ep_indices[env_id] = self.free_idx + torch.arange(num_full, device=config['device']).int()
self.ep_lengths[env_id] = 0
Expand Down Expand Up @@ -352,8 +361,9 @@ def train(self):
shape = self.values.shape
advantages = torch.zeros(shape, device=device)
advantages = compute_puff_advantage(self.values, self.rewards,
self.terminals, self.ratio, advantages, config['gamma'],
config['gae_lambda'], config['vtrace_rho_clip'], config['vtrace_c_clip'])
self.terminals, self.ratio, advantages, self.values_bs.squeeze(1),
self.rewards_bs.squeeze(1), self.terminals_bs.squeeze(1), self.truncations_bs.squeeze(1),
config['gamma'], config['gae_lambda'], config['vtrace_rho_clip'], config['vtrace_c_clip'])

# Prioritize experience by advantage magnitude
adv = advantages.abs().sum(axis=1)
Expand Down Expand Up @@ -657,8 +667,9 @@ def print_dashboard(self, clear=False, idx=[0],

print('\033[0;0H' + capture.get())

def compute_puff_advantage(values, rewards, terminals,
ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip):
def compute_puff_advantage(values, rewards, terminals, ratio, advantages,
values_bs, rewards_bs, terminals_bs, truncations_bs,
gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip):
'''CUDA kernel for puffer advantage with automatic CPU fallback. You need
nvcc (in cuda-dev-tools or in a cuda-dev docker base) for PufferLib to
compile the fast version.'''
Expand All @@ -670,9 +681,14 @@ def compute_puff_advantage(values, rewards, terminals,
terminals = terminals.cpu()
ratio = ratio.cpu()
advantages = advantages.cpu()
values_bs = values_bs.cpu()
rewards_bs = rewards_bs.cpu()
terminals_bs = terminals_bs.cpu()
truncations_bs = truncations_bs.cpu()

torch.ops.pufferlib.compute_puff_advantage(values, rewards, terminals,
ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip)
ratio, advantages, values_bs, rewards_bs, terminals_bs, truncations_bs,
gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip)

if not ADVANTAGE_CUDA:
return advantages.to(device)
Expand Down