Skip to content

Commit 7b85c71

Browse files
committed
[Feature] Transform Module - ModuleTransform and Ray Service Refactor
ghstack-source-id: ef0ff23 Pull-Request: #3184
1 parent 61c178e commit 7b85c71

File tree

9 files changed

+473
-44
lines changed

9 files changed

+473
-44
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,7 @@ to be able to create this other composition:
11301130
InitTracker
11311131
KLRewardTransform
11321132
LineariseRewards
1133+
ModuleTransform
11331134
MultiAction
11341135
NoopResetEnv
11351136
ObservationNorm

test/test_transforms.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
TensorDictBase,
3535
unravel_key,
3636
)
37-
from tensordict.nn import TensorDictSequential, WrapModule
37+
from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule
3838
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
3939
from torch import multiprocessing as mp, nn, Tensor
4040
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
@@ -122,8 +122,9 @@
122122
from torchrl.envs.libs.dm_control import _has_dm_control
123123
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
124124
from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents
125-
from torchrl.envs.transforms import VecNorm
125+
from torchrl.envs.transforms import ModuleTransform, VecNorm
126126
from torchrl.envs.transforms.llm import KLRewardTransform
127+
from torchrl.envs.transforms.module import RayModuleTransform
127128
from torchrl.envs.transforms.r3m import _R3MNet
128129
from torchrl.envs.transforms.transforms import (
129130
_has_tv,
@@ -198,6 +199,8 @@
198199
StateLessCountingEnv,
199200
)
200201

202+
_has_ray = importlib.util.find_spec("ray") is not None
203+
201204
IS_WIN = platform == "win32"
202205
if IS_WIN:
203206
mp_ctx = "spawn"
@@ -14888,6 +14891,130 @@ def test_transform_inverse(self):
1488814891
return
1488914892

1489014893

14894+
class TestModuleTransform(TransformBase):
14895+
@property
14896+
def _module_factory_samespec(self):
14897+
return partial(
14898+
TensorDictModule,
14899+
nn.LazyLinear(7),
14900+
in_keys=["observation"],
14901+
out_keys=["observation"],
14902+
)
14903+
14904+
@property
14905+
def _module_factory_samespec_inverse(self):
14906+
return partial(
14907+
TensorDictModule, nn.LazyLinear(7), in_keys=["action"], out_keys=["action"]
14908+
)
14909+
14910+
def _single_env_maker(self):
14911+
base_env = ContinuousActionVecMockEnv()
14912+
t = ModuleTransform(module_factory=self._module_factory_samespec)
14913+
return base_env.append_transform(t)
14914+
14915+
def test_single_trans_env_check(self):
14916+
env = self._single_env_maker()
14917+
env.check_env_specs()
14918+
14919+
def test_serial_trans_env_check(self):
14920+
env = SerialEnv(2, self._single_env_maker)
14921+
try:
14922+
env.check_env_specs()
14923+
finally:
14924+
env.close(raise_if_closed=False)
14925+
14926+
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv):
14927+
env = maybe_fork_ParallelEnv(2, self._single_env_maker)
14928+
try:
14929+
env.check_env_specs()
14930+
finally:
14931+
env.close(raise_if_closed=False)
14932+
14933+
def test_trans_serial_env_check(self):
14934+
env = SerialEnv(2, ContinuousActionVecMockEnv)
14935+
try:
14936+
env = env.append_transform(
14937+
ModuleTransform(module_factory=self._module_factory_samespec)
14938+
)
14939+
env.check_env_specs()
14940+
finally:
14941+
env.close(raise_if_closed=False)
14942+
14943+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
14944+
env = maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv)
14945+
try:
14946+
env = env.append_transform(
14947+
ModuleTransform(module_factory=self._module_factory_samespec)
14948+
)
14949+
env.check_env_specs()
14950+
finally:
14951+
env.close(raise_if_closed=False)
14952+
14953+
def test_transform_no_env(self):
14954+
t = ModuleTransform(module_factory=self._module_factory_samespec)
14955+
td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
14956+
assert td["observation"].shape == (2, 7)
14957+
14958+
def test_transform_compose(self):
14959+
t = Compose(ModuleTransform(module_factory=self._module_factory_samespec))
14960+
td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
14961+
assert td["observation"].shape == (2, 7)
14962+
14963+
def test_transform_env(self):
14964+
# TODO: We should give users the opportunity to modify the specs
14965+
env = self._single_env_maker()
14966+
env.check_env_specs()
14967+
14968+
def test_transform_model(self):
14969+
t = nn.Sequential(
14970+
Compose(ModuleTransform(module_factory=self._module_factory_samespec))
14971+
)
14972+
td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
14973+
assert td["observation"].shape == (2, 7)
14974+
14975+
def test_transform_rb(self):
14976+
t = ModuleTransform(module_factory=self._module_factory_samespec)
14977+
rb = ReplayBuffer(transform=t)
14978+
rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
14979+
assert rb._storage._storage[0]["observation"].shape == (3,)
14980+
s = rb.sample(2)
14981+
assert s["observation"].shape == (2, 7)
14982+
14983+
rb = ReplayBuffer()
14984+
rb.append_transform(t, invert=True)
14985+
rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
14986+
assert rb._storage._storage[0]["observation"].shape == (7,)
14987+
s = rb.sample(2)
14988+
assert s["observation"].shape == (2, 7)
14989+
14990+
def test_transform_inverse(self):
14991+
t = ModuleTransform(
14992+
module_factory=self._module_factory_samespec_inverse, inverse=True
14993+
)
14994+
env = ContinuousActionVecMockEnv().append_transform(t)
14995+
env.check_env_specs()
14996+
14997+
@pytest.mark.skipif(not _has_ray, reason="ray required")
14998+
def test_ray_extension(self):
14999+
import ray
15000+
15001+
# Check if ray is initialized
15002+
ray_init = ray.is_initialized
15003+
try:
15004+
t = ModuleTransform(
15005+
module_factory=self._module_factory_samespec,
15006+
use_ray_service=True,
15007+
actor_name="my_transform",
15008+
)
15009+
env = ContinuousActionVecMockEnv().append_transform(t)
15010+
assert isinstance(t, RayModuleTransform)
15011+
env.check_env_specs()
15012+
assert ray.get_actor("my_transform") is not None
15013+
finally:
15014+
if not ray_init:
15015+
ray.stop()
15016+
15017+
1489115018
if __name__ == "__main__":
1489215019
args, unknown = argparse.ArgumentParser().parse_known_args()
1489315020
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/llm/transforms/dataloading.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec
1717
from torchrl.envs.common import EnvBase
18+
from torchrl.envs.transforms import TensorDictPrimer, Transform
1819

