diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 3d315a22c..abb5bf491 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -50,6 +50,10 @@ map_dir = "resources/drive/binaries/training" num_maps = 10000 ; If True, allows training with fewer maps than requested by resampling from available maps allow_map_resampling = True +; Optional: Mix maps from multiple sources with weights. Format: "path1:weight1,path2:weight2" +; Example: "resources/drive/binaries/training:0.75,resources/drive/binaries/carla:0.25" +; If set, overrides map_dir +map_sources = ; Determines which step of the trajectory to initialize the agents at upon reset init_steps = 0 ; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only" diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 1f4264ab0..d697ae814 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -3,12 +3,89 @@ import json import struct import os +import tempfile +import random +import atexit +import shutil import pufferlib from pufferlib.ocean.drive import binding from multiprocessing import Pool, cpu_count from tqdm import tqdm +def create_mixed_map_directory(map_sources_str, num_maps): + """ + Parse map_sources string and create a temp directory with symlinks. + + Args: + map_sources_str: Format "path1:weight1,path2:weight2,..." + num_maps: Total number of maps to create + + Returns: + Path to temp directory containing symlinked maps + """ + # Parse the map_sources string + sources = [] + for source in map_sources_str.split(","): + source = source.strip() + if ":" not in source: + raise ValueError(f"Invalid map_sources format: '{source}'. Expected 'path:weight'") + path, weight = source.rsplit(":", 1) + sources.append((path.strip(), float(weight.strip()))) + + # Normalize weights + total_weight = sum(w for _, w in sources) + sources = [(p, w / total_weight) for p, w in sources] + + # Calculate how many maps from each source + map_counts = [] + remaining = num_maps + for i, (path, weight) in enumerate(sources): + if i == len(sources) - 1: + count = remaining # Last source gets remainder + else: + count = int(round(num_maps * weight)) + remaining -= count + map_counts.append(count) + + # Create temp directory and register cleanup + temp_dir = tempfile.mkdtemp(prefix="pufferdrive_maps_") + atexit.register(shutil.rmtree, temp_dir, ignore_errors=True) + + # Collect maps from each source + selected_maps = [] + for (path, _), count in zip(sources, map_counts): + available = sorted([f for f in os.listdir(path) if f.endswith(".bin")]) + if len(available) == 0: + raise ValueError(f"No .bin files found in {path}") + + # Sample without replacement if possible, with replacement if we need more than available + if count <= len(available): + sampled = random.sample(available, k=count) + else: + sampled = random.choices(available, k=count) + for map_file in sampled: + selected_maps.append(os.path.join(path, map_file)) + + # Shuffle to mix sources + random.shuffle(selected_maps) + + # Create symlinks + for i, src_path in enumerate(selected_maps): + link_name = os.path.join(temp_dir, f"map_{i:03d}.bin") + os.symlink(os.path.abspath(src_path), link_name) + + # Print summary + print("\n" + "=" * 80) + print("MIXED MAP SOURCES") + for (path, weight), count in zip(sources, map_counts): + print(f" {path}: {count} maps ({weight * 100:.1f}%)") + print(f"Total: {len(selected_maps)} maps in {temp_dir}") + print("=" * 80 + "\n") + + return temp_dir + + class Drive(pufferlib.PufferEnv): def __init__( self, @@ -42,6 +119,7 @@ def __init__( init_mode="create_all_valid", control_mode="control_vehicles", map_dir="resources/drive/binaries/training", + map_sources=None, use_all_maps=False, allow_map_resampling=True, ): @@ -129,6 +207,9 @@ def __init__( self._action_type_flag = 0 if action_type == "discrete" else 1 + if map_sources is not None and map_sources != "": + map_dir = create_mixed_map_directory(map_sources, num_maps) + self.map_dir = map_dir # Update instance variable to point to temp directory # Check if map directory contains any bin files if not os.path.isdir(map_dir): raise FileNotFoundError( @@ -155,6 +236,7 @@ def __init__( f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). " f"Please reduce num_maps, add more maps to {map_dir}, or set allow_map_resampling=True." ) + self.max_controlled_agents = int(max_controlled_agents) # Iterate through all maps to count total agents that can be initialized for each map