diff --git a/.gitignore b/.gitignore
index 476aab77..1469bc67 100644
--- a/.gitignore
+++ b/.gitignore
@@ -120,6 +120,7 @@ celerybeat.pid
# SageMath parsed files
*.sage.py
+node_modules/
# Environments
.env
@@ -185,4 +186,4 @@ results
results/
data/
cache/
-dump.rdb
\ No newline at end of file
+dump.rdb
diff --git a/conf/base.yaml b/conf/base.yaml
index 5043aa56..80429dfb 100644
--- a/conf/base.yaml
+++ b/conf/base.yaml
@@ -5,6 +5,7 @@ defaults:
- _self_
seed: 42
+use_ray: false
finetune:
seed: ${..seed}
@@ -18,14 +19,18 @@ actor:
result_queue_size: 64
throughput_window_size: 50
shared_memory_entry_size: 10000000
+ async_batch_size: 1 # if ==1, rollout function will be called as synchronous
+ task_submission_delay_sec: 0.5
+ collect_logprobs: true
+
environment: null
preprocess:
input: actor
output: training_data
n_workers: 8
- chunk_n_groups: 2
+ chunk_n_groups: 8
# queue for loaded raw groups
- raw_queue_size: 8
+ raw_queue_size: 128
# queue for processed chunks of multiple groups
input_queue_size: 32
# queue for ready chunks for multiple groups
@@ -67,7 +72,7 @@ vllm_config:
tensor-parallel-size: 1
pipeline-parallel-size: 1
generation-config: vllm
- max_model_len: 10000
+ max_model_len: 16000
world:
replicas: 1
@@ -81,6 +86,8 @@ world:
actor_group_port: 9000
environment_start_port: 7777
+# Remote vs embedded environment execution strategy
+ environment_mode: embedded
# this will be autocreated based on the config
jobs: []
diff --git a/conf/miniwob.yaml b/conf/miniwob.yaml
index a8dc3868..7c051a93 100644
--- a/conf/miniwob.yaml
+++ b/conf/miniwob.yaml
@@ -9,12 +9,11 @@ world:
preprocessor_fraction: 0
finetune_fraction: 6
-# debug:
-# mode: actor
-save_tapes: False
+save_tapes: false
output_dir: results/miniwob/${now:%Y-%m-%d}/${now:%H-%M-%S}
model_path: meta-llama/Llama-3.1-8B-Instruct
+use_ray: true
finetune:
seq_length: 16384 # input + output tokens
@@ -25,6 +24,7 @@ finetune:
eval_every_n_versions: 10240 # 1024 effective bs * 10 "optim steps"
llm:
+ use_cache: false
parameters:
max_tokens: 4096 # output tokens
temperature: 1.0
@@ -36,13 +36,20 @@ test_llm:
top_k: 50
vllm_config:
+ use_v1: false
vllm_kwargs:
- max_model_len: 16384 # input + output tokens
+ max-num-seqs: 256
+ max-num-batched-tokens: 32000
+ max_model_len: 16384
+ gpu-memory-utilization: 0.9
actor:
rollout_policy: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout
+ llm_max_rollouts: 256
+ problem_queue_size: 256
+ async_batch_size: 1
+ rollout_workers: 32
shared_memory_entry_size: 100000000
- llm_max_rollouts: 32
preprocess:
n_workers: 32 # Increase from 8
@@ -144,12 +151,7 @@ agent:
# ENVIRONMENT CONFIGURATION
start_attempts: 3 # number of attempts to start each task
environment:
- _target_: pipelinerl.domains.miniwob.environment_server.WebEnvironmentServer
- miniwob_url: ???
- n_envs: 32
- host: "0.0.0.0"
- env_call_timeout: 60 # timeout for each environment call (e.g. start_task, act, etc.)
- web_env_target: examples.rl_webagent.environment.WebEnvironment
+ _target_: examples.rl_webagent.environment.WebEnvironment
exp_path: null
headless: true
observation_format: html
@@ -162,4 +164,4 @@ dataset_loader_params:
train_dataset_names:
- train
test_dataset_names:
- - test
+ - test
\ No newline at end of file
diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py
index 6d3317af..100aa529 100644
--- a/pipelinerl/actor.py
+++ b/pipelinerl/actor.py
@@ -5,17 +5,20 @@
import os
import queue
import random
+import sys
import time
from collections import defaultdict
from multiprocessing.managers import SharedMemoryManager
from pathlib import Path
from queue import Empty
-from typing import Dict, List
+from typing import Any, Awaitable, Callable, Dict, List
import aiohttp
import hydra
+import numpy as np
+import ray
import uvloop
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, Field
import wandb
@@ -32,8 +35,7 @@
set_streams_backend,
write_to_streams,
)
-
-from .utils import (
+from pipelinerl.utils import (
always_or_never_success_stats,
calculate_stats,
setup_logging,
@@ -57,8 +59,9 @@ class SlidingWindowData(BaseModel):
class SlidingWindowAggregator:
- def __init__(self, window_size: int):
+ def __init__(self, window_size: int, min_samples: int = 5):
self.window_size = window_size
+ self.min_samples = min_samples
self.data = SlidingWindowData()
def update(self, prompt_tokens: list[int], output_tokens: list[int]):
@@ -71,8 +74,11 @@ def update(self, prompt_tokens: list[int], output_tokens: list[int]):
self.data.timestamps.pop(0)
def get_stats(self):
- if len(self.data.prompt_tokens_window) < self.window_size:
+ if len(self.data.prompt_tokens_window) < self.min_samples:
+ logger.warning("Not enough data to compute sliding stats")
return None
+ elif len(self.data.prompt_tokens_window) < self.window_size:
+ logger.warning(f"Compute sliding stats over just {len(self.data.prompt_tokens_window)} samples")
# 1. How many samples do we produce per second?
# 2. How many output tokens do we produce per second?
@@ -103,11 +109,14 @@ def get_stats(self):
}
-
def make_stats_dict() -> dict:
return defaultdict(lambda: defaultdict(list))
+def get_number_of_tokens_in_result(result: RolloutResult) -> int:
+ return sum(training_text.prompt_tokens + training_text.output_tokens for training_text in result.training_texts)
+
+
async def schedule_rollouts(
cfg: DictConfig,
attempts: int,
@@ -133,19 +142,17 @@ async def schedule_rollouts(
active_rollouts = [0] * len(llms)
started_rollouts = 0
finished_rollouts = 0
+ token_count = 0
# Track rollouts per problem group
group_rollouts = {}
rollout_policy = hydra.utils.get_method(cfg.actor.rollout_policy)
- logger.info(f"Use rollout policy: {rollout_policy}")
+ logger.info(f"Use rollout policy: {rollout_policy.__name__}")
final_steps = calculate_train_steps(cfg.finetune, cfg.finetune.interrupt_train_steps)
samples_target = final_steps * cfg.finetune.train_batch_size * cfg.finetune.gradient_accumulation_passes
def is_trainer_finished() -> bool:
- return (
- trainer_state.samples_processed is not None
- and trainer_state.samples_processed >= samples_target
- )
+ return trainer_state.samples_processed is not None and trainer_state.samples_processed >= samples_target
def handle_rollout_exception(exc: Exception):
if isinstance(exc, aiohttp.ClientError) and is_trainer_finished():
@@ -168,13 +175,16 @@ async def rollout_and_maybe_produce_result(
llm_index: int,
session: aiohttp.ClientSession,
):
- nonlocal started_rollouts, finished_rollouts
+ nonlocal started_rollouts, finished_rollouts, token_count
try:
llm = llms[llm_index]
model_version = trainer_state.propagated_weight_version
assert model_version is not None
- rollout_result = await rollout_policy(cfg, llm, problem, session)
+ logger.info(f"Starting rollout policy for problem {problem['id']}")
+ rollout_result: RolloutResult = await rollout_policy(cfg, llm, problem, session)
+ logger.info(f"Finished rollout policy for problem {problem['id']}")
rollout_result.model_version = model_version
+ token_count += get_number_of_tokens_in_result(rollout_result)
# Make a group id that will be different from groups made by another rollout maker
full_group_id = f"{scheduler_name}_{group_id}"
rollout_result.group_id = full_group_id
@@ -204,18 +214,22 @@ async def rollout_and_maybe_produce_result(
logger.info("Starting rollout scheduler")
connector = aiohttp.TCPConnector(limit=50000, limit_per_host=50000, keepalive_timeout=1.0)
timeout = aiohttp.ClientTimeout(total=3600.0, connect=3600.0, sock_read=3600.0)
+ old_finished_rollouts = 0
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
while True:
if is_trainer_finished():
logger.info(f"{scheduler_name}: trainer signalled completion; stopping rollout scheduler")
break
if time.time() - last_logged > 10.0 and sum(active_rollouts):
+ if finished_rollouts > old_finished_rollouts:
+ old_finished_rollouts = finished_rollouts
logger.info(
f"{scheduler_name}: "
f"rollouts in progress: {sum(active_rollouts)}, "
f"groups in progress: {len(group_rollouts)}, "
f"rollouts started so far: {started_rollouts}, "
f"rollouts finished so far: {finished_rollouts}, "
+ f"total tokens produced so far: {token_count}, "
f"groups started so far: {group_id}, "
f"max group size in bytes: {result_queue.max_actual_entry_size()}, "
)
@@ -238,7 +252,6 @@ async def rollout_and_maybe_produce_result(
await asyncio.sleep(0.01)
continue
active_rollouts[next_llm] += 1
- started_rollouts += 1
assert problem is not None
loop.create_task(
rollout_and_maybe_produce_result(
@@ -249,6 +262,7 @@ async def rollout_and_maybe_produce_result(
session=session,
)
)
+ started_rollouts += 1
group_rollout_index += 1
logger.info("Rollout scheduler finished")
@@ -302,40 +316,52 @@ def __init__(
self.sliding_aggregator = SlidingWindowAggregator(window_size=cfg.actor.throughput_window_size)
self.llms = llms
self.loop_start_time = -1
- self.cfg = cfg
+ self.cfg: DictConfig = cfg
self.is_training = is_training
self.is_scheduling_paused = False
self.debug_mode = bool(cfg.debug.mode)
+ self.cfg: DictConfig = cfg
- # Determine the number of processes to use
- num_processes = min(self.cfg.actor.rollout_workers, len(self.llms))
- attempts = self.cfg.attempts if is_training else 1
-
- # Divide LLMs approximately equally across processes
- llm_groups = [[] for _ in range(num_processes)]
- for i, llm in enumerate(self.llms):
- llm_groups[i % num_processes].append((i, llm))
+ self.smm: SharedMemoryManager | None = None
+ self.problem_queue: SharedMemoryQueue | None = None
+ self.result_queue: SharedMemoryQueue | None = None
+ self.rollout_errors = 0
+ logger.info(f"Initialized {'train' if self.is_training else 'test'} actor loop")
+ def start_backend(self):
self.smm = SharedMemoryManager()
self.smm.start()
-
# Use SharedMemoryQueue instead of separate problem_queue, result_queue, and io_buffer
- self.problem_queue = SharedMemoryQueue(self.smm, self.cfg.actor.problem_queue_size, cfg.actor.shared_memory_entry_size)
- self.result_queue = SharedMemoryQueue(self.smm, self.cfg.actor.result_queue_size, cfg.actor.shared_memory_entry_size)
-
- logger.info(f"Initialized {'train' if self.is_training else 'test'} actor loop")
- logger.info(f"Problem queue size: {self.problem_queue.max_size}, result queue size: {self.result_queue.max_size}")
+ self.problem_queue = SharedMemoryQueue(
+ self.smm, self.cfg.actor.problem_queue_size, self.cfg.actor.shared_memory_entry_size
+ )
+ self.result_queue = SharedMemoryQueue(
+ self.smm, self.cfg.actor.result_queue_size, self.cfg.actor.shared_memory_entry_size
+ )
+
+ logger.info(
+ f"Problem queue size: {self.problem_queue.max_size}, result queue size: {self.result_queue.max_size}"
+ )
logger.info(f"Result queue buffer size: {self.result_queue.get_memory_size() / 2**30} Gb")
# Create and start multiple rollout processes
+ attempts = self.cfg.attempts if self.is_training else 1
+ # Determine the number of processes to use
+ num_processes = min(self.cfg.actor.rollout_workers, len(self.llms))
+
+ # Divide LLMs approximately equally across processes
+ llm_groups = [[] for _ in range(num_processes)]
+ for i, llm in enumerate(self.llms):
+ llm_groups[i % num_processes].append((i, llm))
+
self.rollout_processes = []
for llm_group in llm_groups:
assert llm_group
llm_idxs = [llm[0] for llm in llm_group]
llms = [llm[1] for llm in llm_group]
scheduler_name = (
- f"{'train' if is_training else 'test'} scheduler for llms {','.join([str(i) for i in llm_idxs])}"
+ f"{'train' if self.is_training else 'test'} scheduler for llms {','.join([str(i) for i in llm_idxs])}"
)
process = mp.Process(
target=rollout_maker_entrypoint,
@@ -349,15 +375,15 @@ def init_stats(self):
self.latency_list = []
self.model_versions_list = []
self.sliding_stats = defaultdict(list)
-
+
def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]:
metrics = {}
-
- metrics['overflow'] = all([not training_text.finished for training_text in result.training_texts ])
- metrics['num_turns'] = len(result.training_texts)
- metrics['prompt_tokens'] = [training_text.prompt_tokens for training_text in result.training_texts]
- metrics['output_tokens'] = [training_text.output_tokens for training_text in result.training_texts]
-
+
+ metrics["overflow"] = all([not training_text.finished for training_text in result.training_texts])
+ metrics["num_turns"] = len(result.training_texts)
+ metrics["prompt_tokens"] = [training_text.prompt_tokens for training_text in result.training_texts]
+ metrics["output_tokens"] = [training_text.output_tokens for training_text in result.training_texts]
+
return metrics
def update_stats(self, rollout_results: List[RolloutResult]):
@@ -368,8 +394,10 @@ def update_stats(self, rollout_results: List[RolloutResult]):
group_id = result.group_id
self.latency_list.append(result.latency)
self.model_versions_list.append(result.model_version)
- domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result)
+ domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result)
all_metrics = result.metrics.model_dump() | domain_agnostic_metrics
+ all_metrics["used_python"] = int(all_metrics.get("used_python", False))
+ all_metrics["used_math_answer"] = int(all_metrics.get("used_math_answer", False))
for k, v in all_metrics.items():
if isinstance(v, list):
self.stats[k][dataset_name][group_id] += v
@@ -377,16 +405,18 @@ def update_stats(self, rollout_results: List[RolloutResult]):
self.stats[k][dataset_name][group_id].append(v)
else:
raise ValueError(f"Unsupported metric type: {type(v)} for key {k}")
-
- prompt_length_tokens = [training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts]
- output_length_tokens = [training_text.output_tokens for result in rollout_results for training_text in result.training_texts]
+
+ prompt_length_tokens = [
+ training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts
+ ]
+ output_length_tokens = [
+ training_text.output_tokens for result in rollout_results for training_text in result.training_texts
+ ]
self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens)
sliding_window_stats = self.sliding_aggregator.get_stats()
if sliding_window_stats is not None:
for k, v in sliding_window_stats.items():
self.sliding_stats[k].append(v)
-
-
def run(self, dataset: list[tuple[str, dict]]):
loop_start_time = time.time()
@@ -438,6 +468,7 @@ def run(self, dataset: list[tuple[str, dict]]):
can_submit_before_update = math.inf
logger.info(f"Start {'train' if self.is_training else 'test'} actor loop")
+ rollouts_last_minute = []
with (
write_to_streams(self.data_stream, "a") as data_stream_writer,
write_to_streams(self.stats_stream, "a") as stats_writer,
@@ -447,8 +478,13 @@ def run(self, dataset: list[tuple[str, dict]]):
yield
final_steps = calculate_train_steps(self.cfg.finetune, self.cfg.finetune.interrupt_train_steps)
- samples_target = final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes
- if self.trainer_state.samples_processed is not None and self.trainer_state.samples_processed >= samples_target:
+ samples_target = (
+ final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes
+ )
+ if (
+ self.trainer_state.samples_processed is not None
+ and self.trainer_state.samples_processed >= samples_target
+ ):
logger.info("Trainer signalled completion; stopping actor loop")
break
@@ -464,23 +500,22 @@ def run(self, dataset: list[tuple[str, dict]]):
if not self.is_scheduling_paused:
while True:
blocked_by_lag = submitted_groups == can_submit_before_update and self.is_training
- if not blocked_by_lag and not self.problem_queue.full():
+ if not blocked_by_lag and self.have_capacity():
try:
try:
problem = next(problem_iter)
- self.problem_queue.put(problem, block=False)
+ self.submit_problem(problem)
submitted_groups += 1
- except queue.Full:
+ except queue.Full:
assert False, "Problem queue was not full just a moment ago, but now it is full"
except StopIteration:
break
- else:
- break
+ break
# Second, try return a result
try:
# Directly get the result from the SharedMemoryQueue
- rollout_results = self.result_queue.get(block=False)
+ rollout_results = self.get_new_results()
except queue.Empty:
continue
@@ -489,16 +524,17 @@ def run(self, dataset: list[tuple[str, dict]]):
raise rollout_results
assert isinstance(rollout_results, list)
+ if len(rollout_results) == 0:
+ continue
assert isinstance(rollout_results[0], RolloutResult)
- assert len(rollout_results) == attempts, (
- f"Expected {attempts} rollouts, got {len(rollout_results)}"
- )
+ assert len(rollout_results) == attempts, f"Expected {attempts} rollouts, got {len(rollout_results)}"
group_samples = sum(len(r.training_texts) for r in rollout_results)
published_samples += group_samples
- samples_in_queue = self.result_queue.qsize() * attempts
+ samples_in_queue = self.results_ready_to_publish()
all_text_dumps = []
for r in rollout_results:
+ rollouts_last_minute.append(time.perf_counter())
for text in r.training_texts:
all_text_dumps.append(text.model_dump())
data_stream_writer.write(all_text_dumps)
@@ -512,37 +548,43 @@ def run(self, dataset: list[tuple[str, dict]]):
self.update_stats(rollout_results=rollout_results)
finished_groups += 1
+ logger.info(
+ f"Finished {'train' if self.is_training else 'test'} groups {finished_groups} out of {expected_rollouts}"
+ )
time_to_publish_train_stats = (
- self.is_training
- and trainer_version_to_publish is not None
- ) or self.debug_mode
+ self.is_training and trainer_version_to_publish is not None
+ ) or self.debug_mode
time_to_publish_test_stats = finished_groups == expected_rollouts
+ # leave only the rollouts that are in the last minute
+ rollouts_last_minute = [t for t in rollouts_last_minute if t > time.perf_counter() - 60]
+
# Publish stats at every new model version or if all tapes are finished
if time_to_publish_train_stats or time_to_publish_test_stats:
if self.is_training:
loop_stats = {
"published_samples": published_samples,
- "problem_queue_size": self.problem_queue.qsize(),
- "result_queue_size": self.result_queue.qsize(),
+ "problem_queue_size": self.problem_queue_size(),
+ "result_queue_size": self.result_queue_size(),
"finished_groups": finished_groups,
- "trainer_model_version": trainer_version_to_publish,
+ "trainer_model_version": trainer_version_to_publish,
"time_since_start": time.time() - loop_start_time,
+ "groups_in_progress": in_progress,
+ "rollout_errors": self.rollout_errors,
+ "rollouts_per_min": len(rollouts_last_minute),
}
trainer_version_to_publish = None
else:
- loop_stats = {
- "trainer_model_version": last_trainer_version
- }
+ loop_stats = {"trainer_model_version": last_trainer_version}
self.publish_stats(
stats_writer=stats_writer,
loop_stats=loop_stats,
)
-
- if finished_groups == expected_rollouts:
+ if expected_rollouts >= 0 and finished_groups + self.rollout_errors >= expected_rollouts:
logger.info(f"Finished {expected_rollouts} rollouts, stopping actor loop")
+ self.stop_tasks()
break
def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict):
@@ -558,31 +600,269 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict):
stats[f"{dataset_name}/{metric_name}_{agg}"] = sub_stats
stats |= (
- {
- f"{split_name}{k}": v
- for k, v in always_or_never_success_stats(self.stats["success"]).items()
- }
- | {
- f"{split_name}latency_" + k: v
- for k, v in calculate_stats(self.latency_list).items()
- }
- | {
- f"{split_name}model_version_" + k: v
- for k, v in calculate_stats(self.model_versions_list).items()
- }
+ {f"{split_name}{k}": v for k, v in always_or_never_success_stats(self.stats["success"]).items()}
+ | {f"{split_name}latency_" + k: v for k, v in calculate_stats(self.latency_list).items()}
+ | {f"{split_name}model_version_" + k: v for k, v in calculate_stats(self.model_versions_list).items()}
)
stats |= loop_stats
for k, v in self.sliding_stats.items():
stats[k] = sum(v) / len(v) if v else 0
+
+ rename_suffixes = {
+ "num_python_calls_mean": "python_calls_mean",
+ "used_python_mean": "python_usage_rate",
+ "num_math_answer_calls_mean": "math_answer_calls_mean",
+ "used_math_answer_mean": "math_answer_usage_rate",
+ }
+
+ for key in list(stats.keys()):
+ for old_suffix, new_suffix in rename_suffixes.items():
+ if key.endswith(old_suffix):
+ prefix = key[: -len(old_suffix)]
+ stats[f"{prefix}{new_suffix}"] = stats[key]
+ break
+
if self.cfg.wandb.use_wandb:
wandb.log({f"actor/{k}": v for k, v in stats.items()})
stats_writer.write(stats)
self.init_stats() # Reset stats for the next iteration
+ def have_capacity(self) -> bool:
+ return not self.problem_queue.full()
+
+ def submit_problem(self, problem: dict):
+ self.problem_queue.put(problem, block=False)
+
+ def stop_tasks(self):
+ pass
+
+ def get_new_results(self) -> list[RolloutResult]:
+ return self.result_queue.get(block=False)
+
+ def results_ready_to_publish(self) -> int:
+ return self.result_queue_size() * self.cfg.attempts
+
+ def problem_queue_size(self) -> int:
+ return self.problem_queue.qsize()
+
+ def result_queue_size(self) -> int:
+ return self.result_queue.qsize()
+
+
+class ActorLoopRay(ActorLoop):
+ """
+ Loop that runs the ray tasks for n_jobs to perform rollouts in parallel
+ """
+
+ def __init__(self, cfg: DictConfig, *args, **kwargs):
+ assert cfg.attempts % cfg.actor.async_batch_size == 0, (
+ f"attempts {cfg.attempts} must be divisible by actor.async_batch_size {cfg.actor.async_batch_size}"
+ )
+ super().__init__(cfg, *args, **kwargs)
+ self.cfg_dict: dict = OmegaConf.to_container(self.cfg, resolve=True) # type: ignore
+ self.unfinished_tasks = []
+ self.llms_by_url = {llm.get_base_url(): llm for llm in self.llms}
+ self.llms_utilization = {llm.get_base_url(): 0 for llm in self.llms}
+ self.scheduler_name = f"{'train' if self.is_training else 'test'} ray scheduler"
+ self.problem_id = 0
+ self.attempts = self.cfg.attempts if self.is_training else 1
+ self.unfinished_groups = defaultdict(list) # up to `attempts` rollout results for each problem
+ self.finished_groups = []
+ self.token_count = 0
+ self.finished_rollouts_count = 0
+ self.log_dir = Path(self.cfg.output_dir) / "actor" / "ray"
+
+ def start_backend(self):
+ logger.info(f"Initializing Ray with {self.cfg.actor.rollout_workers} workers..")
+ self.log_dir.mkdir(parents=True, exist_ok=True)
+ ray_context = ray.init(
+ num_cpus=self.cfg.actor.rollout_workers,
+ dashboard_host="0.0.0.0",
+ include_dashboard=True,
+ log_to_driver=True,
+ ignore_reinit_error=True,
+ )
+ logger.info(f"Ray initialized, dashboard at {ray_context.dashboard_url}")
+
+ assert self.trainer_state.propagated_weight_version is not None
+ rollout_policy = hydra.utils.get_method(self.cfg.actor.rollout_policy)
+
+ def rollout_wrapper(cfg_dict: dict, llm: TrainableLLM, problems: list[dict]) -> list[RolloutResult]:
+ assert len(problems) == 1, "Sync mode should only be used with 1 problem at a time"
+ cfg = OmegaConf.create(cfg_dict)
+ problem = problems[0]
+ group_id = problem["_group_id"]
+ attempt = problem["_attempt"]
+ log_file = Path(cfg.output_dir) / "actor" / "ray" / f"{group_id}.log"
+ sys.stdout = open(log_file, "a", buffering=1)
+ sys.stderr = sys.stdout
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)
+ logger.info(f"Running sync rollout for task {group_id}_{attempt}")
+ start_ts = time.perf_counter()
+ rollout_result: RolloutResult = rollout_policy(cfg, llm, problem)
+ rollout_result.latency = time.perf_counter() - start_ts
+ rollout_result.llm_url = llm.get_base_url()
+ rollout_result.group_id = group_id
+ rollout_result.attempt = attempt
+ logger.info(f"Task {group_id}_{attempt} finished in {rollout_result.latency:.2f} seconds")
+ return [rollout_result]
+
+ async def run_rollouts_with_session(
+ cfg: DictConfig, llm: TrainableLLM, problems: list[dict]
+ ) -> list[RolloutResult]:
+ connector = aiohttp.TCPConnector(
+ limit=cfg.actor.async_batch_size, limit_per_host=cfg.actor.async_batch_size, keepalive_timeout=1.0
+ )
+ timeout = aiohttp.ClientTimeout(total=3600.0, connect=3600.0, sock_read=3600.0)
+ async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
+ # Run all rollouts in parallel using asyncio.gather
+ async def run_rollout(problem) -> RolloutResult:
+ group_id = problem["_group_id"]
+ attempt = problem["_attempt"]
+ logger.info(f"Running async rollout loop for task {group_id}_{attempt}")
+ start_ts = time.perf_counter()
+ rollout_result = await rollout_policy(cfg, llm, problem, session)
+ rollout_result.latency = time.perf_counter() - start_ts
+ rollout_result.llm_url = llm.get_base_url()
+ rollout_result.group_id = group_id
+ rollout_result.attempt = attempt
+ logger.info(f"Task {group_id}_{attempt} finished in {rollout_result.latency:.2f} seconds")
+ return rollout_result
+
+ tasks = [run_rollout(problem) for problem in problems]
+ rollout_results = await asyncio.gather(*tasks)
+ return rollout_results
+
+ def rollout_async_batch_wrapper(cfg_dict: dict, llm: TrainableLLM, problems: list[dict]) -> list[RolloutResult]:
+ cfg = OmegaConf.create(cfg_dict)
+ group_id = problems[0]["_group_id"]
+ log_file = Path(cfg.output_dir) / "actor" / "ray" / f"{group_id}_async_{len(problems)}.log"
+ sys.stdout = open(log_file, "a", buffering=1)
+ sys.stderr = sys.stdout
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)
+ logger.info(f"Running async rollouts for group {group_id} with {len(problems)} problems")
+ rollout_results = asyncio.run(run_rollouts_with_session(cfg, llm, problems))
+ return rollout_results
+
+ if self.cfg.actor.async_batch_size > 1:
+ logger.info("Using async mode")
+ self.ray_remote = ray.remote()(rollout_async_batch_wrapper)
+ else:
+ logger.info("Using sync mode")
+ self.ray_remote = ray.remote()(rollout_wrapper)
+ self.start_time = time.time()
+
+ def have_capacity(self) -> bool:
+ have_capacity = len(self.unfinished_tasks) < self.cfg.actor.problem_queue_size
+ have_llm_capacity = any(
+ self.llms_utilization[llm_url] < (self.cfg.actor.llm_max_rollouts - self.attempts)
+ for llm_url in self.llms_utilization
+ )
+ have_capacity = have_capacity and have_llm_capacity
+ return have_capacity
+
+ def submit_problem(self, problem: dict):
+ # Make a list of cfg.attempts identical problems (deepcopies can be used if necessary)
+ problems = []
+ for attempt in range(self.attempts):
+ p = problem.copy()
+ p["_group_id"] = f"{self.scheduler_name}_{self.problem_id}"
+ p["_attempt"] = attempt
+ problems.append(p)
+
+ # Split problems into batches of up to cfg.async_batch_size
+ batches = [
+ problems[i : i + self.cfg.actor.async_batch_size]
+ for i in range(0, len(problems), self.cfg.actor.async_batch_size)
+ ]
+ for batch_idx, problem_batch in enumerate(batches):
+ llm_url, task_count = min(self.llms_utilization.items(), key=lambda x: x[1])
+ logger.info(
+ f"Submitting problem {self.problem_id} batch {batch_idx + 1}/{len(batches)} to the least busy LLM {llm_url} with {task_count} tasks"
+ )
+ llm = self.llms_by_url[llm_url]
+ task_ref = self.ray_remote.remote(self.cfg_dict, llm, problem_batch)
+ time.sleep(self.cfg.actor.task_submission_delay_sec)
+ self.llms_utilization[llm_url] += len(problem_batch)
+ self.unfinished_tasks.append(task_ref)
+ self.problem_id += 1
+
+ def stop_tasks(self):
+ ray.shutdown()
+
+ def receive_finished_tasks(self):
+ num_returns = min(100, len(self.unfinished_tasks)) # query up to 100 tasks at a time
+ try:
+ finished_tasks, unfinished_tasks = ray.wait(self.unfinished_tasks, num_returns=num_returns, timeout=0.1)
+ except Exception as e:
+ logger.error(f"Error waiting for finished ray tasks: {e}")
+ return
+ if len(finished_tasks) > 0:
+ logger.info(f"Found {len(finished_tasks)} finished tasks, {len(unfinished_tasks)} unfinished tasks left")
+ self.unfinished_tasks = unfinished_tasks
+ dt = time.time() - self.start_time
+ rollout_results: list[RolloutResult] = []
+ for finished_task in finished_tasks:
+ try:
+ rollout_results += ray.get(finished_task)
+ except Exception as e:
+ logger.error(f"Error getting finished ray task: {e}")
+ self.rollout_errors += self.cfg.actor.async_batch_size
+ continue
+ logger.info(f"Received {len(rollout_results)} rollout results from {len(finished_tasks)} finished tasks")
+ for rollout_result in rollout_results:
+ rollout_result.model_version = self.trainer_state.propagated_weight_version
+ rollout_index = len(self.unfinished_groups[rollout_result.group_id])
+ for step_index, tr_text in enumerate(rollout_result.training_texts):
+ # Downstream in the pipeline we'll need these fields in every sample
+ tr_text.metadata["model_version"] = rollout_result.model_version
+ tr_text.metadata["rollout_index"] = rollout_index
+ tr_text.metadata["step_index"] = step_index
+ tr_text.group_id = rollout_result.group_id
+ if self.llms_utilization[rollout_result.llm_url] > 0:
+ self.llms_utilization[rollout_result.llm_url] -= 1
+ else:
+ logger.warning(f"LLM {rollout_result.llm_url} utilization is 0, but got a result") # should not happen
+ self.token_count += get_number_of_tokens_in_result(rollout_result)
+ self.finished_rollouts_count += 1
+ self.unfinished_groups[rollout_result.group_id].append(rollout_result)
+
+ if len(self.unfinished_groups[rollout_result.group_id]) == self.attempts:
+ logger.info(f"Problem {rollout_result.group_id} group finished")
+ group = self.unfinished_groups[rollout_result.group_id]
+ random.shuffle(group)
+ self.finished_groups.append(group)
+ del self.unfinished_groups[rollout_result.group_id]
+ logger.info(f"{len(self.finished_groups)} finished groups ready to return")
+ logger.info(
+ f"Ray {'train' if self.is_training else 'test'} actor loop: "
+ f"rollouts in progress: {len(self.unfinished_tasks)}, "
+ f"groups in progress: {len(self.unfinished_groups)}, "
+ f"rollouts finished: {self.finished_rollouts_count}, "
+ f"total tokens: {self.token_count}, "
+ f"gen speed: {self.token_count / dt:.2f} tokens/sec, "
+ f"time elapsed: {dt:.2f} sec,\n"
+ f"LLMs utilization: {self.llms_utilization}"
+ )
+
+ def get_new_results(self) -> list[list[RolloutResult]]:
+ self.receive_finished_tasks()
+ if len(self.finished_groups) > 0:
+ logger.info(f"have {len(self.finished_groups)} finished problems, pop one")
+ return self.finished_groups.pop(0)
+ return []
+
+ def problem_queue_size(self) -> int:
+ return len(self.unfinished_tasks)
+
+ def result_queue_size(self) -> int:
+ return len(self.finished_groups)
+
def run_actor_loop(cfg: DictConfig):
set_streams_backend(**cfg.streams)
+ actor_loop_class = ActorLoopRay if cfg.use_ray else ActorLoop
# set seed for reproducibility (mostly intended for dataset loading)
random.seed(cfg.seed)
@@ -603,7 +883,7 @@ def run_actor_loop(cfg: DictConfig):
dataset_loader = hydra.utils.get_method(cfg.dataset_loader)
# Get dataset loader parameters if they exist in config, otherwise use empty dict
- dataset_loader_params = cfg.get('dataset_loader_params', {})
+ dataset_loader_params = cfg.get("dataset_loader_params", {})
# Use **dataset_loader_params to pass parameters only if they exist
train_dataset = dataset_loader(cfg.train_dataset_names, **dataset_loader_params)
test_dataset = dataset_loader(cfg.test_dataset_names, **dataset_loader_params)
@@ -617,12 +897,19 @@ def run_actor_loop(cfg: DictConfig):
actor_model_path = finetune_model_path
else:
actor_model_path = cfg.model_path
-
+
+ # Align client-side context size with vLLM server max_model_len when available
+ try:
+ _context_size = int(cfg.vllm_config.vllm_kwargs.max_model_len)
+ except Exception:
+ _context_size = 32000
+
train_llms = [
TrainableLLM(
base_url=url,
model_name=str(actor_model_path),
tokenizer_name=str(actor_model_path),
+ context_size=_context_size,
parameters=cfg.llm.parameters,
collect_logprobs=True,
)
@@ -633,6 +920,7 @@ def run_actor_loop(cfg: DictConfig):
base_url=url,
model_name=str(actor_model_path),
tokenizer_name=str(actor_model_path),
+ context_size=_context_size,
parameters=cfg.test_llm.parameters,
collect_logprobs=True,
)
@@ -648,13 +936,12 @@ def run_actor_loop(cfg: DictConfig):
trainer_state.start_listening()
trainer_state.wait_for_model_version()
- train_loop = ActorLoop(
+ train_loop = actor_loop_class(
data_stream=data_stream, cfg=cfg, trainer_state=trainer_state, stats_stream=stats_stream, llms=train_llms
)
- train_loop_run = train_loop.run(
- dataset=train_dataset,
- )
- test_loop = ActorLoop(
+ train_loop.start_backend()
+ train_loop_run = train_loop.run(dataset=train_dataset)
+ test_loop = actor_loop_class(
data_stream=test_data_stream,
cfg=cfg,
trainer_state=trainer_state,
@@ -683,9 +970,8 @@ def run_actor_loop(cfg: DictConfig):
and test_loop_run is None
):
logger.info("Create test loop")
- test_loop_run = test_loop.run(
- dataset=test_dataset,
- )
+ test_loop.start_backend()
+ test_loop_run = test_loop.run(dataset=test_dataset)
train_loop.is_scheduling_paused = True
current_eval = next_regular_eval
diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py
index 4e78ebf9..122625ae 100644
--- a/pipelinerl/async_llm.py
+++ b/pipelinerl/async_llm.py
@@ -4,16 +4,20 @@
import aiohttp
import numpy as np
+from omegaconf import DictConfig, ListConfig, OmegaConf
from PIL import Image
-from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM
-from pipelinerl.finetune.data import MASKED_TOKEN_ID
-from pipelinerl.rollouts import TrainingText
+from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM
from pipelinerl.processor_factory import get_processor
-from omegaconf import DictConfig, ListConfig, OmegaConf
+from pipelinerl.rollouts import TrainingText
logger = logging.getLogger(__name__)
+# -100 is the default "ignore_index" in nn.CrossEntropyLoss
+# Defined here to avoid importing dependencies from finetune.data
+# Do not replace. Import from finetune module breaks ray parallelization!
+MASKED_TOKEN_ID = -100
+
def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]:
"""Extract PIL Images from multimodal messages."""
diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py
index 62758c7e..41a61021 100644
--- a/pipelinerl/domains/math/rollouts.py
+++ b/pipelinerl/domains/math/rollouts.py
@@ -1,18 +1,17 @@
-import random
import time
+import random
import aiohttp
from omegaconf import DictConfig
from pydantic import BaseModel
-
-from pipelinerl.async_llm import llm_async_generate, make_training_text
-from pipelinerl.llm import Prompt, TrainableLLM
-from pipelinerl.rollouts import BaseMetrics, RolloutResult
+from pipelinerl.rollouts import RolloutResult, BaseMetrics
from pipelinerl.world import Job
+from tapeagents.core import Prompt
+from tapeagents.llms.trainable import TrainableLLM
+from pipelinerl.async_llm import llm_async_generate, make_training_text
from .verifier_api import verify_answer_rpc
-
class Metrics(BaseMetrics):
penalty: float
diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py
index ea850814..9392efe6 100644
--- a/pipelinerl/domains/miniwob/rollouts.py
+++ b/pipelinerl/domains/miniwob/rollouts.py
@@ -1,51 +1,59 @@
import asyncio
-import json
import logging
import os
import random
import time
import traceback
+from typing import Literal
import aiohttp
-from examples.rl_webagent.steps import WebTape
from hydra.utils import instantiate
from omegaconf import DictConfig
from tapeagents.agent import DEFAULT, Agent
-from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, Observation
+from tapeagents.core import LLMOutputParsingFailureAction, Observation
from tapeagents.io import save_json_tape
-from tapeagents.llms.trainable import TrainableLLM
-from tapeagents.orchestrator import async_execute_agent
+from tapeagents.orchestrator import async_execute_agent, execute_agent
from tapeagents.remote_environment import AsyncRemoteEnvironment
from tapeagents.tools.simple_browser import PageObservation
from pipelinerl.async_llm import make_training_text
-from pipelinerl.llm import LLMCall, TrainableLLM
+from pipelinerl.domains.miniwob.environment import WebEnvironment
+from pipelinerl.domains.miniwob.steps import WebTape
+from pipelinerl.llm import LLMCall, TrainableLLM, TrainingText
from pipelinerl.rollouts import BaseMetrics, RolloutResult
from pipelinerl.world import Job
-from .steps import WebTape
-
logger = logging.getLogger(__name__)
+def task_id(problem: dict) -> str:
+ """Format task identifier for logging."""
+ return f"{problem['dataset']}/{problem['task']}/{problem['seed']}"
+
+
class MiniwobMetrics(BaseMetrics):
- reward: float
- success: bool
- no_error: bool
- no_answer: bool
- overflow: bool
- n_llm_calls: int
- n_step_errors: int
- n_page_observations: int
- n_steps: int
- total_execution_time: float
- agent_execution_time: float
- environment_execution_time: float
- env_step_time: float
- agent_step_time: float
-
-
-def tape_contains_an_error(tape: WebTape) -> bool:
+ reward: float = -1.0
+ success: bool = False
+ has_error: bool = False
+ no_answer: bool = True
+ overflow: bool = False
+ n_llm_calls: int = 0
+ n_step_errors: int = 0
+ n_observations: int = 0
+ n_steps: int = 0
+ env_creation_time: float = 0.0
+ agent_creation_time: float = 0.0
+ env_start_time: float = 0.0
+ env_close_time: float = 0.0
+ agent_execution_time: float = 0.0
+ total_execution_time: float = 0.0
+ llm_call_time: float = 0.0
+ env_step_time: float = 0.0
+ total_llm_call_time: float = 0.0
+ total_env_call_time: float = 0.0
+
+
+def _tape_contains_an_error(tape: WebTape) -> bool:
"""
Returns true if the tape ends with an error, ie if one of the following is true:
- the last step is an LLMOutputParsingFailureAction
@@ -56,94 +64,294 @@ def tape_contains_an_error(tape: WebTape) -> bool:
len(tape.steps) == 0
or isinstance(tape.steps[-1], LLMOutputParsingFailureAction)
or tape.metadata.result.get("error") is not None
- or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error)
+ or (isinstance(tape.steps[-1], PageObservation) and bool(tape.steps[-1].error))
)
+def _compute_reward(
+ tape: WebTape,
+ reward_computation: Literal["uic", "default"] = "default",
+ has_error: bool = False,
+) -> tuple[float, bool]:
+ """
+ Compute reward from tape.
+
+ Args:
+ tape: The execution tape
+ cfg: Configuration with reward_computation setting
+ has_error: If there were errors during execution
+
+ Returns:
+ tuple of (reward, has_error)
+ """
+ # Extract raw reward from last observation
+ obs_steps = [step for step in tape if isinstance(step, Observation)]
+ if obs_steps:
+ last_obs = obs_steps[-1]
+ raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0)
+ else:
+ raw_reward = -1.0
+
+ # Count errors and page observations
+ n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)])
+ n_observations = len([step for step in tape.steps if isinstance(step, Observation)])
+
+ # Determine if tape has errors
+ has_error = has_error or _tape_contains_an_error(tape)
+
+ # Compute final reward based on configuration
+ if reward_computation == "uic":
+ reward = float(raw_reward > 0)
+ if reward == 0.0:
+ reward = -1.0
+ reward *= 0.98**n_observations
+ else:
+ reward = raw_reward * 0.99**n_step_errors if not has_error and raw_reward >= 0 else -1.0
+
+ return reward, has_error
+
+
+def _extract_llm_calls(tape: WebTape) -> list[LLMCall]:
+ """Extract LLM calls from tape steps."""
+ return [
+ LLMCall(**step.metadata.other["llm_call"])
+ if isinstance(step.metadata.other["llm_call"], dict)
+ else step.metadata.other["llm_call"]
+ for step in tape.steps
+ if "llm_call" in step.metadata.other
+ ]
+
+
+def _compute_metrics(
+ tape: WebTape,
+ training_texts: list[TrainingText],
+ reward: float,
+ has_error: bool,
+ n_llm_calls: int,
+) -> MiniwobMetrics:
+ # Create training texts
+ has_overflow = False
+ for text in training_texts:
+ text.reward = reward
+ has_overflow |= not text.finished
+
+ # Extract timing information
+ llm_call_times = [float(step.metadata.other.get("llm_call_time", 0.0)) for step in tape.steps]
+ env_call_times = [float(step.metadata.other.get("action_execution_time", 0.0)) for step in tape.steps]
+ total_llm_call_time = sum(llm_call_times)
+ total_env_call_time = sum(env_call_times)
+ llm_call_time = total_llm_call_time / len(llm_call_times) if llm_call_times else -1.0
+ env_step_time = total_env_call_time / len(env_call_times) if env_call_times else -1.0
+ env_start_time = tape.metadata.result.get("env_start_time", -1.0)
+ env_close_time = tape.metadata.result.get("env_close_time", -1.0)
+ env_creation_time = tape.metadata.result.get("env_creation_time", -1)
+ agent_creation_time = tape.metadata.result.get("agent_creation_time", -1)
+ agent_execution_time = tape.metadata.result.get("agent_execution_time", -1.0)
+ total_execution_time = tape.metadata.result.get("total_execution_time", -1.0)
+
+ # Compute step counts
+ n_observations = len([s for s in tape.steps if isinstance(s, Observation)])
+ n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)])
+
+ metrics = MiniwobMetrics(
+ reward=reward,
+ success=reward > 0.5,
+ has_error=has_error,
+ no_answer=reward < 0,
+ overflow=has_overflow,
+ n_llm_calls=n_llm_calls,
+ n_step_errors=n_step_errors,
+ n_steps=len(tape.steps),
+ n_observations=n_observations,
+
+ env_creation_time=env_creation_time,
+ env_start_time=env_start_time,
+ env_close_time=env_close_time,
+
+ agent_creation_time=agent_creation_time,
+ agent_execution_time=agent_execution_time,
+
+ llm_call_time=llm_call_time,
+ env_step_time=env_step_time,
+ total_llm_call_time=total_llm_call_time,
+ total_env_call_time=total_env_call_time,
+ total_execution_time=total_execution_time,
+ )
+ return metrics
+
+
async def check_env_server_health(env_job: Job, session: aiohttp.ClientSession) -> dict:
"""Check environment server health via HTTP API."""
try:
url = f"http://{env_job.hostname}:{env_job.port}/health"
- async with session.get(url, timeout=5) as response:
+ async with session.get(url, timeout=aiohttp.ClientTimeout(total=5.0)) as response:
if response.status == 200:
health_data = await response.json()
- return {
- "healthy": True,
- "health_data": health_data,
- "last_check": time.time()
- }
+ return {"healthy": True, "health_data": health_data, "last_check": time.time()}
else:
error_text = await response.text()
- return {"healthy": False, "error_message": f"HTTP {response.status}: {error_text}", "last_check": time.time()}
+ health_data = f"HTTP {response.status}: {error_text}"
+ return {"healthy": False, "health_data": health_data, "last_check": time.time()}
except Exception as e:
exception_type = type(e).__name__
exception_message = str(e) if str(e) else "No message available"
- logger.exception(f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True)
- return {"healthy": False, "error_message": f"Exception: {exception_type}: {exception_message}", "last_check": time.time(), "error_stacktrace": traceback.format_exc()}
+ logger.exception(
+ f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True
+ )
+ return {
+ "healthy": False,
+ "health_data": f"Exception: {exception_type}: {exception_message}",
+ "last_check": time.time(),
+ "error_stacktrace": traceback.format_exc(),
+ }
+
+
+def generate_miniwob_rollout(cfg: DictConfig, llm: TrainableLLM, problem: dict) -> RolloutResult:
+ """
+ Generate a MiniWoB rollout. Steps:
+ - make agent and env
+ - set the llm
+ - run the agent
+ - get llm calls from tape
+ - compute rewards
+ - get training text from llm calls
+
+ Args:
+ cfg: Configuration for the rollout
+ llm: The LLM to use
+ problem: The problem dict
+ Returns:
+ RolloutResult with training texts and metrics
+ """
+ tid = task_id(problem)
+ start_time = time.perf_counter()
+ environment: WebEnvironment = instantiate(cfg.environment)
+ environment.initialize()
+ env_creation_time = time.perf_counter() - start_time
+ logger.info(f"Environment tools: {environment.tools_description()}")
+ t = time.perf_counter()
+ agent: Agent = instantiate(
+ cfg.agent,
+ known_actions=environment.actions(),
+ tools_description=environment.tools_description(),
+ llms={DEFAULT: llm},
+ )
+ logger.info(f"Agent and environment loaded, using llm {llm.model_name} at {llm.get_base_url()}")
+ agent_creation_time = time.perf_counter() - t
+ try:
+ start_attempts = cfg.start_attempts
+ t = time.perf_counter()
+ while True:
+ try:
+ tape, _ = environment.start_task(problem)
+ break
+ except Exception as e:
+ logger.exception(f"Failed to start task {tid}: {e}")
+ start_attempts -= 1
+ if start_attempts <= 0:
+ raise Exception(f"Failed to start task {tid} after {cfg.start_attempts} attempts")
+ else:
+ logger.warning("Retrying after 1 second")
+ time.sleep(1)
+ env_start_time = time.perf_counter() - t
+ logger.info(f"Task {tid} started in {env_start_time:.2f}s")
+ t = time.perf_counter()
+ tape = execute_agent(agent, tape, environment, max_loops=cfg.agent_max_loops)
+ agent_execution_time = time.perf_counter() - t
+ finally:
+ t = time.perf_counter()
+ environment.close()
+ env_close_time = time.perf_counter() - t
+ total_execution_time = time.perf_counter() - start_time
+ logger.info(f"Task {tid} finished in {total_execution_time:.2f}s")
+ tape.metadata.result.update(
+ {
+ "total_execution_time": total_execution_time,
+ "env_creation_time": env_creation_time,
+ "env_start_time": env_start_time,
+ "env_close_time": env_close_time,
+ "agent_creation_time": agent_creation_time,
+ "agent_execution_time": agent_execution_time,
+ }
+ )
+ # save the tape as we go
+ if cfg.save_tapes:
+ _save_tapes(cfg, problem, tape)
-async def generate_miniwob_rollout(
+ # Compute reward and metrics
+ reward, has_error = _compute_reward(tape, cfg.reward_computation)
+ llm_calls = _extract_llm_calls(tape)
+ training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls]
+ metrics = _compute_metrics(
+ tape,
+ training_texts,
+ reward,
+ has_error,
+ len(llm_calls),
+ )
+ latency = time.perf_counter() - start_time
+ return RolloutResult(
+ training_texts=training_texts,
+ metrics=metrics,
+ latency=latency,
+ dataset_name=problem["dataset"],
+ )
+
+def _save_tapes(cfg, problem, tape):
+ tape_name = problem.get("_task_id", tape.metadata.id)
+ try:
+ save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape_name)
+ except Exception as e:
+ logger.error(f"Error saving tape {tape_name}: {e}")
+
+
+async def generate_miniwob_rollout_async(
cfg: DictConfig,
llm: TrainableLLM,
problem: dict,
session: aiohttp.ClientSession,
) -> RolloutResult:
- # choose a random environment server
- # Generate environment
- # Generate TapeAgent
- # run the agent
- # get llm calls from tape
- # compute rewards
- # get training text from llm calls
-
- start_time = time.time()
-
- # Overall timeout for the entire rollout to prevent hanging
- rollout_timeout = getattr(cfg, 'rollout_timeout', 600) # 10 minutes default
+ start_time = time.perf_counter()
+ tid = task_id(problem)
+ rollout_timeout = getattr(cfg, "rollout_timeout", 600) # 10 minutes default
env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"]
env_jobs_url_tried = []
- # Try each environment server with health checks until one of them returns a rollout result
for _ in range(len(env_jobs)):
- # Choose the next environment server to try randomly from the ones that have not been tried yet
- env_job = random.choice([job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried])
+ env_job = random.choice(
+ [job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried]
+ )
env_job_url = f"http://{env_job.hostname}:{env_job.port}"
env_jobs_url_tried.append(env_job_url)
- # Check server health before using
health = await check_env_server_health(env_job, session)
if not health["healthy"]:
- logger.warning(f"Environment server {env_job_url} is unhealthy: {health}")
- logger.warning(f"Get health error stacktrace: {health['error_stacktrace']}")
+ logger.warning(f"Env server {env_job_url} unhealthy: {health.get('health_data', 'unknown')}, skip to next one")
continue
- # Log health status for monitoring
- if health["healthy"]:
- logger.info(f"Using healthy environment server {env_job_url}: {health}")
+ logger.debug(f"Using env server {env_job_url}")
try:
- # Execute the entire rollout with a timeout
return await asyncio.wait_for(
_execute_rollout_with_timeout(cfg, llm, problem, session, start_time, env_job_url),
- timeout=rollout_timeout
+ timeout=rollout_timeout,
)
except asyncio.TimeoutError:
- health = await check_env_server_health(env_job, session)
- if stack_trace := health.get("error_stacktrace"):
- logger.warning(f"Get health error stacktrace: {stack_trace}")
- logger.warning(f"Rollout timeout error stacktrace: {traceback.format_exc()}")
- logger.warning(f"Rollout timed out after {rollout_timeout} seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.")
+ logger.warning(f"Task {tid} timed out after {rollout_timeout}s on {env_job_url}")
continue
except Exception as e:
- health = await check_env_server_health(env_job, session)
- if stack_trace := health.get("error_stacktrace"):
- logger.warning(f"Get health error stacktrace: {stack_trace}")
- logger.warning(f"Rollout failed error stacktrace: {traceback.format_exc()}")
- logger.warning(f"Rollout failed for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.")
+ logger.warning(f"Task {tid} failed on {env_job_url}: {e}")
continue
- # If all servers failed
- logger.error(f"All environment servers failed for task {problem['dataset']}/{problem['task']}/{problem['seed']}. Returning a failed rollout result.")
- return _create_failed_rollout_result(problem, start_time, "all environment servers failed")
+
+ logger.error(f"Task {tid}: all environment servers failed")
+ # Return a failed rollout result
+ return RolloutResult(
+ training_texts=[],
+ metrics=MiniwobMetrics(),
+ latency=time.perf_counter() - start_time,
+ dataset_name=problem["dataset"],
+ )
async def _execute_rollout_with_timeout(
@@ -154,190 +362,106 @@ async def _execute_rollout_with_timeout(
start_time: float,
env_job_url: str,
) -> RolloutResult:
- # (2) Generate environment, TapeAgent, and run them to get a Tape
- no_error = True # track if there was an error in the tape
+ tid = task_id(problem)
+ has_error = False
+ start_time = time.perf_counter()
environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore
async with environment.acontext(session, wait_for_env=True) as env:
+ env_creation_time = time.perf_counter() - start_time
+ agent_creation_time = 0.0
start_attempts = cfg.start_attempts
t = time.perf_counter()
+ tape_dict = {}
while start_attempts > 0:
try:
tape_dict, info = await env.start_task(problem)
if info.get("error"):
- raise ValueError(info['error'])
+ raise ValueError(info["error"])
break
except Exception as e:
start_attempts -= 1
- logger.warning(f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']}. {start_attempts} attempts remaining. Error: {e}")
+ logger.warning(f"Task {tid} start failed, {start_attempts} attempts left: {e}")
if start_attempts <= 0:
- logger.error(f"Failed to start task after all retry attempts: {e}")
- no_error = False
- tape_dict = {}
+ logger.error(f"Task {tid} start failed after all retries: {e}")
+ has_error = True
break
else:
- logger.warning("Retry start task after 5 seconds.")
await asyncio.sleep(5)
- logger.info(
- f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds. Worker ID: {env.worker_id}. Tape dict: {tape_dict}"
- )
- tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object
+ env_start_time = time.perf_counter() - t
+ logger.info(f"Task {tid} started in {env_start_time:.2f}s (worker={env.worker_id})")
+ tape: WebTape = WebTape(**tape_dict)
t = time.perf_counter()
- if no_error: # only run the agent if the task started successfully
- logger.info(f"Running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}")
+ agent_execution_time = 0.0
+ if not has_error:
agent_attempts = cfg.agent_attempts
while agent_attempts > 0:
- # check if the worker is alive.
try:
- # this will either raise RuntimeError if worker is not alive anymore, or return a dictionary with the worker status
worker_status = await env.check_worker_alive()
if worker_status.get("status") == "starting":
- logger.warning(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is starting, waiting 5 seconds for it to be fully started.")
+ logger.debug(f"Task {tid}: worker {env.worker_id} starting, waiting...")
await asyncio.sleep(5)
continue
except Exception as e:
- # if worker is dead, no need to retry
- logger.exception(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is dead. Error: {e}", stack_info=True)
- no_error = False
+ logger.exception(f"Task {tid}: worker {env.worker_id} dead: {e}")
+ has_error = True
break
- # if worker is alive, run the agent
+
try:
+ t = time.perf_counter()
actions = await env.a_actions()
tools_description = await env.a_tools_description()
agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description)
- agent.llms = {DEFAULT: llm}
+ agent.llms = {DEFAULT: llm} # type: ignore
+ agent_creation_time = time.perf_counter() - t
+ t = time.perf_counter()
tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops)
- # Check if the tape has an error from the orchestrator (e.g., SocketTimeoutError, RuntimeError: Worker is not alive, etc.)
+ agent_execution_time = time.perf_counter() - t
if tape.metadata.error:
- logger.error(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} returned a tape with error: {tape.metadata.error}")
+ logger.error(f"Task {tid}: agent error: {tape.metadata.error}")
raise ValueError(tape.metadata.error)
- else:
- # Success - break out of retry loop
- logger.info(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} finished successfully")
- break
+ logger.info(f"Task {tid}: agent execution succeeded")
+ break
except Exception as e:
agent_attempts -= 1
- logger.warning(f"Error occurred while running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}. {agent_attempts} attempts remaining. Error: {e}")
+ logger.warning(f"Task {tid}: agent error, {agent_attempts} attempts left: {e}")
if agent_attempts <= 0:
- logger.error(f"Agent execution failed after all retry attempts for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}: {e}")
- no_error = False
+ logger.error(f"Task {tid}: agent failed after all retries: {e}")
+ has_error = True
break
- else:
- logger.warning(f"Retry agent execution after 5 seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}.")
- await asyncio.sleep(5)
- logger.info(
- f"Agent finished task {problem['dataset']}/{problem['task']}/{problem['seed']} in {time.perf_counter() - t:.2f} seconds with worker ID: {env.worker_id} and tape ID {tape.metadata.id}"
- )
- tape.metadata.result.update({"total_execution_time": time.perf_counter() - t})
-
- # save the tape as we go
- if cfg.save_tapes:
- save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id)
-
- # (3) Compute rewards
- obs_steps = [step for step in tape if isinstance(step, Observation)]
- if obs_steps:
- last_obs = obs_steps[-1]
- # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0
- # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L188
- # Let's take directly the RAW_REWARD_GLOBAL from the metadata
- # raw_reward = last_obs.metadata.other.get("reward", 0.0)
- raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0)
- else:
- raw_reward = -1.0
-
- no_error = no_error and not tape_contains_an_error(tape)
- # get the number of LLMOutputParsingFailureAction in the tape
- n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)])
- # get the number of PageObservation steps in the tape
- n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)])
+ await asyncio.sleep(5)
- if cfg.reward_computation == "nico":
- reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0
- elif cfg.reward_computation == "uic":
- reward = float(raw_reward>0)
- if reward == 0.0:
- reward = -1.0
- reward *= 0.98 ** n_page_observations
- else:
- raise ValueError(f"Invalid reward configuration: {cfg.reward_computation}")
+ logger.info(f"Task {tid} finished in {time.perf_counter() - t:.2f}s (worker={env.worker_id})")
+ t = time.perf_counter()
+ await env.aclose()
+ env_close_time = time.perf_counter() - t
+ total_execution_time=time.perf_counter() - start_time
+ tape.metadata.result.update({
+ "total_execution_time": total_execution_time,
+ "env_creation_time": env_creation_time,
+ "env_start_time": env_start_time,
+ "env_close_time": env_close_time,
+ "agent_creation_time": agent_creation_time,
+ "agent_execution_time": agent_execution_time,
+ })
- # (3) Get LLM calls from Tape
- llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None]
- n_llm_calls = len(llm_calls)
- llm_calls: list[LLMCall] = [
- LLMCall(**step.metadata.other["llm_call"]) if isinstance(step.metadata.other["llm_call"], dict)
- else step.metadata.other["llm_call"]
- for step in llm_calls
- ]
+ if cfg.save_tapes:
+ _save_tapes(cfg, problem, tape)
- # (4) # For each LLM interaction in the tape, make a training example.
- all_finished = 1
- prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls]
- output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls]
+ # Compute reward and metrics
+ reward, has_error = _compute_reward(tape, cfg.reward_computation, has_error)
+ llm_calls = _extract_llm_calls(tape)
training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls]
- for text in training_texts:
- text.reward = reward
- all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0
-
- latency = time.time() - start_time
- agent_time = tape.metadata.result.get("agent_execution_time", -1.0)
- env_time = tape.metadata.result.get("environment_execution_time", -1.0)
- n_observations = len([s for s in tape.steps if isinstance(s, Observation)]) # TODO: is this not the same n_page_observations??
- n_other_steps = len(tape.steps) - n_observations
- metrics = MiniwobMetrics(
- reward=reward,
- success=reward > 0.5,
- no_error=no_error,
- no_answer=reward < 0,
- overflow=not all_finished,
- n_llm_calls=n_llm_calls,
- n_step_errors=n_step_errors,
- n_page_observations=n_page_observations,
- n_steps=len(tape.steps),
- total_execution_time=tape.metadata.result.get("total_execution_time", -1.0),
- agent_execution_time=agent_time,
- environment_execution_time=env_time,
- env_step_time=env_time / n_observations if env_time > 0 and n_observations > 0 else -1.0,
- agent_step_time=agent_time / n_other_steps if agent_time > 0 and n_other_steps > 0 else -1.0,
+ metrics = _compute_metrics(
+ tape,
+ training_texts,
+ reward,
+ has_error,
+ len(llm_calls),
)
-
- return RolloutResult(
- training_texts=training_texts,
- metrics=metrics,
- latency=latency,
- dataset_name=problem["dataset"],
- prompt_tokens=prompt_tokens,
- output_tokens=output_tokens,
- )
-
-
-def _create_failed_rollout_result(problem: dict, start_time: float, error_type: str) -> RolloutResult:
- """Create a failed rollout result for timeout or other errors."""
latency = time.time() - start_time
-
- # Create empty training texts and metrics for failed rollout
- metrics = MiniwobMetrics(
- reward=-1.0,
- success=False,
- no_error=False,
- no_answer=True,
- overflow=False,
- n_llm_calls=0,
- n_step_errors=0,
- n_page_observations=0,
- n_steps=0,
- total_execution_time=latency,
- agent_execution_time=-1.0,
- environment_execution_time=-1.0,
- env_step_time=-1.0,
- agent_step_time=-1.0,
- )
-
return RolloutResult(
- training_texts=[],
+ training_texts=training_texts,
metrics=metrics,
latency=latency,
dataset_name=problem["dataset"],
- prompt_tokens=[],
- output_tokens=[],
)
diff --git a/pipelinerl/finetune/logging_.py b/pipelinerl/finetune/logging_.py
index 0b221e24..0624f765 100644
--- a/pipelinerl/finetune/logging_.py
+++ b/pipelinerl/finetune/logging_.py
@@ -25,7 +25,7 @@ def setup_logging(cfg: DictConfig, output_dir: Path, run: wandb_run.Run | None =
debug_handler = logging.FileHandler(log_dir / f"info_{get_accelerator().process_index}.log")
debug_handler.setLevel(logging.INFO)
logging.basicConfig(
- format="[finetune]: %(asctime)s.%(msecs)03d - %(levelname)s - %(name)s - %(message)s",
+ format="[finetune]: %(asctime)s.%(msecs)03d - %(levelname)s - %(name)s:%(lineno)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[debug_handler, logging.StreamHandler()],
diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py
index 9118078e..f2472388 100644
--- a/pipelinerl/finetune/rl/__init__.py
+++ b/pipelinerl/finetune/rl/__init__.py
@@ -2,22 +2,17 @@
import os
from functools import partial
from typing import Any
-from pydantic import BaseModel, Field
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
-from datasets import Dataset
-from transformers import PreTrainedModel
-from pipelinerl.finetune.types import PipelineBatchEncoding
-from pipelinerl.finetune.rl.utils import per_segment_sums
+from pydantic import BaseModel, Field
+from transformers.modeling_utils import PreTrainedModel
-from .utils import (
- sum_sum,
- mean_sum,
- replace_dataset_column,
-)
+from pipelinerl.finetune.rl.utils import per_segment_sums
+from pipelinerl.finetune.types import PipelineBatchEncoding
+from pipelinerl.finetune.rl.utils import sum_sum
# FIXME: remove a warnings, but might be worth investigating
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -39,8 +34,7 @@
class RLConfig(BaseModel):
policy_loss: str = Field(
default="ppo",
- description="Policy Loss to use for RL",
- choices=["ppo", "reinforce", "gspo"],
+ description="Policy Loss to use for RL, one of ['ppo', 'reinforce', 'gspo']",
)
use_advantages: bool = Field(
default=True,
@@ -418,7 +412,11 @@ def rl_step(
def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: RLConfig) -> list[dict[str, Any]]:
- """Populate RL-specific columns (advantages, overflow, num_labels) using a leave-one-out baseline."""
+ """
+ Populates a dataset with reinforcement learning specific data columns including
+ rewards, advantages, and token weights.
+ Uses leave-one-out (LOO) reward mean: each rollout's baseline excludes its own reward.
+ """
# Convert to pandas for processing
df_init = pd.DataFrame(dataset)
assert isinstance(df_init, pd.DataFrame)
@@ -450,7 +448,7 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R
"group_tokens",
]
- # Step 2: calculate advantages for each sample
+ # Step 2: calculate advantages for each sample (with LOO mean)
df_advantages = pd.merge(
df_init[["group_id", "rollout_index", "step_index", "rewards"]],
df_grouped,
@@ -458,6 +456,7 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R
how="left"
)
assert len(df_advantages) == len(df_init)
+
def calculate_advantages(row):
rewards = row["rewards"]
group_sum = row["rollout_reward_sum"]
@@ -487,7 +486,6 @@ def calculate_advantages(row):
# Step 3: bring advantages and group level stats back to the main df
df = df_init.drop(columns=["advantages", "group_tokens"])
df = pd.merge(df, df_advantages, on=["group_id", "rollout_index", "step_index"], how="left")
- # Debug print lengths of all dataframes
assert len(df) == len(df_init)
# Step 4: make token-level overflow and mean group length information
diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py
index da39938e..71dac6ea 100644
--- a/pipelinerl/finetune_loop.py
+++ b/pipelinerl/finetune_loop.py
@@ -450,6 +450,7 @@ def run_finetuning_loop(
logger.info("Load the first version of the model into inference LLMs")
weight_update_manager.send_weight_update(training_metrics.samples)
else:
+ logger.info("send_weight_updates disabled, weight_update_manager is None")
weight_update_manager = None
batch_queue = Queue(maxsize=1)
diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py
index 8d2c33d8..110109dd 100644
--- a/pipelinerl/launch.py
+++ b/pipelinerl/launch.py
@@ -1,6 +1,7 @@
import logging
import math
import os
+import shlex
import shutil
import subprocess
import sys
@@ -19,8 +20,6 @@
logger = logging.getLogger(__name__)
-# All the launch commands in this file pass the environment to child processes
-os.environ["PYTHONPATH"] = f"/home/toolkit/TapeAgents"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["TORCH_DISABLE_SHARE_RDZV_TCP_STORE"] = "1"
os.environ["HF_DATASETS_DISABLE_PROGRESS_BARS"] = "1"
@@ -178,6 +177,29 @@ def run_actor_llm(
str(world_map.weight_update_group_size),
]
+ # Provide deterministic rendezvous port defaults when env vars are absent.
+ # vLLM spins up a torch.distributed TCPStore using VLLM_PORT. On the remote
+ # scheduler we observed replica crashes (store collisions, connection
+ # refused) because every start script inherited the same default port. By
+ # exporting VLLM_PORT_BASE/VLLM_PORT_STRIDE we carve out a rendezvous range
+ # per actor_idx while keeping the public HTTP listener at 8080+local_idx.
+ env = dict(os.environ)
+ if "VLLM_PORT_BASE" not in env:
+ # Each rank gets 1000 ports; 43000 leaves room below.
+ env["VLLM_PORT_BASE"] = str(43000 + 1000 * world_map.my_rank)
+ logger.debug(
+ "Setting default VLLM_PORT_BASE=%s for rank %s",
+ env["VLLM_PORT_BASE"], world_map.my_rank,
+ )
+ if "VLLM_PORT_STRIDE" not in env:
+ env["VLLM_PORT_STRIDE"] = "20"
+
+ env_overrides = {
+ key: str(env[key])
+ for key in ("VLLM_PORT_BASE", "VLLM_PORT_STRIDE")
+ if key in env
+ }
+
# Add vLLM kwargs as separate arguments
if cfg.vllm_config.vllm_kwargs:
for k, v in cfg.vllm_config.vllm_kwargs.items():
@@ -190,13 +212,13 @@ def run_actor_llm(
gpu_str = ",".join([str(gpu) for gpu in gpus])
logger.info(f"Running actor_llm with command: {' '.join(cmd)} on gpus: {gpu_str}")
- save_command(log_dir, cmd)
+ save_command(log_dir, cmd, env_overrides or None)
log_file_path = os.path.join(log_dir, "stdout.log")
err_file_path = os.path.join(log_dir, "stderr.log")
with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file:
proc = _popen(
cmd,
- env={**os.environ, "CUDA_VISIBLE_DEVICES": gpu_str},
+ env={**env, "CUDA_VISIBLE_DEVICES": gpu_str},
stdout=log_file,
stderr=err_file,
)
@@ -405,14 +427,21 @@ def run_redis(cfg: DictConfig):
yield LaunchedProcess(kind="redis", handle=proc)
-def save_command(script_dir: Path, cmd):
+def save_command(script_dir: Path, cmd, env: dict | None = None):
os.makedirs(script_dir, exist_ok=True)
script_path = script_dir / "start.sh"
with open(script_path, "w") as f:
f.write("#!/bin/bash\n")
+ f.write("set -e\n")
+ if env:
+ for key, value in sorted(env.items()):
+ quoted_value = shlex.quote(value)
+ f.write(f"export {key}={quoted_value}\n")
# Properly quote arguments for the shell script
- quoted_cmd = [f"'{arg}'" if " " in arg or "$" in arg else arg for arg in cmd]
- f.write(" ".join(quoted_cmd) + "\n")
+ quoted_cmd = [shlex.quote(arg) for arg in cmd]
+ f.write("exec ")
+ f.write(" ".join(quoted_cmd))
+ f.write("\n")
os.chmod(script_path, 0o755)
logger.info(f"Saved start script to {script_path}")
@@ -583,7 +612,8 @@ def main(cfg: DictConfig):
group = str(exp_dir)
root = cfg.wandb.wandb_workspace_root
if root:
- if not group.startswith(root + "/"):
+ check_root = (root + "/") if not root.endswith("/") else root
+ if not group.startswith(check_root):
raise ValueError(f"run_dir {exp_dir} does not start with root {root}")
cfg.wandb.wandb_group = group[len(root) + 1 :]
if world_map.total_finetune_gpus:
@@ -641,6 +671,8 @@ def main(cfg: DictConfig):
if cfg.debug.mode == "finetune":
processes.extend(launch_jobs(cfg, world_map, ["finetune"]))
+ elif cfg.debug.mode == "llm":
+ processes.extend(launch_jobs(cfg, world_map, ["actor_llm"]))
elif cfg.debug.mode == "actor":
processes.extend(launch_jobs(cfg, world_map, ["actor", "environment", "actor_llm"]))
elif cfg.debug.mode == "preprocessor":
diff --git a/pipelinerl/llm.py b/pipelinerl/llm.py
index cc099c15..04dd93b7 100644
--- a/pipelinerl/llm.py
+++ b/pipelinerl/llm.py
@@ -21,45 +21,12 @@
from pydantic import BaseModel, Field, TypeAdapter
from tenacity import retry, stop_after_attempt, wait_exponential
+from pipelinerl.rollouts import TrainingText
+
logger = logging.getLogger(__name__)
PIPELINERL_LLM_TOKEN = "PIPELINERL_LLM_TOKEN"
-class TrainingText(BaseModel):
- """
- Training text instance used to finetune a language model.
-
- Attributes:
- text (str): The full text of the training instance.
- n_predicted (int): The number of predicted characters in the text.
- reward (float): The reward associated with the training instance. Defaults to 0.0.
- logprobs (List[float]): A list of log probabilities of the completion tokens from the assistant model.
- ref_logprobs (List[float]): A list of reference log probabilities of the completion tokens from the reference model.
- input_ids (List[int]): The tokenized input ids of the text.
- labels (List[int]): The tokenized labels of the text (i.e., masked token ids for the prompt and regular token ids for the prediction).
- group_id (str, optional): ID of the group. It is used by the RL finetuning script to normalize rewards.
- prompt_text (str): Portion of the text that serves as the prompt (i.e., the text excluding the predicted characters).
- output_text (str): Portion of the text that represents the predicted output (i.e., the last n_predicted characters).
- """
-
- text: str
- n_predicted: int
- reward: float = 0.0
- logprobs: list[float] = Field(default_factory=list)
- ref_logprobs: list[float] = Field(default_factory=list)
- input_ids: list[int] = Field(default_factory=list)
- labels: list[int] = Field(default_factory=list)
- group_id: str | None = None
- metadata: dict = Field(default_factory=dict)
-
- @property
- def prompt_text(self) -> str:
- return self.text[: -self.n_predicted]
-
- @property
- def output_text(self) -> str:
- return self.text[-self.n_predicted :]
-
class Prompt(BaseModel):
"""
@@ -392,10 +359,6 @@ class TrainableLLM(LLM):
base_url (str): Base URL of the API endpoint
api_token (str): Authentication token for API access
"""
-
- # TODO: use OpenAI Python client when the certificate issue is resolved.
- # TODO: consider using litellm
-
base_url: str = "https://api.openai.com"
api_token: str = Field(default="", exclude=True)
collect_logprobs: bool = False
@@ -403,7 +366,7 @@ class TrainableLLM(LLM):
max_parallel_requests: int = 32
max_retries: int = 5
base_delay: float = 0.5
- _semaphore: asyncio.Semaphore
+ _semaphore: asyncio.Semaphore = None # type: ignore
def model_post_init(self, __context):
super().model_post_init(__context)
diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py
index 0a6015e4..a45c4438 100644
--- a/pipelinerl/preprocess.py
+++ b/pipelinerl/preprocess.py
@@ -211,7 +211,7 @@ def run_dataset_loader(
# This is a blocking call, but in most cases there will be space
raw_chunk_queue.put(buffer)
except Exception as e:
- logger.error(f"Error in dataset loader: {e}")
+ logger.exception(f"Error in dataset loader: {e}")
raw_chunk_queue.put(e)
break
@@ -398,8 +398,8 @@ def run_preprocessing_loop(
# Initialize TrainerState
trainer_state = TrainerState(exp_root_dir)
- if cfg.debug.mode == "preprocessor":
- logger.info("Debug mode: preprocessor")
+ if cfg.debug.mode == "preprocessor" or cfg.debug.mode == "actor+preprocessor":
+ logger.info(f"Debug mode: {cfg.debug.mode}")
trainer_state.debug_mode_init()
elif cfg.debug.mode == "finetune+preprocessor":
logger.info("Debug mode: finetune+preprocessor")
@@ -554,7 +554,7 @@ def run_preprocessing_loop(
else:
processed_entries_queue_popped_data += 1
if processed_entries_queue_popped_data % 100 == 0 and last_time_notice != processed_entries_queue_popped_data // 100:
- logger.warning(f"Popped {processed_entries_queue_popped_data} old entries from processed entries queue")
+ logger.warning(f"Popped {processed_entries_queue_popped_data} old entries from processed entries queue of max size {processed_entries_queue.maxlen}")
last_time_notice = processed_entries_queue_popped_data // 100
entry = buffer.popleft()
processed_entries_queue.append(entry) # drop from the left if full
@@ -590,6 +590,10 @@ def run_preprocessing_loop(
sample_length = len(entry["input_ids"])
if current_length + sample_length > cfg.finetune.seq_length:
+ if len(current_batch) == 0:
+ raise ValueError(
+ f"sample_length is {sample_length}, but cfg.finetune.seq_length is {cfg.finetune.seq_length}"
+ )
time_to_write = True
break # Current micro batch is full
@@ -654,6 +658,7 @@ def run_preprocessing_loop(
"preprocessor/queue/output": output_queue.qsize(),
"preprocessor/filtered_out_samples": num_filtered_out,
"preprocessor/total_filtered_out_samples": total_filtered_out,
+ "preprocessor/dropped_after_preprocessing": processed_entries_queue_popped_data,
}
if stats_aggregator.has_enough_data():
stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()})
diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py
new file mode 100644
index 00000000..12e6fc2d
--- /dev/null
+++ b/pipelinerl/rl_tool_parser_plugin.py
@@ -0,0 +1,247 @@
+"""
+Tool parser plugin for RL tool calling format.
+"""
+
+import json
+import re
+from typing import Any # noqa: F401
+import logging
+
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
+from vllm.entrypoints.openai.tool_parsers import ToolParserManager
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ExtractedToolCallInformation,
+ ToolCall,
+ FunctionCall
+)
+
+
+@ToolParserManager.register_module("rl_tool")
+class HermesRLToolParser(ToolParser):
+ """
+ Tool parser for RL tool calling format using markers.
+ Supports both standard format and Apriel-style formats:
+ - [{...}, {...}] (preferred if present)
+ - [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE] wrapper
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+
+ # Tool call markers
+ self.tool_call_start_token = ""
+ self.tool_call_end_token = ""
+
+ # Regex pattern for parsing tool calls
+ self.tool_call_regex = re.compile(
+ r"(.*?)|(.*)", re.DOTALL
+ )
+
+ # Apriel-specific patterns
+ self.apriel_final_response_regex = re.compile(
+ r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL
+ )
+ # Prefer parsing aggregated tool calls from ...
+ # Be lenient: case-insensitive; tolerate missing closing tag by capturing to end.
+ self.apriel_tool_calls_regex = re.compile(
+ r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE
+ )
+
+ # State for streaming
+ self.current_tool_name_sent = False
+ self.prev_tool_call_arr = []
+ self.current_tool_id = -1
+ self.streamed_args_for_tool = []
+
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ """
+ Extract tool calls from the model output.
+
+ Args:
+ model_output: The raw model output string
+ request: The request object
+
+ Returns:
+ ExtractedToolCallInformation with tool calls and metadata
+ """
+ logger = logging.getLogger("pipelinerl.tool_parser")
+ # Ensure variable exists for any fallback references below
+ final_response_match = None
+
+ try:
+ # 1) Apriel aggregated tool calls block has priority
+ tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output))
+ if tool_calls_matches:
+ # Use the last match (in case of multiple blocks)
+ last_match = tool_calls_matches[-1]
+ tool_calls_json = last_match.group(1).strip()
+ parsed_calls = []
+ try:
+ parsed_calls = json.loads(tool_calls_json) if tool_calls_json else []
+ except Exception:
+ logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True)
+ parsed_calls = []
+
+ tool_calls: list[ToolCall] = []
+ for i, pc in enumerate(parsed_calls):
+ try:
+ name = pc.get("name", "")
+ args_obj = pc.get("arguments", {})
+ if not isinstance(args_obj, (dict, list, str, int, float, bool)):
+ args_obj = {}
+ args_str = json.dumps(args_obj, ensure_ascii=False)
+ call_id = pc.get("id", f"call_{i}")
+ tool_calls.append(
+ ToolCall(
+ id=call_id,
+ type="function",
+ function=FunctionCall(name=str(name), arguments=args_str),
+ )
+ )
+ except Exception:
+ logger.debug("Skipping malformed aggregated tool call", exc_info=True)
+ continue
+
+ # Prefer final response content if present; otherwise empty string
+ final_response_match = self.apriel_final_response_regex.search(model_output)
+ content = final_response_match.group(1).strip() if final_response_match else ""
+
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content,
+ )
+
+ # 2) Try bare JSON tool-calls (no tags), but only if tools are declared in the request
+ # Accept either a list of {name, arguments} or a single dict
+ try:
+ tools_declared = bool(getattr(request, "tools", None))
+ except Exception:
+ tools_declared = False
+
+ if tools_declared:
+ candidate_strings: list[str] = []
+ final_response_match = self.apriel_final_response_regex.search(model_output)
+ if final_response_match:
+ candidate_strings.append(final_response_match.group(1).strip())
+ candidate_strings.append(model_output.strip())
+
+ for candidate in candidate_strings:
+ try:
+ parsed = json.loads(candidate)
+ except Exception:
+ continue
+ parsed_list = []
+ if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed:
+ parsed_list = [parsed]
+ elif isinstance(parsed, list) and all(isinstance(it, dict) for it in parsed):
+ parsed_list = [it for it in parsed if "name" in it and "arguments" in it]
+ if not parsed_list:
+ continue
+ tool_calls: list[ToolCall] = []
+ for i, pc in enumerate(parsed_list):
+ try:
+ name = pc.get("name", "")
+ args_obj = pc.get("arguments", {})
+ if not isinstance(args_obj, (dict, list, str, int, float, bool)):
+ args_obj = {}
+ args_str = json.dumps(args_obj, ensure_ascii=False)
+ call_id = pc.get("id", f"call_{i}")
+ tool_calls.append(
+ ToolCall(
+ id=call_id,
+ type="function",
+ function=FunctionCall(name=str(name), arguments=args_str),
+ )
+ )
+ except Exception:
+ logger.debug("Skipping malformed bare-JSON tool call", exc_info=True)
+ continue
+ content = final_response_match.group(1).strip() if final_response_match else ""
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content,
+ )
+
+ # 3) Fallback: look for single blocks (legacy / other models)
+ content_to_search = model_output
+ final_response_match = self.apriel_final_response_regex.search(model_output)
+ if final_response_match:
+ final_response_content = final_response_match.group(1).strip()
+ if self.tool_call_start_token in final_response_content:
+ content_to_search = final_response_content
+ elif self.tool_call_start_token not in model_output:
+ # No tool calls found, return final response as content
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=final_response_content
+ )
+
+ # Quick check to avoid unnecessary processing
+ if self.tool_call_start_token not in content_to_search:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output
+ )
+
+ # Find all tool call matches
+ function_call_tuples = self.tool_call_regex.findall(content_to_search)
+
+ # Parse JSON from matches
+ tool_calls = []
+ for i, match in enumerate(function_call_tuples):
+ json_str = match[0] if match[0] else match[1]
+ try:
+ parsed_call = json.loads(json_str.strip())
+ args_obj = parsed_call.get("arguments", {})
+ if not isinstance(args_obj, (dict, list, str, int, float, bool)):
+ args_obj = {}
+ tool_call = ToolCall(
+ id=f"call_{i}",
+ type="function",
+ function=FunctionCall(
+ name=str(parsed_call.get("name", "")),
+ arguments=json.dumps(args_obj, ensure_ascii=False)
+ )
+ )
+ tool_calls.append(tool_call)
+ except Exception:
+ logger.debug("Skipping malformed JSON", exc_info=True)
+ continue
+
+ # Determine content based on whether we found tool calls
+ if tool_calls and final_response_match:
+ # If we found tool calls in final response, use just the tool calls
+ content = ""
+ elif final_response_match:
+ # If we have final response but no tool calls there, use final response
+ content = final_response_match.group(1).strip()
+ else:
+ # Standard processing
+ content = model_output
+
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content
+ )
+
+ except Exception:
+ # Never propagate exceptions to the server; log and return a safe fallback.
+ logger.exception("Tool parser encountered an exception; returning safe fallback.")
+ if final_response_match:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=final_response_match.group(1).strip()
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output
+ )
+
\ No newline at end of file
diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py
index dcb27f2d..c755cf31 100644
--- a/pipelinerl/rollouts.py
+++ b/pipelinerl/rollouts.py
@@ -64,3 +64,5 @@ class RolloutResult(BaseModel):
model_version: int | None = None
dataset_name: str | None = None
group_id: str | None = None
+ llm_url: str = ""
+ attempt: int = 0 # number of attempt in the group of that problem
diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py
index bd7fe5b5..f061d216 100644
--- a/pipelinerl/utils.py
+++ b/pipelinerl/utils.py
@@ -294,19 +294,19 @@ def wait_for_inference_servers(urls: list[str]):
def wait_for_environments(cfg: DictConfig):
- """
- Wait for the verifier to be ready.
- """
+ """Wait for remote environment servers to report healthy."""
+ if cfg.world.environment_mode != "remote":
+ return
+
env_jobs = [Job(**job) for job in cfg.jobs if job.kind == "environment"]
for job in env_jobs:
while True:
url = f"http://{job.hostname}:{job.port}/health"
- # use requests
try:
response = requests.get(url)
if response.status_code == 200:
break
- except:
+ except requests.exceptions.RequestException:
logger.info(f"Waiting for environment at {url} to be ready...")
time.sleep(5.0)
diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py
index 8cd023bd..6858c7cd 100644
--- a/pipelinerl/vllm0.py
+++ b/pipelinerl/vllm0.py
@@ -3,39 +3,39 @@
import logging
import os
import signal
-from pydantic import TypeAdapter
+
import torch
+import torch.distributed as dist
import uvloop
+from pydantic import TypeAdapter
from vllm import AsyncLLMEngine
-from vllm.utils import FlexibleArgumentParser, set_ulimit
-from vllm.entrypoints.openai.cli_args import (
- make_arg_parser,
- validate_parsed_serve_args,
-)
+from vllm._version import version
+from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
- run_server,
- create_server_socket,
build_app,
+ create_server_socket,
init_app_state,
+ run_server,
+)
+from vllm.entrypoints.openai.cli_args import (
+ make_arg_parser,
+ validate_parsed_serve_args,
)
-from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
-from vllm.logger import init_logger
-from vllm._version import version
-from vllm.worker.worker import Worker
-from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper
from vllm.executor.mp_distributed_executor import MultiprocessingDistributedExecutor
+from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper
+from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.usage.usage_lib import UsageContext
-from vllm.worker.multi_step_worker import MultiStepWorker
+from vllm.utils import FlexibleArgumentParser, set_ulimit
from vllm.worker.multi_step_model_runner import MultiStepModelRunner
+from vllm.worker.multi_step_worker import MultiStepWorker
+from vllm.worker.worker import Worker
-
-import torch.distributed as dist
-from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest
import pipelinerl.torch_utils
+from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest
logger = logging.getLogger(__name__)
# configure this logger individually, in order to avoid messign
@@ -180,6 +180,25 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f"invalid tool call parser: {args.tool_call_parser} (chose from {{ {','.join(valide_tool_parses)} }})"
)
+ # Choose a unique rendezvous port per actor to avoid torch.distributed
+ # TCPStore collisions across concurrently launched vLLM processes.
+ try:
+ if "VLLM_PORT" not in os.environ:
+ actor_idx = getattr(args, "actor_llm_idx", None)
+ base_str = os.environ.get("VLLM_PORT_BASE", "")
+ stride_str = os.environ.get("VLLM_PORT_STRIDE", "10")
+ if actor_idx is not None and base_str.isdigit():
+ base = int(base_str)
+ stride = int(stride_str) if stride_str.isdigit() else 10
+ port = base + stride * int(actor_idx)
+ os.environ["VLLM_PORT"] = str(port)
+ logger.info(
+ "Using VLLM_PORT=%s (base=%s stride=%s actor_idx=%s)",
+ port, base, stride, actor_idx,
+ )
+ except Exception as e:
+ logger.warning("Failed to set VLLM_PORT from actor_idx: %s", e)
+
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py
index be98f76f..38d1bc96 100644
--- a/pipelinerl/vllm1.py
+++ b/pipelinerl/vllm1.py
@@ -1,32 +1,32 @@
import logging
import signal
+from typing import Any, Protocol, runtime_checkable
+
import torch
import uvloop
-from vllm.utils import FlexibleArgumentParser, set_ulimit
-from vllm.entrypoints.openai.cli_args import (
- make_arg_parser,
- validate_parsed_serve_args,
-)
+from vllm._version import version
+from vllm.config import ModelConfig
+from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
- run_server,
- create_server_socket,
build_app,
+ create_server_socket,
init_app_state,
+ run_server,
+)
+from vllm.entrypoints.openai.cli_args import (
+ make_arg_parser,
+ validate_parsed_serve_args,
)
-from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
-from vllm._version import version
from vllm.usage.usage_lib import UsageContext
-from vllm.config import ModelConfig
+from vllm.utils import FlexibleArgumentParser, set_ulimit
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import AsyncMPClient
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
-
-from pipelinerl.finetune_loop import WeightUpdateRequest
-from typing import Any, Protocol, runtime_checkable
import pipelinerl.torch_utils
+from pipelinerl.finetune_loop import WeightUpdateRequest
logger = logging.getLogger(__name__)
# configure this logger individually, in order to avoid messign
diff --git a/pipelinerl/world.py b/pipelinerl/world.py
index 992a7c4d..54ee2bf8 100644
--- a/pipelinerl/world.py
+++ b/pipelinerl/world.py
@@ -71,7 +71,7 @@ def __init__(self, cfg: DictConfig, verbose: bool = False):
if place_inference_jobs:
self._place_inference_jobs(cfg)
self._place_pipeline_stages(cfg)
- if cfg.environment:
+ if cfg.environment and cfg.world.environment_mode == "remote":
self._place_environments(cfg)
# Place the finetune workers on the remaining gpus, take all remaining GPUs
diff --git a/pyproject.toml b/pyproject.toml
index 7fc9978a..32504c1e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,6 +43,7 @@ dependencies = [
"uvloop>=0.19.0",
"wandb>=0.16.0",
"hydra-core>=1.3.2",
+ "ray[default]~=2.47.1",
]
[project.optional-dependencies]