diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 2bfe3c640..45832a8d9 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -283,6 +283,7 @@ struct Drive { float *actions; float *rewards; unsigned char *terminals; + unsigned char *truncations; Log log; Log *logs; int num_agents; @@ -1436,6 +1437,7 @@ void allocate(Drive *env) { env->actions = (float *)calloc(env->active_agent_count * 2, sizeof(float)); env->rewards = (float *)calloc(env->active_agent_count, sizeof(float)); env->terminals = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char)); + env->truncations = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char)); } void free_allocated(Drive *env) { @@ -1443,6 +1445,7 @@ void free_allocated(Drive *env) { free(env->actions); free(env->rewards); free(env->terminals); + free(env->truncations); c_close(env); } @@ -1993,24 +1996,9 @@ void respawn_agent(Drive *env, int agent_idx) { void c_step(Drive *env) { memset(env->rewards, 0, env->active_agent_count * sizeof(float)); memset(env->terminals, 0, env->active_agent_count * sizeof(unsigned char)); + memset(env->truncations, 0, env->active_agent_count * sizeof(unsigned char)); env->timestep++; - int originals_remaining = 0; - for (int i = 0; i < env->active_agent_count; i++) { - int agent_idx = env->active_agent_indices[i]; - // Keep flag true if there is at least one agent that has not been respawned yet - if (env->entities[agent_idx].respawn_count == 0) { - originals_remaining = 1; - break; - } - } - - if (env->timestep == env->episode_length || (!originals_remaining && env->termination_mode == 1)) { - add_log(env); - c_reset(env); - return; - } - // Move static experts for (int i = 0; i < env->expert_static_agent_count; i++) { int expert_idx = env->expert_static_agent_indices[i]; @@ -2107,6 +2095,7 @@ void c_step(Drive *env) { int agent_idx = env->active_agent_indices[i]; int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; if (reached_goal) { + env->terminals[i] = 1; respawn_agent(env, agent_idx); env->entities[agent_idx].respawn_count++; } @@ -2122,6 +2111,27 @@ void c_step(Drive *env) { } } + // Episode boundary after this step: treat time-limit and early-termination as truncation. + // `timestep` is incremented at step start, so truncate when `(timestep + 1) >= episode_length`. + int originals_remaining = 0; + for (int i = 0; i < env->active_agent_count; i++) { + int agent_idx = env->active_agent_indices[i]; + if (env->entities[agent_idx].respawn_count == 0) { + originals_remaining = 1; + break; + } + } + int reached_time_limit = (env->timestep + 1) >= env->episode_length; + int reached_early_termination = (!originals_remaining && env->termination_mode == 1); + if (reached_time_limit || reached_early_termination) { + for (int i = 0; i < env->active_agent_count; i++) { + env->truncations[i] = 1; + } + add_log(env); + c_reset(env); + return; + } + compute_observations(env); } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index d40aea9a5..942feb05a 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -205,10 +205,12 @@ def __init__( def reset(self, seed=0): binding.vec_reset(self.c_envs, seed) self.tick = 0 + self.truncations[:] = 0 return self.observations, [] def step(self, actions): self.terminals[:] = 0 + self.truncations[:] = 0 self.actions[:] = actions binding.vec_step(self.c_envs) self.tick += 1 @@ -276,7 +278,8 @@ def step(self, actions): self.c_envs = binding.vectorize(*env_ids) binding.vec_reset(self.c_envs, seed) - self.terminals[:] = 1 + # Map resampling is an external reset boundary (dataset/map switch). Treat as truncation. + self.truncations[:] = 1 return (self.observations, self.rewards, self.terminals, self.truncations, info) def get_global_agent_state(self): diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 3e2b90b9f..8a377aab2 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -128,7 +128,7 @@ static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) { PyErr_SetString(PyExc_ValueError, "Truncations must be 1D"); return NULL; } - // env->truncations = PyArray_DATA(truncations); + env->truncations = PyArray_DATA(truncations); PyObject *seed_arg = PyTuple_GetItem(args, 5); if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) { @@ -412,7 +412,7 @@ static PyObject *vec_init(PyObject *self, PyObject *args, PyObject *kwargs) { env->actions = (void *)((char *)PyArray_DATA(actions) + i * PyArray_STRIDE(actions, 0)); env->rewards = (void *)((char *)PyArray_DATA(rewards) + i * PyArray_STRIDE(rewards, 0)); env->terminals = (void *)((char *)PyArray_DATA(terminals) + i * PyArray_STRIDE(terminals, 0)); - // env->truncations = (void*)((char*)PyArray_DATA(truncations) + i*PyArray_STRIDE(truncations, 0)); + env->truncations = (void *)((char *)PyArray_DATA(truncations) + i * PyArray_STRIDE(truncations, 0)); // Assumes each process has the same number of environments int env_seed = i + seed * vec->num_envs; diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index fdb7a1dd4..452ee0b62 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -258,7 +258,6 @@ def evaluate(self): profile("eval_misc", epoch) env_id = slice(env_id[0], env_id[-1] + 1) - done_mask = d + t # TODO: Handle truncations separately self.global_step += int(mask.sum()) profile("eval_copy", epoch) @@ -266,12 +265,14 @@ def evaluate(self): o_device = o.to(device) # , non_blocking=True) r = torch.as_tensor(r).to(device) # , non_blocking=True) d = torch.as_tensor(d).to(device) # , non_blocking=True) + t = torch.as_tensor(t).to(device) # , non_blocking=True) + done_mask = (d + t).clamp(max=1) profile("eval_forward", epoch) with torch.no_grad(), self.amp_context: state = dict( reward=r, - done=d, + done=done_mask, env_id=env_id, mask=mask, ) @@ -301,8 +302,16 @@ def evaluate(self): self.actions[batch_rows, l] = action self.logprobs[batch_rows, l] = logprob + # Truncation bootstrap hack for auto-reset envs. + # Ideally we add `gamma * V(s_{t+1})` on truncation steps, but Drive resets in C so + # the value at index `l` is post-reset. We use `values[..., l-1]` as a heuristic + # proxy for the pre-reset terminal value (bootstrap term is not clipped). + if l > 0: + trunc_mask = (t > 0) & (d == 0) + r = r + trunc_mask.to(r.dtype) * config["gamma"] * self.values[batch_rows, l - 1] self.rewards[batch_rows, l] = r - self.terminals[batch_rows, l] = d.float() + self.terminals[batch_rows, l] = done_mask.float() + self.truncations[batch_rows, l] = t.float() self.values[batch_rows, l] = value.flatten() # Note: We are not yet handling masks in this version