File tree Expand file tree Collapse file tree 4 files changed +21
-24
lines changed
modules/llm/backends/vllm Expand file tree Collapse file tree 4 files changed +21
-24
lines changed Original file line number Diff line number Diff line change 99export DEBIAN_FRONTEND=noninteractive
1010export TZ=UTC
1111apt-get update
12- apt-get install -yq --no-install-recommends git cmake
12+ apt-get install -yq --no-install-recommends git wget unzip curl patchelf
1313# Avoid error: "fatal: unsafe repository"
1414git config --global --add safe.directory ' *'
15- apt-get install -yq --no-install-recommends wget \
16- gcc \
17- g++ \
18- unzip \
19- curl \
20- patchelf \
21- libosmesa6-dev \
22- libgl1-mesa-glx \
23- libglfw3 \
24- swig3.0 \
25- libglew-dev \
26- libglvnd0 \
27- libgl1 \
28- libglx0 \
29- libegl1 \
30- libgles2
15+ # The base PyTorch devel image provides compilers, CMake >= 3.22, and most build deps.
16+ # Install only minimal utilities not guaranteed to be present.
3117
32- # Upgrade specific package
33- apt-get install -yq --no-install-recommends --only-upgrade libstdc++6
18+ # CMake available in the PyTorch devel image (Ubuntu 22.04) is sufficient.
3419
35- apt-get clean
36- rm -rf /var/lib/apt/lists/*
20+ # Cleanup APT cache
21+ apt-get clean && rm -rf /var/lib/apt/lists/*
3722
3823this_dir=" $( cd " $( dirname " ${BASH_SOURCE[0]} " ) " > /dev/null 2>&1 && pwd ) "
3924root_dir=" $( git rev-parse --show-toplevel) "
Original file line number Diff line number Diff line change 3232 runner : " linux.g6.4xlarge.experimental.nvidia.gpu"
3333 # gpu-arch-type: cuda
3434 # gpu-arch-version: "11.7"
35- docker-image : " nvidia/cudagl:11.4 .0-base "
35+ docker-image : " pytorch/pytorch:2.8 .0-cuda12.8-cudnn9-devel "
3636 timeout : 120
3737 script : |
3838 if [[ "${{ github.ref }}" =~ release/* ]]; then
4545
4646 set -euo pipefail
4747 export PYTHON_VERSION="3.9"
48- export CU_VERSION="cu117 "
48+ export CU_VERSION="cu128 "
4949 export TAR_OPTIONS="--no-same-owner"
5050 export UPLOAD_CHANNEL="nightly"
5151 export TF_CPP_MIN_LOG_LEVEL=0
Original file line number Diff line number Diff line change @@ -475,6 +475,18 @@ def update_policy_weights_(
475475 # Apply to local policy
476476 if hasattr (self , "policy" ) and isinstance (self .policy , nn .Module ):
477477 strategy .apply_weights (self .policy , weights )
478+ elif (
479+ hasattr (self , "_original_policy" )
480+ and isinstance (self ._original_policy , nn .Module )
481+ and hasattr (self , "policy" )
482+ and isinstance (self .policy , nn .Module )
483+ ):
484+ # If no weights were provided, mirror weights from the original (trainer) policy
485+ from torchrl .weight_update .weight_sync_schemes import WeightStrategy
486+
487+ strategy = WeightStrategy (extract_as = "tensordict" )
488+ weights = strategy .extract_weights (self ._original_policy )
489+ strategy .apply_weights (self .policy , weights )
478490 # Otherwise, no action needed - policy is local and changes are immediately visible
479491
480492 def __iter__ (self ) -> Iterator [TensorDictBase ]:
Original file line number Diff line number Diff line change 2020from concurrent .futures import ThreadPoolExecutor , wait
2121from typing import Any , Literal , TYPE_CHECKING
2222
23-
2423import torch
2524
2625from torchrl ._utils import logger as torchrl_logger
@@ -58,6 +57,7 @@ def _get_ray():
5857 "ray is not installed. Please install it with `pip install ray`."
5958 ) from e
6059
60+
6161class _AsyncvLLMWorker :
6262 """Async vLLM worker for Ray with weight update capabilities.
6363
You can’t perform that action at this time.
0 commit comments