|
34 | 34 | TensorDictBase, |
35 | 35 | unravel_key, |
36 | 36 | ) |
37 | | -from tensordict.nn import TensorDictSequential, WrapModule |
| 37 | +from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule |
38 | 38 | from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td |
39 | 39 | from torch import multiprocessing as mp, nn, Tensor |
40 | 40 | from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env |
|
122 | 122 | from torchrl.envs.libs.dm_control import _has_dm_control |
123 | 123 | from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend |
124 | 124 | from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents |
125 | | -from torchrl.envs.transforms import VecNorm |
| 125 | +from torchrl.envs.transforms import ModuleTransform, VecNorm |
126 | 126 | from torchrl.envs.transforms.llm import KLRewardTransform |
| 127 | +from torchrl.envs.transforms.module import RayModuleTransform |
127 | 128 | from torchrl.envs.transforms.r3m import _R3MNet |
128 | 129 | from torchrl.envs.transforms.transforms import ( |
129 | 130 | _has_tv, |
|
198 | 199 | StateLessCountingEnv, |
199 | 200 | ) |
200 | 201 |
|
| 202 | +_has_ray = importlib.util.find_spec("ray") is not None |
| 203 | + |
201 | 204 | IS_WIN = platform == "win32" |
202 | 205 | if IS_WIN: |
203 | 206 | mp_ctx = "spawn" |
@@ -14888,6 +14891,130 @@ def test_transform_inverse(self): |
14888 | 14891 | return |
14889 | 14892 |
|
14890 | 14893 |
|
| 14894 | +class TestModuleTransform(TransformBase): |
| 14895 | + @property |
| 14896 | + def _module_factory_samespec(self): |
| 14897 | + return partial( |
| 14898 | + TensorDictModule, |
| 14899 | + nn.LazyLinear(7), |
| 14900 | + in_keys=["observation"], |
| 14901 | + out_keys=["observation"], |
| 14902 | + ) |
| 14903 | + |
| 14904 | + @property |
| 14905 | + def _module_factory_samespec_inverse(self): |
| 14906 | + return partial( |
| 14907 | + TensorDictModule, nn.LazyLinear(7), in_keys=["action"], out_keys=["action"] |
| 14908 | + ) |
| 14909 | + |
| 14910 | + def _single_env_maker(self): |
| 14911 | + base_env = ContinuousActionVecMockEnv() |
| 14912 | + t = ModuleTransform(module_factory=self._module_factory_samespec) |
| 14913 | + return base_env.append_transform(t) |
| 14914 | + |
| 14915 | + def test_single_trans_env_check(self): |
| 14916 | + env = self._single_env_maker() |
| 14917 | + env.check_env_specs() |
| 14918 | + |
| 14919 | + def test_serial_trans_env_check(self): |
| 14920 | + env = SerialEnv(2, self._single_env_maker) |
| 14921 | + try: |
| 14922 | + env.check_env_specs() |
| 14923 | + finally: |
| 14924 | + env.close(raise_if_closed=False) |
| 14925 | + |
| 14926 | + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): |
| 14927 | + env = maybe_fork_ParallelEnv(2, self._single_env_maker) |
| 14928 | + try: |
| 14929 | + env.check_env_specs() |
| 14930 | + finally: |
| 14931 | + env.close(raise_if_closed=False) |
| 14932 | + |
| 14933 | + def test_trans_serial_env_check(self): |
| 14934 | + env = SerialEnv(2, ContinuousActionVecMockEnv) |
| 14935 | + try: |
| 14936 | + env = env.append_transform( |
| 14937 | + ModuleTransform(module_factory=self._module_factory_samespec) |
| 14938 | + ) |
| 14939 | + env.check_env_specs() |
| 14940 | + finally: |
| 14941 | + env.close(raise_if_closed=False) |
| 14942 | + |
| 14943 | + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): |
| 14944 | + env = maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv) |
| 14945 | + try: |
| 14946 | + env = env.append_transform( |
| 14947 | + ModuleTransform(module_factory=self._module_factory_samespec) |
| 14948 | + ) |
| 14949 | + env.check_env_specs() |
| 14950 | + finally: |
| 14951 | + env.close(raise_if_closed=False) |
| 14952 | + |
| 14953 | + def test_transform_no_env(self): |
| 14954 | + t = ModuleTransform(module_factory=self._module_factory_samespec) |
| 14955 | + td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) |
| 14956 | + assert td["observation"].shape == (2, 7) |
| 14957 | + |
| 14958 | + def test_transform_compose(self): |
| 14959 | + t = Compose(ModuleTransform(module_factory=self._module_factory_samespec)) |
| 14960 | + td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) |
| 14961 | + assert td["observation"].shape == (2, 7) |
| 14962 | + |
| 14963 | + def test_transform_env(self): |
| 14964 | + # TODO: We should give users the opportunity to modify the specs |
| 14965 | + env = self._single_env_maker() |
| 14966 | + env.check_env_specs() |
| 14967 | + |
| 14968 | + def test_transform_model(self): |
| 14969 | + t = nn.Sequential( |
| 14970 | + Compose(ModuleTransform(module_factory=self._module_factory_samespec)) |
| 14971 | + ) |
| 14972 | + td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) |
| 14973 | + assert td["observation"].shape == (2, 7) |
| 14974 | + |
| 14975 | + def test_transform_rb(self): |
| 14976 | + t = ModuleTransform(module_factory=self._module_factory_samespec) |
| 14977 | + rb = ReplayBuffer(transform=t) |
| 14978 | + rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) |
| 14979 | + assert rb._storage._storage[0]["observation"].shape == (3,) |
| 14980 | + s = rb.sample(2) |
| 14981 | + assert s["observation"].shape == (2, 7) |
| 14982 | + |
| 14983 | + rb = ReplayBuffer() |
| 14984 | + rb.append_transform(t, invert=True) |
| 14985 | + rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2])) |
| 14986 | + assert rb._storage._storage[0]["observation"].shape == (7,) |
| 14987 | + s = rb.sample(2) |
| 14988 | + assert s["observation"].shape == (2, 7) |
| 14989 | + |
| 14990 | + def test_transform_inverse(self): |
| 14991 | + t = ModuleTransform( |
| 14992 | + module_factory=self._module_factory_samespec_inverse, inverse=True |
| 14993 | + ) |
| 14994 | + env = ContinuousActionVecMockEnv().append_transform(t) |
| 14995 | + env.check_env_specs() |
| 14996 | + |
| 14997 | + @pytest.mark.skipif(not _has_ray, reason="ray required") |
| 14998 | + def test_ray_extension(self): |
| 14999 | + import ray |
| 15000 | + |
| 15001 | + # Check if ray is initialized |
| 15002 | + ray_init = ray.is_initialized |
| 15003 | + try: |
| 15004 | + t = ModuleTransform( |
| 15005 | + module_factory=self._module_factory_samespec, |
| 15006 | + use_ray_service=True, |
| 15007 | + actor_name="my_transform", |
| 15008 | + ) |
| 15009 | + env = ContinuousActionVecMockEnv().append_transform(t) |
| 15010 | + assert isinstance(t, RayModuleTransform) |
| 15011 | + env.check_env_specs() |
| 15012 | + assert ray.get_actor("my_transform") is not None |
| 15013 | + finally: |
| 15014 | + if not ray_init: |
| 15015 | + ray.stop() |
| 15016 | + |
| 15017 | + |
14891 | 15018 | if __name__ == "__main__": |
14892 | 15019 | args, unknown = argparse.ArgumentParser().parse_known_args() |
14893 | 15020 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
0 commit comments