Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,17 @@ cd PipelineRL

Create the environments with dependencies.
```bash
conda create -n pipeline-rl -y python=3.11
conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0
conda run --no-capture-output -n pipeline-rl pip install -e . --no-build-isolation
conda create -n pipeline-rl -y python=3.12
conda run --no-capture-output -n pipeline-rl pip install -e .
conda run --no-capture-output -n pipeline-rl pip install flash-attn==2.8.3 --no-build-isolation
```

Alternatively for `flash-attn`, you can install it via prebuilt packages (on Linux):
```bash
# Check your PyTorch's C++ ABI setting first:
# python -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)"
# Use cxx11abiTRUE or cxx11abiFALSE in the URL accordingly
conda run --no-capture-output -n pipeline-rl pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
```

By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment:
Expand Down
9 changes: 6 additions & 3 deletions conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ vllm_config:
vllm_kwargs:
dtype: bfloat16
gpu-memory-utilization: 0.9
num-scheduler-steps: 1
disable-log-requests: ""
disable-frontend-multiprocessing: ""
max-num-seqs: ${actor.llm_max_rollouts}
max-num-batched-tokens: 1024
enable-chunked-prefill: ""
Expand All @@ -73,6 +70,12 @@ vllm_config:
pipeline-parallel-size: 1
generation-config: vllm
max_model_len: 10000
# V1 specific settings
# logprobs-mode: processed_logprobs
# V0 specific settings
num-scheduler-steps: 1
disable-log-requests: ""
disable-frontend-multiprocessing: ""

world:
replicas: 1
Expand Down
4 changes: 2 additions & 2 deletions conf/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ environment:
dataset_loader: pipelinerl.domains.math.load_datasets

finetune:
seq_length: 18000
seq_length: 20000

vllm_config:
vllm_kwargs:
max_model_len: 18000
max_model_len: 20000

llm:
parameters:
Expand Down
2 changes: 1 addition & 1 deletion pipelinerl/finetune/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def load_model(args, model_class, current_dir):
is_ds_zero_3 = get_accelerator().state.deepspeed_plugin.zero_stage == 3 # type: ignore

if args.load_as_bf16:
loading_args["torch_dtype"] = torch.bfloat16
loading_args["dtype"] = torch.bfloat16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change? has the transformers API changed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_dtype became deprecated and prints a warning

if args.auto_device_map:
loading_args["device_map"] = "auto"
model_cls = get_auto_model_class(model_class)
Expand Down
2 changes: 1 addition & 1 deletion pipelinerl/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def merge_lora(lora_model_path):
assert os.path.exists(lora_model_config), f"{lora_model_config} does not exists"

