Skip to content

Commit 3f00b13

Browse files
committed
[Feature] PPO Trainer Updates
ghstack-source-id: 526f04d Pull-Request: #3190
1 parent 0a4ead1 commit 3f00b13

File tree

2 files changed

+105
-13
lines changed

2 files changed

+105
-13
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,8 @@ class GAE(ValueEstimatorBase):
13241324
`"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data.
13251325
"""
13261326

1327+
value_network: TensorDictModule | None
1328+
13271329
def __init__(
13281330
self,
13291331
*,

torchrl/trainers/algorithms/ppo.py

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,42 @@ class PPOTrainer(Trainer):
6666
6767
Logging can be configured via constructor parameters to enable/disable specific metrics.
6868
69+
Args:
70+
collector (DataCollectorBase): The data collector for gathering training data.
71+
total_frames (int): Total number of frames to train for.
72+
frame_skip (int): Frame skip value for the environment.
73+
optim_steps_per_batch (int): Number of optimization steps per batch.
74+
loss_module (LossModule): The loss module for computing policy and value losses.
75+
optimizer (optim.Optimizer, optional): The optimizer for training.
76+
logger (Logger, optional): Logger for tracking training metrics.
77+
clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True.
78+
clip_norm (float, optional): Maximum gradient norm value.
79+
progress_bar (bool, optional): Whether to show a progress bar. Default: True.
80+
seed (int, optional): Random seed for reproducibility.
81+
save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000.
82+
log_interval (int, optional): Interval for logging metrics. Default: 10000.
83+
save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state.
84+
num_epochs (int, optional): Number of epochs per batch. Default: 4.
85+
replay_buffer (ReplayBuffer, optional): Replay buffer for storing data.
86+
batch_size (int, optional): Batch size for optimization.
87+
gamma (float, optional): Discount factor for GAE. Default: 0.9.
88+
lmbda (float, optional): Lambda parameter for GAE. Default: 0.99.
89+
enable_logging (bool, optional): Whether to enable logging. Default: True.
90+
log_rewards (bool, optional): Whether to log rewards. Default: True.
91+
log_actions (bool, optional): Whether to log actions. Default: True.
92+
log_observations (bool, optional): Whether to log observations. Default: False.
93+
async_collection (bool, optional): Whether to use async collection. Default: False.
94+
add_gae (bool, optional): Whether to add GAE computation. Default: True.
95+
gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created.
96+
weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in
97+
collector's weight_sync_schemes) to trainer source paths. Required if collector has
98+
weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network",
99+
"replay_buffer.transforms[0]": "loss_module.critic_network"}
100+
log_timings (bool, optional): If True, automatically register a LogTiming hook to log
101+
timing information for all hooks to the logger (e.g., wandb, tensorboard).
102+
Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
103+
Default is False.
104+
69105
Examples:
70106
>>> # Basic usage with manual configuration
71107
>>> from torchrl.trainers.algorithms.ppo import PPOTrainer
@@ -111,6 +147,10 @@ def __init__(
111147
log_actions: bool = True,
112148
log_observations: bool = False,
113149
async_collection: bool = False,
150+
add_gae: bool = True,
151+
gae: Callable[[TensorDictBase], TensorDictBase] | None = None,
152+
weight_update_map: dict[str, str] | None = None,
153+
log_timings: bool = False,
114154
) -> None:
115155
warnings.warn(
116156
"PPOTrainer is an experimental/prototype feature. The API may change in future versions. "
@@ -135,17 +175,21 @@ def __init__(
135175
save_trainer_file=save_trainer_file,
136176
num_epochs=num_epochs,
137177
async_collection=async_collection,
178+
log_timings=log_timings,
138179
)
139180
self.replay_buffer = replay_buffer
140181
self.async_collection = async_collection
141182

142-
gae = GAE(
143-
gamma=gamma,
144-
lmbda=lmbda,
145-
value_network=self.loss_module.critic_network,
146-
average_gae=True,
147-
)
148-
self.register_op("pre_epoch", gae)
183+
if add_gae and gae is None:
184+
gae = GAE(
185+
gamma=gamma,
186+
lmbda=lmbda,
187+
value_network=self.loss_module.critic_network,
188+
average_gae=True,
189+
)
190+
self.register_op("pre_epoch", gae)
191+
elif not add_gae and gae is not None:
192+
raise ValueError("gae must not be provided if add_gae is False")
149193

150194
if (
151195
not self.async_collection
@@ -167,16 +211,62 @@ def __init__(
167211
)
168212

169213
if not self.async_collection:
214+
# rb has been extended by the collector
215+
raise RuntimeError
170216
self.register_op("pre_epoch", rb_trainer.extend)
171217
self.register_op("process_optim_batch", rb_trainer.sample)
172218
self.register_op("post_loss", rb_trainer.update_priority)
173219

174-
policy_weights_getter = partial(
175-
TensorDict.from_module, self.loss_module.actor_network
176-
)
177-
update_weights = UpdateWeights(
178-
self.collector, 1, policy_weights_getter=policy_weights_getter
179-
)
220+
# Set up weight updates
221+
# Validate weight_update_map if collector has weight_sync_schemes
222+
if (
223+
hasattr(self.collector, "_weight_sync_schemes")
224+
and self.collector._weight_sync_schemes
225+
):
226+
if weight_update_map is None:
227+
raise ValueError(
228+
"Collector has weight_sync_schemes configured, but weight_update_map was not provided. "
229+
f"Please provide a mapping for all destinations: {list(self.collector._weight_sync_schemes.keys())}"
230+
)
231+
232+
# Validate that all scheme destinations are covered in the map
233+
scheme_destinations = set(self.collector._weight_sync_schemes.keys())
234+
map_destinations = set(weight_update_map.keys())
235+
236+
if scheme_destinations != map_destinations:
237+
missing = scheme_destinations - map_destinations
238+
extra = map_destinations - scheme_destinations
239+
error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n"
240+
if missing:
241+
error_msg += f" Missing destinations: {missing}\n"
242+
if extra:
243+
error_msg += f" Extra destinations: {extra}\n"
244+
raise ValueError(error_msg)
245+
246+
# Use the weight_update_map approach
247+
update_weights = UpdateWeights(
248+
self.collector,
249+
1,
250+
weight_update_map=weight_update_map,
251+
trainer=self,
252+
)
253+
else:
254+
# Fall back to legacy approach for backward compatibility
255+
if weight_update_map is not None:
256+
warnings.warn(
257+
"weight_update_map was provided but collector has no weight_sync_schemes. "
258+
"Ignoring weight_update_map and using legacy policy_weights_getter.",
259+
UserWarning,
260+
stacklevel=2,
261+
)
262+
263+
policy_weights_getter = partial(
264+
TensorDict.from_module, self.loss_module.actor_network
265+
)
266+
update_weights = UpdateWeights(
267+
self.collector, 1, policy_weights_getter=policy_weights_getter
268+
)
269+
180270
self.register_op("post_steps", update_weights)
181271

182272
# Store logging configuration

0 commit comments

Comments
 (0)