From 771a4f9c6807cef580d7b369adaf2d3f42f857b1 Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 23 Dec 2025 14:31:21 +0000 Subject: [PATCH 01/13] library upgrades + vllm1 weight update changes --- README.md | 6 ++--- conf/base.yaml | 11 +++++--- pipelinerl/finetune_loop.py | 25 ++++++++++------- pipelinerl/torch_utils.py | 53 +++++++++++++++++++++++++++++++++++++ pipelinerl/vllm0.py | 34 +++++++++++++++++++++++- pipelinerl/vllm1.py | 32 ++++++++++++++-------- pyproject.toml | 28 +++++++++++--------- 7 files changed, 148 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 8a172c0d..fccf83a6 100644 --- a/README.md +++ b/README.md @@ -195,9 +195,9 @@ cd PipelineRL Create the environments with dependencies. ```bash -conda create -n pipeline-rl -y python=3.11 -conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0 -conda run --no-capture-output -n pipeline-rl pip install -e . --no-build-isolation +conda create -n pipeline-rl -y python=3.12 +conda run --no-capture-output -n pipeline-rl pip install -e . +conda run --no-capture-output -n pipeline-rl pip install flash-attn==2.8.3 --no-build-isolation ``` By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment: diff --git a/conf/base.yaml b/conf/base.yaml index 1f8d73cc..2dd03d03 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -57,14 +57,11 @@ test_llm: top_k: 50 vllm_config: - use_v1: false + use_v1: true quantization: null # or bf16_last_layer_fp32 vllm_kwargs: dtype: bfloat16 gpu-memory-utilization: 0.9 - num-scheduler-steps: 1 - disable-log-requests: "" - disable-frontend-multiprocessing: "" max-num-seqs: ${actor.llm_max_rollouts} max-num-batched-tokens: 1024 enable-chunked-prefill: "" @@ -73,6 +70,12 @@ vllm_config: pipeline-parallel-size: 1 generation-config: vllm max_model_len: 10000 + # V1 specific settings + # logprobs-mode: processed_logprobs + # V0 specific settings + # num-scheduler-steps: 1 + # disable-log-requests: "" + # disable-frontend-multiprocessing: "" world: replicas: 1 diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 57d2950e..c72d2997 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -27,7 +27,7 @@ from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params from pipelinerl.finetune.value_model import AutoModelForCausalLMWithValueHead -import pipelinerl.torch_utils +from pipelinerl.torch_utils import stateless_init_process_group from pipelinerl.finetune.types import PipelineBatchEncoding from pipelinerl.finetune.checkpoints import ( load_model, @@ -212,7 +212,8 @@ def send_weight_update( for name, parameter in named_parameters.items(): with deepspeed.zero.GatheredParameters([parameter]): if get_accelerator().is_main_process: - dist.broadcast(parameter.data, src=0, group=self.actor_update_group) + # Use PyNcclCommunicator's broadcast method instead of torch.distributed + self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) if get_accelerator().is_main_process: logger.info("Wait for HTTP requests") for future in futures: # type: ignore @@ -254,8 +255,8 @@ def send_weight_update( futures = self.request_weight_updates(messages) logger.info(f"Published weight update request for version {version}") for _, parameter in named_parameters.items(): - dist.broadcast(parameter.data, src=0, group=self.actor_update_group) - dist.barrier(self.actor_update_group) + # Use PyNcclCommunicator's broadcast method instead of torch.distributed + self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) for future in futures: future.result() logger.info("Finished broadcasting weights") @@ -408,13 +409,18 @@ def run_finetuning_loop( get_accelerator().wait_for_everyone() if get_accelerator().is_main_process and args.send_weight_updates: - logger.info("Initializing actor process group") - actor_update_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + logger.info("Initializing actor process group using StatelessProcessGroup") + + # Explicitly set CUDA device before creating NCCL process group + current_device = get_accelerator().device + torch.cuda.set_device(current_device) + logger.info(f"Set CUDA device to {current_device} for actor process group (rank 0)") + + actor_update_group = stateless_init_process_group( init_method=cfg.me.weight_update_group_init_method, rank=0, world_size=cfg.me.weight_update_group_world_size, + device=current_device, ) logger.info("Actor process group initialized") else: @@ -493,8 +499,7 @@ def run_finetuning_loop( finally: if weight_update_manager is not None: weight_update_manager.shutdown() - if actor_update_group: - dist.destroy_process_group(actor_update_group) + # PyNcclCommunicator doesn't need explicit destroy like torch.distributed process groups def rl_finetuning_worker( diff --git a/pipelinerl/torch_utils.py b/pipelinerl/torch_utils.py index 2d16d99f..588aeab5 100644 --- a/pipelinerl/torch_utils.py +++ b/pipelinerl/torch_utils.py @@ -1,14 +1,51 @@ +import logging from datetime import timedelta from typing import Any, Optional, Union +from urllib.parse import urlparse + +import torch +import torch.distributed as dist from torch.distributed.distributed_c10d import ( Backend, PrefixStore, + ProcessGroupNCCL, Store, _new_process_group_helper, _world, default_pg_timeout, rendezvous, ) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.utils import StatelessProcessGroup + +logger = logging.getLogger(__name__) + + +def stateless_init_process_group(init_method, rank, world_size, device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + + Args: + init_method: TCP init method string (e.g., "tcp://localhost:9000") + rank: The rank of this process in the group + world_size: Total number of processes in the group + device: The CUDA device to use for NCCL communication + """ + # Parse master_address and master_port from init_method (e.g., "tcp://localhost:9000") + parsed = urlparse(init_method) + master_address = parsed.hostname or "localhost" + master_port = parsed.port or 9000 + logger.debug(f"Parsed master_address: {master_address}, master_port: {master_port}") + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl # Copy from pytorch to allow creating multiple main groups. @@ -22,6 +59,7 @@ def init_extra_process_group( store: Optional[Store] = None, group_name: str = None, pg_options: Optional[Any] = None, + device_id: Optional[torch.device] = None, ): assert (store is None) or (init_method is None), "Cannot specify both init_method and store." @@ -49,6 +87,19 @@ def init_extra_process_group( # different systems (e.g. RPC) in case the store is multi-tenant. store = PrefixStore(group_name, store) + # Create NCCL-specific options if using NCCL backend + logger.info(f"[{group_name}] Backend: {backend}, str(backend): {str(backend)}") + if pg_options is None and str(backend) == "nccl": + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = False + logger.info(f"[{group_name}] Created NCCL options: {pg_options}") + + # Ensure CUDA is synchronized before creating NCCL process group + if device_id is not None: + torch.cuda.synchronize(device_id) + logger.info(f"[{group_name}] CUDA synchronized on {device_id}") + + logger.info(f"[{group_name}] Creating process group: rank={rank}, world_size={world_size}, device_id={device_id}") pg, _ = _new_process_group_helper( world_size, rank, @@ -58,7 +109,9 @@ def init_extra_process_group( group_name=group_name, backend_options=pg_options, timeout=timeout, + device_id=device_id, ) + logger.info(f"[{group_name}] Process group created successfully") _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 93e99d60..fb3bea58 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -1,3 +1,36 @@ +""" +DEPRECATED - Kept only for backward compatibility with older vLLM versions. + +This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture. + +Compatibility: + - vLLM versions <= 0.8.x only + - The V0 engine was removed in vLLM 0.11.0 + - Use vllm1.py instead +""" +import warnings +from packaging import version as version_parser +import vllm + +# Check vLLM version compatibility +vllm_version = version_parser.parse(vllm.__version__) + +if vllm_version >= version_parser.parse("0.9.0"): + raise ImportError( + f"pipelinerl.vllm0 is not compatible with vLLM {vllm.__version__}. " + "This module only works with vLLM <= 0.8.x. " + "Please use pipelinerl.vllm1 for vLLM >= 0.11.0 instead." + ) + +# Only show deprecation warning for compatible versions +warnings.warn( + "pipelinerl.vllm0 is DEPRECATED and will be removed in a future version. " + "This module only works with vLLM <= 0.8.x. " + "Please use pipelinerl.vllm1 as it is actively maintained.", + DeprecationWarning, + stacklevel=2, +) + import asyncio import json import logging @@ -14,7 +47,6 @@ ) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, create_server_socket, build_app, init_app_state, diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 1ac611d0..86e6be4a 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -2,7 +2,8 @@ import signal import torch import uvloop -from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.system_utils import set_ulimit from vllm.entrypoints.openai.cli_args import ( make_arg_parser, validate_parsed_serve_args, @@ -26,8 +27,8 @@ from pipelinerl.finetune_loop import WeightUpdateRequest from pipelinerl.vllm_quantization import string_to_dtype # reuse mapping +from pipelinerl.torch_utils import stateless_init_process_group from typing import Any, Protocol, runtime_checkable -import pipelinerl.torch_utils import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config logger = logging.getLogger(__name__) @@ -40,7 +41,6 @@ handler.setFormatter(formatter) logger.addHandler(handler) - @runtime_checkable class LikeWorker(Protocol): rank: int @@ -72,15 +72,18 @@ def init_actor_update_group( prefix + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}" ) - self.process_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + + # Use vLLM's StatelessProcessGroup instead of torch.distributed + self.model_update_group = stateless_init_process_group( init_method=weight_update_group_init_method, rank=self.pg_rank, world_size=weight_update_group_world_size, + device=self.device, ) + logger.info(prefix + "Actor update process group initialized") - def receive_weight_update(self: LikeWorker, request: WeightUpdateRequest): + def receive_weight_update(self: LikeWorker, request_json: str): + request = WeightUpdateRequest.model_validate_json(request_json) torch.cuda.synchronize(self.device) logger.info("Start receiving weight update") expected_dtypes = (torch.bfloat16, torch.float32, torch.float16) @@ -89,7 +92,8 @@ def receive_weight_update(self: LikeWorker, request: WeightUpdateRequest): if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - torch.distributed.broadcast(buffer, src=0, group=self.process_group) + # Use PyNcclCommunicator's broadcast method instead of torch.distributed + self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore if len(loaded_params) != 1: raise ValueError(f"model {info.name} not found in model state dict") @@ -114,8 +118,13 @@ async def input_process_groups(self): ) async def receive_weight_update(self, request: WeightUpdateRequest): + # Ensure workers are ready by executing a dummy batch first + # This synchronizes workers before the NCCL collective + # logger.info("Synchronizing workers before weight update...") + # await self.engine_client.execute_dummy_batch_async() + # logger.info("Workers synchronized, starting weight update") await self.engine_client.collective_rpc_async( - "receive_weight_update", args=(request,) + "receive_weight_update", args=(request.model_dump_json(),) ) logger.info("Weight update processed") @@ -157,7 +166,7 @@ def signal_handler(*_) -> None: vllm_config=engine_config, usage_context=UsageContext.OPENAI_API_SERVER, disable_log_stats=engine_args.disable_log_stats, - disable_log_requests=engine_args.disable_log_requests, + enable_log_requests=engine_args.enable_log_requests, ) assert isinstance(engine.engine_core, AsyncMPClient) @@ -172,10 +181,11 @@ def signal_handler(*_) -> None: @app.post("/receive_weight_update") async def _receive_weight_update(request: WeightUpdateRequest): + logger.info("Received weight update request") await weight_update_manager.receive_weight_update(request) return {"status": "ok"} - await init_app_state(engine, engine_config, app.state, args) + await init_app_state(engine, app.state, args) shutdown_task = await serve_http( app, sock, diff --git a/pyproject.toml b/pyproject.toml index 7fc9978a..ab7e69b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,11 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] +requires = [ + "setuptools>=61.0", + "wheel", + "numpy>=1.26.0", + "packaging>=23.0", + "torch~=2.9.0", +] build-backend = "setuptools.build_meta" [project] @@ -14,15 +20,13 @@ authors = [ ] dependencies = [ "aiohttp>=3.9.0", - "torch>=2.6", - "vllm==0.8.5.post1", - "accelerate==1.7.0", - "deepspeed==0.15.4", + "vllm==0.11.2", + "accelerate==1.12.0", + "deepspeed~=0.18.0", "browsergym>=0.13.0", "datasets>=2.21.0", - "transformers==4.51.1" , + "transformers~=4.57.0" , "fastapi>=0.115.0", - "flash-attn==2.7.4.post1", "joblib>=1.3.2", "jsonref>=1.1.0", "litellm>=1.61.0", @@ -32,11 +36,11 @@ dependencies = [ "Pillow>=10.0.0", "psutil>=5.9.0", "pydantic>=2.9.0", - "ring-flash-attn==0.1.6", - "math-verify[antlr4_9_3]==0.7.0", - "orjson==3.10.16", + "ring-flash-attn==0.1.8", + "math-verify[antlr4_9_3]==0.8.0", + "orjson~=3.11.0", "requests>=2.31.0", - "redis==5.2.1", + "redis~=7.0.0", "safetensors>=0.4.0", "tenacity>=8.2.0", "uvicorn>=0.29.0", @@ -50,7 +54,7 @@ tapeagents = [ "Tapeagents[finetune]==0.1.16", ] lora = [ - "peft==0.12.0", + "peft==0.18.0", ] [tool.setuptools.packages.find] From c0dc029df4318e9142ab2e1c5b4e1b3c85607bcb Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 23 Dec 2025 14:38:27 +0000 Subject: [PATCH 02/13] unused parameter (device_id) removed --- pipelinerl/torch_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pipelinerl/torch_utils.py b/pipelinerl/torch_utils.py index 588aeab5..d2b78d53 100644 --- a/pipelinerl/torch_utils.py +++ b/pipelinerl/torch_utils.py @@ -59,7 +59,6 @@ def init_extra_process_group( store: Optional[Store] = None, group_name: str = None, pg_options: Optional[Any] = None, - device_id: Optional[torch.device] = None, ): assert (store is None) or (init_method is None), "Cannot specify both init_method and store." @@ -94,12 +93,6 @@ def init_extra_process_group( pg_options.is_high_priority_stream = False logger.info(f"[{group_name}] Created NCCL options: {pg_options}") - # Ensure CUDA is synchronized before creating NCCL process group - if device_id is not None: - torch.cuda.synchronize(device_id) - logger.info(f"[{group_name}] CUDA synchronized on {device_id}") - - logger.info(f"[{group_name}] Creating process group: rank={rank}, world_size={world_size}, device_id={device_id}") pg, _ = _new_process_group_helper( world_size, rank, @@ -109,7 +102,6 @@ def init_extra_process_group( group_name=group_name, backend_options=pg_options, timeout=timeout, - device_id=device_id, ) logger.info(f"[{group_name}] Process group created successfully") From 28ec2cd3a0913a853e63b3d1393e7ef709129a92 Mon Sep 17 00:00:00 2001 From: ehsk Date: Sun, 18 Jan 2026 16:37:19 +0000 Subject: [PATCH 03/13] transformers dtype warning fixed --- pipelinerl/finetune/checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/finetune/checkpoints.py b/pipelinerl/finetune/checkpoints.py index 0a673fbf..1dda0b99 100644 --- a/pipelinerl/finetune/checkpoints.py +++ b/pipelinerl/finetune/checkpoints.py @@ -179,7 +179,7 @@ def load_model(args, model_class, current_dir): is_ds_zero_3 = get_accelerator().state.deepspeed_plugin.zero_stage == 3 # type: ignore if args.load_as_bf16: - loading_args["torch_dtype"] = torch.bfloat16 + loading_args["dtype"] = torch.bfloat16 if args.auto_device_map: loading_args["device_map"] = "auto" model_cls = get_auto_model_class(model_class) From 88d6da84cd499271058d7be76b96fd15527661c5 Mon Sep 17 00:00:00 2001 From: ehsk Date: Sun, 18 Jan 2026 16:37:33 +0000 Subject: [PATCH 04/13] transformers dtype warning fixed --- pipelinerl/finetune/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/finetune/lora.py b/pipelinerl/finetune/lora.py index 81011be0..3c84610d 100644 --- a/pipelinerl/finetune/lora.py +++ b/pipelinerl/finetune/lora.py @@ -147,7 +147,7 @@ def merge_lora(lora_model_path): assert os.path.exists(lora_model_config), f"{lora_model_config} does not exists" logger.info(f"Merge lora checkpoint {lora_model_path}") - model = lora_load_and_merge(lora_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + model = lora_load_and_merge(lora_model_path, dtype=torch.bfloat16, low_cpu_mem_usage=True) tokenizer = AutoTokenizer.from_pretrained(lora_model_path) tmp_dir = f"{lora_model_path}_merged" From a28bdaf09915df9c85c12706369779fe73e7c695 Mon Sep 17 00:00:00 2001 From: ehsk Date: Sun, 18 Jan 2026 16:38:02 +0000 Subject: [PATCH 05/13] more fixes and updates --- pipelinerl/vllm1.py | 43 +++++++++++++++++++++++++++---------------- pyproject.toml | 4 ++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 86e6be4a..df3d04d1 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -1,3 +1,4 @@ +import asyncio import logging import signal import torch @@ -10,7 +11,6 @@ ) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, create_server_socket, build_app, init_app_state, @@ -41,12 +41,13 @@ handler.setFormatter(formatter) logger.addHandler(handler) + @runtime_checkable class LikeWorker(Protocol): rank: int local_rank: int device: torch.device - model_runner: GPUModelRunner + model_runner: GPUModelRunner pg_rank: int process_group: Any model_config: ModelConfig @@ -87,23 +88,32 @@ def receive_weight_update(self: LikeWorker, request_json: str): torch.cuda.synchronize(self.device) logger.info("Start receiving weight update") expected_dtypes = (torch.bfloat16, torch.float32, torch.float16) + for info in request.parameters_info: target_dtype = string_to_dtype(info.dtype) if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - # Use PyNcclCommunicator's broadcast method instead of torch.distributed self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) - loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore + loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore if len(loaded_params) != 1: raise ValueError(f"model {info.name} not found in model state dict") + pipelinerl.vllm_quantization.invalidate_fp32_cache() logger.info("Weight update received") + def close_communicator(self): + """Closes the communicator when weight synchronization is no longer needed.""" + if hasattr(self, "model_update_group") and self.model_update_group is not None: + del self.model_update_group + self.model_update_group = None + logger.info("Weight update communicator closed") + class WeightUpdateManager: - def __init__(self, args, engine_client: AsyncMPClient): + def __init__(self, args, engine: AsyncLLM, engine_client: AsyncMPClient): self.args = args + self.engine = engine self.engine_client = engine_client async def input_process_groups(self): @@ -118,16 +128,16 @@ async def input_process_groups(self): ) async def receive_weight_update(self, request: WeightUpdateRequest): - # Ensure workers are ready by executing a dummy batch first - # This synchronizes workers before the NCCL collective - # logger.info("Synchronizing workers before weight update...") - # await self.engine_client.execute_dummy_batch_async() - # logger.info("Workers synchronized, starting weight update") + logger.info("Starting weight update...") await self.engine_client.collective_rpc_async( "receive_weight_update", args=(request.model_dump_json(),) ) logger.info("Weight update processed") + async def close_communicator(self): + """Closes the communicator when weight synchronization is no longer needed.""" + await self.engine_client.collective_rpc_async("close_communicator") + async def run_server(args, **uvicorn_kwargs) -> None: # COPIED FROM vllm/entrypoints/openai/api_server.py, vllm version 0.6.6.post1 @@ -170,7 +180,7 @@ def signal_handler(*_) -> None: ) assert isinstance(engine.engine_core, AsyncMPClient) - weight_update_manager = WeightUpdateManager(args, engine.engine_core) + weight_update_manager = WeightUpdateManager(args, engine, engine.engine_core) if not args.disable_weight_updates: await weight_update_manager.input_process_groups() @@ -181,8 +191,9 @@ def signal_handler(*_) -> None: @app.post("/receive_weight_update") async def _receive_weight_update(request: WeightUpdateRequest): - logger.info("Received weight update request") - await weight_update_manager.receive_weight_update(request) + # Fire-and-forget: return immediately, weight update happens in background + logger.info("Received weight update request (fire-and-forget)") + asyncio.create_task(weight_update_manager.receive_weight_update(request)) return {"status": "ok"} await init_app_state(engine, app.state, args) @@ -204,11 +215,11 @@ async def _receive_weight_update(request: WeightUpdateRequest): # NB: Await server shutdown only after the backend context is exited await shutdown_task + # Cleanup + if not args.disable_weight_updates: + await weight_update_manager.close_communicator() sock.close() - # TODO: proper cleanup - # dist.destroy_process_group(actor_update_group) - def run_llm(): parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.") diff --git a/pyproject.toml b/pyproject.toml index ab7e69b0..238c14b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,14 +13,14 @@ name = "pipelinerl" version = "0.1.0" description = "A scalable asynchronous reinforcement learning implementation with in-flight weight updates." readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.12" license = { file = "LICENSE" } authors = [ { name = "ServiceNow" }, ] dependencies = [ "aiohttp>=3.9.0", - "vllm==0.11.2", + "vllm==0.13.0", "accelerate==1.12.0", "deepspeed~=0.18.0", "browsergym>=0.13.0", From 9eb5264d0d6667493956ccd63d2cadafc50a540d Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 19 Jan 2026 19:19:13 +0000 Subject: [PATCH 06/13] minor logging changes --- pipelinerl/finetune_loop.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index c72d2997..22b6e5a4 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -27,8 +27,9 @@ from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params from pipelinerl.finetune.value_model import AutoModelForCausalLMWithValueHead -from pipelinerl.torch_utils import stateless_init_process_group +from pipelinerl import torch_utils from pipelinerl.finetune.types import PipelineBatchEncoding + from pipelinerl.finetune.checkpoints import ( load_model, load_tokenizer, @@ -409,14 +410,11 @@ def run_finetuning_loop( get_accelerator().wait_for_everyone() if get_accelerator().is_main_process and args.send_weight_updates: - logger.info("Initializing actor process group using StatelessProcessGroup") - - # Explicitly set CUDA device before creating NCCL process group current_device = get_accelerator().device torch.cuda.set_device(current_device) + logger.info("Initializing actor process group using StatelessProcessGroup") logger.info(f"Set CUDA device to {current_device} for actor process group (rank 0)") - - actor_update_group = stateless_init_process_group( + actor_update_group = torch_utils.stateless_init_process_group( init_method=cfg.me.weight_update_group_init_method, rank=0, world_size=cfg.me.weight_update_group_world_size, From 7fff37471323b7cef764a71149d36a94100f4055 Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 19 Jan 2026 19:30:46 +0000 Subject: [PATCH 07/13] upgrade vllm but stay on v0 --- conf/base.yaml | 8 ++++---- pipelinerl/vllm0.py | 39 +++++++++++++-------------------------- pyproject.toml | 4 ++-- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/conf/base.yaml b/conf/base.yaml index 2dd03d03..5418f846 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -57,7 +57,7 @@ test_llm: top_k: 50 vllm_config: - use_v1: true + use_v1: false quantization: null # or bf16_last_layer_fp32 vllm_kwargs: dtype: bfloat16 @@ -73,9 +73,9 @@ vllm_config: # V1 specific settings # logprobs-mode: processed_logprobs # V0 specific settings - # num-scheduler-steps: 1 - # disable-log-requests: "" - # disable-frontend-multiprocessing: "" + num-scheduler-steps: 1 + disable-log-requests: "" + disable-frontend-multiprocessing: "" world: replicas: 1 diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index fb3bea58..1dbd3f4c 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -4,39 +4,27 @@ This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture. Compatibility: - - vLLM versions <= 0.8.x only + - vLLM versions <= 0.10.0 only - The V0 engine was removed in vLLM 0.11.0 - Use vllm1.py instead """ -import warnings from packaging import version as version_parser import vllm # Check vLLM version compatibility vllm_version = version_parser.parse(vllm.__version__) -if vllm_version >= version_parser.parse("0.9.0"): +if vllm_version > version_parser.parse("0.10.0"): raise ImportError( f"pipelinerl.vllm0 is not compatible with vLLM {vllm.__version__}. " - "This module only works with vLLM <= 0.8.x. " + "This module only works with vLLM <= 0.10.0. " "Please use pipelinerl.vllm1 for vLLM >= 0.11.0 instead." ) -# Only show deprecation warning for compatible versions -warnings.warn( - "pipelinerl.vllm0 is DEPRECATED and will be removed in a future version. " - "This module only works with vLLM <= 0.8.x. " - "Please use pipelinerl.vllm1 as it is actively maintained.", - DeprecationWarning, - stacklevel=2, -) import asyncio -import json import logging -import os import signal -from pydantic import TypeAdapter import torch import uvloop from vllm import AsyncLLMEngine @@ -53,7 +41,6 @@ ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.logger import init_logger from vllm._version import version from vllm.worker.worker import Worker from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper @@ -65,14 +52,13 @@ from vllm.worker.multi_step_model_runner import MultiStepModelRunner -import torch.distributed as dist from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest from pipelinerl.vllm_quantization import string_to_dtype # reuse dtype mapping -import pipelinerl.torch_utils +from pipelinerl import torch_utils import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config logger = logging.getLogger(__name__) -# configure this logger individually, in order to avoid messign +# configure this logger individually, in order to avoid messing # with the default vllm logger configuration logger.setLevel(logging.INFO) handler = logging.StreamHandler() @@ -104,13 +90,14 @@ def init_actor_update_group( prefix + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}" ) - self.process_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + # Use StatelessProcessGroup + PyNcclCommunicator for cross-process NCCL communication + self.process_group = torch_utils.stateless_init_process_group( init_method=weight_update_group_init_method, rank=self.pg_rank, world_size=weight_update_group_world_size, + device=self.device, ) + logger.info(prefix + "Actor update process group initialized") def receive_weight_update(self, request: WeightUpdateRequest): torch.cuda.synchronize(self.device) @@ -120,7 +107,7 @@ def receive_weight_update(self, request: WeightUpdateRequest): if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - torch.distributed.broadcast(buffer, src=0, group=self.process_group) + self.process_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) if isinstance(self.model_runner, MultiStepModelRunner): loaded_params = self.model_runner._base_model_runner.model.load_weights( weights=[(info.name, buffer)] @@ -187,18 +174,18 @@ def input_process_groups(self): future.get() async def receive_weight_update(self, message: WeightUpdateRequest): - logger.info(f"Received weight update request") + logger.info("Received weight update request") async with executor_lock: if isinstance(self.executor, AsyncRLExecutor): await self.executor.stop_remote_worker_execution_loop_no_lock() - logger.info(f"Stopped remote worker") + logger.info("Stopped remote worker") futures = [] for worker in self.other_workers: futures.append(worker.execute_method("receive_weight_update", message)) self.driver_worker.receive_weight_update(message) for future in futures: future.get() - logger.info(f"All workers received weight updates") + logger.info("All workers received weight updates") async def run_server(args, **uvicorn_kwargs) -> None: diff --git a/pyproject.toml b/pyproject.toml index 238c14b6..ac0ccc28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "wheel", "numpy>=1.26.0", "packaging>=23.0", - "torch~=2.9.0", + "torch==2.7.1", ] build-backend = "setuptools.build_meta" @@ -20,7 +20,7 @@ authors = [ ] dependencies = [ "aiohttp>=3.9.0", - "vllm==0.13.0", + "vllm==0.10.0", "accelerate==1.12.0", "deepspeed~=0.18.0", "browsergym>=0.13.0", From 2c21fe45ef706449ee253803ba3214586d75c285 Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 20 Jan 2026 01:28:10 +0000 Subject: [PATCH 08/13] flash_attention install command updated --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index fccf83a6..7b890b98 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,14 @@ conda run --no-capture-output -n pipeline-rl pip install -e . conda run --no-capture-output -n pipeline-rl pip install flash-attn==2.8.3 --no-build-isolation ``` +Alternatively for `flash-attn`, you can install it via prebuilt packages (on Linux): +```bash +# Check your PyTorch's C++ ABI setting first: +# python -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)" +# Use cxx11abiTRUE or cxx11abiFALSE in the URL accordingly +conda run --no-capture-output -n pipeline-rl pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl +``` + By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment: ```bash conda install redis-server==7.4.0 -c conda-forge From 98e9ebd702f1233d40a91311841c2c83cb63a671 Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 20 Jan 2026 01:28:43 +0000 Subject: [PATCH 09/13] fixed typos and updated comments --- pipelinerl/vllm0.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 1dbd3f4c..54f40baf 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -1,12 +1,9 @@ """ -DEPRECATED - Kept only for backward compatibility with older vLLM versions. - This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture. Compatibility: - vLLM versions <= 0.10.0 only - - The V0 engine was removed in vLLM 0.11.0 - - Use vllm1.py instead + - Use vllm1.py for vLLM >= 0.11.0 as the V0 engine was removed in vLLM 0.11.0 """ from packaging import version as version_parser import vllm @@ -218,7 +215,7 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - # Build the engine with the bespoke Executor and orker clases + # Build the engine with the bespoke Executor and Worker classes multi_step = args.num_scheduler_steps > 1 engine_args = AsyncEngineArgs.from_cli_args(args) engine_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER) From a6a292ca3410d0a5f437fe90999b16a1e10fbb7d Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 20 Jan 2026 01:29:05 +0000 Subject: [PATCH 10/13] increased seq_length for math --- conf/math.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conf/math.yaml b/conf/math.yaml index a772a190..5a6e27b7 100644 --- a/conf/math.yaml +++ b/conf/math.yaml @@ -12,11 +12,11 @@ environment: dataset_loader: pipelinerl.domains.math.load_datasets finetune: - seq_length: 18000 + seq_length: 20000 vllm_config: vllm_kwargs: - max_model_len: 18000 + max_model_len: 20000 llm: parameters: From fd912349552453a715c59f2f3329ae8ee344d534 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 21 Jan 2026 16:22:27 +0000 Subject: [PATCH 11/13] better comments for broadcast --- pipelinerl/finetune_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 22b6e5a4..0f1078c2 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -213,7 +213,7 @@ def send_weight_update( for name, parameter in named_parameters.items(): with deepspeed.zero.GatheredParameters([parameter]): if get_accelerator().is_main_process: - # Use PyNcclCommunicator's broadcast method instead of torch.distributed + # Use PyNcclCommunicator's broadcast method as torch.distributed does not work (gets stuck) self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) if get_accelerator().is_main_process: logger.info("Wait for HTTP requests") @@ -256,7 +256,7 @@ def send_weight_update( futures = self.request_weight_updates(messages) logger.info(f"Published weight update request for version {version}") for _, parameter in named_parameters.items(): - # Use PyNcclCommunicator's broadcast method instead of torch.distributed + # Use PyNcclCommunicator's broadcast method as torch.distributed does not work (gets stuck) self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) for future in futures: future.result() From 84309feb7aa45c99b80d11f1213919d30a12f2a6 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 21 Jan 2026 16:22:37 +0000 Subject: [PATCH 12/13] asyncio removed --- pipelinerl/vllm1.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index df3d04d1..7025f328 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -1,4 +1,3 @@ -import asyncio import logging import signal import torch @@ -191,9 +190,9 @@ def signal_handler(*_) -> None: @app.post("/receive_weight_update") async def _receive_weight_update(request: WeightUpdateRequest): - # Fire-and-forget: return immediately, weight update happens in background - logger.info("Received weight update request (fire-and-forget)") - asyncio.create_task(weight_update_manager.receive_weight_update(request)) + # Blocking: wait for weight update to complete before returning + logger.info("Received weight update request") + await weight_update_manager.receive_weight_update(request) return {"status": "ok"} await init_app_state(engine, app.state, args) From fbbaf85e8e316e83ae1b79031a13cd61f93e2a19 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 21 Jan 2026 16:39:14 +0000 Subject: [PATCH 13/13] better clarification --- pipelinerl/finetune_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 0f1078c2..b23d6429 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -213,7 +213,7 @@ def send_weight_update( for name, parameter in named_parameters.items(): with deepspeed.zero.GatheredParameters([parameter]): if get_accelerator().is_main_process: - # Use PyNcclCommunicator's broadcast method as torch.distributed does not work (gets stuck) + # Use PyNcclCommunicator's broadcast method as torch.distributed does not work since vLLM disabled that transfer path self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) if get_accelerator().is_main_process: logger.info("Wait for HTTP requests") @@ -256,7 +256,7 @@ def send_weight_update( futures = self.request_weight_updates(messages) logger.info(f"Published weight update request for version {version}") for _, parameter in named_parameters.items(): - # Use PyNcclCommunicator's broadcast method as torch.distributed does not work (gets stuck) + # Use PyNcclCommunicator's broadcast method as torch.distributed does not work since vLLM disabled that transfer path self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) for future in futures: future.result()