diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 0ee129fdcba..e997fcbd1bc 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1130,6 +1130,7 @@ to be able to create this other composition: InitTracker KLRewardTransform LineariseRewards + ModuleTransform MultiAction NoopResetEnv ObservationNorm diff --git a/test/test_transforms.py b/test/test_transforms.py index 4455183d9bf..567c0995d20 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -34,7 +34,7 @@ TensorDictBase, unravel_key, ) -from tensordict.nn import TensorDictSequential, WrapModule +from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env @@ -122,8 +122,9 @@ from torchrl.envs.libs.dm_control import _has_dm_control from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents -from torchrl.envs.transforms import VecNorm +from torchrl.envs.transforms import ModuleTransform, VecNorm from torchrl.envs.transforms.llm import KLRewardTransform +from torchrl.envs.transforms.module import RayModuleTransform from torchrl.envs.transforms.r3m import _R3MNet from torchrl.envs.transforms.transforms import ( _has_tv, @@ -198,6 +199,8 @@ StateLessCountingEnv, ) +_has_ray = importlib.util.find_spec("ray") is not None + IS_WIN = platform == "win32" if IS_WIN: mp_ctx = "spawn" @@ -14888,6 +14891,130 @@ def test_transform_inverse(self): return +class TestModuleTransform(TransformBase): + @property + def _module_factory_samespec(self): + return partial( + TensorDictModule, + nn.LazyLinear(7), + in_keys=["observation"], + out_keys=["observation"], + ) + + @property + def _module_factory_samespec_inverse(self): + return partial( + TensorDictModule, nn.LazyLinear(7), in_keys=["action"], out_keys=["action"] + ) + + def _single_env_maker(self): + base_env = ContinuousActionVecMockEnv() + t = ModuleTransform(module_factory=self._module_factory_samespec) + return base_env.append_transform(t) + + def test_single_trans_env_check(self): + env = self._single_env_maker() + env.check_env_specs() + + def test_serial_trans_env_check(self): + env = SerialEnv(2, self._single_env_maker) + try: + env.check_env_specs() + finally: + env.close(raise_if_closed=False) + + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv(2, self._single_env_maker) + try: + env.check_env_specs() + finally: + env.close(raise_if_closed=False) + + def test_trans_serial_env_check(self): + env = SerialEnv(2, ContinuousActionVecMockEnv) + try: + env = env.append_transform( + ModuleTransform(module_factory=self._module_factory_samespec) + ) + env.check_env_specs() + finally: + env.close(raise_if_closed=False) + + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv) + try: + env = env.append_transform( + ModuleTransform(module_factory=self._module_factory_samespec) + ) + env.check_env_specs() + finally: + env.close(raise_if_closed=False) + + def test_transform_no_env(self): + t = ModuleTransform(module_factory=self._module_factory_samespec) + td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) + assert td["observation"].shape == (2, 7) + + def test_transform_compose(self): + t = Compose(ModuleTransform(module_factory=self._module_factory_samespec)) + td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) + assert td["observation"].shape == (2, 7) + + def test_transform_env(self): + # TODO: We should give users the opportunity to modify the specs + env = self._single_env_maker() + env.check_env_specs() + + def test_transform_model(self): + t = nn.Sequential( + Compose(ModuleTransform(module_factory=self._module_factory_samespec)) + ) + td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) + assert td["observation"].shape == (2, 7) + + def test_transform_rb(self): + t = ModuleTransform(module_factory=self._module_factory_samespec) + rb = ReplayBuffer(transform=t) + rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) + assert rb._storage._storage[0]["observation"].shape == (3,) + s = rb.sample(2) + assert s["observation"].shape == (2, 7) + + rb = ReplayBuffer() + rb.append_transform(t, invert=True) + rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) + assert rb._storage._storage[0]["observation"].shape == (7,) + s = rb.sample(2) + assert s["observation"].shape == (2, 7) + + def test_transform_inverse(self): + t = ModuleTransform( + module_factory=self._module_factory_samespec_inverse, inverse=True + ) + env = ContinuousActionVecMockEnv().append_transform(t) + env.check_env_specs() + + @pytest.mark.skipif(not _has_ray, reason="ray required") + def test_ray_extension(self): + import ray + + # Check if ray is initialized + ray_init = ray.is_initialized + try: + t = ModuleTransform( + module_factory=self._module_factory_samespec, + use_ray_service=True, + actor_name="my_transform", + ) + env = ContinuousActionVecMockEnv().append_transform(t) + assert isinstance(t, RayModuleTransform) + env.check_env_specs() + assert ray.get_actor("my_transform") is not None + finally: + if not ray_init: + ray.stop() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/llm/transforms/dataloading.py b/torchrl/envs/llm/transforms/dataloading.py index 2c10e4d4b32..40e2e35122d 100644 --- a/torchrl/envs/llm/transforms/dataloading.py +++ b/torchrl/envs/llm/transforms/dataloading.py @@ -15,14 +15,14 @@ from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec from torchrl.envs.common import EnvBase +from torchrl.envs.transforms import TensorDictPrimer, Transform # Import ray service components -from torchrl.envs.llm.transforms.ray_service import ( +from torchrl.envs.transforms.ray_service import ( _map_input_output_device, _RayServiceMetaClass, RayTransform, ) -from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.utils import make_composite_from_td T = TypeVar("T") @@ -259,7 +259,7 @@ def primers(self): @primers.setter def primers(self, value: TensorSpec): """Set primers property.""" - self._ray.get(self._actor.set_attr.remote("primers", value)) + self._ray.get(self._actor._set_attr.remote("primers", value)) # TensorDictPrimer methods def init(self, tensordict: TensorDictBase | None): @@ -857,7 +857,3 @@ def _update_primers_batch_size(self, parent_batch_size): def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" - - def set_attr(self, name, value): - """Set attribute on the remote actor or locally.""" - setattr(self, name, value) diff --git a/torchrl/envs/llm/transforms/kl.py b/torchrl/envs/llm/transforms/kl.py index b079c089e8e..311c9d72e55 100644 --- a/torchrl/envs/llm/transforms/kl.py +++ b/torchrl/envs/llm/transforms/kl.py @@ -18,7 +18,7 @@ from torchrl.data import Composite, Unbounded from torchrl.data.tensor_specs import DEVICE_TYPING from torchrl.envs import EnvBase, Transform -from torchrl.envs.llm.transforms.ray_service import _RayServiceMetaClass, RayTransform +from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform from torchrl.envs.transforms.transforms import Compose from torchrl.envs.transforms.utils import _set_missing_tolerance from torchrl.modules.llm.policies.common import LLMWrapperBase diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index f3eef9edea3..22fccb27cba 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -17,7 +17,7 @@ import torch from tensordict import lazy_stack, TensorDictBase -from torchrl import torchrl_logger +from torchrl._utils import logger as torchrl_logger from torchrl.data.llm import History from torchrl.envs import Transform diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 7539dbe6844..b6da101eac3 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -5,9 +5,10 @@ from .gym_transforms import EndOfLifeTransform from .llm import KLRewardTransform +from .module import ModuleTransform from .r3m import R3MTransform +from .ray_service import RayTransform from .rb_transforms import MultiStepTransform - from .transforms import ( ActionDiscretizer, ActionMask, @@ -85,9 +86,9 @@ "CatFrames", "CatTensors", "CenterCrop", - "ConditionalPolicySwitch", "ClipTransform", "Compose", + "ConditionalPolicySwitch", "ConditionalSkip", "Crop", "DTypeCastTransform", @@ -104,6 +105,7 @@ "InitTracker", "KLRewardTransform", "LineariseRewards", + "ModuleTransform", "MultiAction", "MultiStepTransform", "NoopResetEnv", @@ -113,6 +115,7 @@ "PinMemoryTransform", "R3MTransform", "RandomCropTensorDict", + "RayTransform", "RemoveEmptySpecs", "RenameTransform", "Resize", diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py new file mode 100644 index 00000000000..288af9054cc --- /dev/null +++ b/torchrl/envs/transforms/module.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections.abc import Callable +from contextlib import nullcontext +from typing import overload + +import torch +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase + +from torchrl.data.tensor_specs import TensorSpec +from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform +from torchrl.envs.transforms.transforms import Transform + + +__all__ = ["ModuleTransform", "RayModuleTransform"] + + +class RayModuleTransform(RayTransform): + """Ray-based ModuleTransform for distributed processing. + + This transform creates a Ray actor that wraps a ModuleTransform, + allowing module execution in a separate Ray worker process. + """ + + def _create_actor(self, **kwargs): + import ray + + remote = self._ray.remote(ModuleTransform) + ray_kwargs = {} + num_gpus = self._num_gpus + if num_gpus is not None: + ray_kwargs["num_gpus"] = num_gpus + num_cpus = self._num_cpus + if num_cpus is not None: + ray_kwargs["num_cpus"] = num_cpus + actor_name = self._actor_name + if actor_name is not None: + ray_kwargs["name"] = actor_name + if ray_kwargs: + remote = remote.options(**ray_kwargs) + actor = remote.remote(**kwargs) + # wait till the actor is ready + ray.get(actor._ready.remote()) + return actor + + @overload + def update_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + ... + + @overload + def update_weights(self, params: TensorDictBase) -> None: + ... + + def update_weights(self, *args, **kwargs) -> None: + import ray + + if self._update_weights_method == "tensordict": + try: + td = kwargs.get("params", args[0]) + except IndexError: + raise ValueError("params must be provided") + return ray.get(self._actor._update_weights_tensordict.remote(params=td)) + elif self._update_weights_method == "state_dict": + try: + state_dict = kwargs.get("state_dict", args[0]) + except IndexError: + raise ValueError("state_dict must be provided") + return ray.get( + self._actor._update_weights_state_dict.remote(state_dict=state_dict) + ) + else: + raise ValueError( + f"Invalid update_weights_method: {self._update_weights_method}" + ) + + +class ModuleTransform(Transform, metaclass=_RayServiceMetaClass): + """A transform that wraps a module. + + Keyword Args: + module (TensorDictModuleBase): The module to wrap. Exclusive with `module_factory`. At least one of `module` or `module_factory` must be provided. + module_factory (Callable[[], TensorDictModuleBase]): The factory to create the module. Exclusive with `module`. At least one of `module` or `module_factory` must be provided. + no_grad (bool, optional): Whether to use gradient computation. Default is `False`. + inverse (bool, optional): Whether to use the inverse of the module. Default is `False`. + device (torch.device, optional): The device to use. Default is `None`. + use_ray_service (bool, optional): Whether to use Ray service. Default is `False`. + num_gpus (int, optional): The number of GPUs to use if using Ray. Default is `None`. + num_cpus (int, optional): The number of CPUs to use if using Ray. Default is `None`. + actor_name (str, optional): The name of the actor to use. Default is `None`. If an actor name is provided and + an actor with this name already exists, the existing actor will be used. + observation_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the observation + after it has been transformed by the module, or a function that modifies the existing spec. + Defaults to `None` (observation specs remain unchanged). + done_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the done + after it has been transformed by the module, or a function that modifies the existing spec. + Defaults to `None` (done specs remain unchanged). + reward_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the reward + after it has been transformed by the module, or a function that modifies the existing spec. + Defaults to `None` (reward specs remain unchanged). + state_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the state + after it has been transformed by the module, or a function that modifies the existing spec. + Defaults to `None` (state specs remain unchanged). + action_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the action + after it has been transformed by the module, or a function that modifies the existing spec. + Defaults to `None` (action specs remain unchanged). + """ + + _RayServiceClass = RayModuleTransform + + def __init__( + self, + *, + module: TensorDictModuleBase | None = None, + module_factory: Callable[[], TensorDictModuleBase] | None = None, + no_grad: bool = False, + inverse: bool = False, + device: torch.device | None = None, + use_ray_service: bool = False, # noqa + actor_name: str | None = None, # noqa + num_gpus: int | None = None, + num_cpus: int | None = None, + observation_spec_transform: TensorSpec + | Callable[[TensorSpec], TensorSpec] + | None = None, + action_spec_transform: TensorSpec + | Callable[[TensorSpec], TensorSpec] + | None = None, + reward_spec_transform: TensorSpec + | Callable[[TensorSpec], TensorSpec] + | None = None, + done_spec_transform: TensorSpec + | Callable[[TensorSpec], TensorSpec] + | None = None, + state_spec_transform: TensorSpec + | Callable[[TensorSpec], TensorSpec] + | None = None, + ): + super().__init__() + if module is None and module_factory is None: + raise ValueError( + "At least one of `module` or `module_factory` must be provided." + ) + if module is not None and module_factory is not None: + raise ValueError( + "Only one of `module` or `module_factory` must be provided." + ) + self.module = module if module is not None else module_factory() + self.no_grad = no_grad + self.inverse = inverse + self.device = device + self.observation_spec_transform = observation_spec_transform + self.action_spec_transform = action_spec_transform + self.reward_spec_transform = reward_spec_transform + self.done_spec_transform = done_spec_transform + self.state_spec_transform = state_spec_transform + + @property + def in_keys(self) -> list[str]: + return self._in_keys() + + def _in_keys(self): + return self.module.in_keys if not self.inverse else [] + + @in_keys.setter + def in_keys(self, value: list[str] | None): + if value is not None: + raise RuntimeError(f"in_keys {value} cannot be set for ModuleTransform") + + @property + def out_keys(self) -> list[str]: + return self._out_keys() + + def _out_keys(self): + return self.module.out_keys if not self.inverse else [] + + @property + def in_keys_inv(self) -> list[str]: + return self._in_keys_inv() + + def _in_keys_inv(self): + return self.module.out_keys if self.inverse else [] + + @in_keys_inv.setter + def in_keys_inv(self, value: list[str]): + if value is not None: + raise RuntimeError(f"in_keys_inv {value} cannot be set for ModuleTransform") + + @property + def out_keys_inv(self) -> list[str]: + return self._out_keys_inv() + + def _out_keys_inv(self): + return self.module.in_keys if self.inverse else [] + + @out_keys_inv.setter + def out_keys_inv(self, value: list[str] | None): + if value is not None: + raise RuntimeError( + f"out_keys_inv {value} cannot be set for ModuleTransform" + ) + + @out_keys.setter + def out_keys(self, value: list[str] | None): + if value is not None: + raise RuntimeError(f"out_keys {value} cannot be set for ModuleTransform") + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return self._call(tensordict) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.inverse: + return tensordict + with torch.no_grad() if self.no_grad else nullcontext(): + with ( + tensordict.to(self.device) + if self.device is not None + else nullcontext(tensordict) + ) as td: + return self.module(td) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self.inverse: + return tensordict + with torch.no_grad() if self.no_grad else nullcontext(): + with ( + tensordict.to(self.device) + if self.device is not None + else nullcontext(tensordict) + ) as td: + return self.module(td) + + def _update_weights_tensordict(self, params: TensorDictBase) -> None: + params.to_module(self.module) + + def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: + self.module.load_state_dict(state_dict) + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if self.observation_spec_transform is not None: + if isinstance(self.observation_spec_transform, TensorSpec): + return self.observation_spec_transform + else: + return self.observation_spec_transform(observation_spec) + return observation_spec + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + if self.action_spec_transform is not None: + if isinstance(self.action_spec_transform, TensorSpec): + return self.action_spec_transform + else: + return self.action_spec_transform(action_spec) + return action_spec + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if self.reward_spec_transform is not None: + if isinstance(self.reward_spec_transform, TensorSpec): + return self.reward_spec_transform + else: + return self.reward_spec_transform(reward_spec) + return reward_spec + + def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + if self.done_spec_transform is not None: + if isinstance(self.done_spec_transform, TensorSpec): + return self.done_spec_transform + else: + return self.done_spec_transform(done_spec) + return done_spec + + def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: + if self.state_spec_transform is not None: + if isinstance(self.state_spec_transform, TensorSpec): + return self.state_spec_transform + else: + return self.state_spec_transform(state_spec) + return state_spec diff --git a/torchrl/envs/llm/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py similarity index 94% rename from torchrl/envs/llm/transforms/ray_service.py rename to torchrl/envs/transforms/ray_service.py index 77848f4f45a..5eca9d1e66d 100644 --- a/torchrl/envs/llm/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -158,11 +158,22 @@ def _create_actor(self, **kwargs): ``` """ + @property + def _ray(self): + # Import ray here to avoid requiring it as a dependency + try: + import ray + except ImportError: + raise ImportError( + "Ray is required for RayTransform. Install with: pip install ray" + ) + return ray + def __init__( self, *, num_cpus: int | None = None, - num_gpus: int = 0, + num_gpus: int | None = None, device: DEVICE_TYPING | None = None, actor_name: str | None = None, **kwargs, @@ -176,20 +187,11 @@ def __init__( actor_name: Name of the Ray actor (for reuse) **kwargs: Additional arguments passed to Transform """ - # Import ray here to avoid requiring it as a dependency - try: - import ray - except ImportError: - raise ImportError( - "Ray is required for RayTransform. Install with: pip install ray" - ) - self._ray = ray - super().__init__( in_keys=kwargs.get("in_keys", []), out_keys=kwargs.get("out_keys", []) ) - self._num_cpus = num_cpus or 1 + self._num_cpus = num_cpus self._num_gpus = num_gpus self._device = device self._actor_name = actor_name @@ -237,10 +239,10 @@ def set_container(self, container: Transform | EnvBase) -> None: parent_batch_size = self.parent.batch_size # Set the batch size directly on the remote actor to override its initialization - self._ray.get(self._actor.set_attr.remote("batch_size", parent_batch_size)) + self._ray.get(self._actor._set_attr.remote("batch_size", parent_batch_size)) # Also disable validation on the remote actor since we'll handle consistency locally - self._ray.get(self._actor.set_attr.remote("_validated", True)) + self._ray.get(self._actor._set_attr.remote("_validated", True)) return result @@ -446,7 +448,7 @@ def primers(self, value): """Set primers.""" self.__dict__["_primers"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("primers", value)) + self._ray.get(self._actor._set_attr.remote("primers", value)) def to(self, *args, **kwargs): """Move to device.""" @@ -470,7 +472,7 @@ def in_keys(self, value): """Set in_keys property.""" self.__dict__["_in_keys"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("in_keys", value)) + self._ray.get(self._actor._set_attr.remote("in_keys", value)) @property def out_keys(self): @@ -482,7 +484,7 @@ def out_keys(self, value): """Set out_keys property.""" self.__dict__["_out_keys"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("out_keys", value)) + self._ray.get(self._actor._set_attr.remote("out_keys", value)) @property def in_keys_inv(self): @@ -494,7 +496,7 @@ def in_keys_inv(self, value): """Set in_keys_inv property.""" self.__dict__["_in_keys_inv"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("in_keys_inv", value)) + self._ray.get(self._actor._set_attr.remote("in_keys_inv", value)) @property def out_keys_inv(self): @@ -506,7 +508,7 @@ def out_keys_inv(self, value): """Set out_keys_inv property.""" self.__dict__["_out_keys_inv"] = value if hasattr(self, "_actor"): - self._ray.get(self._actor.set_attr.remote("out_keys_inv", value)) + self._ray.get(self._actor._set_attr.remote("out_keys_inv", value)) # Generic attribute access for any remaining attributes def __getattr__(self, name): @@ -606,7 +608,7 @@ def __setattr__(self, name, value): # Try to set on remote actor for other attributes try: if hasattr(self, "_actor") and self._actor is not None: - self._ray.get(self._actor.set_attr.remote(name, value)) + self._ray.get(self._actor._set_attr.remote(name, value)) else: super().__setattr__(name, value) except Exception: @@ -621,16 +623,22 @@ class _RayServiceMetaClass(type): alternative class when instantiated with use_ray_service=True. Usage: - class MyClass(metaclass=_RayServiceMetaClass[MyRayClass]): - def __init__(self, use_ray_service=False, **kwargs): - # Regular implementation - pass - - # Returns MyClass instance - obj1 = MyClass(use_ray_service=False) - - # Returns MyRayClass instance - obj2 = MyClass(use_ray_service=True) + >>> class MyRayClass(): + ... def __init__(self, **kwargs): + ... ... + ... + >>> class MyClass(metaclass=_RayServiceMetaClass): + ... _RayServiceClass = MyRayClass + ... + ... def __init__(self, use_ray_service=False, **kwargs): + ... # Regular implementation + ... pass + ... + >>> # Returns MyClass instance + >>> obj1 = MyClass(use_ray_service=False) + >>> + >>> # Returns MyRayClass instance + >>> obj2 = MyClass(use_ray_service=True) """ def __call__(cls, *args, use_ray_service=False, **kwargs): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8f218c218e0..5af2373e996 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -256,6 +256,10 @@ def __init__( self.__dict__["_container"] = None self.__dict__["_parent"] = None + def _ready(self): + # Used to block ray until the actor is ready, see RayTransform + return True + def close(self): """Close the transform.""" @@ -348,6 +352,10 @@ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: def init(self, tensordict) -> None: """Runs init steps for the transform.""" + def _set_attr(self, name, value): + """Set attribute on the remote actor or locally.""" + setattr(self, name, value) + def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: """Applies the transform to a tensor or a leaf. @@ -11678,16 +11686,21 @@ class FlattenTensorDict(Transform): "use BatchSizeTransform instead." ) - def __init__(self): + def __init__(self, inverse: bool = True): super().__init__(in_keys=[], out_keys=[]) + self.inverse = inverse def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """Forward pass - identity operation.""" + if not self.inverse: + return tensordict.reshape(-1) return tensordict def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: """Inverse pass - flatten the tensordict.""" - return tensordict.reshape(-1) + if self.inverse: + return tensordict.reshape(-1) + return tensordict def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Forward pass - identity operation."""