diff --git a/.gitignore b/.gitignore index 476aab77..1469bc67 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ celerybeat.pid # SageMath parsed files *.sage.py +node_modules/ # Environments .env @@ -185,4 +186,4 @@ results results/ data/ cache/ -dump.rdb \ No newline at end of file +dump.rdb diff --git a/conf/base.yaml b/conf/base.yaml index 5043aa56..80429dfb 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -5,6 +5,7 @@ defaults: - _self_ seed: 42 +use_ray: false finetune: seed: ${..seed} @@ -18,14 +19,18 @@ actor: result_queue_size: 64 throughput_window_size: 50 shared_memory_entry_size: 10000000 + async_batch_size: 1 # if ==1, rollout function will be called as synchronous + task_submission_delay_sec: 0.5 + collect_logprobs: true + environment: null preprocess: input: actor output: training_data n_workers: 8 - chunk_n_groups: 2 + chunk_n_groups: 8 # queue for loaded raw groups - raw_queue_size: 8 + raw_queue_size: 128 # queue for processed chunks of multiple groups input_queue_size: 32 # queue for ready chunks for multiple groups @@ -67,7 +72,7 @@ vllm_config: tensor-parallel-size: 1 pipeline-parallel-size: 1 generation-config: vllm - max_model_len: 10000 + max_model_len: 16000 world: replicas: 1 @@ -81,6 +86,8 @@ world: actor_group_port: 9000 environment_start_port: 7777 +# Remote vs embedded environment execution strategy + environment_mode: embedded # this will be autocreated based on the config jobs: [] diff --git a/conf/miniwob.yaml b/conf/miniwob.yaml index a8dc3868..7c051a93 100644 --- a/conf/miniwob.yaml +++ b/conf/miniwob.yaml @@ -9,12 +9,11 @@ world: preprocessor_fraction: 0 finetune_fraction: 6 -# debug: -# mode: actor -save_tapes: False +save_tapes: false output_dir: results/miniwob/${now:%Y-%m-%d}/${now:%H-%M-%S} model_path: meta-llama/Llama-3.1-8B-Instruct +use_ray: true finetune: seq_length: 16384 # input + output tokens @@ -25,6 +24,7 @@ finetune: eval_every_n_versions: 10240 # 1024 effective bs * 10 "optim steps" llm: + use_cache: false parameters: max_tokens: 4096 # output tokens temperature: 1.0 @@ -36,13 +36,20 @@ test_llm: top_k: 50 vllm_config: + use_v1: false vllm_kwargs: - max_model_len: 16384 # input + output tokens + max-num-seqs: 256 + max-num-batched-tokens: 32000 + max_model_len: 16384 + gpu-memory-utilization: 0.9 actor: rollout_policy: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout + llm_max_rollouts: 256 + problem_queue_size: 256 + async_batch_size: 1 + rollout_workers: 32 shared_memory_entry_size: 100000000 - llm_max_rollouts: 32 preprocess: n_workers: 32 # Increase from 8 @@ -144,12 +151,7 @@ agent: # ENVIRONMENT CONFIGURATION start_attempts: 3 # number of attempts to start each task environment: - _target_: pipelinerl.domains.miniwob.environment_server.WebEnvironmentServer - miniwob_url: ??? - n_envs: 32 - host: "0.0.0.0" - env_call_timeout: 60 # timeout for each environment call (e.g. start_task, act, etc.) - web_env_target: examples.rl_webagent.environment.WebEnvironment + _target_: examples.rl_webagent.environment.WebEnvironment exp_path: null headless: true observation_format: html @@ -162,4 +164,4 @@ dataset_loader_params: train_dataset_names: - train test_dataset_names: - - test + - test \ No newline at end of file diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 6d3317af..100aa529 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -5,17 +5,20 @@ import os import queue import random +import sys import time from collections import defaultdict from multiprocessing.managers import SharedMemoryManager from pathlib import Path from queue import Empty -from typing import Dict, List +from typing import Any, Awaitable, Callable, Dict, List import aiohttp import hydra +import numpy as np +import ray import uvloop -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field import wandb @@ -32,8 +35,7 @@ set_streams_backend, write_to_streams, ) - -from .utils import ( +from pipelinerl.utils import ( always_or_never_success_stats, calculate_stats, setup_logging, @@ -57,8 +59,9 @@ class SlidingWindowData(BaseModel): class SlidingWindowAggregator: - def __init__(self, window_size: int): + def __init__(self, window_size: int, min_samples: int = 5): self.window_size = window_size + self.min_samples = min_samples self.data = SlidingWindowData() def update(self, prompt_tokens: list[int], output_tokens: list[int]): @@ -71,8 +74,11 @@ def update(self, prompt_tokens: list[int], output_tokens: list[int]): self.data.timestamps.pop(0) def get_stats(self): - if len(self.data.prompt_tokens_window) < self.window_size: + if len(self.data.prompt_tokens_window) < self.min_samples: + logger.warning("Not enough data to compute sliding stats") return None + elif len(self.data.prompt_tokens_window) < self.window_size: + logger.warning(f"Compute sliding stats over just {len(self.data.prompt_tokens_window)} samples") # 1. How many samples do we produce per second? # 2. How many output tokens do we produce per second? @@ -103,11 +109,14 @@ def get_stats(self): } - def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) +def get_number_of_tokens_in_result(result: RolloutResult) -> int: + return sum(training_text.prompt_tokens + training_text.output_tokens for training_text in result.training_texts) + + async def schedule_rollouts( cfg: DictConfig, attempts: int, @@ -133,19 +142,17 @@ async def schedule_rollouts( active_rollouts = [0] * len(llms) started_rollouts = 0 finished_rollouts = 0 + token_count = 0 # Track rollouts per problem group group_rollouts = {} rollout_policy = hydra.utils.get_method(cfg.actor.rollout_policy) - logger.info(f"Use rollout policy: {rollout_policy}") + logger.info(f"Use rollout policy: {rollout_policy.__name__}") final_steps = calculate_train_steps(cfg.finetune, cfg.finetune.interrupt_train_steps) samples_target = final_steps * cfg.finetune.train_batch_size * cfg.finetune.gradient_accumulation_passes def is_trainer_finished() -> bool: - return ( - trainer_state.samples_processed is not None - and trainer_state.samples_processed >= samples_target - ) + return trainer_state.samples_processed is not None and trainer_state.samples_processed >= samples_target def handle_rollout_exception(exc: Exception): if isinstance(exc, aiohttp.ClientError) and is_trainer_finished(): @@ -168,13 +175,16 @@ async def rollout_and_maybe_produce_result( llm_index: int, session: aiohttp.ClientSession, ): - nonlocal started_rollouts, finished_rollouts + nonlocal started_rollouts, finished_rollouts, token_count try: llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None - rollout_result = await rollout_policy(cfg, llm, problem, session) + logger.info(f"Starting rollout policy for problem {problem['id']}") + rollout_result: RolloutResult = await rollout_policy(cfg, llm, problem, session) + logger.info(f"Finished rollout policy for problem {problem['id']}") rollout_result.model_version = model_version + token_count += get_number_of_tokens_in_result(rollout_result) # Make a group id that will be different from groups made by another rollout maker full_group_id = f"{scheduler_name}_{group_id}" rollout_result.group_id = full_group_id @@ -204,18 +214,22 @@ async def rollout_and_maybe_produce_result( logger.info("Starting rollout scheduler") connector = aiohttp.TCPConnector(limit=50000, limit_per_host=50000, keepalive_timeout=1.0) timeout = aiohttp.ClientTimeout(total=3600.0, connect=3600.0, sock_read=3600.0) + old_finished_rollouts = 0 async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: if is_trainer_finished(): logger.info(f"{scheduler_name}: trainer signalled completion; stopping rollout scheduler") break if time.time() - last_logged > 10.0 and sum(active_rollouts): + if finished_rollouts > old_finished_rollouts: + old_finished_rollouts = finished_rollouts logger.info( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " f"groups in progress: {len(group_rollouts)}, " f"rollouts started so far: {started_rollouts}, " f"rollouts finished so far: {finished_rollouts}, " + f"total tokens produced so far: {token_count}, " f"groups started so far: {group_id}, " f"max group size in bytes: {result_queue.max_actual_entry_size()}, " ) @@ -238,7 +252,6 @@ async def rollout_and_maybe_produce_result( await asyncio.sleep(0.01) continue active_rollouts[next_llm] += 1 - started_rollouts += 1 assert problem is not None loop.create_task( rollout_and_maybe_produce_result( @@ -249,6 +262,7 @@ async def rollout_and_maybe_produce_result( session=session, ) ) + started_rollouts += 1 group_rollout_index += 1 logger.info("Rollout scheduler finished") @@ -302,40 +316,52 @@ def __init__( self.sliding_aggregator = SlidingWindowAggregator(window_size=cfg.actor.throughput_window_size) self.llms = llms self.loop_start_time = -1 - self.cfg = cfg + self.cfg: DictConfig = cfg self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) + self.cfg: DictConfig = cfg - # Determine the number of processes to use - num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) - attempts = self.cfg.attempts if is_training else 1 - - # Divide LLMs approximately equally across processes - llm_groups = [[] for _ in range(num_processes)] - for i, llm in enumerate(self.llms): - llm_groups[i % num_processes].append((i, llm)) + self.smm: SharedMemoryManager | None = None + self.problem_queue: SharedMemoryQueue | None = None + self.result_queue: SharedMemoryQueue | None = None + self.rollout_errors = 0 + logger.info(f"Initialized {'train' if self.is_training else 'test'} actor loop") + def start_backend(self): self.smm = SharedMemoryManager() self.smm.start() - # Use SharedMemoryQueue instead of separate problem_queue, result_queue, and io_buffer - self.problem_queue = SharedMemoryQueue(self.smm, self.cfg.actor.problem_queue_size, cfg.actor.shared_memory_entry_size) - self.result_queue = SharedMemoryQueue(self.smm, self.cfg.actor.result_queue_size, cfg.actor.shared_memory_entry_size) - - logger.info(f"Initialized {'train' if self.is_training else 'test'} actor loop") - logger.info(f"Problem queue size: {self.problem_queue.max_size}, result queue size: {self.result_queue.max_size}") + self.problem_queue = SharedMemoryQueue( + self.smm, self.cfg.actor.problem_queue_size, self.cfg.actor.shared_memory_entry_size + ) + self.result_queue = SharedMemoryQueue( + self.smm, self.cfg.actor.result_queue_size, self.cfg.actor.shared_memory_entry_size + ) + + logger.info( + f"Problem queue size: {self.problem_queue.max_size}, result queue size: {self.result_queue.max_size}" + ) logger.info(f"Result queue buffer size: {self.result_queue.get_memory_size() / 2**30} Gb") # Create and start multiple rollout processes + attempts = self.cfg.attempts if self.is_training else 1 + # Determine the number of processes to use + num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) + + # Divide LLMs approximately equally across processes + llm_groups = [[] for _ in range(num_processes)] + for i, llm in enumerate(self.llms): + llm_groups[i % num_processes].append((i, llm)) + self.rollout_processes = [] for llm_group in llm_groups: assert llm_group llm_idxs = [llm[0] for llm in llm_group] llms = [llm[1] for llm in llm_group] scheduler_name = ( - f"{'train' if is_training else 'test'} scheduler for llms {','.join([str(i) for i in llm_idxs])}" + f"{'train' if self.is_training else 'test'} scheduler for llms {','.join([str(i) for i in llm_idxs])}" ) process = mp.Process( target=rollout_maker_entrypoint, @@ -349,15 +375,15 @@ def init_stats(self): self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) - + def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} - - metrics['overflow'] = all([not training_text.finished for training_text in result.training_texts ]) - metrics['num_turns'] = len(result.training_texts) - metrics['prompt_tokens'] = [training_text.prompt_tokens for training_text in result.training_texts] - metrics['output_tokens'] = [training_text.output_tokens for training_text in result.training_texts] - + + metrics["overflow"] = all([not training_text.finished for training_text in result.training_texts]) + metrics["num_turns"] = len(result.training_texts) + metrics["prompt_tokens"] = [training_text.prompt_tokens for training_text in result.training_texts] + metrics["output_tokens"] = [training_text.output_tokens for training_text in result.training_texts] + return metrics def update_stats(self, rollout_results: List[RolloutResult]): @@ -368,8 +394,10 @@ def update_stats(self, rollout_results: List[RolloutResult]): group_id = result.group_id self.latency_list.append(result.latency) self.model_versions_list.append(result.model_version) - domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result) + domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result) all_metrics = result.metrics.model_dump() | domain_agnostic_metrics + all_metrics["used_python"] = int(all_metrics.get("used_python", False)) + all_metrics["used_math_answer"] = int(all_metrics.get("used_math_answer", False)) for k, v in all_metrics.items(): if isinstance(v, list): self.stats[k][dataset_name][group_id] += v @@ -377,16 +405,18 @@ def update_stats(self, rollout_results: List[RolloutResult]): self.stats[k][dataset_name][group_id].append(v) else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") - - prompt_length_tokens = [training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts] - output_length_tokens = [training_text.output_tokens for result in rollout_results for training_text in result.training_texts] + + prompt_length_tokens = [ + training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts + ] + output_length_tokens = [ + training_text.output_tokens for result in rollout_results for training_text in result.training_texts + ] self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) sliding_window_stats = self.sliding_aggregator.get_stats() if sliding_window_stats is not None: for k, v in sliding_window_stats.items(): self.sliding_stats[k].append(v) - - def run(self, dataset: list[tuple[str, dict]]): loop_start_time = time.time() @@ -438,6 +468,7 @@ def run(self, dataset: list[tuple[str, dict]]): can_submit_before_update = math.inf logger.info(f"Start {'train' if self.is_training else 'test'} actor loop") + rollouts_last_minute = [] with ( write_to_streams(self.data_stream, "a") as data_stream_writer, write_to_streams(self.stats_stream, "a") as stats_writer, @@ -447,8 +478,13 @@ def run(self, dataset: list[tuple[str, dict]]): yield final_steps = calculate_train_steps(self.cfg.finetune, self.cfg.finetune.interrupt_train_steps) - samples_target = final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes - if self.trainer_state.samples_processed is not None and self.trainer_state.samples_processed >= samples_target: + samples_target = ( + final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes + ) + if ( + self.trainer_state.samples_processed is not None + and self.trainer_state.samples_processed >= samples_target + ): logger.info("Trainer signalled completion; stopping actor loop") break @@ -464,23 +500,22 @@ def run(self, dataset: list[tuple[str, dict]]): if not self.is_scheduling_paused: while True: blocked_by_lag = submitted_groups == can_submit_before_update and self.is_training - if not blocked_by_lag and not self.problem_queue.full(): + if not blocked_by_lag and self.have_capacity(): try: try: problem = next(problem_iter) - self.problem_queue.put(problem, block=False) + self.submit_problem(problem) submitted_groups += 1 - except queue.Full: + except queue.Full: assert False, "Problem queue was not full just a moment ago, but now it is full" except StopIteration: break - else: - break + break # Second, try return a result try: # Directly get the result from the SharedMemoryQueue - rollout_results = self.result_queue.get(block=False) + rollout_results = self.get_new_results() except queue.Empty: continue @@ -489,16 +524,17 @@ def run(self, dataset: list[tuple[str, dict]]): raise rollout_results assert isinstance(rollout_results, list) + if len(rollout_results) == 0: + continue assert isinstance(rollout_results[0], RolloutResult) - assert len(rollout_results) == attempts, ( - f"Expected {attempts} rollouts, got {len(rollout_results)}" - ) + assert len(rollout_results) == attempts, f"Expected {attempts} rollouts, got {len(rollout_results)}" group_samples = sum(len(r.training_texts) for r in rollout_results) published_samples += group_samples - samples_in_queue = self.result_queue.qsize() * attempts + samples_in_queue = self.results_ready_to_publish() all_text_dumps = [] for r in rollout_results: + rollouts_last_minute.append(time.perf_counter()) for text in r.training_texts: all_text_dumps.append(text.model_dump()) data_stream_writer.write(all_text_dumps) @@ -512,37 +548,43 @@ def run(self, dataset: list[tuple[str, dict]]): self.update_stats(rollout_results=rollout_results) finished_groups += 1 + logger.info( + f"Finished {'train' if self.is_training else 'test'} groups {finished_groups} out of {expected_rollouts}" + ) time_to_publish_train_stats = ( - self.is_training - and trainer_version_to_publish is not None - ) or self.debug_mode + self.is_training and trainer_version_to_publish is not None + ) or self.debug_mode time_to_publish_test_stats = finished_groups == expected_rollouts + # leave only the rollouts that are in the last minute + rollouts_last_minute = [t for t in rollouts_last_minute if t > time.perf_counter() - 60] + # Publish stats at every new model version or if all tapes are finished if time_to_publish_train_stats or time_to_publish_test_stats: if self.is_training: loop_stats = { "published_samples": published_samples, - "problem_queue_size": self.problem_queue.qsize(), - "result_queue_size": self.result_queue.qsize(), + "problem_queue_size": self.problem_queue_size(), + "result_queue_size": self.result_queue_size(), "finished_groups": finished_groups, - "trainer_model_version": trainer_version_to_publish, + "trainer_model_version": trainer_version_to_publish, "time_since_start": time.time() - loop_start_time, + "groups_in_progress": in_progress, + "rollout_errors": self.rollout_errors, + "rollouts_per_min": len(rollouts_last_minute), } trainer_version_to_publish = None else: - loop_stats = { - "trainer_model_version": last_trainer_version - } + loop_stats = {"trainer_model_version": last_trainer_version} self.publish_stats( stats_writer=stats_writer, loop_stats=loop_stats, ) - - if finished_groups == expected_rollouts: + if expected_rollouts >= 0 and finished_groups + self.rollout_errors >= expected_rollouts: logger.info(f"Finished {expected_rollouts} rollouts, stopping actor loop") + self.stop_tasks() break def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): @@ -558,31 +600,269 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats[f"{dataset_name}/{metric_name}_{agg}"] = sub_stats stats |= ( - { - f"{split_name}{k}": v - for k, v in always_or_never_success_stats(self.stats["success"]).items() - } - | { - f"{split_name}latency_" + k: v - for k, v in calculate_stats(self.latency_list).items() - } - | { - f"{split_name}model_version_" + k: v - for k, v in calculate_stats(self.model_versions_list).items() - } + {f"{split_name}{k}": v for k, v in always_or_never_success_stats(self.stats["success"]).items()} + | {f"{split_name}latency_" + k: v for k, v in calculate_stats(self.latency_list).items()} + | {f"{split_name}model_version_" + k: v for k, v in calculate_stats(self.model_versions_list).items()} ) stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 + + rename_suffixes = { + "num_python_calls_mean": "python_calls_mean", + "used_python_mean": "python_usage_rate", + "num_math_answer_calls_mean": "math_answer_calls_mean", + "used_math_answer_mean": "math_answer_usage_rate", + } + + for key in list(stats.keys()): + for old_suffix, new_suffix in rename_suffixes.items(): + if key.endswith(old_suffix): + prefix = key[: -len(old_suffix)] + stats[f"{prefix}{new_suffix}"] = stats[key] + break + if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) self.init_stats() # Reset stats for the next iteration + def have_capacity(self) -> bool: + return not self.problem_queue.full() + + def submit_problem(self, problem: dict): + self.problem_queue.put(problem, block=False) + + def stop_tasks(self): + pass + + def get_new_results(self) -> list[RolloutResult]: + return self.result_queue.get(block=False) + + def results_ready_to_publish(self) -> int: + return self.result_queue_size() * self.cfg.attempts + + def problem_queue_size(self) -> int: + return self.problem_queue.qsize() + + def result_queue_size(self) -> int: + return self.result_queue.qsize() + + +class ActorLoopRay(ActorLoop): + """ + Loop that runs the ray tasks for n_jobs to perform rollouts in parallel + """ + + def __init__(self, cfg: DictConfig, *args, **kwargs): + assert cfg.attempts % cfg.actor.async_batch_size == 0, ( + f"attempts {cfg.attempts} must be divisible by actor.async_batch_size {cfg.actor.async_batch_size}" + ) + super().__init__(cfg, *args, **kwargs) + self.cfg_dict: dict = OmegaConf.to_container(self.cfg, resolve=True) # type: ignore + self.unfinished_tasks = [] + self.llms_by_url = {llm.get_base_url(): llm for llm in self.llms} + self.llms_utilization = {llm.get_base_url(): 0 for llm in self.llms} + self.scheduler_name = f"{'train' if self.is_training else 'test'} ray scheduler" + self.problem_id = 0 + self.attempts = self.cfg.attempts if self.is_training else 1 + self.unfinished_groups = defaultdict(list) # up to `attempts` rollout results for each problem + self.finished_groups = [] + self.token_count = 0 + self.finished_rollouts_count = 0 + self.log_dir = Path(self.cfg.output_dir) / "actor" / "ray" + + def start_backend(self): + logger.info(f"Initializing Ray with {self.cfg.actor.rollout_workers} workers..") + self.log_dir.mkdir(parents=True, exist_ok=True) + ray_context = ray.init( + num_cpus=self.cfg.actor.rollout_workers, + dashboard_host="0.0.0.0", + include_dashboard=True, + log_to_driver=True, + ignore_reinit_error=True, + ) + logger.info(f"Ray initialized, dashboard at {ray_context.dashboard_url}") + + assert self.trainer_state.propagated_weight_version is not None + rollout_policy = hydra.utils.get_method(self.cfg.actor.rollout_policy) + + def rollout_wrapper(cfg_dict: dict, llm: TrainableLLM, problems: list[dict]) -> list[RolloutResult]: + assert len(problems) == 1, "Sync mode should only be used with 1 problem at a time" + cfg = OmegaConf.create(cfg_dict) + problem = problems[0] + group_id = problem["_group_id"] + attempt = problem["_attempt"] + log_file = Path(cfg.output_dir) / "actor" / "ray" / f"{group_id}.log" + sys.stdout = open(log_file, "a", buffering=1) + sys.stderr = sys.stdout + logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True) + logger.info(f"Running sync rollout for task {group_id}_{attempt}") + start_ts = time.perf_counter() + rollout_result: RolloutResult = rollout_policy(cfg, llm, problem) + rollout_result.latency = time.perf_counter() - start_ts + rollout_result.llm_url = llm.get_base_url() + rollout_result.group_id = group_id + rollout_result.attempt = attempt + logger.info(f"Task {group_id}_{attempt} finished in {rollout_result.latency:.2f} seconds") + return [rollout_result] + + async def run_rollouts_with_session( + cfg: DictConfig, llm: TrainableLLM, problems: list[dict] + ) -> list[RolloutResult]: + connector = aiohttp.TCPConnector( + limit=cfg.actor.async_batch_size, limit_per_host=cfg.actor.async_batch_size, keepalive_timeout=1.0 + ) + timeout = aiohttp.ClientTimeout(total=3600.0, connect=3600.0, sock_read=3600.0) + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + # Run all rollouts in parallel using asyncio.gather + async def run_rollout(problem) -> RolloutResult: + group_id = problem["_group_id"] + attempt = problem["_attempt"] + logger.info(f"Running async rollout loop for task {group_id}_{attempt}") + start_ts = time.perf_counter() + rollout_result = await rollout_policy(cfg, llm, problem, session) + rollout_result.latency = time.perf_counter() - start_ts + rollout_result.llm_url = llm.get_base_url() + rollout_result.group_id = group_id + rollout_result.attempt = attempt + logger.info(f"Task {group_id}_{attempt} finished in {rollout_result.latency:.2f} seconds") + return rollout_result + + tasks = [run_rollout(problem) for problem in problems] + rollout_results = await asyncio.gather(*tasks) + return rollout_results + + def rollout_async_batch_wrapper(cfg_dict: dict, llm: TrainableLLM, problems: list[dict]) -> list[RolloutResult]: + cfg = OmegaConf.create(cfg_dict) + group_id = problems[0]["_group_id"] + log_file = Path(cfg.output_dir) / "actor" / "ray" / f"{group_id}_async_{len(problems)}.log" + sys.stdout = open(log_file, "a", buffering=1) + sys.stderr = sys.stdout + logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True) + logger.info(f"Running async rollouts for group {group_id} with {len(problems)} problems") + rollout_results = asyncio.run(run_rollouts_with_session(cfg, llm, problems)) + return rollout_results + + if self.cfg.actor.async_batch_size > 1: + logger.info("Using async mode") + self.ray_remote = ray.remote()(rollout_async_batch_wrapper) + else: + logger.info("Using sync mode") + self.ray_remote = ray.remote()(rollout_wrapper) + self.start_time = time.time() + + def have_capacity(self) -> bool: + have_capacity = len(self.unfinished_tasks) < self.cfg.actor.problem_queue_size + have_llm_capacity = any( + self.llms_utilization[llm_url] < (self.cfg.actor.llm_max_rollouts - self.attempts) + for llm_url in self.llms_utilization + ) + have_capacity = have_capacity and have_llm_capacity + return have_capacity + + def submit_problem(self, problem: dict): + # Make a list of cfg.attempts identical problems (deepcopies can be used if necessary) + problems = [] + for attempt in range(self.attempts): + p = problem.copy() + p["_group_id"] = f"{self.scheduler_name}_{self.problem_id}" + p["_attempt"] = attempt + problems.append(p) + + # Split problems into batches of up to cfg.async_batch_size + batches = [ + problems[i : i + self.cfg.actor.async_batch_size] + for i in range(0, len(problems), self.cfg.actor.async_batch_size) + ] + for batch_idx, problem_batch in enumerate(batches): + llm_url, task_count = min(self.llms_utilization.items(), key=lambda x: x[1]) + logger.info( + f"Submitting problem {self.problem_id} batch {batch_idx + 1}/{len(batches)} to the least busy LLM {llm_url} with {task_count} tasks" + ) + llm = self.llms_by_url[llm_url] + task_ref = self.ray_remote.remote(self.cfg_dict, llm, problem_batch) + time.sleep(self.cfg.actor.task_submission_delay_sec) + self.llms_utilization[llm_url] += len(problem_batch) + self.unfinished_tasks.append(task_ref) + self.problem_id += 1 + + def stop_tasks(self): + ray.shutdown() + + def receive_finished_tasks(self): + num_returns = min(100, len(self.unfinished_tasks)) # query up to 100 tasks at a time + try: + finished_tasks, unfinished_tasks = ray.wait(self.unfinished_tasks, num_returns=num_returns, timeout=0.1) + except Exception as e: + logger.error(f"Error waiting for finished ray tasks: {e}") + return + if len(finished_tasks) > 0: + logger.info(f"Found {len(finished_tasks)} finished tasks, {len(unfinished_tasks)} unfinished tasks left") + self.unfinished_tasks = unfinished_tasks + dt = time.time() - self.start_time + rollout_results: list[RolloutResult] = [] + for finished_task in finished_tasks: + try: + rollout_results += ray.get(finished_task) + except Exception as e: + logger.error(f"Error getting finished ray task: {e}") + self.rollout_errors += self.cfg.actor.async_batch_size + continue + logger.info(f"Received {len(rollout_results)} rollout results from {len(finished_tasks)} finished tasks") + for rollout_result in rollout_results: + rollout_result.model_version = self.trainer_state.propagated_weight_version + rollout_index = len(self.unfinished_groups[rollout_result.group_id]) + for step_index, tr_text in enumerate(rollout_result.training_texts): + # Downstream in the pipeline we'll need these fields in every sample + tr_text.metadata["model_version"] = rollout_result.model_version + tr_text.metadata["rollout_index"] = rollout_index + tr_text.metadata["step_index"] = step_index + tr_text.group_id = rollout_result.group_id + if self.llms_utilization[rollout_result.llm_url] > 0: + self.llms_utilization[rollout_result.llm_url] -= 1 + else: + logger.warning(f"LLM {rollout_result.llm_url} utilization is 0, but got a result") # should not happen + self.token_count += get_number_of_tokens_in_result(rollout_result) + self.finished_rollouts_count += 1 + self.unfinished_groups[rollout_result.group_id].append(rollout_result) + + if len(self.unfinished_groups[rollout_result.group_id]) == self.attempts: + logger.info(f"Problem {rollout_result.group_id} group finished") + group = self.unfinished_groups[rollout_result.group_id] + random.shuffle(group) + self.finished_groups.append(group) + del self.unfinished_groups[rollout_result.group_id] + logger.info(f"{len(self.finished_groups)} finished groups ready to return") + logger.info( + f"Ray {'train' if self.is_training else 'test'} actor loop: " + f"rollouts in progress: {len(self.unfinished_tasks)}, " + f"groups in progress: {len(self.unfinished_groups)}, " + f"rollouts finished: {self.finished_rollouts_count}, " + f"total tokens: {self.token_count}, " + f"gen speed: {self.token_count / dt:.2f} tokens/sec, " + f"time elapsed: {dt:.2f} sec,\n" + f"LLMs utilization: {self.llms_utilization}" + ) + + def get_new_results(self) -> list[list[RolloutResult]]: + self.receive_finished_tasks() + if len(self.finished_groups) > 0: + logger.info(f"have {len(self.finished_groups)} finished problems, pop one") + return self.finished_groups.pop(0) + return [] + + def problem_queue_size(self) -> int: + return len(self.unfinished_tasks) + + def result_queue_size(self) -> int: + return len(self.finished_groups) + def run_actor_loop(cfg: DictConfig): set_streams_backend(**cfg.streams) + actor_loop_class = ActorLoopRay if cfg.use_ray else ActorLoop # set seed for reproducibility (mostly intended for dataset loading) random.seed(cfg.seed) @@ -603,7 +883,7 @@ def run_actor_loop(cfg: DictConfig): dataset_loader = hydra.utils.get_method(cfg.dataset_loader) # Get dataset loader parameters if they exist in config, otherwise use empty dict - dataset_loader_params = cfg.get('dataset_loader_params', {}) + dataset_loader_params = cfg.get("dataset_loader_params", {}) # Use **dataset_loader_params to pass parameters only if they exist train_dataset = dataset_loader(cfg.train_dataset_names, **dataset_loader_params) test_dataset = dataset_loader(cfg.test_dataset_names, **dataset_loader_params) @@ -617,12 +897,19 @@ def run_actor_loop(cfg: DictConfig): actor_model_path = finetune_model_path else: actor_model_path = cfg.model_path - + + # Align client-side context size with vLLM server max_model_len when available + try: + _context_size = int(cfg.vllm_config.vllm_kwargs.max_model_len) + except Exception: + _context_size = 32000 + train_llms = [ TrainableLLM( base_url=url, model_name=str(actor_model_path), tokenizer_name=str(actor_model_path), + context_size=_context_size, parameters=cfg.llm.parameters, collect_logprobs=True, ) @@ -633,6 +920,7 @@ def run_actor_loop(cfg: DictConfig): base_url=url, model_name=str(actor_model_path), tokenizer_name=str(actor_model_path), + context_size=_context_size, parameters=cfg.test_llm.parameters, collect_logprobs=True, ) @@ -648,13 +936,12 @@ def run_actor_loop(cfg: DictConfig): trainer_state.start_listening() trainer_state.wait_for_model_version() - train_loop = ActorLoop( + train_loop = actor_loop_class( data_stream=data_stream, cfg=cfg, trainer_state=trainer_state, stats_stream=stats_stream, llms=train_llms ) - train_loop_run = train_loop.run( - dataset=train_dataset, - ) - test_loop = ActorLoop( + train_loop.start_backend() + train_loop_run = train_loop.run(dataset=train_dataset) + test_loop = actor_loop_class( data_stream=test_data_stream, cfg=cfg, trainer_state=trainer_state, @@ -683,9 +970,8 @@ def run_actor_loop(cfg: DictConfig): and test_loop_run is None ): logger.info("Create test loop") - test_loop_run = test_loop.run( - dataset=test_dataset, - ) + test_loop.start_backend() + test_loop_run = test_loop.run(dataset=test_dataset) train_loop.is_scheduling_paused = True current_eval = next_regular_eval diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 4e78ebf9..122625ae 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -4,16 +4,20 @@ import aiohttp import numpy as np +from omegaconf import DictConfig, ListConfig, OmegaConf from PIL import Image -from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM -from pipelinerl.finetune.data import MASKED_TOKEN_ID -from pipelinerl.rollouts import TrainingText +from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM from pipelinerl.processor_factory import get_processor -from omegaconf import DictConfig, ListConfig, OmegaConf +from pipelinerl.rollouts import TrainingText logger = logging.getLogger(__name__) +# -100 is the default "ignore_index" in nn.CrossEntropyLoss +# Defined here to avoid importing dependencies from finetune.data +# Do not replace. Import from finetune module breaks ray parallelization! +MASKED_TOKEN_ID = -100 + def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]: """Extract PIL Images from multimodal messages.""" diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 62758c7e..41a61021 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -1,18 +1,17 @@ -import random import time +import random import aiohttp from omegaconf import DictConfig from pydantic import BaseModel - -from pipelinerl.async_llm import llm_async_generate, make_training_text -from pipelinerl.llm import Prompt, TrainableLLM -from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.rollouts import RolloutResult, BaseMetrics from pipelinerl.world import Job +from tapeagents.core import Prompt +from tapeagents.llms.trainable import TrainableLLM +from pipelinerl.async_llm import llm_async_generate, make_training_text from .verifier_api import verify_answer_rpc - class Metrics(BaseMetrics): penalty: float diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py index ea850814..9392efe6 100644 --- a/pipelinerl/domains/miniwob/rollouts.py +++ b/pipelinerl/domains/miniwob/rollouts.py @@ -1,51 +1,59 @@ import asyncio -import json import logging import os import random import time import traceback +from typing import Literal import aiohttp -from examples.rl_webagent.steps import WebTape from hydra.utils import instantiate from omegaconf import DictConfig from tapeagents.agent import DEFAULT, Agent -from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, Observation +from tapeagents.core import LLMOutputParsingFailureAction, Observation from tapeagents.io import save_json_tape -from tapeagents.llms.trainable import TrainableLLM -from tapeagents.orchestrator import async_execute_agent +from tapeagents.orchestrator import async_execute_agent, execute_agent from tapeagents.remote_environment import AsyncRemoteEnvironment from tapeagents.tools.simple_browser import PageObservation from pipelinerl.async_llm import make_training_text -from pipelinerl.llm import LLMCall, TrainableLLM +from pipelinerl.domains.miniwob.environment import WebEnvironment +from pipelinerl.domains.miniwob.steps import WebTape +from pipelinerl.llm import LLMCall, TrainableLLM, TrainingText from pipelinerl.rollouts import BaseMetrics, RolloutResult from pipelinerl.world import Job -from .steps import WebTape - logger = logging.getLogger(__name__) +def task_id(problem: dict) -> str: + """Format task identifier for logging.""" + return f"{problem['dataset']}/{problem['task']}/{problem['seed']}" + + class MiniwobMetrics(BaseMetrics): - reward: float - success: bool - no_error: bool - no_answer: bool - overflow: bool - n_llm_calls: int - n_step_errors: int - n_page_observations: int - n_steps: int - total_execution_time: float - agent_execution_time: float - environment_execution_time: float - env_step_time: float - agent_step_time: float - - -def tape_contains_an_error(tape: WebTape) -> bool: + reward: float = -1.0 + success: bool = False + has_error: bool = False + no_answer: bool = True + overflow: bool = False + n_llm_calls: int = 0 + n_step_errors: int = 0 + n_observations: int = 0 + n_steps: int = 0 + env_creation_time: float = 0.0 + agent_creation_time: float = 0.0 + env_start_time: float = 0.0 + env_close_time: float = 0.0 + agent_execution_time: float = 0.0 + total_execution_time: float = 0.0 + llm_call_time: float = 0.0 + env_step_time: float = 0.0 + total_llm_call_time: float = 0.0 + total_env_call_time: float = 0.0 + + +def _tape_contains_an_error(tape: WebTape) -> bool: """ Returns true if the tape ends with an error, ie if one of the following is true: - the last step is an LLMOutputParsingFailureAction @@ -56,94 +64,294 @@ def tape_contains_an_error(tape: WebTape) -> bool: len(tape.steps) == 0 or isinstance(tape.steps[-1], LLMOutputParsingFailureAction) or tape.metadata.result.get("error") is not None - or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error) + or (isinstance(tape.steps[-1], PageObservation) and bool(tape.steps[-1].error)) ) +def _compute_reward( + tape: WebTape, + reward_computation: Literal["uic", "default"] = "default", + has_error: bool = False, +) -> tuple[float, bool]: + """ + Compute reward from tape. + + Args: + tape: The execution tape + cfg: Configuration with reward_computation setting + has_error: If there were errors during execution + + Returns: + tuple of (reward, has_error) + """ + # Extract raw reward from last observation + obs_steps = [step for step in tape if isinstance(step, Observation)] + if obs_steps: + last_obs = obs_steps[-1] + raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) + else: + raw_reward = -1.0 + + # Count errors and page observations + n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) + n_observations = len([step for step in tape.steps if isinstance(step, Observation)]) + + # Determine if tape has errors + has_error = has_error or _tape_contains_an_error(tape) + + # Compute final reward based on configuration + if reward_computation == "uic": + reward = float(raw_reward > 0) + if reward == 0.0: + reward = -1.0 + reward *= 0.98**n_observations + else: + reward = raw_reward * 0.99**n_step_errors if not has_error and raw_reward >= 0 else -1.0 + + return reward, has_error + + +def _extract_llm_calls(tape: WebTape) -> list[LLMCall]: + """Extract LLM calls from tape steps.""" + return [ + LLMCall(**step.metadata.other["llm_call"]) + if isinstance(step.metadata.other["llm_call"], dict) + else step.metadata.other["llm_call"] + for step in tape.steps + if "llm_call" in step.metadata.other + ] + + +def _compute_metrics( + tape: WebTape, + training_texts: list[TrainingText], + reward: float, + has_error: bool, + n_llm_calls: int, +) -> MiniwobMetrics: + # Create training texts + has_overflow = False + for text in training_texts: + text.reward = reward + has_overflow |= not text.finished + + # Extract timing information + llm_call_times = [float(step.metadata.other.get("llm_call_time", 0.0)) for step in tape.steps] + env_call_times = [float(step.metadata.other.get("action_execution_time", 0.0)) for step in tape.steps] + total_llm_call_time = sum(llm_call_times) + total_env_call_time = sum(env_call_times) + llm_call_time = total_llm_call_time / len(llm_call_times) if llm_call_times else -1.0 + env_step_time = total_env_call_time / len(env_call_times) if env_call_times else -1.0 + env_start_time = tape.metadata.result.get("env_start_time", -1.0) + env_close_time = tape.metadata.result.get("env_close_time", -1.0) + env_creation_time = tape.metadata.result.get("env_creation_time", -1) + agent_creation_time = tape.metadata.result.get("agent_creation_time", -1) + agent_execution_time = tape.metadata.result.get("agent_execution_time", -1.0) + total_execution_time = tape.metadata.result.get("total_execution_time", -1.0) + + # Compute step counts + n_observations = len([s for s in tape.steps if isinstance(s, Observation)]) + n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) + + metrics = MiniwobMetrics( + reward=reward, + success=reward > 0.5, + has_error=has_error, + no_answer=reward < 0, + overflow=has_overflow, + n_llm_calls=n_llm_calls, + n_step_errors=n_step_errors, + n_steps=len(tape.steps), + n_observations=n_observations, + + env_creation_time=env_creation_time, + env_start_time=env_start_time, + env_close_time=env_close_time, + + agent_creation_time=agent_creation_time, + agent_execution_time=agent_execution_time, + + llm_call_time=llm_call_time, + env_step_time=env_step_time, + total_llm_call_time=total_llm_call_time, + total_env_call_time=total_env_call_time, + total_execution_time=total_execution_time, + ) + return metrics + + async def check_env_server_health(env_job: Job, session: aiohttp.ClientSession) -> dict: """Check environment server health via HTTP API.""" try: url = f"http://{env_job.hostname}:{env_job.port}/health" - async with session.get(url, timeout=5) as response: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=5.0)) as response: if response.status == 200: health_data = await response.json() - return { - "healthy": True, - "health_data": health_data, - "last_check": time.time() - } + return {"healthy": True, "health_data": health_data, "last_check": time.time()} else: error_text = await response.text() - return {"healthy": False, "error_message": f"HTTP {response.status}: {error_text}", "last_check": time.time()} + health_data = f"HTTP {response.status}: {error_text}" + return {"healthy": False, "health_data": health_data, "last_check": time.time()} except Exception as e: exception_type = type(e).__name__ exception_message = str(e) if str(e) else "No message available" - logger.exception(f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True) - return {"healthy": False, "error_message": f"Exception: {exception_type}: {exception_message}", "last_check": time.time(), "error_stacktrace": traceback.format_exc()} + logger.exception( + f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True + ) + return { + "healthy": False, + "health_data": f"Exception: {exception_type}: {exception_message}", + "last_check": time.time(), + "error_stacktrace": traceback.format_exc(), + } + + +def generate_miniwob_rollout(cfg: DictConfig, llm: TrainableLLM, problem: dict) -> RolloutResult: + """ + Generate a MiniWoB rollout. Steps: + - make agent and env + - set the llm + - run the agent + - get llm calls from tape + - compute rewards + - get training text from llm calls + + Args: + cfg: Configuration for the rollout + llm: The LLM to use + problem: The problem dict + Returns: + RolloutResult with training texts and metrics + """ + tid = task_id(problem) + start_time = time.perf_counter() + environment: WebEnvironment = instantiate(cfg.environment) + environment.initialize() + env_creation_time = time.perf_counter() - start_time + logger.info(f"Environment tools: {environment.tools_description()}") + t = time.perf_counter() + agent: Agent = instantiate( + cfg.agent, + known_actions=environment.actions(), + tools_description=environment.tools_description(), + llms={DEFAULT: llm}, + ) + logger.info(f"Agent and environment loaded, using llm {llm.model_name} at {llm.get_base_url()}") + agent_creation_time = time.perf_counter() - t + try: + start_attempts = cfg.start_attempts + t = time.perf_counter() + while True: + try: + tape, _ = environment.start_task(problem) + break + except Exception as e: + logger.exception(f"Failed to start task {tid}: {e}") + start_attempts -= 1 + if start_attempts <= 0: + raise Exception(f"Failed to start task {tid} after {cfg.start_attempts} attempts") + else: + logger.warning("Retrying after 1 second") + time.sleep(1) + env_start_time = time.perf_counter() - t + logger.info(f"Task {tid} started in {env_start_time:.2f}s") + t = time.perf_counter() + tape = execute_agent(agent, tape, environment, max_loops=cfg.agent_max_loops) + agent_execution_time = time.perf_counter() - t + finally: + t = time.perf_counter() + environment.close() + env_close_time = time.perf_counter() - t + total_execution_time = time.perf_counter() - start_time + logger.info(f"Task {tid} finished in {total_execution_time:.2f}s") + tape.metadata.result.update( + { + "total_execution_time": total_execution_time, + "env_creation_time": env_creation_time, + "env_start_time": env_start_time, + "env_close_time": env_close_time, + "agent_creation_time": agent_creation_time, + "agent_execution_time": agent_execution_time, + } + ) + # save the tape as we go + if cfg.save_tapes: + _save_tapes(cfg, problem, tape) -async def generate_miniwob_rollout( + # Compute reward and metrics + reward, has_error = _compute_reward(tape, cfg.reward_computation) + llm_calls = _extract_llm_calls(tape) + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + metrics = _compute_metrics( + tape, + training_texts, + reward, + has_error, + len(llm_calls), + ) + latency = time.perf_counter() - start_time + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + ) + +def _save_tapes(cfg, problem, tape): + tape_name = problem.get("_task_id", tape.metadata.id) + try: + save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape_name) + except Exception as e: + logger.error(f"Error saving tape {tape_name}: {e}") + + +async def generate_miniwob_rollout_async( cfg: DictConfig, llm: TrainableLLM, problem: dict, session: aiohttp.ClientSession, ) -> RolloutResult: - # choose a random environment server - # Generate environment - # Generate TapeAgent - # run the agent - # get llm calls from tape - # compute rewards - # get training text from llm calls - - start_time = time.time() - - # Overall timeout for the entire rollout to prevent hanging - rollout_timeout = getattr(cfg, 'rollout_timeout', 600) # 10 minutes default + start_time = time.perf_counter() + tid = task_id(problem) + rollout_timeout = getattr(cfg, "rollout_timeout", 600) # 10 minutes default env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] env_jobs_url_tried = [] - # Try each environment server with health checks until one of them returns a rollout result for _ in range(len(env_jobs)): - # Choose the next environment server to try randomly from the ones that have not been tried yet - env_job = random.choice([job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried]) + env_job = random.choice( + [job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried] + ) env_job_url = f"http://{env_job.hostname}:{env_job.port}" env_jobs_url_tried.append(env_job_url) - # Check server health before using health = await check_env_server_health(env_job, session) if not health["healthy"]: - logger.warning(f"Environment server {env_job_url} is unhealthy: {health}") - logger.warning(f"Get health error stacktrace: {health['error_stacktrace']}") + logger.warning(f"Env server {env_job_url} unhealthy: {health.get('health_data', 'unknown')}, skip to next one") continue - # Log health status for monitoring - if health["healthy"]: - logger.info(f"Using healthy environment server {env_job_url}: {health}") + logger.debug(f"Using env server {env_job_url}") try: - # Execute the entire rollout with a timeout return await asyncio.wait_for( _execute_rollout_with_timeout(cfg, llm, problem, session, start_time, env_job_url), - timeout=rollout_timeout + timeout=rollout_timeout, ) except asyncio.TimeoutError: - health = await check_env_server_health(env_job, session) - if stack_trace := health.get("error_stacktrace"): - logger.warning(f"Get health error stacktrace: {stack_trace}") - logger.warning(f"Rollout timeout error stacktrace: {traceback.format_exc()}") - logger.warning(f"Rollout timed out after {rollout_timeout} seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + logger.warning(f"Task {tid} timed out after {rollout_timeout}s on {env_job_url}") continue except Exception as e: - health = await check_env_server_health(env_job, session) - if stack_trace := health.get("error_stacktrace"): - logger.warning(f"Get health error stacktrace: {stack_trace}") - logger.warning(f"Rollout failed error stacktrace: {traceback.format_exc()}") - logger.warning(f"Rollout failed for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + logger.warning(f"Task {tid} failed on {env_job_url}: {e}") continue - # If all servers failed - logger.error(f"All environment servers failed for task {problem['dataset']}/{problem['task']}/{problem['seed']}. Returning a failed rollout result.") - return _create_failed_rollout_result(problem, start_time, "all environment servers failed") + + logger.error(f"Task {tid}: all environment servers failed") + # Return a failed rollout result + return RolloutResult( + training_texts=[], + metrics=MiniwobMetrics(), + latency=time.perf_counter() - start_time, + dataset_name=problem["dataset"], + ) async def _execute_rollout_with_timeout( @@ -154,190 +362,106 @@ async def _execute_rollout_with_timeout( start_time: float, env_job_url: str, ) -> RolloutResult: - # (2) Generate environment, TapeAgent, and run them to get a Tape - no_error = True # track if there was an error in the tape + tid = task_id(problem) + has_error = False + start_time = time.perf_counter() environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore async with environment.acontext(session, wait_for_env=True) as env: + env_creation_time = time.perf_counter() - start_time + agent_creation_time = 0.0 start_attempts = cfg.start_attempts t = time.perf_counter() + tape_dict = {} while start_attempts > 0: try: tape_dict, info = await env.start_task(problem) if info.get("error"): - raise ValueError(info['error']) + raise ValueError(info["error"]) break except Exception as e: start_attempts -= 1 - logger.warning(f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']}. {start_attempts} attempts remaining. Error: {e}") + logger.warning(f"Task {tid} start failed, {start_attempts} attempts left: {e}") if start_attempts <= 0: - logger.error(f"Failed to start task after all retry attempts: {e}") - no_error = False - tape_dict = {} + logger.error(f"Task {tid} start failed after all retries: {e}") + has_error = True break else: - logger.warning("Retry start task after 5 seconds.") await asyncio.sleep(5) - logger.info( - f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds. Worker ID: {env.worker_id}. Tape dict: {tape_dict}" - ) - tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object + env_start_time = time.perf_counter() - t + logger.info(f"Task {tid} started in {env_start_time:.2f}s (worker={env.worker_id})") + tape: WebTape = WebTape(**tape_dict) t = time.perf_counter() - if no_error: # only run the agent if the task started successfully - logger.info(f"Running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}") + agent_execution_time = 0.0 + if not has_error: agent_attempts = cfg.agent_attempts while agent_attempts > 0: - # check if the worker is alive. try: - # this will either raise RuntimeError if worker is not alive anymore, or return a dictionary with the worker status worker_status = await env.check_worker_alive() if worker_status.get("status") == "starting": - logger.warning(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is starting, waiting 5 seconds for it to be fully started.") + logger.debug(f"Task {tid}: worker {env.worker_id} starting, waiting...") await asyncio.sleep(5) continue except Exception as e: - # if worker is dead, no need to retry - logger.exception(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is dead. Error: {e}", stack_info=True) - no_error = False + logger.exception(f"Task {tid}: worker {env.worker_id} dead: {e}") + has_error = True break - # if worker is alive, run the agent + try: + t = time.perf_counter() actions = await env.a_actions() tools_description = await env.a_tools_description() agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) - agent.llms = {DEFAULT: llm} + agent.llms = {DEFAULT: llm} # type: ignore + agent_creation_time = time.perf_counter() - t + t = time.perf_counter() tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) - # Check if the tape has an error from the orchestrator (e.g., SocketTimeoutError, RuntimeError: Worker is not alive, etc.) + agent_execution_time = time.perf_counter() - t if tape.metadata.error: - logger.error(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} returned a tape with error: {tape.metadata.error}") + logger.error(f"Task {tid}: agent error: {tape.metadata.error}") raise ValueError(tape.metadata.error) - else: - # Success - break out of retry loop - logger.info(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} finished successfully") - break + logger.info(f"Task {tid}: agent execution succeeded") + break except Exception as e: agent_attempts -= 1 - logger.warning(f"Error occurred while running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}. {agent_attempts} attempts remaining. Error: {e}") + logger.warning(f"Task {tid}: agent error, {agent_attempts} attempts left: {e}") if agent_attempts <= 0: - logger.error(f"Agent execution failed after all retry attempts for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}: {e}") - no_error = False + logger.error(f"Task {tid}: agent failed after all retries: {e}") + has_error = True break - else: - logger.warning(f"Retry agent execution after 5 seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}.") - await asyncio.sleep(5) - logger.info( - f"Agent finished task {problem['dataset']}/{problem['task']}/{problem['seed']} in {time.perf_counter() - t:.2f} seconds with worker ID: {env.worker_id} and tape ID {tape.metadata.id}" - ) - tape.metadata.result.update({"total_execution_time": time.perf_counter() - t}) - - # save the tape as we go - if cfg.save_tapes: - save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id) - - # (3) Compute rewards - obs_steps = [step for step in tape if isinstance(step, Observation)] - if obs_steps: - last_obs = obs_steps[-1] - # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 - # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L188 - # Let's take directly the RAW_REWARD_GLOBAL from the metadata - # raw_reward = last_obs.metadata.other.get("reward", 0.0) - raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) - else: - raw_reward = -1.0 - - no_error = no_error and not tape_contains_an_error(tape) - # get the number of LLMOutputParsingFailureAction in the tape - n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) - # get the number of PageObservation steps in the tape - n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)]) + await asyncio.sleep(5) - if cfg.reward_computation == "nico": - reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 - elif cfg.reward_computation == "uic": - reward = float(raw_reward>0) - if reward == 0.0: - reward = -1.0 - reward *= 0.98 ** n_page_observations - else: - raise ValueError(f"Invalid reward configuration: {cfg.reward_computation}") + logger.info(f"Task {tid} finished in {time.perf_counter() - t:.2f}s (worker={env.worker_id})") + t = time.perf_counter() + await env.aclose() + env_close_time = time.perf_counter() - t + total_execution_time=time.perf_counter() - start_time + tape.metadata.result.update({ + "total_execution_time": total_execution_time, + "env_creation_time": env_creation_time, + "env_start_time": env_start_time, + "env_close_time": env_close_time, + "agent_creation_time": agent_creation_time, + "agent_execution_time": agent_execution_time, + }) - # (3) Get LLM calls from Tape - llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] - n_llm_calls = len(llm_calls) - llm_calls: list[LLMCall] = [ - LLMCall(**step.metadata.other["llm_call"]) if isinstance(step.metadata.other["llm_call"], dict) - else step.metadata.other["llm_call"] - for step in llm_calls - ] + if cfg.save_tapes: + _save_tapes(cfg, problem, tape) - # (4) # For each LLM interaction in the tape, make a training example. - all_finished = 1 - prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] - output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] + # Compute reward and metrics + reward, has_error = _compute_reward(tape, cfg.reward_computation, has_error) + llm_calls = _extract_llm_calls(tape) training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] - for text in training_texts: - text.reward = reward - all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 - - latency = time.time() - start_time - agent_time = tape.metadata.result.get("agent_execution_time", -1.0) - env_time = tape.metadata.result.get("environment_execution_time", -1.0) - n_observations = len([s for s in tape.steps if isinstance(s, Observation)]) # TODO: is this not the same n_page_observations?? - n_other_steps = len(tape.steps) - n_observations - metrics = MiniwobMetrics( - reward=reward, - success=reward > 0.5, - no_error=no_error, - no_answer=reward < 0, - overflow=not all_finished, - n_llm_calls=n_llm_calls, - n_step_errors=n_step_errors, - n_page_observations=n_page_observations, - n_steps=len(tape.steps), - total_execution_time=tape.metadata.result.get("total_execution_time", -1.0), - agent_execution_time=agent_time, - environment_execution_time=env_time, - env_step_time=env_time / n_observations if env_time > 0 and n_observations > 0 else -1.0, - agent_step_time=agent_time / n_other_steps if agent_time > 0 and n_other_steps > 0 else -1.0, + metrics = _compute_metrics( + tape, + training_texts, + reward, + has_error, + len(llm_calls), ) - - return RolloutResult( - training_texts=training_texts, - metrics=metrics, - latency=latency, - dataset_name=problem["dataset"], - prompt_tokens=prompt_tokens, - output_tokens=output_tokens, - ) - - -def _create_failed_rollout_result(problem: dict, start_time: float, error_type: str) -> RolloutResult: - """Create a failed rollout result for timeout or other errors.""" latency = time.time() - start_time - - # Create empty training texts and metrics for failed rollout - metrics = MiniwobMetrics( - reward=-1.0, - success=False, - no_error=False, - no_answer=True, - overflow=False, - n_llm_calls=0, - n_step_errors=0, - n_page_observations=0, - n_steps=0, - total_execution_time=latency, - agent_execution_time=-1.0, - environment_execution_time=-1.0, - env_step_time=-1.0, - agent_step_time=-1.0, - ) - return RolloutResult( - training_texts=[], + training_texts=training_texts, metrics=metrics, latency=latency, dataset_name=problem["dataset"], - prompt_tokens=[], - output_tokens=[], ) diff --git a/pipelinerl/finetune/logging_.py b/pipelinerl/finetune/logging_.py index 0b221e24..0624f765 100644 --- a/pipelinerl/finetune/logging_.py +++ b/pipelinerl/finetune/logging_.py @@ -25,7 +25,7 @@ def setup_logging(cfg: DictConfig, output_dir: Path, run: wandb_run.Run | None = debug_handler = logging.FileHandler(log_dir / f"info_{get_accelerator().process_index}.log") debug_handler.setLevel(logging.INFO) logging.basicConfig( - format="[finetune]: %(asctime)s.%(msecs)03d - %(levelname)s - %(name)s - %(message)s", + format="[finetune]: %(asctime)s.%(msecs)03d - %(levelname)s - %(name)s:%(lineno)d - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[debug_handler, logging.StreamHandler()], diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 9118078e..f2472388 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -2,22 +2,17 @@ import os from functools import partial from typing import Any -from pydantic import BaseModel, Field import numpy as np import pandas as pd import torch import torch.nn.functional as F -from datasets import Dataset -from transformers import PreTrainedModel -from pipelinerl.finetune.types import PipelineBatchEncoding -from pipelinerl.finetune.rl.utils import per_segment_sums +from pydantic import BaseModel, Field +from transformers.modeling_utils import PreTrainedModel -from .utils import ( - sum_sum, - mean_sum, - replace_dataset_column, -) +from pipelinerl.finetune.rl.utils import per_segment_sums +from pipelinerl.finetune.types import PipelineBatchEncoding +from pipelinerl.finetune.rl.utils import sum_sum # FIXME: remove a warnings, but might be worth investigating os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -39,8 +34,7 @@ class RLConfig(BaseModel): policy_loss: str = Field( default="ppo", - description="Policy Loss to use for RL", - choices=["ppo", "reinforce", "gspo"], + description="Policy Loss to use for RL, one of ['ppo', 'reinforce', 'gspo']", ) use_advantages: bool = Field( default=True, @@ -418,7 +412,11 @@ def rl_step( def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: RLConfig) -> list[dict[str, Any]]: - """Populate RL-specific columns (advantages, overflow, num_labels) using a leave-one-out baseline.""" + """ + Populates a dataset with reinforcement learning specific data columns including + rewards, advantages, and token weights. + Uses leave-one-out (LOO) reward mean: each rollout's baseline excludes its own reward. + """ # Convert to pandas for processing df_init = pd.DataFrame(dataset) assert isinstance(df_init, pd.DataFrame) @@ -450,7 +448,7 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R "group_tokens", ] - # Step 2: calculate advantages for each sample + # Step 2: calculate advantages for each sample (with LOO mean) df_advantages = pd.merge( df_init[["group_id", "rollout_index", "step_index", "rewards"]], df_grouped, @@ -458,6 +456,7 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R how="left" ) assert len(df_advantages) == len(df_init) + def calculate_advantages(row): rewards = row["rewards"] group_sum = row["rollout_reward_sum"] @@ -487,7 +486,6 @@ def calculate_advantages(row): # Step 3: bring advantages and group level stats back to the main df df = df_init.drop(columns=["advantages", "group_tokens"]) df = pd.merge(df, df_advantages, on=["group_id", "rollout_index", "step_index"], how="left") - # Debug print lengths of all dataframes assert len(df) == len(df_init) # Step 4: make token-level overflow and mean group length information diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index da39938e..71dac6ea 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -450,6 +450,7 @@ def run_finetuning_loop( logger.info("Load the first version of the model into inference LLMs") weight_update_manager.send_weight_update(training_metrics.samples) else: + logger.info("send_weight_updates disabled, weight_update_manager is None") weight_update_manager = None batch_queue = Queue(maxsize=1) diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 8d2c33d8..110109dd 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -1,6 +1,7 @@ import logging import math import os +import shlex import shutil import subprocess import sys @@ -19,8 +20,6 @@ logger = logging.getLogger(__name__) -# All the launch commands in this file pass the environment to child processes -os.environ["PYTHONPATH"] = f"/home/toolkit/TapeAgents" os.environ["NCCL_CUMEM_ENABLE"] = "0" os.environ["TORCH_DISABLE_SHARE_RDZV_TCP_STORE"] = "1" os.environ["HF_DATASETS_DISABLE_PROGRESS_BARS"] = "1" @@ -178,6 +177,29 @@ def run_actor_llm( str(world_map.weight_update_group_size), ] + # Provide deterministic rendezvous port defaults when env vars are absent. + # vLLM spins up a torch.distributed TCPStore using VLLM_PORT. On the remote + # scheduler we observed replica crashes (store collisions, connection + # refused) because every start script inherited the same default port. By + # exporting VLLM_PORT_BASE/VLLM_PORT_STRIDE we carve out a rendezvous range + # per actor_idx while keeping the public HTTP listener at 8080+local_idx. + env = dict(os.environ) + if "VLLM_PORT_BASE" not in env: + # Each rank gets 1000 ports; 43000 leaves room below. + env["VLLM_PORT_BASE"] = str(43000 + 1000 * world_map.my_rank) + logger.debug( + "Setting default VLLM_PORT_BASE=%s for rank %s", + env["VLLM_PORT_BASE"], world_map.my_rank, + ) + if "VLLM_PORT_STRIDE" not in env: + env["VLLM_PORT_STRIDE"] = "20" + + env_overrides = { + key: str(env[key]) + for key in ("VLLM_PORT_BASE", "VLLM_PORT_STRIDE") + if key in env + } + # Add vLLM kwargs as separate arguments if cfg.vllm_config.vllm_kwargs: for k, v in cfg.vllm_config.vllm_kwargs.items(): @@ -190,13 +212,13 @@ def run_actor_llm( gpu_str = ",".join([str(gpu) for gpu in gpus]) logger.info(f"Running actor_llm with command: {' '.join(cmd)} on gpus: {gpu_str}") - save_command(log_dir, cmd) + save_command(log_dir, cmd, env_overrides or None) log_file_path = os.path.join(log_dir, "stdout.log") err_file_path = os.path.join(log_dir, "stderr.log") with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: proc = _popen( cmd, - env={**os.environ, "CUDA_VISIBLE_DEVICES": gpu_str}, + env={**env, "CUDA_VISIBLE_DEVICES": gpu_str}, stdout=log_file, stderr=err_file, ) @@ -405,14 +427,21 @@ def run_redis(cfg: DictConfig): yield LaunchedProcess(kind="redis", handle=proc) -def save_command(script_dir: Path, cmd): +def save_command(script_dir: Path, cmd, env: dict | None = None): os.makedirs(script_dir, exist_ok=True) script_path = script_dir / "start.sh" with open(script_path, "w") as f: f.write("#!/bin/bash\n") + f.write("set -e\n") + if env: + for key, value in sorted(env.items()): + quoted_value = shlex.quote(value) + f.write(f"export {key}={quoted_value}\n") # Properly quote arguments for the shell script - quoted_cmd = [f"'{arg}'" if " " in arg or "$" in arg else arg for arg in cmd] - f.write(" ".join(quoted_cmd) + "\n") + quoted_cmd = [shlex.quote(arg) for arg in cmd] + f.write("exec ") + f.write(" ".join(quoted_cmd)) + f.write("\n") os.chmod(script_path, 0o755) logger.info(f"Saved start script to {script_path}") @@ -583,7 +612,8 @@ def main(cfg: DictConfig): group = str(exp_dir) root = cfg.wandb.wandb_workspace_root if root: - if not group.startswith(root + "/"): + check_root = (root + "/") if not root.endswith("/") else root + if not group.startswith(check_root): raise ValueError(f"run_dir {exp_dir} does not start with root {root}") cfg.wandb.wandb_group = group[len(root) + 1 :] if world_map.total_finetune_gpus: @@ -641,6 +671,8 @@ def main(cfg: DictConfig): if cfg.debug.mode == "finetune": processes.extend(launch_jobs(cfg, world_map, ["finetune"])) + elif cfg.debug.mode == "llm": + processes.extend(launch_jobs(cfg, world_map, ["actor_llm"])) elif cfg.debug.mode == "actor": processes.extend(launch_jobs(cfg, world_map, ["actor", "environment", "actor_llm"])) elif cfg.debug.mode == "preprocessor": diff --git a/pipelinerl/llm.py b/pipelinerl/llm.py index cc099c15..04dd93b7 100644 --- a/pipelinerl/llm.py +++ b/pipelinerl/llm.py @@ -21,45 +21,12 @@ from pydantic import BaseModel, Field, TypeAdapter from tenacity import retry, stop_after_attempt, wait_exponential +from pipelinerl.rollouts import TrainingText + logger = logging.getLogger(__name__) PIPELINERL_LLM_TOKEN = "PIPELINERL_LLM_TOKEN" -class TrainingText(BaseModel): - """ - Training text instance used to finetune a language model. - - Attributes: - text (str): The full text of the training instance. - n_predicted (int): The number of predicted characters in the text. - reward (float): The reward associated with the training instance. Defaults to 0.0. - logprobs (List[float]): A list of log probabilities of the completion tokens from the assistant model. - ref_logprobs (List[float]): A list of reference log probabilities of the completion tokens from the reference model. - input_ids (List[int]): The tokenized input ids of the text. - labels (List[int]): The tokenized labels of the text (i.e., masked token ids for the prompt and regular token ids for the prediction). - group_id (str, optional): ID of the group. It is used by the RL finetuning script to normalize rewards. - prompt_text (str): Portion of the text that serves as the prompt (i.e., the text excluding the predicted characters). - output_text (str): Portion of the text that represents the predicted output (i.e., the last n_predicted characters). - """ - - text: str - n_predicted: int - reward: float = 0.0 - logprobs: list[float] = Field(default_factory=list) - ref_logprobs: list[float] = Field(default_factory=list) - input_ids: list[int] = Field(default_factory=list) - labels: list[int] = Field(default_factory=list) - group_id: str | None = None - metadata: dict = Field(default_factory=dict) - - @property - def prompt_text(self) -> str: - return self.text[: -self.n_predicted] - - @property - def output_text(self) -> str: - return self.text[-self.n_predicted :] - class Prompt(BaseModel): """ @@ -392,10 +359,6 @@ class TrainableLLM(LLM): base_url (str): Base URL of the API endpoint api_token (str): Authentication token for API access """ - - # TODO: use OpenAI Python client when the certificate issue is resolved. - # TODO: consider using litellm - base_url: str = "https://api.openai.com" api_token: str = Field(default="", exclude=True) collect_logprobs: bool = False @@ -403,7 +366,7 @@ class TrainableLLM(LLM): max_parallel_requests: int = 32 max_retries: int = 5 base_delay: float = 0.5 - _semaphore: asyncio.Semaphore + _semaphore: asyncio.Semaphore = None # type: ignore def model_post_init(self, __context): super().model_post_init(__context) diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 0a6015e4..a45c4438 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -211,7 +211,7 @@ def run_dataset_loader( # This is a blocking call, but in most cases there will be space raw_chunk_queue.put(buffer) except Exception as e: - logger.error(f"Error in dataset loader: {e}") + logger.exception(f"Error in dataset loader: {e}") raw_chunk_queue.put(e) break @@ -398,8 +398,8 @@ def run_preprocessing_loop( # Initialize TrainerState trainer_state = TrainerState(exp_root_dir) - if cfg.debug.mode == "preprocessor": - logger.info("Debug mode: preprocessor") + if cfg.debug.mode == "preprocessor" or cfg.debug.mode == "actor+preprocessor": + logger.info(f"Debug mode: {cfg.debug.mode}") trainer_state.debug_mode_init() elif cfg.debug.mode == "finetune+preprocessor": logger.info("Debug mode: finetune+preprocessor") @@ -554,7 +554,7 @@ def run_preprocessing_loop( else: processed_entries_queue_popped_data += 1 if processed_entries_queue_popped_data % 100 == 0 and last_time_notice != processed_entries_queue_popped_data // 100: - logger.warning(f"Popped {processed_entries_queue_popped_data} old entries from processed entries queue") + logger.warning(f"Popped {processed_entries_queue_popped_data} old entries from processed entries queue of max size {processed_entries_queue.maxlen}") last_time_notice = processed_entries_queue_popped_data // 100 entry = buffer.popleft() processed_entries_queue.append(entry) # drop from the left if full @@ -590,6 +590,10 @@ def run_preprocessing_loop( sample_length = len(entry["input_ids"]) if current_length + sample_length > cfg.finetune.seq_length: + if len(current_batch) == 0: + raise ValueError( + f"sample_length is {sample_length}, but cfg.finetune.seq_length is {cfg.finetune.seq_length}" + ) time_to_write = True break # Current micro batch is full @@ -654,6 +658,7 @@ def run_preprocessing_loop( "preprocessor/queue/output": output_queue.qsize(), "preprocessor/filtered_out_samples": num_filtered_out, "preprocessor/total_filtered_out_samples": total_filtered_out, + "preprocessor/dropped_after_preprocessing": processed_entries_queue_popped_data, } if stats_aggregator.has_enough_data(): stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()}) diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py new file mode 100644 index 00000000..12e6fc2d --- /dev/null +++ b/pipelinerl/rl_tool_parser_plugin.py @@ -0,0 +1,247 @@ +""" +Tool parser plugin for RL tool calling format. +""" + +import json +import re +from typing import Any # noqa: F401 +import logging + +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ExtractedToolCallInformation, + ToolCall, + FunctionCall +) + + +@ToolParserManager.register_module("rl_tool") +class HermesRLToolParser(ToolParser): + """ + Tool parser for RL tool calling format using markers. + Supports both standard format and Apriel-style formats: + - [{...}, {...}] (preferred if present) + - [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE] wrapper + """ + + def __init__(self, tokenizer): + super().__init__(tokenizer) + + # Tool call markers + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + # Regex pattern for parsing tool calls + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL + ) + + # Apriel-specific patterns + self.apriel_final_response_regex = re.compile( + r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL + ) + # Prefer parsing aggregated tool calls from ... + # Be lenient: case-insensitive; tolerate missing closing tag by capturing to end. + self.apriel_tool_calls_regex = re.compile( + r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE + ) + + # State for streaming + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract tool calls from the model output. + + Args: + model_output: The raw model output string + request: The request object + + Returns: + ExtractedToolCallInformation with tool calls and metadata + """ + logger = logging.getLogger("pipelinerl.tool_parser") + # Ensure variable exists for any fallback references below + final_response_match = None + + try: + # 1) Apriel aggregated tool calls block has priority + tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output)) + if tool_calls_matches: + # Use the last match (in case of multiple blocks) + last_match = tool_calls_matches[-1] + tool_calls_json = last_match.group(1).strip() + parsed_calls = [] + try: + parsed_calls = json.loads(tool_calls_json) if tool_calls_json else [] + except Exception: + logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True) + parsed_calls = [] + + tool_calls: list[ToolCall] = [] + for i, pc in enumerate(parsed_calls): + try: + name = pc.get("name", "") + args_obj = pc.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + args_str = json.dumps(args_obj, ensure_ascii=False) + call_id = pc.get("id", f"call_{i}") + tool_calls.append( + ToolCall( + id=call_id, + type="function", + function=FunctionCall(name=str(name), arguments=args_str), + ) + ) + except Exception: + logger.debug("Skipping malformed aggregated tool call", exc_info=True) + continue + + # Prefer final response content if present; otherwise empty string + final_response_match = self.apriel_final_response_regex.search(model_output) + content = final_response_match.group(1).strip() if final_response_match else "" + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # 2) Try bare JSON tool-calls (no tags), but only if tools are declared in the request + # Accept either a list of {name, arguments} or a single dict + try: + tools_declared = bool(getattr(request, "tools", None)) + except Exception: + tools_declared = False + + if tools_declared: + candidate_strings: list[str] = [] + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + candidate_strings.append(final_response_match.group(1).strip()) + candidate_strings.append(model_output.strip()) + + for candidate in candidate_strings: + try: + parsed = json.loads(candidate) + except Exception: + continue + parsed_list = [] + if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: + parsed_list = [parsed] + elif isinstance(parsed, list) and all(isinstance(it, dict) for it in parsed): + parsed_list = [it for it in parsed if "name" in it and "arguments" in it] + if not parsed_list: + continue + tool_calls: list[ToolCall] = [] + for i, pc in enumerate(parsed_list): + try: + name = pc.get("name", "") + args_obj = pc.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + args_str = json.dumps(args_obj, ensure_ascii=False) + call_id = pc.get("id", f"call_{i}") + tool_calls.append( + ToolCall( + id=call_id, + type="function", + function=FunctionCall(name=str(name), arguments=args_str), + ) + ) + except Exception: + logger.debug("Skipping malformed bare-JSON tool call", exc_info=True) + continue + content = final_response_match.group(1).strip() if final_response_match else "" + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # 3) Fallback: look for single blocks (legacy / other models) + content_to_search = model_output + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + final_response_content = final_response_match.group(1).strip() + if self.tool_call_start_token in final_response_content: + content_to_search = final_response_content + elif self.tool_call_start_token not in model_output: + # No tool calls found, return final response as content + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_content + ) + + # Quick check to avoid unnecessary processing + if self.tool_call_start_token not in content_to_search: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + + # Find all tool call matches + function_call_tuples = self.tool_call_regex.findall(content_to_search) + + # Parse JSON from matches + tool_calls = [] + for i, match in enumerate(function_call_tuples): + json_str = match[0] if match[0] else match[1] + try: + parsed_call = json.loads(json_str.strip()) + args_obj = parsed_call.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + tool_call = ToolCall( + id=f"call_{i}", + type="function", + function=FunctionCall( + name=str(parsed_call.get("name", "")), + arguments=json.dumps(args_obj, ensure_ascii=False) + ) + ) + tool_calls.append(tool_call) + except Exception: + logger.debug("Skipping malformed JSON", exc_info=True) + continue + + # Determine content based on whether we found tool calls + if tool_calls and final_response_match: + # If we found tool calls in final response, use just the tool calls + content = "" + elif final_response_match: + # If we have final response but no tool calls there, use final response + content = final_response_match.group(1).strip() + else: + # Standard processing + content = model_output + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content + ) + + except Exception: + # Never propagate exceptions to the server; log and return a safe fallback. + logger.exception("Tool parser encountered an exception; returning safe fallback.") + if final_response_match: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_match.group(1).strip() + ) + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + \ No newline at end of file diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index dcb27f2d..c755cf31 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -64,3 +64,5 @@ class RolloutResult(BaseModel): model_version: int | None = None dataset_name: str | None = None group_id: str | None = None + llm_url: str = "" + attempt: int = 0 # number of attempt in the group of that problem diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index bd7fe5b5..f061d216 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -294,19 +294,19 @@ def wait_for_inference_servers(urls: list[str]): def wait_for_environments(cfg: DictConfig): - """ - Wait for the verifier to be ready. - """ + """Wait for remote environment servers to report healthy.""" + if cfg.world.environment_mode != "remote": + return + env_jobs = [Job(**job) for job in cfg.jobs if job.kind == "environment"] for job in env_jobs: while True: url = f"http://{job.hostname}:{job.port}/health" - # use requests try: response = requests.get(url) if response.status_code == 200: break - except: + except requests.exceptions.RequestException: logger.info(f"Waiting for environment at {url} to be ready...") time.sleep(5.0) diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 8cd023bd..6858c7cd 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -3,39 +3,39 @@ import logging import os import signal -from pydantic import TypeAdapter + import torch +import torch.distributed as dist import uvloop +from pydantic import TypeAdapter from vllm import AsyncLLMEngine -from vllm.utils import FlexibleArgumentParser, set_ulimit -from vllm.entrypoints.openai.cli_args import ( - make_arg_parser, - validate_parsed_serve_args, -) +from vllm._version import version +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, - create_server_socket, build_app, + create_server_socket, init_app_state, + run_server, +) +from vllm.entrypoints.openai.cli_args import ( + make_arg_parser, + validate_parsed_serve_args, ) -from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.logger import init_logger -from vllm._version import version -from vllm.worker.worker import Worker -from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper from vllm.executor.mp_distributed_executor import MultiprocessingDistributedExecutor +from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.usage.usage_lib import UsageContext -from vllm.worker.multi_step_worker import MultiStepWorker +from vllm.utils import FlexibleArgumentParser, set_ulimit from vllm.worker.multi_step_model_runner import MultiStepModelRunner +from vllm.worker.multi_step_worker import MultiStepWorker +from vllm.worker.worker import Worker - -import torch.distributed as dist -from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest import pipelinerl.torch_utils +from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest logger = logging.getLogger(__name__) # configure this logger individually, in order to avoid messign @@ -180,6 +180,25 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid tool call parser: {args.tool_call_parser} (chose from {{ {','.join(valide_tool_parses)} }})" ) + # Choose a unique rendezvous port per actor to avoid torch.distributed + # TCPStore collisions across concurrently launched vLLM processes. + try: + if "VLLM_PORT" not in os.environ: + actor_idx = getattr(args, "actor_llm_idx", None) + base_str = os.environ.get("VLLM_PORT_BASE", "") + stride_str = os.environ.get("VLLM_PORT_STRIDE", "10") + if actor_idx is not None and base_str.isdigit(): + base = int(base_str) + stride = int(stride_str) if stride_str.isdigit() else 10 + port = base + stride * int(actor_idx) + os.environ["VLLM_PORT"] = str(port) + logger.info( + "Using VLLM_PORT=%s (base=%s stride=%s actor_idx=%s)", + port, base, stride, actor_idx, + ) + except Exception as e: + logger.warning("Failed to set VLLM_PORT from actor_idx: %s", e) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index be98f76f..38d1bc96 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -1,32 +1,32 @@ import logging import signal +from typing import Any, Protocol, runtime_checkable + import torch import uvloop -from vllm.utils import FlexibleArgumentParser, set_ulimit -from vllm.entrypoints.openai.cli_args import ( - make_arg_parser, - validate_parsed_serve_args, -) +from vllm._version import version +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, - create_server_socket, build_app, + create_server_socket, init_app_state, + run_server, +) +from vllm.entrypoints.openai.cli_args import ( + make_arg_parser, + validate_parsed_serve_args, ) -from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm._version import version from vllm.usage.usage_lib import UsageContext -from vllm.config import ModelConfig +from vllm.utils import FlexibleArgumentParser, set_ulimit from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import AsyncMPClient from vllm.v1.worker.gpu_model_runner import GPUModelRunner - -from pipelinerl.finetune_loop import WeightUpdateRequest -from typing import Any, Protocol, runtime_checkable import pipelinerl.torch_utils +from pipelinerl.finetune_loop import WeightUpdateRequest logger = logging.getLogger(__name__) # configure this logger individually, in order to avoid messign diff --git a/pipelinerl/world.py b/pipelinerl/world.py index 992a7c4d..54ee2bf8 100644 --- a/pipelinerl/world.py +++ b/pipelinerl/world.py @@ -71,7 +71,7 @@ def __init__(self, cfg: DictConfig, verbose: bool = False): if place_inference_jobs: self._place_inference_jobs(cfg) self._place_pipeline_stages(cfg) - if cfg.environment: + if cfg.environment and cfg.world.environment_mode == "remote": self._place_environments(cfg) # Place the finetune workers on the remaining gpus, take all remaining GPUs diff --git a/pyproject.toml b/pyproject.toml index 7fc9978a..32504c1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "uvloop>=0.19.0", "wandb>=0.16.0", "hydra-core>=1.3.2", + "ray[default]~=2.47.1", ] [project.optional-dependencies]