Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ vtrace_rho_clip = 1
checkpoint_interval = 1000
; Rendering options
render = True
render_async = False # Render interval of below 50 might cause process starvation and slowness in training
render_interval = 1000
; If True, show exactly what the agent sees in agent observation
obs_only = True
Expand Down
10 changes: 5 additions & 5 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ def __init__(

# Check maps availability
available_maps = len(self.map_files)
if num_maps > available_maps:
if self.num_maps > available_maps:
if allow_fewer_maps:
print("\n" + "=" * 80)
print("WARNING: FEWER MAPS THAN REQUESTED")
print(f"Requested {num_maps} maps but only {available_maps} available in {map_dir}")
print(f"Requested {self.num_maps} maps but only {available_maps} available in {map_dir}")
print(f"Maps will be randomly sampled from the {available_maps} available maps.")
print("=" * 80 + "\n")
num_maps = available_maps
self.num_maps = available_maps
else:
raise ValueError(
f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). "
f"num_maps ({self.num_maps}) exceeds available maps in directory ({available_maps}). "
f"Please reduce num_maps, add more maps to {map_dir}, or set allow_fewer_maps=True."
)
self.max_controlled_agents = int(max_controlled_agents)
Expand All @@ -164,7 +164,7 @@ def __init__(
agent_offsets, map_ids, num_envs = binding.shared(
map_files=self.map_files,
num_agents=num_agents,
num_maps=num_maps,
num_maps=self.num_maps,
init_mode=self.init_mode,
control_mode=self.control_mode,
init_steps=self.init_steps,
Expand Down
116 changes: 113 additions & 3 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@

import signal # Aggressively exit on ctrl+c

import multiprocessing
import queue

signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))

# Assume advantage kernel has been built if CUDA compiler is available
Expand Down Expand Up @@ -121,11 +124,16 @@ def __init__(self, config, vecenv, policy, logger=None):
self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32)
self.free_idx = total_agents
self.render = config["render"]
self.render_async = config["render_async"] and self.render # Only supported if rendering is enabled
self.render_interval = config["render_interval"]

if self.render:
ensure_drive_binary()

if self.render_async:
self.render_queue = multiprocessing.Queue()
self.render_processes = []

# LSTM
if config["use_rnn"]:
n = vecenv.agents_per_batch
Expand Down Expand Up @@ -193,6 +201,9 @@ def __init__(self, config, vecenv, policy, logger=None):
self.logger = logger
if logger is None:
self.logger = NoLogger(config)
if self.render_async and hasattr(self.logger, "wandb") and self.logger.wandb:
self.logger.wandb.define_metric("render_step", hidden=True)
self.logger.wandb.define_metric("render/*", step_metric="render_step")

# Learning rate scheduler
epochs = config["total_timesteps"] // config["batch_size"]
Expand Down Expand Up @@ -514,9 +525,53 @@ def train(self):
path=bin_path,
silent=True,
)
pufferlib.utils.render_videos(
self.config, self.vecenv, self.logger, self.epoch, self.global_step, bin_path
)

bin_path_epoch = f"{model_dir}_epoch_{self.epoch:06d}.bin"
shutil.copy2(bin_path, bin_path_epoch)

env_cfg = getattr(self.vecenv, "driver_env", None)
wandb_log = True if hasattr(self.logger, "wandb") and self.logger.wandb else False
wandb_run = self.logger.wandb if hasattr(self.logger, "wandb") else None
if self.render_async:
# Clean up finished processes
self.render_processes = [p for p in self.render_processes if p.is_alive()]

# Cap the number of processes to num_workers
max_processes = self.config.get("num_workers", 1)
if len(self.render_processes) >= max_processes:
print("Waiting for render processes to finish...")
while len(self.render_processes) >= max_processes:
time.sleep(1)
self.render_processes = [p for p in self.render_processes if p.is_alive()]

render_proc = multiprocessing.Process(
target=pufferlib.utils.render_videos,
args=(
self.config,
env_cfg,
self.logger.run_id,
wandb_log,
self.epoch,
self.global_step,
bin_path_epoch,
self.render_async,
self.render_queue,
),
)
render_proc.start()
self.render_processes.append(render_proc)
else:
pufferlib.utils.render_videos(
self.config,
env_cfg,
self.logger.run_id,
wandb_log,
self.epoch,
self.global_step,
bin_path_epoch,
self.render_async,
wandb_run=wandb_run,
)

except Exception as e:
print(f"Failed to export model weights: {e}")
Expand All @@ -531,7 +586,41 @@ def train(self):
):
pufferlib.utils.run_human_replay_eval_in_subprocess(self.config, self.logger, self.global_step)

def check_render_queue(self):
"""Check if any async render jobs finished and log them."""
if not self.render_async or not hasattr(self, "render_queue"):
return

try:
while not self.render_queue.empty():
result = self.render_queue.get_nowait()
step = result["step"]
videos = result["videos"]

# Log to wandb if available
if hasattr(self.logger, "wandb") and self.logger.wandb:
import wandb

payload = {}
if videos["output_topdown"]:
payload["render/world_state"] = [wandb.Video(p, format="mp4") for p in videos["output_topdown"]]
if videos["output_agent"]:
payload["render/agent_view"] = [wandb.Video(p, format="mp4") for p in videos["output_agent"]]

if payload:
# Custom step for render logs to prevent monotonic logic wandb errors
payload["render_step"] = step
self.logger.wandb.log(payload)

except queue.Empty:
pass
except Exception as e:
print(f"Error reading render queue: {e}")

def mean_and_log(self):
# Check render queue for finished async jobs
self.check_render_queue()

config = self.config
for k in list(self.stats.keys()):
v = self.stats[k]
Expand Down Expand Up @@ -567,6 +656,25 @@ def mean_and_log(self):
def close(self):
self.vecenv.close()
self.utilization.stop()

if self.render_async: # Ensure all render processes are properly terminated before closing the queue
if hasattr(self, "render_processes"):
for p in self.render_processes:
try:
if p.is_alive():
p.terminate()
p.join(timeout=5)
if p.is_alive():
p.kill()
except Exception:
# Best-effort cleanup; avoid letting close() crash on process errors
print(f"Failed to terminate render process {p.pid}")
# Optionally clear the list to drop references to finished processes
self.render_processes = []
if hasattr(self, "render_queue"):
self.render_queue.close()
self.render_queue.join_thread()

model_path = self.save_checkpoint()
run_id = self.logger.run_id
path = os.path.join(self.config["data_dir"], f"{self.config['env']}_{run_id}.pt")
Expand Down Expand Up @@ -999,6 +1107,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):
logger = None

train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {}))
if "vec" in args and "num_workers" in args["vec"]:
train_config["num_workers"] = args["vec"]["num_workers"]
pufferl = PuffeRL(train_config, vecenv, policy, logger)

all_logs = []
Expand Down
60 changes: 39 additions & 21 deletions pufferlib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def run_wosac_eval_in_subprocess(config, logger, global_step):
print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}")


