Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
922 changes: 697 additions & 225 deletions profile_kernels.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pufferlib/config/ocean/breakout.ini
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mean = 1
scale = auto

[train]
total_timesteps = 120_000_000
total_timesteps = 100_000_000
adam_beta1 = 0.8946507418260217
adam_beta2 = 0.9
adam_eps = 0.0001
Expand Down
13 changes: 4 additions & 9 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ rnn_name = Recurrent

[vec]
total_agents = 8192
num_buffers = 2
num_buffers = 8

[policy]
input_size = 64
Expand All @@ -30,16 +30,11 @@ num_maps = 10000

[train]
total_timesteps = 2_000_000_000
#learning_rate = 0.02
#gamma = 0.985
anneal_lr = True
batch_size = 745472
minibatch_size = 11648
max_minibatch_size = 11648
#minibatch_size = 32768
batch_size = auto
minibatch_size = 32768
num_minibatches = 16
bptt_horizon = 91
#bptt_horizon = 64
bptt_horizon = 64
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_eps = 1e-8
Expand Down
14 changes: 9 additions & 5 deletions pufferlib/extensions/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ Tensor initial_state(pybind11::object pufferl_obj, int64_t batch_size, torch::De
}

void python_vec_recv(pybind11::object pufferl_obj, int buf) {
auto& pufferl = pufferl_obj.cast<PuffeRL&>();
pufferl.env_exports->vec_recv(pufferl.vec, buf, pufferl.vec->streams[buf]);
// Not used in static/OMP path
}

void python_vec_send(pybind11::object pufferl_obj, int buf) {
auto& pufferl = pufferl_obj.cast<PuffeRL&>();
pufferl.env_exports->vec_send(pufferl.vec, buf, pufferl.vec->streams[buf]);
// Not used in static/OMP path
}

torch::autograd::tensor_list env_buffers(pybind11::object pufferl_obj) {
Expand All @@ -44,7 +42,7 @@ void rollouts(pybind11::object pufferl_obj) {
PuffeRL& pufferl = pufferl_obj.cast<PuffeRL&>();
pybind11::gil_scoped_release no_gil;
if (pufferl.hypers.use_omp) {
pufferl.env_exports->vec_omp_step(pufferl.vec);
static_vec_omp_step(pufferl.vec);
} else {
rollouts_impl(pufferl);
}
Expand Down Expand Up @@ -151,6 +149,8 @@ TORCH_LIBRARY(_C, m) {
m.def("log_coeffs_and_values(Tensor gate, Tensor hidden) -> (Tensor, Tensor)");
m.def("fused_scan(Tensor combined, Tensor state) -> (Tensor, Tensor)");
m.def("fused_ppo_loss(Tensor logits, Tensor values, Tensor actions, Tensor old_logprobs, Tensor advantages, Tensor prio, Tensor values, Tensor returns, Tensor adv_mean, Tensor adv_std, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef) -> Tensor");
m.def("fc_relu_fc_max(Tensor x, Tensor W1, Tensor b1, Tensor W2, Tensor b2) -> Tensor");
m.def("fc_max(Tensor x, Tensor W, Tensor b) -> Tensor");
m.def("policy_forward(Tensor obs, Tensor state) -> (Tensor, Tensor, Tensor)");
}

Expand All @@ -165,6 +165,10 @@ PYBIND11_MODULE(_C, m) {
m.def("log_coeffs_and_values", &log_coeffs_and_values);
m.def("fused_scan", &fused_scan);
m.def("fused_ppo_loss", &fused_ppo_loss);
m.def("fc_relu_fc_max", &fc_relu_fc_max);
m.def("fc_relu_fc_max_cpp", &fc_relu_fc_max_cpp);
m.def("fc_max", &fc_max);
m.def("fc_max_cpp", &fc_max_cpp);
m.def("sample_logits", &sample_logits);
m.def("python_vec_recv", &python_vec_recv);
m.def("python_vec_send", &python_vec_send);
Expand Down
Loading
Loading