Skip to content

feat: optimize worker rollout: batched multi-env inference with torch.compile#89

Open
smly wants to merge 5 commits intomainfrom
feat/profiling-and-update-model-config
Open

feat: optimize worker rollout: batched multi-env inference with torch.compile#89
smly wants to merge 5 commits intomainfrom
feat/profiling-and-update-model-config

Conversation

@smly
Copy link
Owner

@smly smly commented Feb 8, 2026

Summary

  • Add standalone rollout profiler (scripts/profile_rollout.py) that measures per-step timings (env_step, encode, to_device, forward, action_select, store)
  • Redesign worker to run N environments per worker with batched GPU inference, replacing sequential single-env rollout
  • Apply torch.compile() automatically for GPU workers to fuse kernels and reduce launch overhead
  • Reduce default model size from 24 blocks / 512ch (43.2M params) to 8 blocks / 128ch (2.0M params)
  • Add num_envs_per_worker config option (default: 16) to all config files

Profiling Results

Bottleneck analysis showed model.forward() consumed 90-99% of rollout time. Combined optimizations yield:

Configuration Params trans/s Speedup
CPU, 24b/512ch, single-env (original) 43.2M 26.2 1.0x
GPU, 24b/512ch, single-env 43.2M 162.0 6.2x
GPU + compile, 24b/512ch, single-env 43.2M 447.5 17.1x
GPU + fp16, 24b/512ch, single-env 43.2M 193.6 7.4x
GPU + compile + fp16, 24b/512ch, single-env 43.2M 408.9 15.6x
CPU, 8b/128ch, single-env 2.0M 736.8 28.1x
GPU, 8b/128ch, single-env 2.0M 582.4 22.2x
GPU, 24b/512ch, batch4 43.2M 733.1 28.0x
GPU, 24b/512ch, batch8 43.2M 1,355.4 51.7x
GPU, 24b/512ch, batch16 43.2M 2,307.9 88.1x
GPU + compile, 24b/512ch, batch4 43.2M 1,286.2 49.1x
GPU + compile, 24b/512ch, batch8 43.2M 2,303.5 87.9x
GPU + compile, 24b/512ch, batch16 43.2M 3,857.1 147.2x
GPU + compile, 8b/128ch, batch16 (new default) 2.0M 8,773.7 335x

…erformance

- Reduced sample limits in baseline and discard history configs to 500k.
- Adjusted model architecture parameters (num_blocks and conv_channels) for baseline, discard history, and discard history shanten configs.
- Added profiling script (profile_rollout.py) to measure performance bottlenecks during rollout.
- Enhanced train_online.py to support num_envs_per_worker argument for batched rollout.
- Updated OnlineConfig to include num_envs_per_worker with a default value of 16.
- Modified MahjongWorker to handle multiple environments and improved action selection logic.
@smly smly self-assigned this Feb 8, 2026
@smly smly added the enhancement New feature or request label Feb 8, 2026
@smly smly added this to the v0.3.0 milestone Feb 8, 2026
@smly smly requested a review from Copilot February 8, 2026 03:27
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the online RL rollout pipeline by switching Ray workers from single-env sequential stepping to multi-env lockstep stepping with batched GPU inference, and adds a standalone rollout profiler to measure per-stage timings.

Changes:

  • Redesign Ray worker rollout to run num_envs_per_worker environments in parallel with batched model inference (and auto-torch.compile on CUDA).
  • Wire the new num_envs_per_worker option through config/CLI and update trainer dispatch to call collect_episodes().
  • Add scripts/profile_rollout.py for standalone rollout profiling and update default model/config sizes in YAMLs.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
demos/ml_baseline/src/riichienv_ml/training/ray_actor.py Batched multi-env rollout, CUDA torch.compile, weight loading for compiled models
demos/ml_baseline/src/riichienv_ml/training/online_trainer.py Trainer updated to pass num_envs_per_worker, call collect_episodes(), and count episodes accordingly
demos/ml_baseline/src/riichienv_ml/config.py Adds num_envs_per_worker and changes default model size
demos/ml_baseline/scripts/train_online.py Adds CLI arg plumbing for --num_envs_per_worker
demos/ml_baseline/scripts/profile_rollout.py New standalone rollout profiler (single-env and batched modes)
demos/ml_baseline/configs/baseline.yml Updates model size defaults and adds num_envs_per_worker
demos/ml_baseline/configs/discard_history.yml Same as above for discard-history config
demos/ml_baseline/configs/discard_history_shanten.yml Same as above for discard-history+shanten config

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +270 to +271
probs = torch.softmax(logits, dim=-1)
sampled = torch.multinomial(probs, 1).squeeze(-1)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In batched mode, top_p is accepted as an argument but not used (sampling is softmax + multinomial over the full action distribution). If training uses top-p / nucleus sampling for boltzmann exploration, this profiler will not match actual rollout behavior. Consider applying the same sample_top_p logic per batch row when boltzmann sampling is selected.

