From 5d330734d6130ffe0f18c9a81c6ca52acce63117 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:40:59 +0100 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- torchrl/objectives/value/advantages.py | 2 + torchrl/trainers/algorithms/ppo.py | 116 ++++++++++++++++++++++--- 2 files changed, 105 insertions(+), 13 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 05ef48e00cd..abd961631cd 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1324,6 +1324,8 @@ class GAE(ValueEstimatorBase): `"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data. """ + value_network: TensorDictModule | None + def __init__( self, *, diff --git a/torchrl/trainers/algorithms/ppo.py b/torchrl/trainers/algorithms/ppo.py index 4b09788894a..98e2f34ae0a 100644 --- a/torchrl/trainers/algorithms/ppo.py +++ b/torchrl/trainers/algorithms/ppo.py @@ -66,6 +66,42 @@ class PPOTrainer(Trainer): Logging can be configured via constructor parameters to enable/disable specific metrics. + Args: + collector (DataCollectorBase): The data collector for gathering training data. + total_frames (int): Total number of frames to train for. + frame_skip (int): Frame skip value for the environment. + optim_steps_per_batch (int): Number of optimization steps per batch. + loss_module (LossModule): The loss module for computing policy and value losses. + optimizer (optim.Optimizer, optional): The optimizer for training. + logger (Logger, optional): Logger for tracking training metrics. + clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True. + clip_norm (float, optional): Maximum gradient norm value. + progress_bar (bool, optional): Whether to show a progress bar. Default: True. + seed (int, optional): Random seed for reproducibility. + save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000. + log_interval (int, optional): Interval for logging metrics. Default: 10000. + save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state. + num_epochs (int, optional): Number of epochs per batch. Default: 4. + replay_buffer (ReplayBuffer, optional): Replay buffer for storing data. + batch_size (int, optional): Batch size for optimization. + gamma (float, optional): Discount factor for GAE. Default: 0.9. + lmbda (float, optional): Lambda parameter for GAE. Default: 0.99. + enable_logging (bool, optional): Whether to enable logging. Default: True. + log_rewards (bool, optional): Whether to log rewards. Default: True. + log_actions (bool, optional): Whether to log actions. Default: True. + log_observations (bool, optional): Whether to log observations. Default: False. + async_collection (bool, optional): Whether to use async collection. Default: False. + add_gae (bool, optional): Whether to add GAE computation. Default: True. + gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created. + weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in + collector's weight_sync_schemes) to trainer source paths. Required if collector has + weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network", + "replay_buffer.transforms[0]": "loss_module.critic_network"} + log_timings (bool, optional): If True, automatically register a LogTiming hook to log + timing information for all hooks to the logger (e.g., wandb, tensorboard). + Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights"). + Default is False. + Examples: >>> # Basic usage with manual configuration >>> from torchrl.trainers.algorithms.ppo import PPOTrainer @@ -111,6 +147,10 @@ def __init__( log_actions: bool = True, log_observations: bool = False, async_collection: bool = False, + add_gae: bool = True, + gae: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_update_map: dict[str, str] | None = None, + log_timings: bool = False, ) -> None: warnings.warn( "PPOTrainer is an experimental/prototype feature. The API may change in future versions. " @@ -135,17 +175,21 @@ def __init__( save_trainer_file=save_trainer_file, num_epochs=num_epochs, async_collection=async_collection, + log_timings=log_timings, ) self.replay_buffer = replay_buffer self.async_collection = async_collection - gae = GAE( - gamma=gamma, - lmbda=lmbda, - value_network=self.loss_module.critic_network, - average_gae=True, - ) - self.register_op("pre_epoch", gae) + if add_gae and gae is None: + gae = GAE( + gamma=gamma, + lmbda=lmbda, + value_network=self.loss_module.critic_network, + average_gae=True, + ) + self.register_op("pre_epoch", gae) + elif not add_gae and gae is not None: + raise ValueError("gae must not be provided if add_gae is False") if ( not self.async_collection @@ -167,16 +211,62 @@ def __init__( ) if not self.async_collection: + # rb has been extended by the collector + raise RuntimeError self.register_op("pre_epoch", rb_trainer.extend) self.register_op("process_optim_batch", rb_trainer.sample) self.register_op("post_loss", rb_trainer.update_priority) - policy_weights_getter = partial( - TensorDict.from_module, self.loss_module.actor_network - ) - update_weights = UpdateWeights( - self.collector, 1, policy_weights_getter=policy_weights_getter - ) + # Set up weight updates + # Validate weight_update_map if collector has weight_sync_schemes + if ( + hasattr(self.collector, "_weight_sync_schemes") + and self.collector._weight_sync_schemes + ): + if weight_update_map is None: + raise ValueError( + "Collector has weight_sync_schemes configured, but weight_update_map was not provided. " + f"Please provide a mapping for all destinations: {list(self.collector._weight_sync_schemes.keys())}" + ) + + # Validate that all scheme destinations are covered in the map + scheme_destinations = set(self.collector._weight_sync_schemes.keys()) + map_destinations = set(weight_update_map.keys()) + + if scheme_destinations != map_destinations: + missing = scheme_destinations - map_destinations + extra = map_destinations - scheme_destinations + error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n" + if missing: + error_msg += f" Missing destinations: {missing}\n" + if extra: + error_msg += f" Extra destinations: {extra}\n" + raise ValueError(error_msg) + + # Use the weight_update_map approach + update_weights = UpdateWeights( + self.collector, + 1, + weight_update_map=weight_update_map, + trainer=self, + ) + else: + # Fall back to legacy approach for backward compatibility + if weight_update_map is not None: + warnings.warn( + "weight_update_map was provided but collector has no weight_sync_schemes. " + "Ignoring weight_update_map and using legacy policy_weights_getter.", + UserWarning, + stacklevel=2, + ) + + policy_weights_getter = partial( + TensorDict.from_module, self.loss_module.actor_network + ) + update_weights = UpdateWeights( + self.collector, 1, policy_weights_getter=policy_weights_getter + ) + self.register_op("post_steps", update_weights) # Store logging configuration From 5364dfff4026fe163bcce82d24f2f6d7befecf43 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Oct 2025 17:54:52 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- test/test_configs.py | 5 ++- .../trainers/algorithms/configs/modules.py | 44 +++++++++++++++++-- torchrl/trainers/algorithms/ppo.py | 1 - 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/test/test_configs.py b/test/test_configs.py index c6882cc6a6b..9dff306d18e 100644 --- a/test/test_configs.py +++ b/test/test_configs.py @@ -835,7 +835,10 @@ def test_tensor_dict_module_config(self): in_keys=["observation"], out_keys=["action"], ) - assert cfg._target_ == "tensordict.nn.TensorDictModule" + assert ( + cfg._target_ + == "torchrl.trainers.algorithms.configs.modules._make_tensordict_module" + ) assert cfg.module._target_ == "torchrl.modules.MLP" assert cfg.in_keys == ["observation"] assert cfg.out_keys == ["action"] diff --git a/torchrl/trainers/algorithms/configs/modules.py b/torchrl/trainers/algorithms/configs/modules.py index 189d47c2561..00c6f8e1868 100644 --- a/torchrl/trainers/algorithms/configs/modules.py +++ b/torchrl/trainers/algorithms/configs/modules.py @@ -222,7 +222,7 @@ class TensorDictModuleConfig(ModelConfig): """ module: MLPConfig = MISSING - _target_: str = "tensordict.nn.TensorDictModule" + _target_: str = "torchrl.trainers.algorithms.configs.modules._make_tensordict_module" _partial_: bool = False def __post_init__(self) -> None: @@ -292,6 +292,30 @@ def __post_init__(self) -> None: super().__post_init__() +def _make_tensordict_module(*args, **kwargs): + """Helper function to create a TensorDictModule.""" + from hydra.utils import instantiate + from tensordict.nn import TensorDictModule + + module = kwargs.pop("module") + shared = kwargs.pop("shared", False) + + # Instantiate the module if it's a config + if hasattr(module, "_target_"): + module = instantiate(module) + elif callable(module) and hasattr(module, "func"): # partial function + module = module() + + # Create the TensorDictModule + tensordict_module = TensorDictModule(module, **kwargs) + + # Apply share_memory if needed + if shared: + tensordict_module = tensordict_module.share_memory() + + return tensordict_module + + def _make_tanh_normal_model(*args, **kwargs): """Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential.""" from hydra.utils import instantiate @@ -351,10 +375,24 @@ def _make_tanh_normal_model(*args, **kwargs): def _make_value_model(*args, **kwargs): """Helper function to create a ValueOperator with the given network.""" + from hydra.utils import instantiate + from torchrl.modules import ValueOperator network = kwargs.pop("network") shared = kwargs.pop("shared", False) + + # Instantiate the network if it's a config + if hasattr(network, "_target_"): + network = instantiate(network) + elif callable(network) and hasattr(network, "func"): # partial function + network = network() + + # Create the ValueOperator + value_operator = ValueOperator(network, **kwargs) + + # Apply share_memory if needed if shared: - network = network.share_memory() - return ValueOperator(network, **kwargs) + value_operator = value_operator.share_memory() + + return value_operator diff --git a/torchrl/trainers/algorithms/ppo.py b/torchrl/trainers/algorithms/ppo.py index 98e2f34ae0a..130cb633bb0 100644 --- a/torchrl/trainers/algorithms/ppo.py +++ b/torchrl/trainers/algorithms/ppo.py @@ -212,7 +212,6 @@ def __init__( if not self.async_collection: # rb has been extended by the collector - raise RuntimeError self.register_op("pre_epoch", rb_trainer.extend) self.register_op("process_optim_batch", rb_trainer.sample) self.register_op("post_loss", rb_trainer.update_priority) From 78f04e52d16d29d881608811742d4c0476a56091 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Oct 2025 18:09:15 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchrl/trainers/algorithms/configs/modules.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/trainers/algorithms/configs/modules.py b/torchrl/trainers/algorithms/configs/modules.py index 00c6f8e1868..dbfef71ffd0 100644 --- a/torchrl/trainers/algorithms/configs/modules.py +++ b/torchrl/trainers/algorithms/configs/modules.py @@ -222,7 +222,9 @@ class TensorDictModuleConfig(ModelConfig): """ module: MLPConfig = MISSING - _target_: str = "torchrl.trainers.algorithms.configs.modules._make_tensordict_module" + _target_: str = ( + "torchrl.trainers.algorithms.configs.modules._make_tensordict_module" + ) _partial_: bool = False def __post_init__(self) -> None: