diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index e8a650864..f68434856 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -87,6 +87,7 @@ vtrace_rho_clip = 1 checkpoint_interval = 1000 ; Rendering options render = True +render_async = False # Render interval of below 50 might cause process starvation and slowness in training render_interval = 1000 ; If True, show exactly what the agent sees in agent observation obs_only = True diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index c533fc648..61c4f7ed1 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -145,17 +145,17 @@ def __init__( # Check maps availability available_maps = len(self.map_files) - if num_maps > available_maps: + if self.num_maps > available_maps: if allow_fewer_maps: print("\n" + "=" * 80) print("WARNING: FEWER MAPS THAN REQUESTED") - print(f"Requested {num_maps} maps but only {available_maps} available in {map_dir}") + print(f"Requested {self.num_maps} maps but only {available_maps} available in {map_dir}") print(f"Maps will be randomly sampled from the {available_maps} available maps.") print("=" * 80 + "\n") - num_maps = available_maps + self.num_maps = available_maps else: raise ValueError( - f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). " + f"num_maps ({self.num_maps}) exceeds available maps in directory ({available_maps}). " f"Please reduce num_maps, add more maps to {map_dir}, or set allow_fewer_maps=True." ) self.max_controlled_agents = int(max_controlled_agents) @@ -164,7 +164,7 @@ def __init__( agent_offsets, map_ids, num_envs = binding.shared( map_files=self.map_files, num_agents=num_agents, - num_maps=num_maps, + num_maps=self.num_maps, init_mode=self.init_mode, control_mode=self.control_mode, init_steps=self.init_steps, diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index a8a131583..908bd4537 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -53,6 +53,9 @@ import signal # Aggressively exit on ctrl+c +import multiprocessing +import queue + signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0)) # Assume advantage kernel has been built if CUDA compiler is available @@ -121,11 +124,16 @@ def __init__(self, config, vecenv, policy, logger=None): self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32) self.free_idx = total_agents self.render = config["render"] + self.render_async = config["render_async"] and self.render # Only supported if rendering is enabled self.render_interval = config["render_interval"] if self.render: ensure_drive_binary() + if self.render_async: + self.render_queue = multiprocessing.Queue() + self.render_processes = [] + # LSTM if config["use_rnn"]: n = vecenv.agents_per_batch @@ -193,6 +201,9 @@ def __init__(self, config, vecenv, policy, logger=None): self.logger = logger if logger is None: self.logger = NoLogger(config) + if self.render_async and hasattr(self.logger, "wandb") and self.logger.wandb: + self.logger.wandb.define_metric("render_step", hidden=True) + self.logger.wandb.define_metric("render/*", step_metric="render_step") # Learning rate scheduler epochs = config["total_timesteps"] // config["batch_size"] @@ -514,9 +525,53 @@ def train(self): path=bin_path, silent=True, ) - pufferlib.utils.render_videos( - self.config, self.vecenv, self.logger, self.epoch, self.global_step, bin_path - ) + + bin_path_epoch = f"{model_dir}_epoch_{self.epoch:06d}.bin" + shutil.copy2(bin_path, bin_path_epoch) + + env_cfg = getattr(self.vecenv, "driver_env", None) + wandb_log = True if hasattr(self.logger, "wandb") and self.logger.wandb else False + wandb_run = self.logger.wandb if hasattr(self.logger, "wandb") else None + if self.render_async: + # Clean up finished processes + self.render_processes = [p for p in self.render_processes if p.is_alive()] + + # Cap the number of processes to num_workers + max_processes = self.config.get("num_workers", 1) + if len(self.render_processes) >= max_processes: + print("Waiting for render processes to finish...") + while len(self.render_processes) >= max_processes: + time.sleep(1) + self.render_processes = [p for p in self.render_processes if p.is_alive()] + + render_proc = multiprocessing.Process( + target=pufferlib.utils.render_videos, + args=( + self.config, + env_cfg, + self.logger.run_id, + wandb_log, + self.epoch, + self.global_step, + bin_path_epoch, + self.render_async, + self.render_queue, + ), + ) + render_proc.start() + self.render_processes.append(render_proc) + else: + pufferlib.utils.render_videos( + self.config, + env_cfg, + self.logger.run_id, + wandb_log, + self.epoch, + self.global_step, + bin_path_epoch, + self.render_async, + wandb_run=wandb_run, + ) except Exception as e: print(f"Failed to export model weights: {e}") @@ -531,7 +586,41 @@ def train(self): ): pufferlib.utils.run_human_replay_eval_in_subprocess(self.config, self.logger, self.global_step) + def check_render_queue(self): + """Check if any async render jobs finished and log them.""" + if not self.render_async or not hasattr(self, "render_queue"): + return + + try: + while not self.render_queue.empty(): + result = self.render_queue.get_nowait() + step = result["step"] + videos = result["videos"] + + # Log to wandb if available + if hasattr(self.logger, "wandb") and self.logger.wandb: + import wandb + + payload = {} + if videos["output_topdown"]: + payload["render/world_state"] = [wandb.Video(p, format="mp4") for p in videos["output_topdown"]] + if videos["output_agent"]: + payload["render/agent_view"] = [wandb.Video(p, format="mp4") for p in videos["output_agent"]] + + if payload: + # Custom step for render logs to prevent monotonic logic wandb errors + payload["render_step"] = step + self.logger.wandb.log(payload) + + except queue.Empty: + pass + except Exception as e: + print(f"Error reading render queue: {e}") + def mean_and_log(self): + # Check render queue for finished async jobs + self.check_render_queue() + config = self.config for k in list(self.stats.keys()): v = self.stats[k] @@ -567,6 +656,25 @@ def mean_and_log(self): def close(self): self.vecenv.close() self.utilization.stop() + + if self.render_async: # Ensure all render processes are properly terminated before closing the queue + if hasattr(self, "render_processes"): + for p in self.render_processes: + try: + if p.is_alive(): + p.terminate() + p.join(timeout=5) + if p.is_alive(): + p.kill() + except Exception: + # Best-effort cleanup; avoid letting close() crash on process errors + print(f"Failed to terminate render process {p.pid}") + # Optionally clear the list to drop references to finished processes + self.render_processes = [] + if hasattr(self, "render_queue"): + self.render_queue.close() + self.render_queue.join_thread() + model_path = self.save_checkpoint() run_id = self.logger.run_id path = os.path.join(self.config["data_dir"], f"{self.config['env']}_{run_id}.pt") @@ -999,6 +1107,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): logger = None train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {})) + if "vec" in args and "num_workers" in args["vec"]: + train_config["num_workers"] = args["vec"]["num_workers"] pufferl = PuffeRL(train_config, vecenv, policy, logger) all_logs = [] diff --git a/pufferlib/utils.py b/pufferlib/utils.py index 860a61934..1f2ccd514 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -167,7 +167,9 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}") -def render_videos(config, vecenv, logger, epoch, global_step, bin_path): +def render_videos( + config, env_cfg, run_id, wandb_log, epoch, global_step, bin_path, render_async, render_queue=None, wandb_run=None +): """ Generate and log training videos using C-based rendering. @@ -186,7 +188,6 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): print(f"Binary weights file does not exist: {bin_path}") return - run_id = logger.run_id model_dir = os.path.join(config["data_dir"], f"{config['env']}_{run_id}") # Now call the C rendering function @@ -195,11 +196,6 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): video_output_dir = os.path.join(model_dir, "videos") os.makedirs(video_output_dir, exist_ok=True) - # Copy the binary weights to the expected location - expected_weights_path = "resources/drive/puffer_drive_weights.bin" - os.makedirs(os.path.dirname(expected_weights_path), exist_ok=True) - shutil.copy2(bin_path, expected_weights_path) - # TODO: Fix memory leaks so that this is not needed # Suppress AddressSanitizer exit code (temp) env_vars = os.environ.copy() @@ -230,10 +226,11 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): base_cmd.extend(["--view", view_mode]) # Get num_maps if available - env_cfg = getattr(vecenv, "driver_env", None) if env_cfg is not None and getattr(env_cfg, "num_maps", None): base_cmd.extend(["--num-maps", str(env_cfg.num_maps)]) + base_cmd.extend(["--policy-name", bin_path]) + # Handle single or multiple map rendering render_maps = config.get("render_map", None) if render_maps is None or render_maps == "none": @@ -262,40 +259,49 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): # Collect videos to log as lists so W&B shows all in the same step videos_to_log_world = [] videos_to_log_agent = [] + generated_videos = {"output_topdown": [], "output_agent": []} + output_topdown = f"resources/drive/output_topdown_{epoch}" + output_agent = f"resources/drive/output_agent_{epoch}" for i, map_path in enumerate(render_maps): cmd = list(base_cmd) # copy if map_path is not None and os.path.exists(map_path): cmd.extend(["--map-name", str(map_path)]) + output_topdown_map = output_topdown + (f"_map{i:02d}.mp4" if len(render_maps) > 1 else ".mp4") + output_agent_map = output_agent + (f"_map{i:02d}.mp4" if len(render_maps) > 1 else ".mp4") + # Output paths (overwrite each iteration; then moved/renamed) - cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"]) - cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"]) + cmd.extend(["--output-topdown", output_topdown_map]) + cmd.extend(["--output-agent", output_agent_map]) result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=600, env=env_vars) - vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists( - "resources/drive/output_agent.mp4" - ) + vids_exist = os.path.exists(output_topdown_map) and os.path.exists(output_agent_map) if result.returncode == 0 or (result.returncode == 1 and vids_exist): videos = [ ( - "resources/drive/output_topdown.mp4", + "output_topdown", + output_topdown_map, f"epoch_{epoch:06d}_map{i:02d}_topdown.mp4" if map_path else f"epoch_{epoch:06d}_topdown.mp4", ), ( - "resources/drive/output_agent.mp4", + "output_agent", + output_agent_map, f"epoch_{epoch:06d}_map{i:02d}_agent.mp4" if map_path else f"epoch_{epoch:06d}_agent.mp4", ), ] - for source_vid, target_filename in videos: + for vid_type, source_vid, target_filename in videos: if os.path.exists(source_vid): target_path = os.path.join(video_output_dir, target_filename) shutil.move(source_vid, target_path) + generated_videos[vid_type].append(target_path) + if render_async: + continue # Accumulate for a single wandb.log call - if hasattr(logger, "wandb") and logger.wandb: + if wandb_log: import wandb if "topdown" in target_filename: @@ -304,17 +310,29 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): videos_to_log_agent.append(wandb.Video(target_path, format="mp4")) else: print(f"Video generation completed but {source_vid} not found") + if result.stdout: + print(f"StdOUT: {result.stdout}") + if result.stderr: + print(f"StdERR: {result.stderr}") else: print(f"C rendering failed (map index {i}) with exit code {result.returncode}: {result.stdout}") + if render_async: + render_queue.put( + { + "videos": generated_videos, + "step": global_step, + } + ) + # Log all videos at once so W&B keeps all of them under the same step - if hasattr(logger, "wandb") and logger.wandb and (videos_to_log_world or videos_to_log_agent): + if wandb_log and (videos_to_log_world or videos_to_log_agent) and not render_async: payload = {} if videos_to_log_world: payload["render/world_state"] = videos_to_log_world if videos_to_log_agent: payload["render/agent_view"] = videos_to_log_agent - logger.wandb.log(payload, step=global_step) + wandb_run.log(payload, step=global_step) except subprocess.TimeoutExpired: print("C rendering timed out") @@ -323,5 +341,5 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): finally: # Clean up bin weights file - if os.path.exists(expected_weights_path): - os.remove(expected_weights_path) + if os.path.exists(bin_path): + os.remove(bin_path)