Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions pufferlib/ocean/drive/visualize.c
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ static int make_gif_from_frames(const char *pattern, int fps, const char *palett
return 0;
}

int eval_gif(const char *map_name, const char *policy_name, int show_grid, int obs_only, int lasers,
int show_human_logs, int frame_skip, const char *view_mode, const char *output_topdown,
int eval_gif(const char *map_name, const char *policy_name, const char *config_file, int show_grid, int obs_only,
int lasers, int show_human_logs, int frame_skip, const char *view_mode, const char *output_topdown,
const char *output_agent, int num_maps, int zoom_in) {

// Parse configuration from INI file
env_init_config conf = {0};
const char *ini_file = "pufferlib/config/ocean/drive.ini";
const char *ini_file = config_file ? config_file : "pufferlib/config/ocean/drive.ini";
if (ini_parse(ini_file, handler, &conf) < 0) {
fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file);
return -1;
Expand Down Expand Up @@ -417,6 +417,7 @@ int main(int argc, char *argv[]) {
// File paths and num_maps (not in [env] section)
const char *map_name = NULL;
const char *policy_name = "resources/drive/puffer_drive_weights.bin";
const char *config_file = NULL;
const char *output_topdown = NULL;
const char *output_agent = NULL;
int num_maps = 1;
Expand Down Expand Up @@ -485,10 +486,18 @@ int main(int argc, char *argv[]) {
num_maps = atoi(argv[i + 1]);
i++;
}
} else if (strcmp(argv[i], "--config") == 0) {
if (i + 1 < argc) {
config_file = argv[i + 1];
i++;
} else {
fprintf(stderr, "Error: --config option requires a config file path\n");
return 1;
}
}
}

eval_gif(map_name, policy_name, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode, output_topdown,
output_agent, num_maps, zoom_in);
eval_gif(map_name, policy_name, config_file, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode,
output_topdown, output_agent, num_maps, zoom_in);
return 0;
}
42 changes: 39 additions & 3 deletions pufferlib/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import configparser
import os
import sys
import glob
Expand Down Expand Up @@ -167,6 +168,37 @@ def run_wosac_eval_in_subprocess(config, logger, global_step):
print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}")


def _write_visualizer_config(env_cfg, output_path, base_ini="pufferlib/config/ocean/drive.ini"):
"""Write a config file for the visualizer, inheriting from base INI but overriding with env_cfg."""
cfg = configparser.ConfigParser()
if not cfg.read(base_ini):
raise FileNotFoundError(f"Base config file not found: {base_ini}")

if env_cfg is not None and "env" in cfg:
# Special attribute mappings for Drive class (attr stored differently than INI key)
special_mappings = {
"action_type": ("_action_type_flag", lambda v: "discrete" if v == 0 else "continuous"),
}

# Override [env] values with attributes from env_cfg
for key in cfg["env"]:
attr_name = key.replace("-", "_")

if attr_name in special_mappings:
# Handle special cases where attribute name/format differs
real_attr, transform = special_mappings[attr_name]
if hasattr(env_cfg, real_attr):
cfg["env"][key] = transform(getattr(env_cfg, real_attr))
elif hasattr(env_cfg, attr_name):
cfg["env"][key] = str(getattr(env_cfg, attr_name))
elif hasattr(env_cfg, f"{attr_name}_str"):
# For attributes like control_mode that have a _str variant
cfg["env"][key] = str(getattr(env_cfg, f"{attr_name}_str"))

with open(output_path, "w") as f:
cfg.write(f)


def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
"""
Generate and log training videos using C-based rendering.
Expand Down Expand Up @@ -205,8 +237,13 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
env_vars = os.environ.copy()
env_vars["ASAN_OPTIONS"] = "exitcode=0"

# Base command with only visualization flags (env config comes from INI)
base_cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"]
# Write temp config inheriting base INI but with training env overrides
env_cfg = getattr(vecenv, "driver_env", None)
temp_config_path = os.path.join(model_dir, "visualizer_config.ini")
_write_visualizer_config(env_cfg, temp_config_path)

# Base command with config file
base_cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize", "--config", temp_config_path]

# Visualization config flags only
if config.get("show_grid", False):
Expand All @@ -230,7 +267,6 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path):
base_cmd.extend(["--view", view_mode])

# Get num_maps if available
env_cfg = getattr(vecenv, "driver_env", None)
if env_cfg is not None and getattr(env_cfg, "num_maps", None):
base_cmd.extend(["--num-maps", str(env_cfg.num_maps)])

Expand Down