@@ -66,6 +66,42 @@ class PPOTrainer(Trainer):
6666
6767 Logging can be configured via constructor parameters to enable/disable specific metrics.
6868
69+ Args:
70+ collector (DataCollectorBase): The data collector for gathering training data.
71+ total_frames (int): Total number of frames to train for.
72+ frame_skip (int): Frame skip value for the environment.
73+ optim_steps_per_batch (int): Number of optimization steps per batch.
74+ loss_module (LossModule): The loss module for computing policy and value losses.
75+ optimizer (optim.Optimizer, optional): The optimizer for training.
76+ logger (Logger, optional): Logger for tracking training metrics.
77+ clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True.
78+ clip_norm (float, optional): Maximum gradient norm value.
79+ progress_bar (bool, optional): Whether to show a progress bar. Default: True.
80+ seed (int, optional): Random seed for reproducibility.
81+ save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000.
82+ log_interval (int, optional): Interval for logging metrics. Default: 10000.
83+ save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state.
84+ num_epochs (int, optional): Number of epochs per batch. Default: 4.
85+ replay_buffer (ReplayBuffer, optional): Replay buffer for storing data.
86+ batch_size (int, optional): Batch size for optimization.
87+ gamma (float, optional): Discount factor for GAE. Default: 0.9.
88+ lmbda (float, optional): Lambda parameter for GAE. Default: 0.99.
89+ enable_logging (bool, optional): Whether to enable logging. Default: True.
90+ log_rewards (bool, optional): Whether to log rewards. Default: True.
91+ log_actions (bool, optional): Whether to log actions. Default: True.
92+ log_observations (bool, optional): Whether to log observations. Default: False.
93+ async_collection (bool, optional): Whether to use async collection. Default: False.
94+ add_gae (bool, optional): Whether to add GAE computation. Default: True.
95+ gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created.
96+ weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in
97+ collector's weight_sync_schemes) to trainer source paths. Required if collector has
98+ weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network",
99+ "replay_buffer.transforms[0]": "loss_module.critic_network"}
100+ log_timings (bool, optional): If True, automatically register a LogTiming hook to log
101+ timing information for all hooks to the logger (e.g., wandb, tensorboard).
102+ Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
103+ Default is False.
104+
69105 Examples:
70106 >>> # Basic usage with manual configuration
71107 >>> from torchrl.trainers.algorithms.ppo import PPOTrainer
@@ -111,6 +147,10 @@ def __init__(
111147 log_actions : bool = True ,
112148 log_observations : bool = False ,
113149 async_collection : bool = False ,
150+ add_gae : bool = True ,
151+ gae : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
152+ weight_update_map : dict [str , str ] | None = None ,
153+ log_timings : bool = False ,
114154 ) -> None :
115155 warnings .warn (
116156 "PPOTrainer is an experimental/prototype feature. The API may change in future versions. "
@@ -135,17 +175,21 @@ def __init__(
135175 save_trainer_file = save_trainer_file ,
136176 num_epochs = num_epochs ,
137177 async_collection = async_collection ,
178+ log_timings = log_timings ,
138179 )
139180 self .replay_buffer = replay_buffer
140181 self .async_collection = async_collection
141182
142- gae = GAE (
143- gamma = gamma ,
144- lmbda = lmbda ,
145- value_network = self .loss_module .critic_network ,
146- average_gae = True ,
147- )
148- self .register_op ("pre_epoch" , gae )
183+ if add_gae and gae is None :
184+ gae = GAE (
185+ gamma = gamma ,
186+ lmbda = lmbda ,
187+ value_network = self .loss_module .critic_network ,
188+ average_gae = True ,
189+ )
190+ self .register_op ("pre_epoch" , gae )
191+ elif not add_gae and gae is not None :
192+ raise ValueError ("gae must not be provided if add_gae is False" )
149193
150194 if (
151195 not self .async_collection
@@ -167,16 +211,62 @@ def __init__(
167211 )
168212
169213 if not self .async_collection :
214+ # rb has been extended by the collector
215+ raise RuntimeError
170216 self .register_op ("pre_epoch" , rb_trainer .extend )
171217 self .register_op ("process_optim_batch" , rb_trainer .sample )
172218 self .register_op ("post_loss" , rb_trainer .update_priority )
173219
174- policy_weights_getter = partial (
175- TensorDict .from_module , self .loss_module .actor_network
176- )
177- update_weights = UpdateWeights (
178- self .collector , 1 , policy_weights_getter = policy_weights_getter
179- )
220+ # Set up weight updates
221+ # Validate weight_update_map if collector has weight_sync_schemes
222+ if (
223+ hasattr (self .collector , "_weight_sync_schemes" )
224+ and self .collector ._weight_sync_schemes
225+ ):
226+ if weight_update_map is None :
227+ raise ValueError (
228+ "Collector has weight_sync_schemes configured, but weight_update_map was not provided. "
229+ f"Please provide a mapping for all destinations: { list (self .collector ._weight_sync_schemes .keys ())} "
230+ )
231+
232+ # Validate that all scheme destinations are covered in the map
233+ scheme_destinations = set (self .collector ._weight_sync_schemes .keys ())
234+ map_destinations = set (weight_update_map .keys ())
235+
236+ if scheme_destinations != map_destinations :
237+ missing = scheme_destinations - map_destinations
238+ extra = map_destinations - scheme_destinations
239+ error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n "
240+ if missing :
241+ error_msg += f" Missing destinations: { missing } \n "
242+ if extra :
243+ error_msg += f" Extra destinations: { extra } \n "
244+ raise ValueError (error_msg )
245+
246+ # Use the weight_update_map approach
247+ update_weights = UpdateWeights (
248+ self .collector ,
249+ 1 ,
250+ weight_update_map = weight_update_map ,
251+ trainer = self ,
252+ )
253+ else :
254+ # Fall back to legacy approach for backward compatibility
255+ if weight_update_map is not None :
256+ warnings .warn (
257+ "weight_update_map was provided but collector has no weight_sync_schemes. "
258+ "Ignoring weight_update_map and using legacy policy_weights_getter." ,
259+ UserWarning ,
260+ stacklevel = 2 ,
261+ )
262+
263+ policy_weights_getter = partial (
264+ TensorDict .from_module , self .loss_module .actor_network
265+ )
266+ update_weights = UpdateWeights (
267+ self .collector , 1 , policy_weights_getter = policy_weights_getter
268+ )
269+
180270 self .register_op ("post_steps" , update_weights )
181271
182272 # Store logging configuration
0 commit comments