diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index de235b4fa..d89e807bd 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -189,13 +189,13 @@ static int make_gif_from_frames(const char *pattern, int fps, const char *palett return 0; } -int eval_gif(const char *map_name, const char *policy_name, int show_grid, int obs_only, int lasers, - int show_human_logs, int frame_skip, const char *view_mode, const char *output_topdown, +int eval_gif(const char *map_name, const char *policy_name, const char *config_file, int show_grid, int obs_only, + int lasers, int show_human_logs, int frame_skip, const char *view_mode, const char *output_topdown, const char *output_agent, int num_maps, int zoom_in) { // Parse configuration from INI file env_init_config conf = {0}; - const char *ini_file = "pufferlib/config/ocean/drive.ini"; + const char *ini_file = config_file ? config_file : "pufferlib/config/ocean/drive.ini"; if (ini_parse(ini_file, handler, &conf) < 0) { fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); return -1; @@ -417,6 +417,7 @@ int main(int argc, char *argv[]) { // File paths and num_maps (not in [env] section) const char *map_name = NULL; const char *policy_name = "resources/drive/puffer_drive_weights.bin"; + const char *config_file = NULL; const char *output_topdown = NULL; const char *output_agent = NULL; int num_maps = 1; @@ -485,10 +486,18 @@ int main(int argc, char *argv[]) { num_maps = atoi(argv[i + 1]); i++; } + } else if (strcmp(argv[i], "--config") == 0) { + if (i + 1 < argc) { + config_file = argv[i + 1]; + i++; + } else { + fprintf(stderr, "Error: --config option requires a config file path\n"); + return 1; + } } } - eval_gif(map_name, policy_name, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode, output_topdown, - output_agent, num_maps, zoom_in); + eval_gif(map_name, policy_name, config_file, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode, + output_topdown, output_agent, num_maps, zoom_in); return 0; } diff --git a/pufferlib/utils.py b/pufferlib/utils.py index 860a61934..6796d9212 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -1,3 +1,4 @@ +import configparser import os import sys import glob @@ -167,6 +168,37 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}") +def _write_visualizer_config(env_cfg, output_path, base_ini="pufferlib/config/ocean/drive.ini"): + """Write a config file for the visualizer, inheriting from base INI but overriding with env_cfg.""" + cfg = configparser.ConfigParser() + if not cfg.read(base_ini): + raise FileNotFoundError(f"Base config file not found: {base_ini}") + + if env_cfg is not None and "env" in cfg: + # Special attribute mappings for Drive class (attr stored differently than INI key) + special_mappings = { + "action_type": ("_action_type_flag", lambda v: "discrete" if v == 0 else "continuous"), + } + + # Override [env] values with attributes from env_cfg + for key in cfg["env"]: + attr_name = key.replace("-", "_") + + if attr_name in special_mappings: + # Handle special cases where attribute name/format differs + real_attr, transform = special_mappings[attr_name] + if hasattr(env_cfg, real_attr): + cfg["env"][key] = transform(getattr(env_cfg, real_attr)) + elif hasattr(env_cfg, attr_name): + cfg["env"][key] = str(getattr(env_cfg, attr_name)) + elif hasattr(env_cfg, f"{attr_name}_str"): + # For attributes like control_mode that have a _str variant + cfg["env"][key] = str(getattr(env_cfg, f"{attr_name}_str")) + + with open(output_path, "w") as f: + cfg.write(f) + + def render_videos(config, vecenv, logger, epoch, global_step, bin_path): """ Generate and log training videos using C-based rendering. @@ -205,8 +237,13 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): env_vars = os.environ.copy() env_vars["ASAN_OPTIONS"] = "exitcode=0" - # Base command with only visualization flags (env config comes from INI) - base_cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"] + # Write temp config inheriting base INI but with training env overrides + env_cfg = getattr(vecenv, "driver_env", None) + temp_config_path = os.path.join(model_dir, "visualizer_config.ini") + _write_visualizer_config(env_cfg, temp_config_path) + + # Base command with config file + base_cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize", "--config", temp_config_path] # Visualization config flags only if config.get("show_grid", False): @@ -230,7 +267,6 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): base_cmd.extend(["--view", view_mode]) # Get num_maps if available - env_cfg = getattr(vecenv, "driver_env", None) if env_cfg is not None and getattr(env_cfg, "num_maps", None): base_cmd.extend(["--num-maps", str(env_cfg.num_maps)])