-
Notifications
You must be signed in to change notification settings - Fork 15
Add map_sources option for mixing maps from multiple directories #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 3.0
Are you sure you want to change the base?
Changes from all commits
91ea58b
40cdece
3990c32
d5a8af5
ddfc3be
c19ada0
f75cd92
6e10a10
f87764d
81336d8
89d99c0
062ab62
36ae0cd
9181164
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,12 +3,89 @@ | |||||||||
| import json | ||||||||||
| import struct | ||||||||||
| import os | ||||||||||
| import tempfile | ||||||||||
| import random | ||||||||||
| import atexit | ||||||||||
| import shutil | ||||||||||
| import pufferlib | ||||||||||
| from pufferlib.ocean.drive import binding | ||||||||||
| from multiprocessing import Pool, cpu_count | ||||||||||
| from tqdm import tqdm | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def create_mixed_map_directory(map_sources_str, num_maps): | ||||||||||
| """ | ||||||||||
| Parse map_sources string and create a temp directory with symlinks. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| map_sources_str: Format "path1:weight1,path2:weight2,..." | ||||||||||
| num_maps: Total number of maps to create | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| Path to temp directory containing symlinked maps | ||||||||||
| """ | ||||||||||
| # Parse the map_sources string | ||||||||||
| sources = [] | ||||||||||
| for source in map_sources_str.split(","): | ||||||||||
| source = source.strip() | ||||||||||
| if ":" not in source: | ||||||||||
| raise ValueError(f"Invalid map_sources format: '{source}'. Expected 'path:weight'") | ||||||||||
| path, weight = source.rsplit(":", 1) | ||||||||||
| sources.append((path.strip(), float(weight.strip()))) | ||||||||||
|
|
||||||||||
| # Normalize weights | ||||||||||
| total_weight = sum(w for _, w in sources) | ||||||||||
| sources = [(p, w / total_weight) for p, w in sources] | ||||||||||
|
|
||||||||||
| # Calculate how many maps from each source | ||||||||||
| map_counts = [] | ||||||||||
| remaining = num_maps | ||||||||||
| for i, (path, weight) in enumerate(sources): | ||||||||||
| if i == len(sources) - 1: | ||||||||||
| count = remaining # Last source gets remainder | ||||||||||
| else: | ||||||||||
| count = int(round(num_maps * weight)) | ||||||||||
| remaining -= count | ||||||||||
| map_counts.append(count) | ||||||||||
|
|
||||||||||
| # Create temp directory and register cleanup | ||||||||||
| temp_dir = tempfile.mkdtemp(prefix="pufferdrive_maps_") | ||||||||||
| atexit.register(shutil.rmtree, temp_dir, ignore_errors=True) | ||||||||||
|
|
||||||||||
| # Collect maps from each source | ||||||||||
| selected_maps = [] | ||||||||||
| for (path, _), count in zip(sources, map_counts): | ||||||||||
| available = sorted([f for f in os.listdir(path) if f.endswith(".bin")]) | ||||||||||
| if len(available) == 0: | ||||||||||
| raise ValueError(f"No .bin files found in {path}") | ||||||||||
|
|
||||||||||
| # Sample without replacement if possible, with replacement if we need more than available | ||||||||||
| if count <= len(available): | ||||||||||
| sampled = random.sample(available, k=count) | ||||||||||
| else: | ||||||||||
| sampled = random.choices(available, k=count) | ||||||||||
| for map_file in sampled: | ||||||||||
| selected_maps.append(os.path.join(path, map_file)) | ||||||||||
|
|
||||||||||
| # Shuffle to mix sources | ||||||||||
| random.shuffle(selected_maps) | ||||||||||
|
Comment on lines
+64
to
+71
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Missing seed parameter for reproducibility - Consider adding a def create_mixed_map_directory(map_sources_str, num_maps, seed=None):
rng = random.Random(seed) if seed is not None else random.Random()
# Then use rng.sample(), rng.choices(), rng.shuffle()Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! Prompt To Fix With AIThis is a comment left during a code review.
Path: pufferlib/ocean/drive/drive.py
Line: 64:71
Comment:
**logic:** Missing seed parameter for reproducibility - `random.sample()`, `random.choices()`, and `random.shuffle()` use global random state. Different runs with same config will produce different map selections.
Consider adding a `seed` parameter to `create_mixed_map_directory()` and using `random.Random(seed)` instance:
```python
def create_mixed_map_directory(map_sources_str, num_maps, seed=None):
rng = random.Random(seed) if seed is not None else random.Random()
# Then use rng.sample(), rng.choices(), rng.shuffle()
```
<sub>Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!</sub>
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||
|
|
||||||||||
| # Create symlinks | ||||||||||
| for i, src_path in enumerate(selected_maps): | ||||||||||
| link_name = os.path.join(temp_dir, f"map_{i:03d}.bin") | ||||||||||
| os.symlink(os.path.abspath(src_path), link_name) | ||||||||||
|
|
||||||||||
| # Print summary | ||||||||||
| print("\n" + "=" * 80) | ||||||||||
| print("MIXED MAP SOURCES") | ||||||||||
| for (path, weight), count in zip(sources, map_counts): | ||||||||||
| print(f" {path}: {count} maps ({weight * 100:.1f}%)") | ||||||||||
| print(f"Total: {len(selected_maps)} maps in {temp_dir}") | ||||||||||
| print("=" * 80 + "\n") | ||||||||||
|
|
||||||||||
| return temp_dir | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class Drive(pufferlib.PufferEnv): | ||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
|
|
@@ -42,6 +119,7 @@ def __init__( | |||||||||
| init_mode="create_all_valid", | ||||||||||
| control_mode="control_vehicles", | ||||||||||
| map_dir="resources/drive/binaries/training", | ||||||||||
| map_sources=None, | ||||||||||
| use_all_maps=False, | ||||||||||
| allow_map_resampling=True, | ||||||||||
| ): | ||||||||||
|
|
@@ -129,6 +207,9 @@ def __init__( | |||||||||
|
|
||||||||||
| self._action_type_flag = 0 if action_type == "discrete" else 1 | ||||||||||
|
|
||||||||||
| if map_sources is not None and map_sources != "": | ||||||||||
| map_dir = create_mixed_map_directory(map_sources, num_maps) | ||||||||||
| self.map_dir = map_dir # Update instance variable to point to temp directory | ||||||||||
|
Comment on lines
+211
to
+212
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Need to pass seed to
Suggested change
Prompt To Fix With AIThis is a comment left during a code review.
Path: pufferlib/ocean/drive/drive.py
Line: 212:213
Comment:
**logic:** Need to pass seed to `create_mixed_map_directory()` for reproducible map selection:
```suggestion
map_dir = create_mixed_map_directory(map_sources, num_maps, seed)
self.map_dir = map_dir # Update instance variable to point to temp directory
```
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||
| # Check if map directory contains any bin files | ||||||||||
| if not os.path.isdir(map_dir): | ||||||||||
| raise FileNotFoundError( | ||||||||||
|
|
@@ -155,6 +236,7 @@ def __init__( | |||||||||
| f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). " | ||||||||||
| f"Please reduce num_maps, add more maps to {map_dir}, or set allow_map_resampling=True." | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| self.max_controlled_agents = int(max_controlled_agents) | ||||||||||
|
|
||||||||||
| # Iterate through all maps to count total agents that can be initialized for each map | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Temp directory created but never cleaned up - will accumulate in
/tmpover multiple runs. Consider usingtempfile.TemporaryDirectory()context manager or add cleanup in theclose()methodPrompt To Fix With AI