Skip to content

Commit ce4d189

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 7f2caeb commit ce4d189

File tree

4 files changed

+21
-24
lines changed

4 files changed

+21
-24
lines changed

.github/unittest/llm/scripts_llm/setup_env.sh

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,16 @@ set -e
99
export DEBIAN_FRONTEND=noninteractive
1010
export TZ=UTC
1111
apt-get update
12-
apt-get install -yq --no-install-recommends git cmake
12+
apt-get install -yq --no-install-recommends git wget unzip curl patchelf
1313
# Avoid error: "fatal: unsafe repository"
1414
git config --global --add safe.directory '*'
15-
apt-get install -yq --no-install-recommends wget \
16-
gcc \
17-
g++ \
18-
unzip \
19-
curl \
20-
patchelf \
21-
libosmesa6-dev \
22-
libgl1-mesa-glx \
23-
libglfw3 \
24-
swig3.0 \
25-
libglew-dev \
26-
libglvnd0 \
27-
libgl1 \
28-
libglx0 \
29-
libegl1 \
30-
libgles2
15+
# The base PyTorch devel image provides compilers, CMake >= 3.22, and most build deps.
16+
# Install only minimal utilities not guaranteed to be present.
3117

32-
# Upgrade specific package
33-
apt-get install -yq --no-install-recommends --only-upgrade libstdc++6
18+
# CMake available in the PyTorch devel image (Ubuntu 22.04) is sufficient.
3419

35-
apt-get clean
36-
rm -rf /var/lib/apt/lists/*
20+
# Cleanup APT cache
21+
apt-get clean && rm -rf /var/lib/apt/lists/*
3722

3823
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
3924
root_dir="$(git rev-parse --show-toplevel)"

.github/workflows/test-linux-llm.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
runner: "linux.g6.4xlarge.experimental.nvidia.gpu"
3333
# gpu-arch-type: cuda
3434
# gpu-arch-version: "11.7"
35-
docker-image: "nvidia/cudagl:11.4.0-base"
35+
docker-image: "pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel"
3636
timeout: 120
3737
script: |
3838
if [[ "${{ github.ref }}" =~ release/* ]]; then
@@ -45,7 +45,7 @@ jobs:
4545
4646
set -euo pipefail
4747
export PYTHON_VERSION="3.9"
48-
export CU_VERSION="cu117"
48+
export CU_VERSION="cu128"
4949
export TAR_OPTIONS="--no-same-owner"
5050
export UPLOAD_CHANNEL="nightly"
5151
export TF_CPP_MIN_LOG_LEVEL=0

torchrl/collectors/collectors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,18 @@ def update_policy_weights_(
475475
# Apply to local policy
476476
if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
477477
strategy.apply_weights(self.policy, weights)
478+
elif (
479+
hasattr(self, "_original_policy")
480+
and isinstance(self._original_policy, nn.Module)
481+
and hasattr(self, "policy")
482+
and isinstance(self.policy, nn.Module)
483+
):
484+
# If no weights were provided, mirror weights from the original (trainer) policy
485+
from torchrl.weight_update.weight_sync_schemes import WeightStrategy
486+
487+
strategy = WeightStrategy(extract_as="tensordict")
488+
weights = strategy.extract_weights(self._original_policy)
489+
strategy.apply_weights(self.policy, weights)
478490
# Otherwise, no action needed - policy is local and changes are immediately visible
479491

480492
def __iter__(self) -> Iterator[TensorDictBase]:

torchrl/modules/llm/backends/vllm/vllm_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from concurrent.futures import ThreadPoolExecutor, wait
2121
from typing import Any, Literal, TYPE_CHECKING
2222

23-
2423
import torch
2524

2625
from torchrl._utils import logger as torchrl_logger
@@ -58,6 +57,7 @@ def _get_ray():
5857
"ray is not installed. Please install it with `pip install ray`."
5958
) from e
6059

60+
6161
class _AsyncvLLMWorker:
6262
"""Async vLLM worker for Ray with weight update capabilities.
6363

0 commit comments

Comments
 (0)