1920
# Import ray service components
20-
from torchrl.envs.llm.transforms.ray_service import (
21+
from torchrl.envs.transforms.ray_service import (
2122
_map_input_output_device,
2223
_RayServiceMetaClass,
2324
RayTransform,
2425
)
25-
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
2626
from torchrl.envs.utils import make_composite_from_td
2727

2828
T = TypeVar("T")
@@ -259,7 +259,7 @@ def primers(self):
259259
@primers.setter
260260
def primers(self, value: TensorSpec):
261261
"""Set primers property."""
262-
self._ray.get(self._actor.set_attr.remote("primers", value))
262+
self._ray.get(self._actor._set_attr.remote("primers", value))
263263

264264
# TensorDictPrimer methods
265265
def init(self, tensordict: TensorDictBase | None):
@@ -857,7 +857,3 @@ def _update_primers_batch_size(self, parent_batch_size):
857857
def __repr__(self) -> str:
858858
class_name = self.__class__.__name__
859859
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"
860-
861-
def set_attr(self, name, value):
862-
"""Set attribute on the remote actor or locally."""
863-
setattr(self, name, value)

torchrl/envs/llm/transforms/kl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchrl.data import Composite, Unbounded
1919
from torchrl.data.tensor_specs import DEVICE_TYPING
2020
from torchrl.envs import EnvBase, Transform
21-
from torchrl.envs.llm.transforms.ray_service import _RayServiceMetaClass, RayTransform
21+
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
2222
from torchrl.envs.transforms.transforms import Compose
2323
from torchrl.envs.transforms.utils import _set_missing_tolerance
2424
from torchrl.modules.llm.policies.common import LLMWrapperBase

torchrl/envs/llm/transforms/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from tensordict import lazy_stack, TensorDictBase
20-
from torchrl import torchrl_logger
20+
from torchrl._utils import logger as torchrl_logger
2121
from torchrl.data.llm import History
2222

2323
from torchrl.envs import Transform

torchrl/envs/transforms/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from .gym_transforms import EndOfLifeTransform
77
from .llm import KLRewardTransform
8+
from .module import ModuleTransform
89
from .r3m import R3MTransform
10+
from .ray_service import RayTransform
911
from .rb_transforms import MultiStepTransform
10-
1112
from .transforms import (
1213
ActionDiscretizer,
1314
ActionMask,
@@ -85,9 +86,9 @@
8586
"CatFrames",
8687
"CatTensors",
8788
"CenterCrop",
88-
"ConditionalPolicySwitch",
8989
"ClipTransform",
9090
"Compose",
91+
"ConditionalPolicySwitch",
9192
"ConditionalSkip",
9293
"Crop",
9394
"DTypeCastTransform",
@@ -104,6 +105,7 @@
104105
"InitTracker",
105106
"KLRewardTransform",
106107
"LineariseRewards",
108+
"ModuleTransform",
107109
"MultiAction",
108110
"MultiStepTransform",
109111
"NoopResetEnv",
@@ -113,6 +115,7 @@
113115
"PinMemoryTransform",
114116
"R3MTransform",
115117
"RandomCropTensorDict",
118+
"RayTransform",
116119
"RemoveEmptySpecs",
117120
"RenameTransform",
118121
"Resize",

0 commit comments

Comments
 (0)