diff --git a/README.md b/README.md index 8a172c0d..7b890b98 100644 --- a/README.md +++ b/README.md @@ -195,9 +195,17 @@ cd PipelineRL Create the environments with dependencies. ```bash -conda create -n pipeline-rl -y python=3.11 -conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0 -conda run --no-capture-output -n pipeline-rl pip install -e . --no-build-isolation +conda create -n pipeline-rl -y python=3.12 +conda run --no-capture-output -n pipeline-rl pip install -e . +conda run --no-capture-output -n pipeline-rl pip install flash-attn==2.8.3 --no-build-isolation +``` + +Alternatively for `flash-attn`, you can install it via prebuilt packages (on Linux): +```bash +# Check your PyTorch's C++ ABI setting first: +# python -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)" +# Use cxx11abiTRUE or cxx11abiFALSE in the URL accordingly +conda run --no-capture-output -n pipeline-rl pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ``` By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment: diff --git a/conf/base.yaml b/conf/base.yaml index 1f8d73cc..5418f846 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -62,9 +62,6 @@ vllm_config: vllm_kwargs: dtype: bfloat16 gpu-memory-utilization: 0.9 - num-scheduler-steps: 1 - disable-log-requests: "" - disable-frontend-multiprocessing: "" max-num-seqs: ${actor.llm_max_rollouts} max-num-batched-tokens: 1024 enable-chunked-prefill: "" @@ -73,6 +70,12 @@ vllm_config: pipeline-parallel-size: 1 generation-config: vllm max_model_len: 10000 + # V1 specific settings + # logprobs-mode: processed_logprobs + # V0 specific settings + num-scheduler-steps: 1 + disable-log-requests: "" + disable-frontend-multiprocessing: "" world: replicas: 1 diff --git a/conf/math.yaml b/conf/math.yaml index a772a190..5a6e27b7 100644 --- a/conf/math.yaml +++ b/conf/math.yaml @@ -12,11 +12,11 @@ environment: dataset_loader: pipelinerl.domains.math.load_datasets finetune: - seq_length: 18000 + seq_length: 20000 vllm_config: vllm_kwargs: - max_model_len: 18000 + max_model_len: 20000 llm: parameters: diff --git a/pipelinerl/finetune/checkpoints.py b/pipelinerl/finetune/checkpoints.py index 0a673fbf..1dda0b99 100644 --- a/pipelinerl/finetune/checkpoints.py +++ b/pipelinerl/finetune/checkpoints.py @@ -179,7 +179,7 @@ def load_model(args, model_class, current_dir): is_ds_zero_3 = get_accelerator().state.deepspeed_plugin.zero_stage == 3 # type: ignore if args.load_as_bf16: - loading_args["torch_dtype"] = torch.bfloat16 + loading_args["dtype"] = torch.bfloat16 if args.auto_device_map: loading_args["device_map"] = "auto" model_cls = get_auto_model_class(model_class) diff --git a/pipelinerl/finetune/lora.py b/pipelinerl/finetune/lora.py index 81011be0..3c84610d 100644 --- a/pipelinerl/finetune/lora.py +++ b/pipelinerl/finetune/lora.py @@ -147,7 +147,7 @@ def merge_lora(lora_model_path): assert os.path.exists(lora_model_config), f"{lora_model_config} does not exists" logger.info(f"Merge lora checkpoint {lora_model_path}") - model = lora_load_and_merge(lora_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + model = lora_load_and_merge(lora_model_path, dtype=torch.bfloat16, low_cpu_mem_usage=True) tokenizer = AutoTokenizer.from_pretrained(lora_model_path) tmp_dir = f"{lora_model_path}_merged" diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 57d2950e..b23d6429 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -27,8 +27,9 @@ from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params from pipelinerl.finetune.value_model import AutoModelForCausalLMWithValueHead -import pipelinerl.torch_utils +from pipelinerl import torch_utils from pipelinerl.finetune.types import PipelineBatchEncoding + from pipelinerl.finetune.checkpoints import ( load_model, load_tokenizer, @@ -212,7 +213,8 @@ def send_weight_update( for name, parameter in named_parameters.items(): with deepspeed.zero.GatheredParameters([parameter]): if get_accelerator().is_main_process: - dist.broadcast(parameter.data, src=0, group=self.actor_update_group) + # Use PyNcclCommunicator's broadcast method as torch.distributed does not work since vLLM disabled that transfer path + self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) if get_accelerator().is_main_process: logger.info("Wait for HTTP requests") for future in futures: # type: ignore @@ -254,8 +256,8 @@ def send_weight_update( futures = self.request_weight_updates(messages) logger.info(f"Published weight update request for version {version}") for _, parameter in named_parameters.items(): - dist.broadcast(parameter.data, src=0, group=self.actor_update_group) - dist.barrier(self.actor_update_group) + # Use PyNcclCommunicator's broadcast method as torch.distributed does not work since vLLM disabled that transfer path + self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) for future in futures: future.result() logger.info("Finished broadcasting weights") @@ -408,13 +410,15 @@ def run_finetuning_loop( get_accelerator().wait_for_everyone() if get_accelerator().is_main_process and args.send_weight_updates: - logger.info("Initializing actor process group") - actor_update_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + current_device = get_accelerator().device + torch.cuda.set_device(current_device) + logger.info("Initializing actor process group using StatelessProcessGroup") + logger.info(f"Set CUDA device to {current_device} for actor process group (rank 0)") + actor_update_group = torch_utils.stateless_init_process_group( init_method=cfg.me.weight_update_group_init_method, rank=0, world_size=cfg.me.weight_update_group_world_size, + device=current_device, ) logger.info("Actor process group initialized") else: @@ -493,8 +497,7 @@ def run_finetuning_loop( finally: if weight_update_manager is not None: weight_update_manager.shutdown() - if actor_update_group: - dist.destroy_process_group(actor_update_group) + # PyNcclCommunicator doesn't need explicit destroy like torch.distributed process groups def rl_finetuning_worker( diff --git a/pipelinerl/torch_utils.py b/pipelinerl/torch_utils.py index 2d16d99f..d2b78d53 100644 --- a/pipelinerl/torch_utils.py +++ b/pipelinerl/torch_utils.py @@ -1,14 +1,51 @@ +import logging from datetime import timedelta from typing import Any, Optional, Union +from urllib.parse import urlparse + +import torch +import torch.distributed as dist from torch.distributed.distributed_c10d import ( Backend, PrefixStore, + ProcessGroupNCCL, Store, _new_process_group_helper, _world, default_pg_timeout, rendezvous, ) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.utils import StatelessProcessGroup + +logger = logging.getLogger(__name__) + + +def stateless_init_process_group(init_method, rank, world_size, device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + + Args: + init_method: TCP init method string (e.g., "tcp://localhost:9000") + rank: The rank of this process in the group + world_size: Total number of processes in the group + device: The CUDA device to use for NCCL communication + """ + # Parse master_address and master_port from init_method (e.g., "tcp://localhost:9000") + parsed = urlparse(init_method) + master_address = parsed.hostname or "localhost" + master_port = parsed.port or 9000 + logger.debug(f"Parsed master_address: {master_address}, master_port: {master_port}") + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl # Copy from pytorch to allow creating multiple main groups. @@ -49,6 +86,13 @@ def init_extra_process_group( # different systems (e.g. RPC) in case the store is multi-tenant. store = PrefixStore(group_name, store) + # Create NCCL-specific options if using NCCL backend + logger.info(f"[{group_name}] Backend: {backend}, str(backend): {str(backend)}") + if pg_options is None and str(backend) == "nccl": + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = False + logger.info(f"[{group_name}] Created NCCL options: {pg_options}") + pg, _ = _new_process_group_helper( world_size, rank, @@ -59,6 +103,7 @@ def init_extra_process_group( backend_options=pg_options, timeout=timeout, ) + logger.info(f"[{group_name}] Process group created successfully") _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 93e99d60..54f40baf 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -1,9 +1,27 @@ +""" +This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture. + +Compatibility: + - vLLM versions <= 0.10.0 only + - Use vllm1.py for vLLM >= 0.11.0 as the V0 engine was removed in vLLM 0.11.0 +""" +from packaging import version as version_parser +import vllm + +# Check vLLM version compatibility +vllm_version = version_parser.parse(vllm.__version__) + +if vllm_version > version_parser.parse("0.10.0"): + raise ImportError( + f"pipelinerl.vllm0 is not compatible with vLLM {vllm.__version__}. " + "This module only works with vLLM <= 0.10.0. " + "Please use pipelinerl.vllm1 for vLLM >= 0.11.0 instead." + ) + + import asyncio -import json import logging -import os import signal -from pydantic import TypeAdapter import torch import uvloop from vllm import AsyncLLMEngine @@ -14,14 +32,12 @@ ) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, create_server_socket, build_app, init_app_state, ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.logger import init_logger from vllm._version import version from vllm.worker.worker import Worker from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper @@ -33,14 +49,13 @@ from vllm.worker.multi_step_model_runner import MultiStepModelRunner -import torch.distributed as dist from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest from pipelinerl.vllm_quantization import string_to_dtype # reuse dtype mapping -import pipelinerl.torch_utils +from pipelinerl import torch_utils import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config logger = logging.getLogger(__name__) -# configure this logger individually, in order to avoid messign +# configure this logger individually, in order to avoid messing # with the default vllm logger configuration logger.setLevel(logging.INFO) handler = logging.StreamHandler() @@ -72,13 +87,14 @@ def init_actor_update_group( prefix + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}" ) - self.process_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + # Use StatelessProcessGroup + PyNcclCommunicator for cross-process NCCL communication + self.process_group = torch_utils.stateless_init_process_group( init_method=weight_update_group_init_method, rank=self.pg_rank, world_size=weight_update_group_world_size, + device=self.device, ) + logger.info(prefix + "Actor update process group initialized") def receive_weight_update(self, request: WeightUpdateRequest): torch.cuda.synchronize(self.device) @@ -88,7 +104,7 @@ def receive_weight_update(self, request: WeightUpdateRequest): if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - torch.distributed.broadcast(buffer, src=0, group=self.process_group) + self.process_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) if isinstance(self.model_runner, MultiStepModelRunner): loaded_params = self.model_runner._base_model_runner.model.load_weights( weights=[(info.name, buffer)] @@ -155,18 +171,18 @@ def input_process_groups(self): future.get() async def receive_weight_update(self, message: WeightUpdateRequest): - logger.info(f"Received weight update request") + logger.info("Received weight update request") async with executor_lock: if isinstance(self.executor, AsyncRLExecutor): await self.executor.stop_remote_worker_execution_loop_no_lock() - logger.info(f"Stopped remote worker") + logger.info("Stopped remote worker") futures = [] for worker in self.other_workers: futures.append(worker.execute_method("receive_weight_update", message)) self.driver_worker.receive_weight_update(message) for future in futures: future.get() - logger.info(f"All workers received weight updates") + logger.info("All workers received weight updates") async def run_server(args, **uvicorn_kwargs) -> None: @@ -199,7 +215,7 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - # Build the engine with the bespoke Executor and orker clases + # Build the engine with the bespoke Executor and Worker classes multi_step = args.num_scheduler_steps > 1 engine_args = AsyncEngineArgs.from_cli_args(args) engine_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER) diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 1ac611d0..7025f328 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -2,14 +2,14 @@ import signal import torch import uvloop -from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.system_utils import set_ulimit from vllm.entrypoints.openai.cli_args import ( make_arg_parser, validate_parsed_serve_args, ) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, create_server_socket, build_app, init_app_state, @@ -26,8 +26,8 @@ from pipelinerl.finetune_loop import WeightUpdateRequest from pipelinerl.vllm_quantization import string_to_dtype # reuse mapping +from pipelinerl.torch_utils import stateless_init_process_group from typing import Any, Protocol, runtime_checkable -import pipelinerl.torch_utils import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config logger = logging.getLogger(__name__) @@ -46,7 +46,7 @@ class LikeWorker(Protocol): rank: int local_rank: int device: torch.device - model_runner: GPUModelRunner + model_runner: GPUModelRunner pg_rank: int process_group: Any model_config: ModelConfig @@ -72,34 +72,47 @@ def init_actor_update_group( prefix + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}" ) - self.process_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + + # Use vLLM's StatelessProcessGroup instead of torch.distributed + self.model_update_group = stateless_init_process_group( init_method=weight_update_group_init_method, rank=self.pg_rank, world_size=weight_update_group_world_size, + device=self.device, ) + logger.info(prefix + "Actor update process group initialized") - def receive_weight_update(self: LikeWorker, request: WeightUpdateRequest): + def receive_weight_update(self: LikeWorker, request_json: str): + request = WeightUpdateRequest.model_validate_json(request_json) torch.cuda.synchronize(self.device) logger.info("Start receiving weight update") expected_dtypes = (torch.bfloat16, torch.float32, torch.float16) + for info in request.parameters_info: target_dtype = string_to_dtype(info.dtype) if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - torch.distributed.broadcast(buffer, src=0, group=self.process_group) - loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore + self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) + loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore if len(loaded_params) != 1: raise ValueError(f"model {info.name} not found in model state dict") + pipelinerl.vllm_quantization.invalidate_fp32_cache() logger.info("Weight update received") + def close_communicator(self): + """Closes the communicator when weight synchronization is no longer needed.""" + if hasattr(self, "model_update_group") and self.model_update_group is not None: + del self.model_update_group + self.model_update_group = None + logger.info("Weight update communicator closed") + class WeightUpdateManager: - def __init__(self, args, engine_client: AsyncMPClient): + def __init__(self, args, engine: AsyncLLM, engine_client: AsyncMPClient): self.args = args + self.engine = engine self.engine_client = engine_client async def input_process_groups(self): @@ -114,11 +127,16 @@ async def input_process_groups(self): ) async def receive_weight_update(self, request: WeightUpdateRequest): + logger.info("Starting weight update...") await self.engine_client.collective_rpc_async( - "receive_weight_update", args=(request,) + "receive_weight_update", args=(request.model_dump_json(),) ) logger.info("Weight update processed") + async def close_communicator(self): + """Closes the communicator when weight synchronization is no longer needed.""" + await self.engine_client.collective_rpc_async("close_communicator") + async def run_server(args, **uvicorn_kwargs) -> None: # COPIED FROM vllm/entrypoints/openai/api_server.py, vllm version 0.6.6.post1 @@ -157,11 +175,11 @@ def signal_handler(*_) -> None: vllm_config=engine_config, usage_context=UsageContext.OPENAI_API_SERVER, disable_log_stats=engine_args.disable_log_stats, - disable_log_requests=engine_args.disable_log_requests, + enable_log_requests=engine_args.enable_log_requests, ) assert isinstance(engine.engine_core, AsyncMPClient) - weight_update_manager = WeightUpdateManager(args, engine.engine_core) + weight_update_manager = WeightUpdateManager(args, engine, engine.engine_core) if not args.disable_weight_updates: await weight_update_manager.input_process_groups() @@ -172,10 +190,12 @@ def signal_handler(*_) -> None: @app.post("/receive_weight_update") async def _receive_weight_update(request: WeightUpdateRequest): + # Blocking: wait for weight update to complete before returning + logger.info("Received weight update request") await weight_update_manager.receive_weight_update(request) return {"status": "ok"} - await init_app_state(engine, engine_config, app.state, args) + await init_app_state(engine, app.state, args) shutdown_task = await serve_http( app, sock, @@ -194,11 +214,11 @@ async def _receive_weight_update(request: WeightUpdateRequest): # NB: Await server shutdown only after the backend context is exited await shutdown_task + # Cleanup + if not args.disable_weight_updates: + await weight_update_manager.close_communicator() sock.close() - # TODO: proper cleanup - # dist.destroy_process_group(actor_update_group) - def run_llm(): parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.") diff --git a/pyproject.toml b/pyproject.toml index 7fc9978a..ac0ccc28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,11 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] +requires = [ + "setuptools>=61.0", + "wheel", + "numpy>=1.26.0", + "packaging>=23.0", + "torch==2.7.1", +] build-backend = "setuptools.build_meta" [project] @@ -7,22 +13,20 @@ name = "pipelinerl" version = "0.1.0" description = "A scalable asynchronous reinforcement learning implementation with in-flight weight updates." readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.12" license = { file = "LICENSE" } authors = [ { name = "ServiceNow" }, ] dependencies = [ "aiohttp>=3.9.0", - "torch>=2.6", - "vllm==0.8.5.post1", - "accelerate==1.7.0", - "deepspeed==0.15.4", + "vllm==0.10.0", + "accelerate==1.12.0", + "deepspeed~=0.18.0", "browsergym>=0.13.0", "datasets>=2.21.0", - "transformers==4.51.1" , + "transformers~=4.57.0" , "fastapi>=0.115.0", - "flash-attn==2.7.4.post1", "joblib>=1.3.2", "jsonref>=1.1.0", "litellm>=1.61.0", @@ -32,11 +36,11 @@ dependencies = [ "Pillow>=10.0.0", "psutil>=5.9.0", "pydantic>=2.9.0", - "ring-flash-attn==0.1.6", - "math-verify[antlr4_9_3]==0.7.0", - "orjson==3.10.16", + "ring-flash-attn==0.1.8", + "math-verify[antlr4_9_3]==0.8.0", + "orjson~=3.11.0", "requests>=2.31.0", - "redis==5.2.1", + "redis~=7.0.0", "safetensors>=0.4.0", "tenacity>=8.2.0", "uvicorn>=0.29.0", @@ -50,7 +54,7 @@ tapeagents = [ "Tapeagents[finetune]==0.1.16", ] lora = [ - "peft==0.12.0", + "peft==0.18.0", ] [tool.setuptools.packages.find]