def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
def render_videos(
config, env_cfg, run_id, wandb_log, epoch, global_step, bin_path, render_async, render_queue=None, wandb_run=None
):
"""
Generate and log training videos using C-based rendering.

Expand All @@ -186,7 +188,6 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
print(f"Binary weights file does not exist: {bin_path}")
return

run_id = logger.run_id
model_dir = os.path.join(config["data_dir"], f"{config['env']}_{run_id}")

# Now call the C rendering function
Expand All @@ -195,11 +196,6 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
video_output_dir = os.path.join(model_dir, "videos")
os.makedirs(video_output_dir, exist_ok=True)

# Copy the binary weights to the expected location
expected_weights_path = "resources/drive/puffer_drive_weights.bin"
os.makedirs(os.path.dirname(expected_weights_path), exist_ok=True)
shutil.copy2(bin_path, expected_weights_path)

# TODO: Fix memory leaks so that this is not needed
# Suppress AddressSanitizer exit code (temp)
env_vars = os.environ.copy()
Expand Down Expand Up @@ -230,10 +226,11 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
base_cmd.extend(["--view", view_mode])

# Get num_maps if available
env_cfg = getattr(vecenv, "driver_env", None)
if env_cfg is not None and getattr(env_cfg, "num_maps", None):
base_cmd.extend(["--num-maps", str(env_cfg.num_maps)])

base_cmd.extend(["--policy-name", bin_path])

# Handle single or multiple map rendering
render_maps = config.get("render_map", None)
if render_maps is None or render_maps == "none":
Expand Down Expand Up @@ -262,40 +259,49 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
# Collect videos to log as lists so W&B shows all in the same step
videos_to_log_world = []
videos_to_log_agent = []
generated_videos = {"output_topdown": [], "output_agent": []}
output_topdown = f"resources/drive/output_topdown_{epoch}"
output_agent = f"resources/drive/output_agent_{epoch}"

for i, map_path in enumerate(render_maps):
cmd = list(base_cmd) # copy
if map_path is not None and os.path.exists(map_path):
cmd.extend(["--map-name", str(map_path)])

output_topdown_map = output_topdown + (f"_map{i:02d}.mp4" if len(render_maps) > 1 else ".mp4")
output_agent_map = output_agent + (f"_map{i:02d}.mp4" if len(render_maps) > 1 else ".mp4")

# Output paths (overwrite each iteration; then moved/renamed)
cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"])
cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"])
cmd.extend(["--output-topdown", output_topdown_map])
cmd.extend(["--output-agent", output_agent_map])

result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=600, env=env_vars)

vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists(
"resources/drive/output_agent.mp4"
)
vids_exist = os.path.exists(output_topdown_map) and os.path.exists(output_agent_map)

if result.returncode == 0 or (result.returncode == 1 and vids_exist):
videos = [
(
"resources/drive/output_topdown.mp4",
"output_topdown",
output_topdown_map,
f"epoch_{epoch:06d}_map{i:02d}_topdown.mp4" if map_path else f"epoch_{epoch:06d}_topdown.mp4",
),
(
"resources/drive/output_agent.mp4",
"output_agent",
output_agent_map,
f"epoch_{epoch:06d}_map{i:02d}_agent.mp4" if map_path else f"epoch_{epoch:06d}_agent.mp4",
),
]

for source_vid, target_filename in videos:
for vid_type, source_vid, target_filename in videos:
if os.path.exists(source_vid):
target_path = os.path.join(video_output_dir, target_filename)
shutil.move(source_vid, target_path)
generated_videos[vid_type].append(target_path)
if render_async:
continue
# Accumulate for a single wandb.log call
if hasattr(logger, "wandb") and logger.wandb:
if wandb_log:
import wandb

if "topdown" in target_filename:
Expand All @@ -304,17 +310,29 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
videos_to_log_agent.append(wandb.Video(target_path, format="mp4"))
else:
print(f"Video generation completed but {source_vid} not found")
if result.stdout:
print(f"StdOUT: {result.stdout}")
if result.stderr:
print(f"StdERR: {result.stderr}")
else:
print(f"C rendering failed (map index {i}) with exit code {result.returncode}: {result.stdout}")

if render_async:
render_queue.put(
{
"videos": generated_videos,
"step": global_step,
}
)

# Log all videos at once so W&B keeps all of them under the same step
if hasattr(logger, "wandb") and logger.wandb and (videos_to_log_world or videos_to_log_agent):
if wandb_log and (videos_to_log_world or videos_to_log_agent) and not render_async:
payload = {}
if videos_to_log_world:
payload["render/world_state"] = videos_to_log_world
if videos_to_log_agent:
payload["render/agent_view"] = videos_to_log_agent
logger.wandb.log(payload, step=global_step)
wandb_run.log(payload, step=global_step)

except subprocess.TimeoutExpired:
print("C rendering timed out")
Expand All @@ -323,5 +341,5 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):

finally:
# Clean up bin weights file
if os.path.exists(expected_weights_path):
os.remove(expected_weights_path)
if os.path.exists(bin_path):
os.remove(bin_path)