Suggested change
probs = torch.softmax(logits, dim=-1)
sampled = torch.multinomial(probs, 1).squeeze(-1)
# If top-p / nucleus sampling is enabled, use the same helper as in training;
# otherwise fall back to full-distribution multinomial sampling.
if top_p is not None and top_p < 1.0:
sampled = sample_top_p(logits, top_p=top_p)
else:
probs = torch.softmax(logits, dim=-1)
sampled = torch.multinomial(probs, 1).squeeze(-1)

Copilot uses AI. Check for mistakes.
Comment on lines +161 to +162
probs = torch.softmax(logits, dim=-1)
sampled = torch.multinomial(probs, 1).squeeze(-1)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the batched boltzmann path, top_p is no longer applied (sampling uses softmax + multinomial over all actions). This changes exploration behavior vs the previous per-step sample_top_p(...) logic and makes the top_p config effectively unused here. Consider applying nucleus sampling per row (reusing sample_top_p) when boltzmann sampling is chosen.

Suggested change
probs = torch.softmax(logits, dim=-1)
sampled = torch.multinomial(probs, 1).squeeze(-1)
sampled = sample_top_p(logits, self.top_p)

Copilot uses AI. Check for mistakes.
Comment on lines +125 to +127
for ei in range(self.num_envs):
if active[ei] and self.envs[ei].done():
active[ei] = False
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When batch_items is empty, the loop breaks even if some envs are still not done(). That can produce truncated trajectories and then calls ranks() on envs that may not be in a terminal state. The previous single-env implementation explicitly treated this as a deadlock and aborted. Please restore explicit deadlock handling here (e.g., detect not done + no legal actions and return an empty transition list / reset the env / mark env inactive) so training data isn’t silently corrupted.

Suggested change
for ei in range(self.num_envs):
if active[ei] and self.envs[ei].done():
active[ei] = False
# No env produced any legal actions. This can happen if:
# (a) all active envs are done(), or
# (b) some envs are not done() but have no legal actions (deadlock).
deadlocked_envs = []
for ei in range(self.num_envs):
if not active[ei]:
continue
if self.envs[ei].done():
# Normal termination.
active[ei] = False
else:
# Deadlock: env is not done but has no legal actions.
active[ei] = False
deadlocked_envs.append(ei)
# Discard any partial trajectories from this env so they
# are not treated as complete episodes downstream.
for pid in all_buffers[ei]:
all_buffers[ei][pid].clear()
if deadlocked_envs:
# Warn but do not crash; trajectories from these envs were dropped.
print(
f"[MahjongWorker {self.worker_id}] Deadlock detected in "
f"envs {deadlocked_envs}; discarding partial trajectories."
)

Copilot uses AI. Check for mistakes.
Comment on lines 79 to 90
def _warmup_compile(self):
"""Run a dummy forward pass to trigger torch.compile JIT."""
if self._compiled_warmup:
return
mc = {}
# Infer in_channels from model
target = self.model._orig_mod if hasattr(self.model, "_orig_mod") else self.model
in_ch = target.backbone.conv_in.in_channels
dummy = torch.randn(1, in_ch, 34, device=self.device)
with torch.no_grad():
self.model(dummy)
self._compiled_warmup = True
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_warmup_compile() assumes the model has backbone.conv_in.in_channels. That’s true for QNetwork, but model_class is configurable, so this can crash for other architectures. Also mc = {} is unused. Consider inferring input shape more generically (e.g., from model_config["in_channels"] passed to the worker, or a small real encoded observation) and gating warmup to only run when the model is actually torch.compile-wrapped.

Copilot uses AI. Check for mistakes.
Comment on lines +100 to +101
if not self._compiled_warmup:
self._warmup_compile()
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collect_episodes() always runs _warmup_compile() on the first call, even for CPU workers where the model is not compiled. This adds an extra forward pass per worker startup that isn’t needed. Consider guarding the warmup call with if self.device.type == "cuda" (and/or hasattr(self.model, "_orig_mod")).

Copilot uses AI. Check for mistakes.
smly added 4 commits February 8, 2026 03:46
- Updated `train_online.py` to support both DQN and PPO algorithms.
- Introduced `PPOLearner` class for PPO-specific training logic.
- Added `PPOWorker` class for parallel trajectory collection with Ray.
- Created `ActorCriticNetwork` model for PPO.
- Enhanced configuration management in `config.py` to include PPO parameters.
- Integrated logging with Loguru for better monitoring.
- Updated dependencies in `pyproject.toml` to include Loguru.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant