Skip to content

Commit 7a28716

Browse files
committed
Update
[ghstack-poisoned]
2 parents ecc2eb3 + 4a7cb21 commit 7a28716

File tree

26 files changed

+51
-116
lines changed

26 files changed

+51
-116
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
build-docs:
2727
strategy:
2828
matrix:
29-
python_version: [ "3.9" ]
29+
python_version: [ "3.12" ]
3030
cuda_arch_version: [ "12.8" ]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
@@ -60,7 +60,7 @@ jobs:
6060
bash ./miniconda.sh -b -f -p "${conda_dir}"
6161
eval "$(${conda_dir}/bin/conda shell.bash hook)"
6262
printf "* Creating a test environment\n"
63-
conda create --prefix "${env_dir}" -y python=3.9
63+
conda create --prefix "${env_dir}" -y python=3.12
6464
printf "* Activating\n"
6565
conda activate "${env_dir}"
6666

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ sphinx_design
1616
torchvision
1717
dm_control
1818
mujoco<3.3.6
19-
gym[classic_control,accept-rom-license,ale-py,atari]
19+
gymnasium[classic_control,atari]
2020
pygame
2121
tqdm
2222
ipython

docs/source/reference/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ Training and Optimization Configurations
507507
SparseAdamConfig
508508

509509
Logging Configurations
510-
~~~~~~~~~~~~~~~~~~~~~
510+
~~~~~~~~~~~~~~~~~~~~~~
511511

512512
.. currentmodule:: torchrl.trainers.algorithms.configs.logging
513513

docs/source/reference/envs.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,6 @@ to be able to create this other composition:
11231123
ExcludeTransform
11241124
FiniteTensorDictCheck
11251125
FlattenObservation
1126-
FlattenTensorDict
11271126
FrameSkipTransform
11281127
GrayScale
11291128
Hash

docs/source/reference/llms.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ Usage
118118
Adding Custom Templates
119119
^^^^^^^^^^^^^^^^^^^^^^^
120120

121-
You can add custom chat templates for new model families using the :func:`torchrl.data.llm.chat.add_chat_template` function.
121+
You can add custom chat templates for new model families using the :func:`torchrl.data.llm.add_chat_template` function.
122122

123-
.. autofunction:: torchrl.data.llm.chat.add_chat_template
123+
.. autofunction:: torchrl.data.llm.add_chat_template
124124

125125
Usage Examples
126126
^^^^^^^^^^^^^^
@@ -130,7 +130,7 @@ Adding a Llama Template
130130

131131
.. code-block:: python
132132
133-
>>> from torchrl.data.llm.chat import add_chat_template, History
133+
>>> from torchrl.data.llm import add_chat_template, History
134134
>>> from transformers import AutoTokenizer
135135
>>>
136136
>>> # Define the Llama chat template

docs/source/reference/utils.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. currentmodule:: torchrl
22

33
torchrl._utils package
4-
====================
4+
======================
55

66
Set of utility methods that are used internally by the library.
77

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,4 @@ first_party_detection = false
149149
[project.entry-points."vllm.general_plugins"]
150150
# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and
151151
# the registry subprocess) before resolving model classes.
152-
fp32_overrides = "torchrl.modules.llm.backends.vllm_plugin:register_fp32_overrides"
152+
fp32_overrides = "torchrl.modules.llm.backends.vllm.vllm_plugin:register_fp32_overrides"

test/llm/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TestAsyncVLLMIntegration:
4040
@pytest.mark.slow
4141
def test_vllm_api_compatibility(self, sampling_params):
4242
"""Test that AsyncVLLM supports the same inputs as vLLM.LLM.generate()."""
43-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
43+
from torchrl.modules.llm.backends import AsyncVLLM
4444

4545
# Create AsyncVLLM service
4646
service = AsyncVLLM.from_pretrained(
@@ -113,7 +113,7 @@ def test_vllm_api_compatibility(self, sampling_params):
113113
def test_weight_updates_with_transformer(self, sampling_params):
114114
"""Test weight updates using vLLMUpdater with a real transformer model."""
115115
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
116-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
116+
from torchrl.modules.llm.backends import AsyncVLLM
117117
from torchrl.modules.llm.policies.transformers_wrapper import (
118118
TransformersWrapper,
119119
)

test/llm/test_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensordict.utils import _zip_strict
1919
from torchrl.data.llm import History
2020
from torchrl.envs.llm.transforms.kl import KLComputation, RetrieveKL, RetrieveLogProb
21-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
21+
from torchrl.modules.llm import AsyncVLLM
2222
from torchrl.modules.llm.policies.common import (
2323
_batching,
2424
ChatHistory,

torchrl/collectors/collectors.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,13 @@ def async_shutdown(
282282
) -> None:
283283
"""Shuts down the collector when started asynchronously with the `start` method.
284284
285-
Arg:
285+
Args:
286286
timeout (float, optional): The maximum time to wait for the collector to shutdown.
287287
close_env (bool, optional): If True, the collector will close the contained environment.
288288
Defaults to `True`.
289289
290290
.. seealso:: :meth:`~.start`
291+
291292
"""
292293
return self.shutdown(timeout=timeout, close_env=close_env)
293294

@@ -595,7 +596,7 @@ class SyncDataCollector(DataCollectorBase):
595596
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
596597
597598
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
598-
pickled directly), the :arg:`policy_factory` should be used instead.
599+
pickled directly), the ``policy_factory`` should be used instead.
599600
600601
Keyword Args:
601602
policy_factory (Callable[[], Callable], optional): a callable that returns
@@ -2082,7 +2083,7 @@ class _MultiDataCollector(DataCollectorBase):
20822083
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
20832084
20842085
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
2085-
pickled directly), the :arg:`policy_factory` should be used instead.
2086+
pickled directly), the ``policy_factory`` should be used instead.
20862087
20872088
Keyword Args:
20882089
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -3278,8 +3279,8 @@ class MultiSyncDataCollector(_MultiDataCollector):
32783279
... if i == 2:
32793280
... print(data)
32803281
... break
3281-
>>> collector.shutdown()
3282-
>>> del collector
3282+
... collector.shutdown()
3283+
... del collector
32833284
TensorDict(
32843285
fields={
32853286
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -3665,8 +3666,8 @@ class MultiaSyncDataCollector(_MultiDataCollector):
36653666
... if i == 2:
36663667
... print(data)
36673668
... break
3668-
... collector.shutdown()
3669-
... del collector
3669+
... collector.shutdown()
3670+
... del collector
36703671
TensorDict(
36713672
fields={
36723673
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -3901,7 +3902,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
39013902
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
39023903
39033904
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
3904-
pickled directly), the :arg:`policy_factory` should be used instead.
3905+
pickled directly), the ``policy_factory`` should be used instead.
39053906
39063907
Keyword Args:
39073908
policy_factory (Callable[[], Callable], optional): a callable that returns
@@ -3915,8 +3916,8 @@ class aSyncDataCollector(MultiaSyncDataCollector):
39153916
total number of frames returned by the collector
39163917
during its lifespan. If the ``total_frames`` is not divisible by
39173918
``frames_per_batch``, an exception is raised.
3918-
Endless collectors can be created by passing ``total_frames=-1``.
3919-
Defaults to ``-1`` (never ending collector).
3919+
Endless collectors can be created by passing ``total_frames=-1``.
3920+
Defaults to ``-1`` (never ending collector).
39203921
device (int, str or torch.device, optional): The generic device of the
39213922
collector. The ``device`` args fills any non-specified device: if
39223923
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or

0 commit comments

Comments
 (0)