logger.info(f"Merge lora checkpoint {lora_model_path}")
model = lora_load_and_merge(lora_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
model = lora_load_and_merge(lora_model_path, dtype=torch.bfloat16, low_cpu_mem_usage=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_dtype renamed to dtype

tokenizer = AutoTokenizer.from_pretrained(lora_model_path)

tmp_dir = f"{lora_model_path}_merged"
Expand Down
23 changes: 13 additions & 10 deletions pipelinerl/finetune_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params

from pipelinerl.finetune.value_model import AutoModelForCausalLMWithValueHead
import pipelinerl.torch_utils
from pipelinerl import torch_utils
from pipelinerl.finetune.types import PipelineBatchEncoding

from pipelinerl.finetune.checkpoints import (
load_model,
load_tokenizer,
Expand Down Expand Up @@ -212,7 +213,8 @@ def send_weight_update(
for name, parameter in named_parameters.items():
with deepspeed.zero.GatheredParameters([parameter]):
if get_accelerator().is_main_process:
dist.broadcast(parameter.data, src=0, group=self.actor_update_group)
# Use PyNcclCommunicator's broadcast method as torch.distributed does not work since vLLM disabled that transfer path
self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream())
if get_accelerator().is_main_process:
logger.info("Wait for HTTP requests")
for future in futures: # type: ignore
Expand Down Expand Up @@ -254,8 +256,8 @@ def send_weight_update(
futures = self.request_weight_updates(messages)
logger.info(f"Published weight update request for version {version}")
for _, parameter in named_parameters.items():
dist.broadcast(parameter.data, src=0, group=self.actor_update_group)
dist.barrier(self.actor_update_group)
# Use PyNcclCommunicator's broadcast method as torch.distributed does not work since vLLM disabled that transfer path
self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream())
for future in futures:
future.result()
logger.info("Finished broadcasting weights")
Expand Down Expand Up @@ -408,13 +410,15 @@ def run_finetuning_loop(
get_accelerator().wait_for_everyone()

if get_accelerator().is_main_process and args.send_weight_updates:
logger.info("Initializing actor process group")
actor_update_group = pipelinerl.torch_utils.init_extra_process_group(
group_name="actor",
backend="nccl",
current_device = get_accelerator().device
torch.cuda.set_device(current_device)
logger.info("Initializing actor process group using StatelessProcessGroup")
logger.info(f"Set CUDA device to {current_device} for actor process group (rank 0)")
actor_update_group = torch_utils.stateless_init_process_group(
init_method=cfg.me.weight_update_group_init_method,
rank=0,
world_size=cfg.me.weight_update_group_world_size,
device=current_device,
)
logger.info("Actor process group initialized")
else:
Expand Down Expand Up @@ -493,8 +497,7 @@ def run_finetuning_loop(
finally:
if weight_update_manager is not None:
weight_update_manager.shutdown()
if actor_update_group:
dist.destroy_process_group(actor_update_group)
# PyNcclCommunicator doesn't need explicit destroy like torch.distributed process groups


def rl_finetuning_worker(
Expand Down
45 changes: 45 additions & 0 deletions pipelinerl/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import logging
from datetime import timedelta
from typing import Any, Optional, Union
from urllib.parse import urlparse

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
ProcessGroupNCCL,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

logger = logging.getLogger(__name__)


def stateless_init_process_group(init_method, rank, world_size, device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.

Args:
init_method: TCP init method string (e.g., "tcp://localhost:9000")
rank: The rank of this process in the group
world_size: Total number of processes in the group
device: The CUDA device to use for NCCL communication
"""
# Parse master_address and master_port from init_method (e.g., "tcp://localhost:9000")
parsed = urlparse(init_method)
master_address = parsed.hostname or "localhost"
master_port = parsed.port or 9000
logger.debug(f"Parsed master_address: {master_address}, master_port: {master_port}")

pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


# Copy from pytorch to allow creating multiple main groups.
Expand Down Expand Up @@ -49,6 +86,13 @@ def init_extra_process_group(
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)

# Create NCCL-specific options if using NCCL backend
logger.info(f"[{group_name}] Backend: {backend}, str(backend): {str(backend)}")
if pg_options is None and str(backend) == "nccl":
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = False
logger.info(f"[{group_name}] Created NCCL options: {pg_options}")

pg, _ = _new_process_group_helper(
world_size,
rank,
Expand All @@ -59,6 +103,7 @@ def init_extra_process_group(
backend_options=pg_options,
timeout=timeout,
)
logger.info(f"[{group_name}] Process group created successfully")

_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

Expand Down
48 changes: 32 additions & 16 deletions pipelinerl/vllm0.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
"""
This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture.

Compatibility:
- vLLM versions <= 0.10.0 only
- Use vllm1.py for vLLM >= 0.11.0 as the V0 engine was removed in vLLM 0.11.0
"""
from packaging import version as version_parser
import vllm

# Check vLLM version compatibility
vllm_version = version_parser.parse(vllm.__version__)

if vllm_version > version_parser.parse("0.10.0"):
raise ImportError(
f"pipelinerl.vllm0 is not compatible with vLLM {vllm.__version__}. "
"This module only works with vLLM <= 0.10.0. "
"Please use pipelinerl.vllm1 for vLLM >= 0.11.0 instead."
)


import asyncio
import json
import logging
import os
import signal
from pydantic import TypeAdapter
import torch
import uvloop
from vllm import AsyncLLMEngine
Expand All @@ -14,14 +32,12 @@
)
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
run_server,
create_server_socket,
build_app,
init_app_state,
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm._version import version
from vllm.worker.worker import Worker
from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper
Expand All @@ -33,14 +49,13 @@
from vllm.worker.multi_step_model_runner import MultiStepModelRunner


import torch.distributed as dist
from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest
from pipelinerl.vllm_quantization import string_to_dtype # reuse dtype mapping
import pipelinerl.torch_utils
from pipelinerl import torch_utils
import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config

logger = logging.getLogger(__name__)
# configure this logger individually, in order to avoid messign
# configure this logger individually, in order to avoid messing
# with the default vllm logger configuration
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
Expand Down Expand Up @@ -72,13 +87,14 @@ def init_actor_update_group(
prefix
+ f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}"
)
self.process_group = pipelinerl.torch_utils.init_extra_process_group(
group_name="actor",
backend="nccl",
# Use StatelessProcessGroup + PyNcclCommunicator for cross-process NCCL communication
self.process_group = torch_utils.stateless_init_process_group(
init_method=weight_update_group_init_method,
rank=self.pg_rank,
world_size=weight_update_group_world_size,
device=self.device,
)
logger.info(prefix + "Actor update process group initialized")

def receive_weight_update(self, request: WeightUpdateRequest):
torch.cuda.synchronize(self.device)
Expand All @@ -88,7 +104,7 @@ def receive_weight_update(self, request: WeightUpdateRequest):
if target_dtype not in expected_dtypes:
logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}")
buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device)
torch.distributed.broadcast(buffer, src=0, group=self.process_group)
self.process_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream())
if isinstance(self.model_runner, MultiStepModelRunner):
loaded_params = self.model_runner._base_model_runner.model.load_weights(
weights=[(info.name, buffer)]
Expand Down Expand Up @@ -155,18 +171,18 @@ def input_process_groups(self):
future.get()

async def receive_weight_update(self, message: WeightUpdateRequest):
logger.info(f"Received weight update request")
logger.info("Received weight update request")
async with executor_lock:
if isinstance(self.executor, AsyncRLExecutor):
await self.executor.stop_remote_worker_execution_loop_no_lock()
logger.info(f"Stopped remote worker")
logger.info("Stopped remote worker")
futures = []
for worker in self.other_workers:
futures.append(worker.execute_method("receive_weight_update", message))
self.driver_worker.receive_weight_update(message)
for future in futures:
future.get()
logger.info(f"All workers received weight updates")
logger.info("All workers received weight updates")


async def run_server(args, **uvicorn_kwargs) -> None:
Expand Down Expand Up @@ -199,7 +215,7 @@ def signal_handler(*_) -> None:

signal.signal(signal.SIGTERM, signal_handler)

# Build the engine with the bespoke Executor and orker clases
# Build the engine with the bespoke Executor and Worker classes
multi_step = args.num_scheduler_steps > 1
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER)
Expand Down
Loading