From 4fb3b4a7e82b04ecadd4a11ed8b322302e3d1c62 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:40:41 +0100 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- torchrl/envs/llm/transforms/dataloading.py | 10 +- torchrl/envs/llm/transforms/kl.py | 2 +- torchrl/envs/llm/transforms/tools.py | 2 +- torchrl/envs/transforms/__init__.py | 7 +- torchrl/envs/transforms/module.py | 179 ++++++++++++++++++ .../envs/{llm => }/transforms/ray_service.py | 62 +++--- torchrl/envs/transforms/transforms.py | 13 +- 7 files changed, 235 insertions(+), 40 deletions(-) create mode 100644 torchrl/envs/transforms/module.py rename torchrl/envs/{llm => }/transforms/ray_service.py (94%) diff --git a/torchrl/envs/llm/transforms/dataloading.py b/torchrl/envs/llm/transforms/dataloading.py index 2c10e4d4b32..40e2e35122d 100644 --- a/torchrl/envs/llm/transforms/dataloading.py +++ b/torchrl/envs/llm/transforms/dataloading.py @@ -15,14 +15,14 @@ from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec from torchrl.envs.common import EnvBase +from torchrl.envs.transforms import TensorDictPrimer, Transform # Import ray service components -from torchrl.envs.llm.transforms.ray_service import ( +from torchrl.envs.transforms.ray_service import ( _map_input_output_device, _RayServiceMetaClass, RayTransform, ) -from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.utils import make_composite_from_td T = TypeVar("T") @@ -259,7 +259,7 @@ def primers(self): @primers.setter def primers(self, value: TensorSpec): """Set primers property.""" - self._ray.get(self._actor.set_attr.remote("primers", value)) + self._ray.get(self._actor._set_attr.remote("primers", value)) # TensorDictPrimer methods def init(self, tensordict: TensorDictBase | None): @@ -857,7 +857,3 @@ def _update_primers_batch_size(self, parent_batch_size): def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" - - def set_attr(self, name, value): - """Set attribute on the remote actor or locally.""" - setattr(self, name, value) diff --git a/torchrl/envs/llm/transforms/kl.py b/torchrl/envs/llm/transforms/kl.py index b079c089e8e..311c9d72e55 100644 --- a/torchrl/envs/llm/transforms/kl.py +++ b/torchrl/envs/llm/transforms/kl.py @@ -18,7 +18,7 @@ from torchrl.data import Composite, Unbounded from torchrl.data.tensor_specs import DEVICE_TYPING from torchrl.envs import EnvBase, Transform -from torchrl.envs.llm.transforms.ray_service import _RayServiceMetaClass, RayTransform +from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform from torchrl.envs.transforms.transforms import Compose from torchrl.envs.transforms.utils import _set_missing_tolerance from torchrl.modules.llm.policies.common import LLMWrapperBase diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index a9fd7e28434..e5d8574af7d 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -17,7 +17,7 @@ import torch from tensordict import lazy_stack, TensorDictBase -from torchrl import torchrl_logger +from torchrl._utils import logger as torchrl_logger from torchrl.data.llm import History from torchrl.envs import Transform diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 7539dbe6844..b6da101eac3 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -5,9 +5,10 @@ from .gym_transforms import EndOfLifeTransform from .llm import KLRewardTransform +from .module import ModuleTransform from .r3m import R3MTransform +from .ray_service import RayTransform from .rb_transforms import MultiStepTransform - from .transforms import ( ActionDiscretizer, ActionMask, @@ -85,9 +86,9 @@ "CatFrames", "CatTensors", "CenterCrop", - "ConditionalPolicySwitch", "ClipTransform", "Compose", + "ConditionalPolicySwitch", "ConditionalSkip", "Crop", "DTypeCastTransform", @@ -104,6 +105,7 @@ "InitTracker", "KLRewardTransform", "LineariseRewards", + "ModuleTransform", "MultiAction", "MultiStepTransform", "NoopResetEnv", @@ -113,6 +115,7 @@ "PinMemoryTransform", "R3MTransform", "RandomCropTensorDict", + "RayTransform", "RemoveEmptySpecs", "RenameTransform", "Resize", diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py new file mode 100644 index 00000000000..fb4be15b7c0 --- /dev/null +++ b/torchrl/envs/transforms/module.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable +from contextlib import nullcontext +from typing import overload + +import torch +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform +from torchrl.envs.transforms.transforms import Transform + + +__all__ = ["ModuleTransform", "RayModuleTransform"] + + +class RayModuleTransform(RayTransform): + """Ray-based ModuleTransform for distributed processing. + + This transform creates a Ray actor that wraps a ModuleTransform, + allowing module execution in a separate Ray worker process. + """ + + def _create_actor(self, **kwargs): + return self._ray.remote(ModuleTransform).remote(**kwargs) + + @overload + def update_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + ... + + @overload + def update_weights(self, params: TensorDictBase) -> None: + ... + + def update_weights(self, *args, **kwargs) -> None: + import ray + + if self._update_weights_method == "tensordict": + try: + td = kwargs.get("params", args[0]) + except IndexError: + raise ValueError("params must be provided") + return ray.get(self._actor._update_weights_tensordict.remote(params=td)) + elif self._update_weights_method == "state_dict": + try: + state_dict = kwargs.get("state_dict", args[0]) + except IndexError: + raise ValueError("state_dict must be provided") + return ray.get( + self._actor._update_weights_state_dict.remote(state_dict=state_dict) + ) + else: + raise ValueError( + f"Invalid update_weights_method: {self._update_weights_method}" + ) + + +class ModuleTransform(Transform, metaclass=_RayServiceMetaClass): + """A transform that wraps a module. + + Keyword Args: + module (TensorDictModuleBase): The module to wrap. Exclusive with `module_factory`. At least one of `module` or `module_factory` must be provided. + module_factory (Callable[[], TensorDictModuleBase]): The factory to create the module. Exclusive with `module`. At least one of `module` or `module_factory` must be provided. + no_grad (bool, optional): Whether to use gradient computation. Default is `False`. + inverse (bool, optional): Whether to use the inverse of the module. Default is `False`. + device (torch.device, optional): The device to use. Default is `None`. + use_ray_service (bool, optional): Whether to use Ray service. Default is `False`. + actor_name (str, optional): The name of the actor to use. Default is `None`. If an actor name is provided and + an actor with this name already exists, the existing actor will be used. + + """ + + _RayServiceClass = RayModuleTransform + + def __init__( + self, + *, + module: TensorDictModuleBase | None = None, + module_factory: Callable[[], TensorDictModuleBase] | None = None, + no_grad: bool = False, + inverse: bool = False, + device: torch.device | None = None, + use_ray_service: bool = False, + actor_name: str | None = None, + ): + super().__init__() + if module is None and module_factory is None: + raise ValueError( + "At least one of `module` or `module_factory` must be provided." + ) + if module is not None and module_factory is not None: + raise ValueError( + "Only one of `module` or `module_factory` must be provided." + ) + self.module = module if module is not None else module_factory() + self.no_grad = no_grad + self.inverse = inverse + self.device = device + + @property + def in_keys(self) -> list[str]: + return self._in_keys() + + def _in_keys(self): + return self.module.in_keys if not self.inverse else [] + + @in_keys.setter + def in_keys(self, value: list[str] | None): + if value is not None: + raise RuntimeError(f"in_keys {value} cannot be set for ModuleTransform") + + @property + def out_keys(self) -> list[str]: + return self._out_keys() + + def _out_keys(self): + return self.module.out_keys if not self.inverse else [] + + @property + def in_keys_inv(self) -> list[str]: + return self._in_keys_inv() + + def _in_keys_inv(self): + return self.module.out_keys if self.inverse else [] + + @in_keys_inv.setter + def in_keys_inv(self, value: list[str]): + if value is not None: + raise RuntimeError(f"in_keys_inv {value} cannot be set for ModuleTransform") + + @property + def out_keys_inv(self) -> list[str]: + return self._out_keys_inv() + + def _out_keys_inv(self): + return self.module.in_keys if self.inverse else [] + + @out_keys_inv.setter + def out_keys_inv(self, value: list[str] | None): + if value is not None: + raise RuntimeError( + f"out_keys_inv {value} cannot be set for ModuleTransform" + ) + + @out_keys.setter + def out_keys(self, value: list[str] | None): + if value is not None: + raise RuntimeError(f"out_keys {value} cannot be set for ModuleTransform") + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.inverse: + return tensordict + with torch.no_grad() if self.no_grad else nullcontext(): + with ( + tensordict.to(self.device) + if self.device is not None + else nullcontext(tensordict) + ) as td: + return self.module(td) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self.inverse: + return tensordict + with torch.no_grad() if self.no_grad else nullcontext(): + with ( + tensordict.to(self.device) + if self.device is not None + else nullcontext(tensordict) + ) as td: + return self.module(td) + + def _update_weights_tensordict(self, params: TensorDictBase) -> None: + params.to_module(self.module) + + def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: + self.module.load_state_dict(state_dict) diff --git a/torchrl/envs/llm/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py similarity index 94% rename from torchrl/envs/llm/transforms/ray_service.py rename to torchrl/envs/transforms/ray_service.py index 77848f4f45a..f244819d164 100644 --- a/torchrl/envs/llm/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -158,6 +158,17 @@ def _create_actor(self, **kwargs): ``` """ + @property + def _ray(self): + # Import ray here to avoid requiring it as a dependency + try: + import ray + except ImportError: + raise ImportError( + "Ray is required for RayTransform. Install with: pip install ray" + ) + return ray + def __init__( self, *, @@ -176,15 +187,6 @@ def __init__( actor_name: Name of the Ray actor (for reuse) **kwargs: Additional arguments passed to Transform """ - # Import ray here to avoid requiring it as a dependency - try: - import ray - except ImportError: - raise ImportError( - "Ray is required for RayTransform. Install with: pip install ray" - ) - self._ray = ray - super().__init__( in_keys=kwargs.get("in_keys", []), out_keys=kwargs.get("out_keys", []) ) @@ -237,10 +239,10 @@ def set_container(self, container: Transform | EnvBase) -> None: parent_batch_size = self.parent.batch_size # Set the batch size directly on the remote actor to override its initialization - self._ray.get(self._actor.set_attr.remote("batch_size", parent_batch_size)) + self._ray.get(self._actor._set_attr.remote("batch_size", parent_batch_size)) # Also disable validation on the remote actor since we'll handle consistency locally - self._ray.get(self._actor.set_attr.remote("_validated", True)) + self._ray.get(self._actor._set_attr.remote("_validated", True)) return result @@ -446,7 +448,7 @@ def primers(self, value): """Set primers.""" self.__dict__["_primers"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("primers", value)) + self._ray.get(self._actor._set_attr.remote("primers", value)) def to(self, *args, **kwargs): """Move to device.""" @@ -470,7 +472,7 @@ def in_keys(self, value): """Set in_keys property.""" self.__dict__["_in_keys"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("in_keys", value)) + self._ray.get(self._actor._set_attr.remote("in_keys", value)) @property def out_keys(self): @@ -482,7 +484,7 @@ def out_keys(self, value): """Set out_keys property.""" self.__dict__["_out_keys"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("out_keys", value)) + self._ray.get(self._actor._set_attr.remote("out_keys", value)) @property def in_keys_inv(self): @@ -494,7 +496,7 @@ def in_keys_inv(self, value): """Set in_keys_inv property.""" self.__dict__["_in_keys_inv"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("in_keys_inv", value)) + self._ray.get(self._actor._set_attr.remote("in_keys_inv", value)) @property def out_keys_inv(self): @@ -506,7 +508,7 @@ def out_keys_inv(self, value): """Set out_keys_inv property.""" self.__dict__["_out_keys_inv"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("out_keys_inv", value)) + self._ray.get(self._actor._set_attr.remote("out_keys_inv", value)) # Generic attribute access for any remaining attributes def __getattr__(self, name): @@ -606,7 +608,7 @@ def __setattr__(self, name, value): # Try to set on remote actor for other attributes try: if hasattr(self, "_actor") and self._actor is not None: - self._ray.get(self._actor.set_attr.remote(name, value)) + self._ray.get(self._actor._set_attr.remote(name, value)) else: super().__setattr__(name, value) except Exception: @@ -621,16 +623,22 @@ class _RayServiceMetaClass(type): alternative class when instantiated with use_ray_service=True. Usage: - class MyClass(metaclass=_RayServiceMetaClass[MyRayClass]): - def __init__(self, use_ray_service=False, **kwargs): - # Regular implementation - pass - - # Returns MyClass instance - obj1 = MyClass(use_ray_service=False) - - # Returns MyRayClass instance - obj2 = MyClass(use_ray_service=True) + >>> class MyRayClass(): + ... def __init__(self, **kwargs): + ... ... + ... + >>> class MyClass(metaclass=_RayServiceMetaClass): + ... _RayServiceClass = MyRayClass + ... + ... def __init__(self, use_ray_service=False, **kwargs): + ... # Regular implementation + ... pass + ... + >>> # Returns MyClass instance + >>> obj1 = MyClass(use_ray_service=False) + >>> + >>> # Returns MyRayClass instance + >>> obj2 = MyClass(use_ray_service=True) """ def __call__(cls, *args, use_ray_service=False, **kwargs): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8f218c218e0..9e86268cefb 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -348,6 +348,10 @@ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: def init(self, tensordict) -> None: """Runs init steps for the transform.""" + def _set_attr(self, name, value): + """Set attribute on the remote actor or locally.""" + setattr(self, name, value) + def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: """Applies the transform to a tensor or a leaf. @@ -11678,16 +11682,21 @@ class FlattenTensorDict(Transform): "use BatchSizeTransform instead." ) - def __init__(self): + def __init__(self, inverse: bool = True): super().__init__(in_keys=[], out_keys=[]) + self.inverse = inverse def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """Forward pass - identity operation.""" + if not self.inverse: + return tensordict.reshape(-1) return tensordict def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: """Inverse pass - flatten the tensordict.""" - return tensordict.reshape(-1) + if self.inverse: + return tensordict.reshape(-1) + return tensordict def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Forward pass - identity operation.""" From 21087a7a28aeceecb9191ddd1f6a87e094e1509c Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 09:56:04 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- docs/source/reference/envs.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 0ee129fdcba..e997fcbd1bc 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1130,6 +1130,7 @@ to be able to create this other composition: InitTracker KLRewardTransform LineariseRewards + ModuleTransform MultiAction NoopResetEnv ObservationNorm From d7685dfaa0db7a14d7de04de92d75ef12b985676 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 15:53:33 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchrl/data/datasets/d4rl.py | 2 ++ torchrl/envs/transforms/module.py | 1 + 2 files changed, 3 insertions(+) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index ceb9e55e5d9..6c3719d216a 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -279,6 +279,7 @@ def _get_dataset_direct(self, name, env_kwargs): # so we need to ensure we're using the gym backend with set_gym_backend("gym"): import gym + env = GymWrapper(gym.make(name)) with tempfile.TemporaryDirectory() as tmpdir: os.environ["D4RL_DATASET_DIR"] = tmpdir @@ -358,6 +359,7 @@ def _get_dataset_from_env(self, name, env_kwargs): # so we need to ensure we're using the gym backend with set_gym_backend("gym"), tempfile.TemporaryDirectory() as tmpdir: import gym + os.environ["D4RL_DATASET_DIR"] = tmpdir env = GymWrapper(gym.make(name)) dataset = make_tensordict( diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py index 7a52965bf05..288af9054cc 100644 --- a/torchrl/envs/transforms/module.py +++ b/torchrl/envs/transforms/module.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from collections.abc import Callable from contextlib import nullcontext