Skip to content

Commit 0947940

Browse files
committed
Update
[ghstack-poisoned]
2 parents d7685df + f300d25 commit 0947940

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

test/test_env.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pickle
1414
import random
1515
import re
16+
import time
1617
from collections import defaultdict
1718
from functools import partial
1819
from sys import platform
@@ -3715,26 +3716,39 @@ def test_batched_nondynamic(self, penv):
37153716
use_buffers=True,
37163717
mp_start_method=mp_ctx if penv is ParallelEnv else None,
37173718
)
3718-
env_buffers.set_seed(0)
3719-
torch.manual_seed(0)
3720-
rollout_buffers = env_buffers.rollout(
3721-
20, return_contiguous=True, break_when_any_done=False
3722-
)
3723-
del env_buffers
3719+
try:
3720+
env_buffers.set_seed(0)
3721+
torch.manual_seed(0)
3722+
rollout_buffers = env_buffers.rollout(
3723+
20, return_contiguous=True, break_when_any_done=False
3724+
)
3725+
finally:
3726+
env_buffers.close(raise_if_closed=False)
3727+
del env_buffers
37243728
gc.collect()
3729+
# Add a small delay to allow multiprocessing resource_sharer threads
3730+
# to fully clean up before creating the next environment. This prevents
3731+
# a race condition where the old resource_sharer service thread is still
3732+
# active when the new environment starts, causing a deadlock.
3733+
# See: https://bugs.python.org/issue30289
3734+
if penv is ParallelEnv:
3735+
time.sleep(0.1)
37253736

37263737
env_no_buffers = penv(
37273738
3,
37283739
lambda: GymEnv(CARTPOLE_VERSIONED(), device=None),
37293740
use_buffers=False,
37303741
mp_start_method=mp_ctx if penv is ParallelEnv else None,
37313742
)
3732-
env_no_buffers.set_seed(0)
3733-
torch.manual_seed(0)
3734-
rollout_no_buffers = env_no_buffers.rollout(
3735-
20, return_contiguous=True, break_when_any_done=False
3736-
)
3737-
del env_no_buffers
3743+
try:
3744+
env_no_buffers.set_seed(0)
3745+
torch.manual_seed(0)
3746+
rollout_no_buffers = env_no_buffers.rollout(
3747+
20, return_contiguous=True, break_when_any_done=False
3748+
)
3749+
finally:
3750+
env_no_buffers.close(raise_if_closed=False)
3751+
del env_no_buffers
37383752
gc.collect()
37393753
assert_allclose_td(rollout_buffers, rollout_no_buffers)
37403754

test/test_libs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import collections
8+
import copy
89
import functools
910
import gc
1011
import importlib.util
@@ -2811,14 +2812,27 @@ def test_vmas_seeding(self, scenario_name):
28112812
final_seed = []
28122813
tdreset = []
28132814
tdrollout = []
2814-
for _ in range(2):
2815-
env = VmasEnv(
2815+
rollout_length = 10
2816+
2817+
def create_env():
2818+
return VmasEnv(
28162819
scenario=scenario_name,
28172820
num_envs=4,
28182821
)
2822+
2823+
env = create_env()
2824+
td_actions = [env.action_spec.rand() for _ in range(rollout_length)]
2825+
2826+
for _ in range(2):
2827+
env = create_env()
2828+
td_actions_buffer = copy.deepcopy(td_actions)
2829+
2830+
def policy(td, actions=td_actions_buffer):
2831+
return actions.pop(0)
2832+
28192833
final_seed.append(env.set_seed(0))
28202834
tdreset.append(env.reset())
2821-
tdrollout.append(env.rollout(max_steps=10))
2835+
tdrollout.append(env.rollout(max_steps=rollout_length, policy=policy))
28222836
env.close()
28232837
del env
28242838
assert final_seed[0] == final_seed[1]

0 commit comments

Comments
 (0)