From 38462574ce54d205a64f38b4fe4db3cb279f5901 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:40:49 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- examples/collectors/multi_weight_updates.py | 115 ++++ test/test_collector.py | 180 ++++- test/test_distributed.py | 234 +++---- torchrl/collectors/collectors.py | 689 ++++++++++++++++---- torchrl/collectors/distributed/generic.py | 91 ++- torchrl/collectors/distributed/ray.py | 108 ++- torchrl/collectors/distributed/rpc.py | 96 ++- 7 files changed, 1203 insertions(+), 310 deletions(-) create mode 100644 examples/collectors/multi_weight_updates.py diff --git a/examples/collectors/multi_weight_updates.py b/examples/collectors/multi_weight_updates.py new file mode 100644 index 00000000000..7011e7f4879 --- /dev/null +++ b/examples/collectors/multi_weight_updates.py @@ -0,0 +1,115 @@ +"""Example of updating weights of several models at once in a multiprocessed data collector. + +This example demonstrates: +1. Using different weight sync schemes for different models +2. Updating the policy (via pipes with MultiProcessWeightSyncScheme) +3. Updating Ray-based transforms in env and replay buffer (via RayModuleTransformScheme) +4. Atomic multi-model weight updates using weights_dict + +Note: +- Ray actors are shared across all workers, so RayModuleTransformScheme uses a + single transport rather than per-worker pipes. +- When using transform_factory with a replay buffer, delayed_init automatically defaults + to True for proper serialization in multiprocessing contexts. +- extend_buffer defaults to True in all collectors, extending the buffer with entire + rollouts rather than individual frames for better compatibility with postprocessing. +""" + +from functools import partial + +import torch.nn as nn +from tensordict import TensorDict +from tensordict.nn import TensorDictModule + +from torchrl.collectors import MultiSyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms.module import ModuleTransform +from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme + + +def make_module(): + # A module that transforms the observations + return TensorDictModule( + nn.Linear(3, 3), in_keys=["observation"], out_keys=["observation"] + ) + + +def policy_factory(): + # A module that produces the actions + return TensorDictModule( + nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] + ) + + +def make_env(): + env_module = ModuleTransform( + module_factory=make_module, inverse=False, no_grad=True + ) + return GymEnv("Pendulum-v1").append_transform(env_module) + + +def main(): + rb = ReplayBuffer( + storage=LazyTensorStorage(10000, shared_init=True), + transform_factory=partial( + ModuleTransform, + module_factory=make_module, + inverse=True, + no_grad=True, + ), + # delayed_init automatically defaults to True when transform_factory is provided + ) + + policy = policy_factory() + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), + "replay_buffer.transform[0].module": MultiProcessWeightSyncScheme( + strategy="tensordict" + ), + "env.transform[0].module": MultiProcessWeightSyncScheme(strategy="tensordict"), + } + + collector = MultiSyncDataCollector( + create_env_fn=[make_env, make_env], + policy_factory=policy_factory, + total_frames=2000, + max_frames_per_traj=50, + frames_per_batch=200, + init_random_frames=-1, + device="cpu", + storing_device="cpu", + weight_sync_schemes=weight_sync_schemes, + replay_buffer=rb, + local_init_rb=True, + # extend_buffer=True is the default for MultiSyncDataCollector + ) + + policy_weights = TensorDict.from_module(policy).data + env_module_weights = TensorDict.from_module(make_module()).data + rb_module_weights = TensorDict.from_module(make_module()).data + + for i, _data in enumerate(collector): + env_module_weights.zero_() + rb_module_weights.zero_() + policy_weights.zero_() + + collector.update_policy_weights_( + weights_dict={ + "policy": policy_weights, + "env.transform[0].module": env_module_weights, + "replay_buffer.transform[0].module": rb_module_weights, + } + ) + + assert len(rb) == i * 200 + 200 + + if i >= 10: + break + + collector.shutdown() + + +if __name__ == "__main__": + main() diff --git a/test/test_collector.py b/test/test_collector.py index a47d0a8aba0..bb0c0330bf7 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1486,6 +1486,10 @@ def env_fn(seed): @pytest.mark.parametrize("cudagraph", [False, True]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") def test_update_weights(self, use_async, cudagraph): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + def create_env(): return ContinuousActionVecMockEnv() @@ -1506,6 +1510,7 @@ def create_env(): frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) try: # collect state_dict @@ -1553,6 +1558,69 @@ def create_env(): collector.shutdown() del collector + @pytest.mark.parametrize( + "use_async", [True] + ) # MultiSync has known indexing issues with SharedMem + def test_update_weights_shared_mem(self, use_async): + """Test shared memory weight synchronization scheme.""" + from tensordict import TensorDict + from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme + + def create_env(): + return ContinuousActionVecMockEnv() + + n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] + policy = SafeModule( + torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] + ) + policy(create_env().reset()) + + # Get policy weights and put them in shared memory + policy_weights = TensorDict.from_module(policy) + policy_weights.share_memory_() + + # Create shared memory weight sync scheme + weight_sync_scheme = SharedMemWeightSyncScheme() + weight_sync_scheme.register_shared_weights("policy", policy_weights) + + collector_class = ( + MultiSyncDataCollector if not use_async else MultiaSyncDataCollector + ) + collector = collector_class( + [create_env] * 3, + policy=policy, + frames_per_batch=20, + cat_results="stack", + weight_sync_schemes={"policy": weight_sync_scheme}, + ) + try: + # Collect first batch + for _ in collector: + break + + # Change policy weights + old_weight = policy.module.weight.data.clone() + for p in policy.parameters(): + p.data += torch.randn_like(p) + new_weight = policy.module.weight.data.clone() + + # Verify weights changed + assert not torch.allclose(old_weight, new_weight) + + # Update weights using shared memory + collector.update_policy_weights_() + + # Collect another batch - should use new weights + for _ in collector: + break + + # Verify shared memory was updated + assert torch.allclose(policy_weights["module", "weight"], new_weight) + + finally: + collector.shutdown() + del collector + @pytest.mark.parametrize("num_env", [1, 2]) @pytest.mark.parametrize("env_name", ["vec"]) @pytest.mark.parametrize("frames_per_batch_worker", [[10, 10], [15, 5]]) @@ -2209,23 +2277,23 @@ def env_fn(seed): @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @pytest.mark.parametrize( - "collector_class", + "collector_class,num_envs", [ - SyncDataCollector, - MultiaSyncDataCollector, - functools.partial(MultiSyncDataCollector, cat_results="stack"), + (SyncDataCollector, 1), + (MultiaSyncDataCollector, 1), + (functools.partial(MultiSyncDataCollector, cat_results="stack"), 1), + (MultiaSyncDataCollector, 2), + (functools.partial(MultiSyncDataCollector, cat_results="stack"), 2), ], ) class TestAutoWrap: - num_envs = 1 - @pytest.fixture def env_maker(self): from torchrl.envs.libs.gym import GymEnv return lambda: GymEnv(PENDULUM_VERSIONED()) - def _create_collector_kwargs(self, env_maker, collector_class, policy): + def _create_collector_kwargs(self, env_maker, collector_class, policy, num_envs): collector_kwargs = { "create_env_fn": env_maker, "policy": policy, @@ -2235,7 +2303,7 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): if collector_class is not SyncDataCollector: collector_kwargs["create_env_fn"] = [ - collector_kwargs["create_env_fn"] for _ in range(self.num_envs) + collector_kwargs["create_env_fn"] for _ in range(num_envs) ] return collector_kwargs @@ -2243,7 +2311,7 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): @pytest.mark.parametrize("multiple_outputs", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) def test_auto_wrap_modules( - self, collector_class, multiple_outputs, env_maker, device + self, collector_class, multiple_outputs, env_maker, device, num_envs ): policy = WrappablePolicy( out_features=env_maker().action_spec.shape[-1], @@ -2253,33 +2321,40 @@ def test_auto_wrap_modules( policy(env_maker().reset().get("observation")) collector = collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy), + **self._create_collector_kwargs( + env_maker, collector_class, policy, num_envs + ), device=device, ) - out_keys = ["action"] - if multiple_outputs: - out_keys.extend(f"output{i}" for i in range(1, 4)) - - if collector_class is SyncDataCollector: - assert isinstance(collector.policy, TensorDictModule) - assert collector.policy.out_keys == out_keys - # this does not work now that we force the device of the policy - # assert collector.policy.module is policy + try: + out_keys = ["action"] + if multiple_outputs: + out_keys.extend(f"output{i}" for i in range(1, 4)) - for i, data in enumerate(collector): - if i == 0: - assert (data["action"] != 0).any() - for p in policy.parameters(): - p.data.zero_() - assert p.device == torch.device("cpu") - collector.update_policy_weights_() - elif i == 4: - assert (data["action"] == 0).all() - break + if collector_class is SyncDataCollector: + assert isinstance(collector._wrapped_policy, TensorDictModule) + assert collector._wrapped_policy.out_keys == out_keys + # this does not work now that we force the device of the policy + # assert collector.policy.module is policy - collector.shutdown() - del collector + for i, data in enumerate(collector): + # Debug: iteration {i} + if i == 0: + assert (data["action"] != 0).any() + for p in policy.parameters(): + p.data.zero_() + assert p.device == torch.device("cpu") + # Debug: updating policy weights + collector.update_policy_weights_() + # Debug: updated policy weights + elif i == 4: + assert (data["action"] == 0).all() + break + finally: + # Debug: shutting down collector + collector.shutdown() + del collector # Deprecated as from v0.3 # def test_no_wrap_compatible_module(self, collector_class, env_maker): @@ -2314,14 +2389,16 @@ def test_auto_wrap_modules( # collector.shutdown() # del collector - def test_auto_wrap_error(self, collector_class, env_maker): + def test_auto_wrap_error(self, collector_class, env_maker, num_envs): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, match=("Arguments to policy.forward are incompatible with entries in"), ): collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy) + **self._create_collector_kwargs( + env_maker, collector_class, policy, num_envs + ) ) @@ -2779,13 +2856,22 @@ def forward(self, td): ], ) def test_param_sync(self, give_weights, collector, policy_device, env_device): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + policy = TestUpdateParams.Policy().to(policy_device) env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device)) device = env().device env = [env] col = collector( - env, policy, device=device, total_frames=200, frames_per_batch=10 + env, + policy, + device=device, + total_frames=200, + frames_per_batch=10, + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) try: for i, data in enumerate(col): @@ -2833,6 +2919,10 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device): def test_param_sync_mixed_device( self, give_weights, collector, policy_device, env_device ): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + with torch.device("cpu"): policy = TestUpdateParams.Policy() policy.param = nn.Parameter(policy.param.data.to(policy_device)) @@ -2842,7 +2932,12 @@ def test_param_sync_mixed_device( device = env().device env = [env] col = collector( - env, policy, device=device, total_frames=200, frames_per_batch=10 + env, + policy, + device=device, + total_frames=200, + frames_per_batch=10, + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) try: for i, data in enumerate(col): @@ -3865,6 +3960,10 @@ def test_start_multi(self, total_frames, cls): "cls", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] ) def test_start_update_policy(self, total_frames, cls): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) env = CountingEnv() m = nn.Linear(env.observation_spec["observation"].shape[-1], 1) @@ -3882,12 +3981,19 @@ def test_start_update_policy(self, total_frames, cls): td = TensorDict.from_module(policy).data.clone() if cls != SyncDataCollector: env = [CountingEnv] * 2 + + # Add weight sync schemes for multi-process collectors + kwargs = {} + if cls != SyncDataCollector: + kwargs["weight_sync_schemes"] = {"policy": MultiProcessWeightSyncScheme()} + collector = cls( env, policy, replay_buffer=rb, total_frames=total_frames, frames_per_batch=16, + **kwargs, ) try: collector.start() @@ -3913,4 +4019,6 @@ def test_start_update_policy(self, total_frames, cls): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [__file__, "--capture", "no", "--exitfirst", "--timeout", "180"] + unknown + ) diff --git a/test/test_distributed.py b/test/test_distributed.py index 1f03d385607..6183132394e 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -102,26 +102,29 @@ def _start_worker(cls): @classmethod def _test_distributed_collector_basic(cls, queue, frames_per_batch): - cls._start_worker() - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - torchrl_logger.info("creating collector") - collector = cls.distributed_class()( - [env] * 2, - policy, - total_frames=1000, - frames_per_batch=frames_per_batch, - **cls.distributed_kwargs(), - ) - total = 0 - torchrl_logger.info("getting data...") - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - assert data.names[-1] == "time" - collector.shutdown() - assert total == 1000 - queue.put("passed") + try: + cls._start_worker() + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + torchrl_logger.info("creating collector") + collector = cls.distributed_class()( + [env] * 2, + policy, + total_frames=1000, + frames_per_batch=frames_per_batch, + **cls.distributed_kwargs(), + ) + total = 0 + torchrl_logger.info("getting data...") + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + assert data.names[-1] == "time" + collector.shutdown() + assert total == 1000 + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize("frames_per_batch", [50, 100]) def test_distributed_collector_basic(self, frames_per_batch): @@ -143,23 +146,26 @@ def test_distributed_collector_basic(self, frames_per_batch): @classmethod def _test_distributed_collector_mult(cls, queue, frames_per_batch): - cls._start_worker() - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - collector = cls.distributed_class()( - [env] * 2, - policy, - total_frames=1000, - frames_per_batch=frames_per_batch, - **cls.distributed_kwargs(), - ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == -frames_per_batch * (1000 // -frames_per_batch) - queue.put("passed") + try: + cls._start_worker() + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + collector = cls.distributed_class()( + [env] * 2, + policy, + total_frames=1000, + frames_per_batch=frames_per_batch, + **cls.distributed_kwargs(), + ) + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + collector.shutdown() + assert total == -frames_per_batch * (1000 // -frames_per_batch) + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {e}") def test_distributed_collector_mult(self, frames_per_batch=200): """Testing multiple nodes.""" @@ -181,24 +187,27 @@ def test_distributed_collector_mult(self, frames_per_batch=200): @classmethod def _test_distributed_collector_sync(cls, queue, sync): - frames_per_batch = 50 - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - collector = cls.distributed_class()( - [env] * 2, - policy, - total_frames=200, - frames_per_batch=frames_per_batch, - sync=sync, - **cls.distributed_kwargs(), - ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == 200 - queue.put("passed") + try: + frames_per_batch = 50 + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + collector = cls.distributed_class()( + [env] * 2, + policy, + total_frames=200, + frames_per_batch=frames_per_batch, + sync=sync, + **cls.distributed_kwargs(), + ) + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + collector.shutdown() + assert total == 200 + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_sync(self, sync): @@ -220,24 +229,27 @@ def test_distributed_collector_sync(self, sync): @classmethod def _test_distributed_collector_class(cls, queue, collector_class): - frames_per_batch = 50 - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - collector = cls.distributed_class()( - [env] * 2, - policy, - collector_class=collector_class, - total_frames=200, - frames_per_batch=frames_per_batch, - **cls.distributed_kwargs(), - ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == 200 - queue.put("passed") + try: + frames_per_batch = 50 + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + collector = cls.distributed_class()( + [env] * 2, + policy, + collector_class=collector_class, + total_frames=200, + frames_per_batch=frames_per_batch, + **cls.distributed_kwargs(), + ) + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + collector.shutdown() + assert total == 200 + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize( "collector_class", @@ -266,42 +278,45 @@ def test_distributed_collector_class(self, collector_class): @classmethod def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): - frames_per_batch = 50 - total_frames = 300 - env = CountingEnv - policy = CountingPolicy() - if collector_class is MultiaSyncDataCollector: - # otherwise we may collect data from a collector that has not yet been - # updated - n_collectors = 1 - else: - n_collectors = 2 - collector = cls.distributed_class()( - [env] * n_collectors, - policy, - collector_class=collector_class, - total_frames=total_frames, - frames_per_batch=frames_per_batch, - sync=sync, - **cls.distributed_kwargs(), - ) - total = 0 - first_batch = None - last_batch = None - for i, data in enumerate(collector): - total += data.numel() - assert data.numel() == frames_per_batch - if i == 0: - first_batch = data - policy.weight.data += 1 - collector.update_policy_weights_() - elif total == total_frames - frames_per_batch: - last_batch = data - assert (first_batch["action"] == 1).all(), first_batch["action"] - assert (last_batch["action"] == 2).all(), last_batch["action"] - collector.shutdown() - assert total == total_frames - queue.put("passed") + try: + frames_per_batch = 50 + total_frames = 300 + env = CountingEnv + policy = CountingPolicy() + if collector_class is MultiaSyncDataCollector: + # otherwise we may collect data from a collector that has not yet been + # updated + n_collectors = 1 + else: + n_collectors = 2 + collector = cls.distributed_class()( + [env] * n_collectors, + policy, + collector_class=collector_class, + total_frames=total_frames, + frames_per_batch=frames_per_batch, + sync=sync, + **cls.distributed_kwargs(), + ) + total = 0 + first_batch = None + last_batch = None + for i, data in enumerate(collector): + total += data.numel() + assert data.numel() == frames_per_batch + if i == 0: + first_batch = data + policy.weight.data += 1 + collector.update_policy_weights_() + elif total == total_frames - frames_per_batch: + last_batch = data + assert (first_batch["action"] == 1).all(), first_batch["action"] + assert (last_batch["action"] == 2).all(), last_batch["action"] + collector.shutdown() + assert total == total_frames + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize( "collector_class", @@ -470,7 +485,6 @@ def distributed_kwargs(cls) -> dict: ray_init_config["runtime_env"] = { "working_dir": os.path.dirname(__file__), "env_vars": {"PYTHONPATH": os.path.dirname(__file__)}, - "pip": ["ray"], } # for ray workers remote_configs = { "num_cpus": 1, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 34937c22d47..355e6e98db0 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -51,11 +51,7 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories -from torchrl.collectors.weight_update import ( - MultiProcessedWeightUpdater, - VanillaWeightUpdater, - WeightUpdaterBase, -) +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase @@ -70,6 +66,13 @@ RandomPolicy, set_exploration_type, ) +from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + MultiProcessWeightSyncScheme, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) try: from torch.compiler import cudagraph_mark_step_begin @@ -157,6 +160,9 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): compiled_policy: bool cudagraphed_policy: bool _weight_updater: WeightUpdaterBase | None = None + _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None + _weight_senders: dict[str, WeightSender] | None = None + _weight_receivers: dict[str, WeightReceiver] | None = None verbose: bool = False @property @@ -197,15 +203,6 @@ def _get_policy_and_device( if policy_device is NO_DEFAULT: policy_device = self.policy_device - if not self.trust_policy: - env = getattr(self, "env", None) - policy = _make_compatible_policy( - policy, - getattr(env, "observation_spec", None), - env=env, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) if not policy_device: return policy, None @@ -294,11 +291,59 @@ def async_shutdown( """ return self.shutdown(timeout=timeout, close_env=close_env) + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier for resolving string paths. + + Returns: + Extracted weights in the appropriate format. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None: + if model_id == "policy" and hasattr(self, "policy_weights"): + return self.policy_weights + elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): + policy_device = ( + self.policy_device + if not isinstance(self.policy_device, (list, tuple)) + else self.policy_device[0] + ) + return self._policy_weights_dict.get(policy_device) + return None + + if scheme is None: + return weights + + from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + WeightStrategy, + ) + + strategy = WeightStrategy(extract_as=scheme.strategy) + + if isinstance(weights, nn.Module): + return strategy.extract_weights(weights) + elif isinstance(weights, str): + model = _resolve_model(self, weights) + return strategy.extract_weights(model) + else: + return weights + def update_policy_weights_( self, policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, **kwargs, ) -> None: """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. @@ -317,9 +362,15 @@ def update_policy_weights_( worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the workers that need to be updated. This is relevant when the collector has more than one worker associated with it. + model_id (str | None, optional): The model identifier to update. If provided, only updates this specific + model. Cannot be used together with weights_dict. + weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. + Cannot be used together with model_id or policy_or_weights. Raises: TypeError: If `worker_ids` is provided but no `weight_updater` is configured. + ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). .. note:: Users should extend the `WeightUpdaterBase` classes to customize the weight update logic for specific use cases. This method should not be overwritten. @@ -335,9 +386,83 @@ def update_policy_weights_( ) policy_or_weights = kwargs.pop("policy_weights") - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) + if weights_dict is not None and model_id is not None: + raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") + + if weights_dict is not None and policy_or_weights is not None: + raise ValueError( + "Cannot specify both 'weights_dict' and 'policy_or_weights'" + ) + + # Priority: new weight sync schemes > old weight updater system + if self._weight_senders: + if weights_dict is not None: + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_senders: + raise KeyError( + f"Model '{target_model_id}' not found in registered weight senders. " + f"Available models: {list(self._weight_senders.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + weights, target_model_id + ) + self._weight_senders[target_model_id].update_weights( + processed_weights + ) + elif model_id is not None: + if model_id not in self._weight_senders: + raise KeyError( + f"Model '{model_id}' not found in registered weight senders. " + f"Available models: {list(self._weight_senders.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + policy_or_weights, model_id + ) + self._weight_senders[model_id].update_weights(processed_weights) + else: + if "policy" in self._weight_senders: + processed_weights = self._extract_weights_if_needed( + policy_or_weights, "policy" + ) + self._weight_senders["policy"].update_weights(processed_weights) + elif len(self._weight_senders) == 1: + single_model_id = next(iter(self._weight_senders.keys())) + single_sender = self._weight_senders[single_model_id] + processed_weights = self._extract_weights_if_needed( + policy_or_weights, single_model_id + ) + single_sender.update_weights(processed_weights) + else: + for target_model_id, sender in self._weight_senders.items(): + processed_weights = self._extract_weights_if_needed( + policy_or_weights, target_model_id + ) + sender.update_weights(processed_weights) + + elif self._weight_updater is not None: + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + else: + # No weight updater configured + # For single-process collectors, apply weights locally if explicitly provided + if policy_or_weights is not None: + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + # Use WeightStrategy to apply weights properly + strategy = WeightStrategy(extract_as="tensordict") + + # Extract weights if needed + if isinstance(policy_or_weights, nn.Module): + weights = strategy.extract_weights(policy_or_weights) + else: + weights = policy_or_weights + + # Apply to local policy + if hasattr(self, "policy") and isinstance(self.policy, nn.Module): + strategy.apply_weights(self.policy, weights) + # Otherwise, no action needed - policy is local and changes are immediately visible def __iter__(self) -> Iterator[TensorDictBase]: try: @@ -547,14 +672,19 @@ class SyncDataCollector(DataCollectorBase): but populate the buffer instead. Defaults to ``None``. - .. seealso:: By default, the buffer is populated every time a (batch of) frames is collected. - If the buffer needs to be extended with entire rollouts, set `extend_buffer` to `True`. + .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. + If the buffer needs to be populated with individual frames as they are collected, + set ``extend_buffer=False`` (deprecated). - .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` is prohibited unless + .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `False`. + with single steps. Defaults to `True`. + + .. note:: Setting this to `False` is deprecated and will be removed in a future version. + Extending the buffer with entire rollouts is the recommended approach for better + compatibility with postprocessing and trajectory splitting. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. @@ -636,6 +766,8 @@ class SyncDataCollector(DataCollectorBase): """ + _ignore_rb: bool = False + def __init__( self, create_env_fn: ( @@ -664,7 +796,8 @@ def __init__( set_truncated: bool = False, use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = False, + extend_buffer: bool = True, + local_init_rb: bool | None = None, trust_policy: bool | None = None, compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, @@ -672,6 +805,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, **kwargs, ): @@ -794,33 +928,38 @@ def __init__( # Policy version tracking setup self.policy_version_tracker = track_policy_version - if PolicyVersion is not None: - if isinstance(track_policy_version, bool) and track_policy_version: - from torchrl.envs.batched_envs import BatchedEnvBase + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase - if isinstance(self.env, BatchedEnvBase): - raise RuntimeError( - "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " - "and pass that transform to the collector." - ) - self.policy_version_tracker = PolicyVersion() - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr( - track_policy_version, "increment_version" - ): # Check if it's a PolicyVersion instance - self.policy_version_tracker = track_policy_version - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - else: - self.policy_version_tracker = None - else: - if track_policy_version: - raise ImportError( - "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr( + track_policy_version, "increment_version" + ): # Check if it's a PolicyVersion instance + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: self.policy_version_tracker = None self.replay_buffer = replay_buffer self.extend_buffer = extend_buffer - if self.replay_buffer is not None: + + # Handle local_init_rb deprecation for SyncDataCollector + if local_init_rb is None: + local_init_rb = False # Default for SyncDataCollector + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + if self.replay_buffer is not None and not self._ignore_rb: if postproc is not None and not self.extend_buffer: raise TypeError( "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." @@ -852,23 +991,37 @@ def __init__( if hasattr(self.env, "register_collector"): self.env.register_collector(self) - (self.policy, self.get_weights_fn,) = self._get_policy_and_device( + self._original_policy = policy + (policy, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, ) - if isinstance(self.policy, nn.Module): + + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + if isinstance(self._wrapped_policy, nn.Module): self.policy_weights = TensorDict.from_module( - self.policy, as_module=True + self._wrapped_policy, as_module=True ).data else: self.policy_weights = TensorDict() if self.compiled_policy: - self.policy = compile_with_warmup( - self.policy, **self.compiled_policy_kwargs + self._wrapped_policy = compile_with_warmup( + self._wrapped_policy, **self.compiled_policy_kwargs ) if self.cudagraphed_policy: - self.policy = CudaGraphModule( - self.policy, + self._wrapped_policy = CudaGraphModule( + self._wrapped_policy, in_keys=[], out_keys=[], device=self.policy_device, @@ -975,16 +1128,44 @@ def __init__( self._frames = 0 self._iter = -1 - if weight_updater is None: - weight_updater = VanillaWeightUpdater( - weight_getter=self.get_weights_fn, policy_weights=self.policy_weights - ) - elif not isinstance(weight_updater, WeightUpdaterBase): - raise TypeError( - f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." - ) + # Set up weight synchronization - prefer new schemes over legacy updater + # For single-process SyncDataCollector, no weight sync is needed (policy is local) + # Weight sync schemes are only needed for multi-process/distributed collectors + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # For single-process collectors, we don't need senders/receivers + # The policy is local and changes are immediately visible + # Senders will be set up in multiprocess collectors during _run_processes + + self.weight_updater = None # Don't use legacy system + elif weight_updater is not None: + # Use legacy weight updater system if explicitly provided + if not isinstance(weight_updater, WeightUpdaterBase): + if callable(weight_updater): + weight_updater = weight_updater() + else: + raise TypeError( + f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." + ) - self.weight_updater = weight_updater + warnings.warn( + "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + else: + # No weight sync needed for single-process collectors + # The policy is local and changes are immediately visible + self.weight_updater = None + self._weight_sync_schemes = None + self._weight_senders = {} @property def _traj_pool(self): @@ -1034,18 +1215,18 @@ def _maybe_make_final_rollout(self, make_rollout: bool): self._policy_output_keys = set() if ( make_rollout - and hasattr(self.policy, "spec") - and self.policy.spec is not None - and all(v is not None for v in self.policy.spec.values(True, True)) + and hasattr(self._wrapped_policy, "spec") + and self._wrapped_policy.spec is not None + and all(v is not None for v in self._wrapped_policy.spec.values(True, True)) ): if any( key not in self._final_rollout.keys(isinstance(key, tuple)) - for key in self.policy.spec.keys(True, True) + for key in self._wrapped_policy.spec.keys(True, True) ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information # the policy could have more keys than the env: - policy_spec = self.policy.spec + policy_spec = self._wrapped_policy.spec if policy_spec.ndim < self._final_rollout.ndim: policy_spec = policy_spec.expand(self._final_rollout.shape) for key, spec in policy_spec.items(True, True): @@ -1055,10 +1236,10 @@ def _maybe_make_final_rollout(self, make_rollout: bool): self._final_rollout.set(key, spec.zero()) elif ( not make_rollout - and hasattr(self.policy, "out_keys") - and self.policy.out_keys + and hasattr(self._wrapped_policy, "out_keys") + and self._wrapped_policy.out_keys ): - self._policy_output_keys = list(self.policy.out_keys) + self._policy_output_keys = list(self._wrapped_policy.out_keys) else: if make_rollout: # otherwise, we perform a small number of steps with the policy to @@ -1078,7 +1259,7 @@ def _maybe_make_final_rollout(self, make_rollout: bool): ) # to test if values have changed in-place if self.compiled_policy: cudagraph_mark_step_begin() - policy_output = self.policy(policy_input) + policy_output = self._wrapped_policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys def check_exclusive(val): @@ -1319,7 +1500,7 @@ def cuda_check(tensor: torch.Tensor): event.record() event.synchronize() yield tensordict_out - elif self.replay_buffer is not None: + elif self.replay_buffer is not None and not self._ignore_rb: self.replay_buffer.extend(tensordict_out) if self.verbose: torchrl_logger.info( @@ -1539,7 +1720,7 @@ def rollout(self) -> TensorDictBase: # we still do the assignment for security if self.compiled_policy: cudagraph_mark_step_begin() - policy_output = self.policy(policy_input) + policy_output = self._wrapped_policy(policy_input) if self.compiled_policy: policy_output = policy_output.clone() if self._shuttle is not policy_output: @@ -1575,7 +1756,11 @@ def rollout(self) -> TensorDictBase: next_data.clear_device_() self._shuttle.set("next", next_data) - if self.replay_buffer is not None and not self.extend_buffer: + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): self.replay_buffer.add(self._shuttle) if self._increment_frames(self._shuttle.numel()): return @@ -1606,7 +1791,11 @@ def rollout(self) -> TensorDictBase: self.interruptor is not None and self.interruptor.collection_stopped() ): - if self.replay_buffer is not None and not self.extend_buffer: + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): return result = self._final_rollout if self._use_buffers: @@ -1643,7 +1832,11 @@ def rollout(self) -> TensorDictBase: self._final_rollout.ndim - 1, out=self._final_rollout, ) - elif self.replay_buffer is not None and not self.extend_buffer: + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): return else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) @@ -1775,7 +1968,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: def __repr__(self) -> str: try: env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self.policy}", 4 * " ") + policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") td_out_str = indent( f"td_out={getattr(self, '_final_rollout', None)}", 4 * " " ) @@ -1820,7 +2013,7 @@ def get_policy_version(self) -> str | int | None: def getattr_policy(self, attr): """Get an attribute from the policy.""" # send command to policy to return the attr - return getattr(self.policy, attr) + return getattr(self._wrapped_policy, attr) def getattr_env(self, attr): """Get an attribute from the environment.""" @@ -2002,6 +2195,11 @@ class _MultiDataCollector(DataCollectorBase): but populate the buffer instead. Defaults to ``None``. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not with single steps. Defaults to `True` for multiprocessed data collectors. + local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize + the replay buffer in the main process (legacy behavior). If ``True``, the storage-level + coordination will handle initialization with real data from worker processes. + Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. + This parameter is deprecated and will be removed in v0.12. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. @@ -2021,6 +2219,8 @@ class _MultiDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, which handles weight synchronization across multiple processes. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. + If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track @@ -2064,6 +2264,7 @@ def __init__( replay_buffer: ReplayBuffer | None = None, extend_buffer: bool = True, replay_buffer_chunk: bool | None = None, + local_init_rb: bool | None = None, trust_policy: bool | None = None, compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, @@ -2071,6 +2272,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, ): self.closed = True @@ -2142,6 +2344,21 @@ def __init__( self._use_buffers = use_buffers self.replay_buffer = replay_buffer + + # Handle local_init_rb deprecation + if local_init_rb is None: + # v0.11: Default to False (current behavior), show deprecation warning + # v0.12: Default to True (new behavior) + local_init_rb = False # Will become True in 0.12 + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + + self.local_init_rb = local_init_rb + self._check_replay_buffer_init() if replay_buffer_chunk is not None: if extend_buffer is None: @@ -2186,17 +2403,20 @@ def __init__( if type(policy_new_device) is not type(policy): policy = policy_new_device weights = ( - TensorDict.from_module(policy_new_device).data + TensorDict.from_module(policy_new_device) if isinstance(policy_new_device, nn.Module) else TensorDict() ) self._policy_weights_dict[policy_device] = weights self._get_weights_fn = get_weights_fn if weight_updater is None: - weight_updater = MultiProcessedWeightUpdater( - get_server_weights=self._get_weights_fn, - policy_weights=self._policy_weights_dict, - ) + # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default + if weight_sync_schemes is None: + weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + # Don't create legacy weight updater if we have schemes + else: + # Legacy weight updater was explicitly provided + pass elif weight_updater is None: warnings.warn( "weight_updater is None, but policy_factory is provided. This means that the server will " @@ -2207,7 +2427,18 @@ def __init__( "This will work whenever your inference and training policies are nn.Module instances with similar structures." ) - self.weight_updater = weight_updater + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + # Senders will be created in _run_processes when pipes are available + self.weight_updater = None # Don't use legacy system + else: + # Fall back to legacy weight updater system + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} # Policy version tracking setup self.policy_version_tracker = track_policy_version @@ -2301,8 +2532,17 @@ def __init__( def _check_replay_buffer_init(self): if self.replay_buffer is None: return - is_init = getattr(self.replay_buffer._storage, "initialized", True) + is_init = hasattr(self.replay_buffer, "_storage") and getattr( + self.replay_buffer._storage, "initialized", True + ) if not is_init: + if self.local_init_rb: + # New behavior: storage handles all coordination itself + # Nothing to do here - the storage will coordinate during first write + self.replay_buffer.share() + return + + # Legacy behavior: fake tensordict initialization if isinstance(self.create_env_fn[0], EnvCreator): fake_td = self.create_env_fn[0].meta_data.tensordict elif isinstance(self.create_env_fn[0], EnvBase): @@ -2394,6 +2634,15 @@ def _run_processes(self) -> None: 1, torch.get_num_threads() - total_workers ) # 1 more thread for this proc + # Initialize weight senders for multiprocess collectors + if self._weight_sync_schemes: + # Create one sender per model using scheme's factory method + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + self._weight_senders[model_id] = sender torch.set_num_threads(self.num_threads) queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] @@ -2465,6 +2714,7 @@ def _run_processes(self) -> None: "postproc": self.postprocs if self.replay_buffer is not None else None, + "weight_sync_schemes": self._weight_sync_schemes, } proc = _ProcessNoWarn( target=_main_async_collector, @@ -2474,6 +2724,21 @@ def _run_processes(self) -> None: # proc.daemon can't be set as daemonic processes may be launched by the process itself try: proc.start() + except TypeError as err: + if "cannot pickle" in str(err): + raise RuntimeError( + "A non-serializable object was passed to the collector workers." + ) from err + except RuntimeError as err: + if "Cowardly refusing to serialize non-leaf tensor" in str(err): + raise RuntimeError( + "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " + "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" + "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" + "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." + ) from err + else: + raise err except _pickle.PicklingError as err: if "" in str(err): raise RuntimeError( @@ -2488,11 +2753,62 @@ def _run_processes(self) -> None: pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) - for pipe_parent in self.pipes: + + # Register worker with senders + if self._weight_senders: + for _, sender in self._weight_senders.items(): + sender.register_worker(i, pipe_parent) + + for i, pipe_parent in enumerate(self.pipes): pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) - msg = pipe_parent.recv() + try: + msg = pipe_parent.recv() + except EOFError as e: + raise RuntimeError( + f"Worker {i} failed to initialize and closed the connection before sending status. " + f"This typically indicates that the worker process crashed during initialization. " + f"Check the worker process logs for the actual error." + ) from e if msg != "instantiated": - raise RuntimeError(msg) + # Check if it's an error dict from worker + if isinstance(msg, dict) and msg.get("error"): + # Reconstruct the exception from the worker + exc_type_name = msg["exception_type"] + exc_msg = msg["exception_msg"] + traceback_str = msg["traceback"] + + # Try to get the actual exception class + exc_class = None + exc_module = msg["exception_module"] + + if exc_module == "builtins": + # Get from builtins + import builtins + + exc_class = getattr(builtins, exc_type_name, None) + else: + # Try to import from the module + try: + import importlib + + mod = importlib.import_module(exc_module) + exc_class = getattr(mod, exc_type_name, None) + except Exception: + pass + + # Re-raise with original exception type if possible + if exc_class is not None: + raise exc_class( + f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Fall back to RuntimeError if we can't get the original type + raise RuntimeError( + f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Legacy string error message + raise RuntimeError(msg) self.queue_out = queue_out self.closed = False @@ -3057,6 +3373,7 @@ def iterator(self) -> Iterator[TensorDictBase]: msg = "continue_random" else: msg = "continue" + # Debug: sending 'continue' self.pipes[idx].send((None, msg)) self._iter += 1 @@ -3120,13 +3437,13 @@ def iterator(self) -> Iterator[TensorDictBase]: # mask buffers if cat, and create a mask if stack if cat_results != "stack": buffers = {} - for idx, buffer in self.buffers.items(): + for worker_idx, buffer in self.buffers.items(): valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) if valid.ndim == 2: valid = valid.any(0) - buffers[idx] = buffer[..., valid] + buffers[worker_idx] = buffer[..., valid] else: for buffer in self.buffers.values(): with buffer.unlock_(): @@ -3138,6 +3455,11 @@ def iterator(self) -> Iterator[TensorDictBase]: else: buffers = self.buffers + # Skip frame counting if this worker didn't send data this iteration + # (happens when reusing buffers or on first iteration with some workers) + if idx not in buffers: + continue + workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() if workers_frames[idx] >= self.total_frames: @@ -3156,7 +3478,7 @@ def iterator(self) -> Iterator[TensorDictBase]: # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx in range(self.num_workers): + for idx in buffers.keys(): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) if preempt: @@ -3775,6 +4097,7 @@ def _main_async_collector( policy_factory: Callable | None = None, collector_class: type | Callable[[], DataCollectorBase] | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ) -> None: if collector_class is None: collector_class = SyncDataCollector @@ -3782,40 +4105,77 @@ def _main_async_collector( # init variables that will be cleared when closing collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None - inner_collector = collector_class( - create_env_fn, - create_env_kwargs=create_env_kwargs, - policy=policy, - policy_factory=policy_factory, - total_frames=-1, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - postproc=postproc, - split_trajs=False, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - return_same_td=replay_buffer is None, - interruptor=interruptor, - set_truncated=set_truncated, - use_buffers=use_buffers, - replay_buffer=replay_buffer if not extend_buffer else None, - extend_buffer=False, - traj_pool=traj_pool, - trust_policy=trust_policy, - compile_policy=compile_policy, - cudagraph_policy=cudagraph_policy, - no_cuda_sync=no_cuda_sync, - ) - use_buffers = inner_collector._use_buffers - if verbose: - torchrl_logger.info("Sync data collector created") - dc_iter = iter(inner_collector) - j = 0 - pipe_child.send("instantiated") + try: + collector_class._ignore_rb = extend_buffer + inner_collector = collector_class( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + policy_factory=policy_factory, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=False, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + return_same_td=replay_buffer is None, + interruptor=interruptor, + set_truncated=set_truncated, + use_buffers=use_buffers, + replay_buffer=replay_buffer, + extend_buffer=False, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, + weight_sync_schemes=weight_sync_schemes, + ) + + # Set up weight receivers for worker process + if weight_sync_schemes: + inner_collector._weight_receivers = {} + for model_id, scheme in weight_sync_schemes.items(): + receiver = scheme.create_receiver() + receiver.set_context(inner_collector) + receiver.register_worker_transport(pipe_child) + + model = _resolve_model(inner_collector, model_id) + receiver.register_model(model) + + inner_collector._weight_receivers[model_id] = receiver + else: + inner_collector._weight_receivers = {} + + use_buffers = inner_collector._use_buffers + if verbose: + torchrl_logger.info("Sync data collector created") + dc_iter = iter(inner_collector) + j = 0 + pipe_child.send("instantiated") + except Exception as e: + # Send error information to main process + # We send a dict with the exception info so we can recreate it in the main process + import traceback + + error_info = { + "error": True, + "exception_type": type(e).__name__, + "exception_module": type(e).__module__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + } + try: + pipe_child.send(error_info) + except Exception: + # If pipe is broken, nothing we can do + pass + return has_timed_out = False counter = 0 @@ -3888,7 +4248,81 @@ def _main_async_collector( data_in = None # TODO: this does not work with random frames msg = "continue" + # Note: The "continue" message handling has been moved below after update_weights handling + # to allow falling through from update_weights to continue + + if msg == "update": + torchrl_logger.info(f"worker {idx} updating the params...") + inner_collector.update_policy_weights_(policy_weights=data_in) + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + if msg == "register_shared_weights": + # Shared memory lazy registration: main process sends buffer reference + if verbose: + torchrl_logger.info( + f"worker {idx} received shared memory buffer registration" + ) + model_id, shared_buffer = data_in + + # Store the shared buffer reference for this model + # The receiver will use this buffer for all future weight accesses + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + # Update receiver's buffer reference + receiver = inner_collector._weight_receivers[model_id] + # Store the shared buffer - the model's parameters should point to this + if hasattr(receiver, "_shared_weights"): + receiver._shared_weights[model_id] = shared_buffer + + # Apply the buffer to the model immediately + receiver.apply_weights(shared_buffer) + + if verbose: + torchrl_logger.info( + f"worker {idx} registered shared buffer for model '{model_id}'" + ) + else: + torchrl_logger.warning( + f"worker {idx} received shared buffer for unknown model '{model_id}'" + ) + + # Send acknowledgment back to main process + pipe_child.send((None, "registered")) + has_timed_out = False + continue + + if msg == "update_weights": + # New weight update protocol for simplified weight sync system + if verbose: + torchrl_logger.info( + f"worker {idx} received weight update via new protocol" + ) + model_id, weights = data_in + + # Apply weights using the appropriate receiver for this model + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + inner_collector._weight_receivers[model_id].apply_weights(weights) + else: + torchrl_logger.warning( + f"worker {idx} received weights for unknown model '{model_id}'" + ) + + # After applying weights, we continue collecting immediately as if we received + # a "continue" message. This ensures the worker keeps collecting data without + # waiting for an explicit continue from the main process. + has_timed_out = False + msg = "continue" + # Now check if we should continue collecting + if msg in ("continue", "continue_random"): + # This block handles both explicit continue messages and implicit ones after weight updates if msg == "continue_random": inner_collector.init_random_frames = float("inf") else: @@ -3980,14 +4414,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): has_timed_out = True continue - elif msg == "update": - torchrl_logger.info(f"worker {idx} updating the params...") - inner_collector.update_policy_weights_(policy_weights=data_in) - pipe_child.send((j, "updated")) - has_timed_out = False - continue - - elif msg == "seed": + if msg == "seed": data_in, static_seed = data_in new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) torch.manual_seed(data_in) diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index df35a782d75..494af927f4e 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -37,6 +37,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None try: @@ -461,6 +462,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): if collector_class == "async": @@ -566,14 +568,61 @@ def __init__( self._init_workers() self._make_container() - if weight_updater is None: - weight_updater = DistributedWeightUpdater( - store=self._store, - policy_weights=self.policy_weights, - num_workers=self.num_workers, - sync=self._sync, + + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_updater is None and weight_sync_schemes is None: + # Default to Distributed weight sync scheme for distributed collectors + from torchrl.weight_update.weight_sync_schemes import ( + DistributedWeightSyncScheme, ) - self.weight_updater = weight_updater + + weight_sync_schemes = { + "policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync) + } + + if weight_sync_schemes is not None: + # Use new weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + + # Create transports for each remote collector + for i in range(self.num_workers): + rank = i + 1 # Workers are 1-indexed in distributed + transport = scheme.create_transport((self._store, rank)) + sender._transports[i] = transport + + # Set context and register model + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(self, "policy") + and self.policy is not None + ): + sender._source_model = self.policy + + self._weight_senders[model_id] = sender + + self.weight_updater = None + else: + # Fall back to legacy weight updater system + if weight_updater is None: + weight_updater = DistributedWeightUpdater( + store=self._store, + policy_weights=self.policy_weights, + num_workers=self.num_workers, + sync=self._sync, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} @property def device(self) -> list[torch.device]: @@ -928,6 +977,34 @@ def _next_async(self, total_frames, trackers): break return data, total_frames + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For distributed collectors, when weights is None and we have a weight sync scheme, + extract fresh weights from the tracked policy model. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None and scheme is not None: + # Extract fresh weights from the source model + sender = self._weight_senders.get(model_id) + if ( + sender + and hasattr(sender, "_source_model") + and sender._source_model is not None + ): + # For distributed collectors, we need TensorDict format for isend/irecv + from tensordict import TensorDict + + return TensorDict.from_module(sender._source_model).data.lock_() + + # Fall back to base class implementation + return super()._extract_weights_if_needed(weights, model_id) + def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 97955c7c8b5..f2e00828dd3 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -28,6 +28,7 @@ from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme RAY_ERR = None try: @@ -258,11 +259,15 @@ class RayCollector(DataCollectorBase): .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a :class:`~torchrl.data.RayReplayBuffer` instance should be used here. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + weight_updater (WeightUpdaterBase or constructor, optional): (Deprecated) An instance of :class:`~torchrl.collectors.WeightUpdaterBase` or its subclass, responsible for updating the policy weights on remote inference workers managed by Ray. If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging Ray's distributed capabilities. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary mapping model identifiers to + :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` instances. + This is the recommended way to configure weight synchronization. If not provided, + defaults to ``{"policy": RayWeightSyncScheme()}``. Examples: >>> from torch import nn @@ -326,6 +331,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): self.frames_per_batch = frames_per_batch if remote_configs is None: @@ -436,9 +442,9 @@ def check_list_length_consistency(*lists): if not isinstance(policy_factory, Sequence): policy_factory = [policy_factory] * len(create_env_fn) self.policy_factory = policy_factory - self._local_policy = policy - if isinstance(self._local_policy, nn.Module): - policy_weights = TensorDict.from_module(self._local_policy) + self.policy = policy # Store policy for weight extraction + if isinstance(policy, nn.Module): + policy_weights = TensorDict.from_module(policy) policy_weights = policy_weights.data.lock_() else: policy_weights = TensorDict(lock=True) @@ -476,7 +482,10 @@ def check_list_length_consistency(*lists): # update collector kwargs for i, collector_kwarg in enumerate(self.collector_kwargs): - collector_kwarg["policy_factory"] = policy_factory[i] + # Don't pass policy_factory if we have a policy - remote collectors need the policy object + # to be able to apply weight updates + if policy is None: + collector_kwarg["policy_factory"] = policy_factory[i] collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_collectors @@ -510,19 +519,84 @@ def check_list_length_consistency(*lists): collector_kwargs, remote_configs, ) - if weight_updater is None: - weight_updater = RayWeightUpdater( - policy_weights=policy_weights, - remote_collectors=self.remote_collectors, - max_interval=self.max_weight_update_interval, - ) - self.weight_updater = weight_updater + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_updater is None and weight_sync_schemes is None: + # Default to Ray weight sync scheme for Ray collectors + from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme - # Print info of all remote workers - pending_samples = [ - e.print_remote_collector_info.remote() for e in self.remote_collectors - ] - ray.wait(pending_samples) + weight_sync_schemes = {"policy": RayWeightSyncScheme()} + + if weight_sync_schemes is not None: + # Use new weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + + # Register each remote collector as a separate worker + # This follows the same pattern as multiprocess collectors + for worker_idx, remote_collector in enumerate(self.remote_collectors): + # Create a transport for this specific collector + # Pass the collector as context so the transport knows which one to talk to + sender.register_worker(worker_idx, remote_collector) + + # Set context and register model + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + + # Store reference to source model for automatic extraction + if model_id == "policy": + sender._source_model = self.policy + + self._weight_senders[model_id] = sender + + self.weight_updater = None # Don't use legacy system + else: + # Fall back to legacy weight updater system + if weight_updater is None: + weight_updater = RayWeightUpdater( + policy_weights=policy_weights, + remote_collectors=self.remote_collectors, + max_interval=self.max_weight_update_interval, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + + # Print info of all remote workers (fire and forget - no need to wait) + for e in self.remote_collectors: + e.print_remote_collector_info.remote() + + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For Ray collectors, when weights is None and we have a weight sync scheme, + extract fresh weights from the tracked policy model. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None and scheme is not None: + # Extract fresh weights from the source model + sender = self._weight_senders.get(model_id) + if ( + sender + and hasattr(sender, "_source_model") + and sender._source_model is not None + ): + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as=scheme.strategy) + return strategy.extract_weights(sender._source_model) + + # Fall back to base class behavior + return super()._extract_weights_if_needed(weights, model_id) @property def num_workers(self): diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 5336b588232..9d2cf36c0cf 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -42,6 +42,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None try: @@ -308,6 +309,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): if collector_class == "async": collector_class = MultiaSyncDataCollector @@ -413,15 +415,63 @@ def __init__( tensorpipe_options ) self._init() - if weight_updater is None: - weight_updater = RPCWeightUpdater( - collector_infos=self.collector_infos, - collector_class=self.collector_class, - collector_rrefs=self.collector_rrefs, - policy_weights=self.policy_weights, - num_workers=self.num_workers, - ) - self.weight_updater = weight_updater + + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_updater is None and weight_sync_schemes is None: + # Default to RPC weight sync scheme for RPC collectors + from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme + + weight_sync_schemes = {"policy": RPCWeightSyncScheme()} + + if weight_sync_schemes is not None: + # Use new weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + + # Create transports for each remote collector + for i in range(self.num_workers): + transport = scheme.create_transport( + ( + self.collector_infos[i], + self.collector_rrefs[i], + self.collector_class, + ) + ) + sender._transports[i] = transport + + # Set context and register model + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(self, "policy") + and self.policy is not None + ): + sender._source_model = self.policy + + self._weight_senders[model_id] = sender + + self.weight_updater = None + else: + # Fall back to legacy weight updater system + if weight_updater is None: + weight_updater = RPCWeightUpdater( + collector_infos=self.collector_infos, + collector_class=self.collector_class, + collector_rrefs=self.collector_rrefs, + policy_weights=self.policy_weights, + num_workers=self.num_workers, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} @property def device(self) -> list[torch.device]: @@ -764,6 +814,34 @@ def _next_sync_rpc(self): self._collected_frames += data.numel() return data + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For RPC collectors, when weights is None and we have a weight sync scheme, + extract fresh weights from the tracked policy model. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None and scheme is not None: + # Extract fresh weights from the source model + sender = self._weight_senders.get(model_id) + if ( + sender + and hasattr(sender, "_source_model") + and sender._source_model is not None + ): + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as=scheme.strategy) + return strategy.extract_weights(sender._source_model) + + # Fall back to base class implementation + return super()._extract_weights_if_needed(weights, model_id) + def set_seed(self, seed: int, static_seed: bool = False) -> int: for worker in self.collector_infos: seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,))