feat: optimize worker rollout: batched multi-env inference with torch.compile#89
feat: optimize worker rollout: batched multi-env inference with torch.compile#89
torch.compile#89Conversation
…erformance - Reduced sample limits in baseline and discard history configs to 500k. - Adjusted model architecture parameters (num_blocks and conv_channels) for baseline, discard history, and discard history shanten configs. - Added profiling script (profile_rollout.py) to measure performance bottlenecks during rollout. - Enhanced train_online.py to support num_envs_per_worker argument for batched rollout. - Updated OnlineConfig to include num_envs_per_worker with a default value of 16. - Modified MahjongWorker to handle multiple environments and improved action selection logic.
There was a problem hiding this comment.
Pull request overview
This PR optimizes the online RL rollout pipeline by switching Ray workers from single-env sequential stepping to multi-env lockstep stepping with batched GPU inference, and adds a standalone rollout profiler to measure per-stage timings.
Changes:
- Redesign Ray worker rollout to run
num_envs_per_workerenvironments in parallel with batched model inference (and auto-torch.compileon CUDA). - Wire the new
num_envs_per_workeroption through config/CLI and update trainer dispatch to callcollect_episodes(). - Add
scripts/profile_rollout.pyfor standalone rollout profiling and update default model/config sizes in YAMLs.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
demos/ml_baseline/src/riichienv_ml/training/ray_actor.py |
Batched multi-env rollout, CUDA torch.compile, weight loading for compiled models |
demos/ml_baseline/src/riichienv_ml/training/online_trainer.py |
Trainer updated to pass num_envs_per_worker, call collect_episodes(), and count episodes accordingly |
demos/ml_baseline/src/riichienv_ml/config.py |
Adds num_envs_per_worker and changes default model size |
demos/ml_baseline/scripts/train_online.py |
Adds CLI arg plumbing for --num_envs_per_worker |
demos/ml_baseline/scripts/profile_rollout.py |
New standalone rollout profiler (single-env and batched modes) |
demos/ml_baseline/configs/baseline.yml |
Updates model size defaults and adds num_envs_per_worker |
demos/ml_baseline/configs/discard_history.yml |
Same as above for discard-history config |
demos/ml_baseline/configs/discard_history_shanten.yml |
Same as above for discard-history+shanten config |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| probs = torch.softmax(logits, dim=-1) | ||
| sampled = torch.multinomial(probs, 1).squeeze(-1) |
There was a problem hiding this comment.
In batched mode, top_p is accepted as an argument but not used (sampling is softmax + multinomial over the full action distribution). If training uses top-p / nucleus sampling for boltzmann exploration, this profiler will not match actual rollout behavior. Consider applying the same sample_top_p logic per batch row when boltzmann sampling is selected.
| probs = torch.softmax(logits, dim=-1) | |
| sampled = torch.multinomial(probs, 1).squeeze(-1) | |
| # If top-p / nucleus sampling is enabled, use the same helper as in training; | |
| # otherwise fall back to full-distribution multinomial sampling. | |
| if top_p is not None and top_p < 1.0: | |
| sampled = sample_top_p(logits, top_p=top_p) | |
| else: | |
| probs = torch.softmax(logits, dim=-1) | |
| sampled = torch.multinomial(probs, 1).squeeze(-1) |
| probs = torch.softmax(logits, dim=-1) | ||
| sampled = torch.multinomial(probs, 1).squeeze(-1) |
There was a problem hiding this comment.
In the batched boltzmann path, top_p is no longer applied (sampling uses softmax + multinomial over all actions). This changes exploration behavior vs the previous per-step sample_top_p(...) logic and makes the top_p config effectively unused here. Consider applying nucleus sampling per row (reusing sample_top_p) when boltzmann sampling is chosen.
| probs = torch.softmax(logits, dim=-1) | |
| sampled = torch.multinomial(probs, 1).squeeze(-1) | |
| sampled = sample_top_p(logits, self.top_p) |
| for ei in range(self.num_envs): | ||
| if active[ei] and self.envs[ei].done(): | ||
| active[ei] = False |
There was a problem hiding this comment.
When batch_items is empty, the loop breaks even if some envs are still not done(). That can produce truncated trajectories and then calls ranks() on envs that may not be in a terminal state. The previous single-env implementation explicitly treated this as a deadlock and aborted. Please restore explicit deadlock handling here (e.g., detect not done + no legal actions and return an empty transition list / reset the env / mark env inactive) so training data isn’t silently corrupted.
| for ei in range(self.num_envs): | |
| if active[ei] and self.envs[ei].done(): | |
| active[ei] = False | |
| # No env produced any legal actions. This can happen if: | |
| # (a) all active envs are done(), or | |
| # (b) some envs are not done() but have no legal actions (deadlock). | |
| deadlocked_envs = [] | |
| for ei in range(self.num_envs): | |
| if not active[ei]: | |
| continue | |
| if self.envs[ei].done(): | |
| # Normal termination. | |
| active[ei] = False | |
| else: | |
| # Deadlock: env is not done but has no legal actions. | |
| active[ei] = False | |
| deadlocked_envs.append(ei) | |
| # Discard any partial trajectories from this env so they | |
| # are not treated as complete episodes downstream. | |
| for pid in all_buffers[ei]: | |
| all_buffers[ei][pid].clear() | |
| if deadlocked_envs: | |
| # Warn but do not crash; trajectories from these envs were dropped. | |
| print( | |
| f"[MahjongWorker {self.worker_id}] Deadlock detected in " | |
| f"envs {deadlocked_envs}; discarding partial trajectories." | |
| ) |
| def _warmup_compile(self): | ||
| """Run a dummy forward pass to trigger torch.compile JIT.""" | ||
| if self._compiled_warmup: | ||
| return | ||
| mc = {} | ||
| # Infer in_channels from model | ||
| target = self.model._orig_mod if hasattr(self.model, "_orig_mod") else self.model | ||
| in_ch = target.backbone.conv_in.in_channels | ||
| dummy = torch.randn(1, in_ch, 34, device=self.device) | ||
| with torch.no_grad(): | ||
| self.model(dummy) | ||
| self._compiled_warmup = True |
There was a problem hiding this comment.
_warmup_compile() assumes the model has backbone.conv_in.in_channels. That’s true for QNetwork, but model_class is configurable, so this can crash for other architectures. Also mc = {} is unused. Consider inferring input shape more generically (e.g., from model_config["in_channels"] passed to the worker, or a small real encoded observation) and gating warmup to only run when the model is actually torch.compile-wrapped.
| if not self._compiled_warmup: | ||
| self._warmup_compile() |
There was a problem hiding this comment.
collect_episodes() always runs _warmup_compile() on the first call, even for CPU workers where the model is not compiled. This adds an extra forward pass per worker startup that isn’t needed. Consider guarding the warmup call with if self.device.type == "cuda" (and/or hasattr(self.model, "_orig_mod")).
…batched_episodes functions
- Updated `train_online.py` to support both DQN and PPO algorithms. - Introduced `PPOLearner` class for PPO-specific training logic. - Added `PPOWorker` class for parallel trajectory collection with Ray. - Created `ActorCriticNetwork` model for PPO. - Enhanced configuration management in `config.py` to include PPO parameters. - Integrated logging with Loguru for better monitoring. - Updated dependencies in `pyproject.toml` to include Loguru.
Summary
scripts/profile_rollout.py) that measures per-step timings (env_step, encode, to_device, forward, action_select, store)torch.compile()automatically for GPU workers to fuse kernels and reduce launch overheadnum_envs_per_workerconfig option (default: 16) to all config filesProfiling Results
Bottleneck analysis showed
model.forward()consumed 90-99% of rollout time. Combined optimizations yield: