Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion test/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,10 @@ def test_tensor_dict_module_config(self):
in_keys=["observation"],
out_keys=["action"],
)
assert cfg._target_ == "tensordict.nn.TensorDictModule"
assert (
cfg._target_
== "torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
)
assert cfg.module._target_ == "torchrl.modules.MLP"
assert cfg.in_keys == ["observation"]
assert cfg.out_keys == ["action"]
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,8 @@ class GAE(ValueEstimatorBase):
`"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data.
"""

value_network: TensorDictModule | None

def __init__(
self,
*,
Expand Down
46 changes: 43 additions & 3 deletions torchrl/trainers/algorithms/configs/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ class TensorDictModuleConfig(ModelConfig):
"""

module: MLPConfig = MISSING
_target_: str = "tensordict.nn.TensorDictModule"
_target_: str = (
"torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
)
_partial_: bool = False

def __post_init__(self) -> None:
Expand Down Expand Up @@ -292,6 +294,30 @@ def __post_init__(self) -> None:
super().__post_init__()


def _make_tensordict_module(*args, **kwargs):
"""Helper function to create a TensorDictModule."""
from hydra.utils import instantiate
from tensordict.nn import TensorDictModule

module = kwargs.pop("module")
shared = kwargs.pop("shared", False)

# Instantiate the module if it's a config
if hasattr(module, "_target_"):
module = instantiate(module)
elif callable(module) and hasattr(module, "func"): # partial function
module = module()

# Create the TensorDictModule
tensordict_module = TensorDictModule(module, **kwargs)

# Apply share_memory if needed
if shared:
tensordict_module = tensordict_module.share_memory()

return tensordict_module


def _make_tanh_normal_model(*args, **kwargs):
"""Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
from hydra.utils import instantiate
Expand Down Expand Up @@ -351,10 +377,24 @@ def _make_tanh_normal_model(*args, **kwargs):

def _make_value_model(*args, **kwargs):
"""Helper function to create a ValueOperator with the given network."""
from hydra.utils import instantiate

from torchrl.modules import ValueOperator

network = kwargs.pop("network")
shared = kwargs.pop("shared", False)

# Instantiate the network if it's a config
if hasattr(network, "_target_"):
network = instantiate(network)
elif callable(network) and hasattr(network, "func"): # partial function
network = network()

# Create the ValueOperator
value_operator = ValueOperator(network, **kwargs)

# Apply share_memory if needed
if shared:
network = network.share_memory()
return ValueOperator(network, **kwargs)
value_operator = value_operator.share_memory()

return value_operator
115 changes: 102 additions & 13 deletions torchrl/trainers/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,42 @@ class PPOTrainer(Trainer):

Logging can be configured via constructor parameters to enable/disable specific metrics.

Args:
collector (DataCollectorBase): The data collector for gathering training data.
total_frames (int): Total number of frames to train for.
frame_skip (int): Frame skip value for the environment.
optim_steps_per_batch (int): Number of optimization steps per batch.
loss_module (LossModule): The loss module for computing policy and value losses.
optimizer (optim.Optimizer, optional): The optimizer for training.
logger (Logger, optional): Logger for tracking training metrics.
clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True.
clip_norm (float, optional): Maximum gradient norm value.
progress_bar (bool, optional): Whether to show a progress bar. Default: True.
seed (int, optional): Random seed for reproducibility.
save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000.
log_interval (int, optional): Interval for logging metrics. Default: 10000.
save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state.
num_epochs (int, optional): Number of epochs per batch. Default: 4.
replay_buffer (ReplayBuffer, optional): Replay buffer for storing data.
batch_size (int, optional): Batch size for optimization.
gamma (float, optional): Discount factor for GAE. Default: 0.9.
lmbda (float, optional): Lambda parameter for GAE. Default: 0.99.
enable_logging (bool, optional): Whether to enable logging. Default: True.
log_rewards (bool, optional): Whether to log rewards. Default: True.
log_actions (bool, optional): Whether to log actions. Default: True.
log_observations (bool, optional): Whether to log observations. Default: False.
async_collection (bool, optional): Whether to use async collection. Default: False.
add_gae (bool, optional): Whether to add GAE computation. Default: True.
gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created.
weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in
collector's weight_sync_schemes) to trainer source paths. Required if collector has
weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network",
"replay_buffer.transforms[0]": "loss_module.critic_network"}
log_timings (bool, optional): If True, automatically register a LogTiming hook to log
timing information for all hooks to the logger (e.g., wandb, tensorboard).
Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
Default is False.

Examples:
>>> # Basic usage with manual configuration
>>> from torchrl.trainers.algorithms.ppo import PPOTrainer
Expand Down Expand Up @@ -111,6 +147,10 @@ def __init__(
log_actions: bool = True,
log_observations: bool = False,
async_collection: bool = False,
add_gae: bool = True,
gae: Callable[[TensorDictBase], TensorDictBase] | None = None,
weight_update_map: dict[str, str] | None = None,
log_timings: bool = False,
) -> None:
warnings.warn(
"PPOTrainer is an experimental/prototype feature. The API may change in future versions. "
Expand All @@ -135,17 +175,21 @@ def __init__(
save_trainer_file=save_trainer_file,
num_epochs=num_epochs,
async_collection=async_collection,
log_timings=log_timings,
)
self.replay_buffer = replay_buffer
self.async_collection = async_collection

gae = GAE(
gamma=gamma,
lmbda=lmbda,
value_network=self.loss_module.critic_network,
average_gae=True,
)
self.register_op("pre_epoch", gae)
if add_gae and gae is None:
gae = GAE(
gamma=gamma,
lmbda=lmbda,
value_network=self.loss_module.critic_network,
average_gae=True,
)
self.register_op("pre_epoch", gae)
elif not add_gae and gae is not None:
raise ValueError("gae must not be provided if add_gae is False")

if (
not self.async_collection
Expand All @@ -167,16 +211,61 @@ def __init__(
)

if not self.async_collection:
# rb has been extended by the collector
self.register_op("pre_epoch", rb_trainer.extend)
self.register_op("process_optim_batch", rb_trainer.sample)
self.register_op("post_loss", rb_trainer.update_priority)

policy_weights_getter = partial(
TensorDict.from_module, self.loss_module.actor_network
)
update_weights = UpdateWeights(
self.collector, 1, policy_weights_getter=policy_weights_getter
)
# Set up weight updates
# Validate weight_update_map if collector has weight_sync_schemes
if (
hasattr(self.collector, "_weight_sync_schemes")
and self.collector._weight_sync_schemes
):
if weight_update_map is None:
raise ValueError(
"Collector has weight_sync_schemes configured, but weight_update_map was not provided. "
f"Please provide a mapping for all destinations: {list(self.collector._weight_sync_schemes.keys())}"
)

# Validate that all scheme destinations are covered in the map
scheme_destinations = set(self.collector._weight_sync_schemes.keys())
map_destinations = set(weight_update_map.keys())

if scheme_destinations != map_destinations:
missing = scheme_destinations - map_destinations
extra = map_destinations - scheme_destinations
error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n"
if missing:
error_msg += f" Missing destinations: {missing}\n"
if extra:
error_msg += f" Extra destinations: {extra}\n"
raise ValueError(error_msg)

# Use the weight_update_map approach
update_weights = UpdateWeights(
self.collector,
1,
weight_update_map=weight_update_map,
trainer=self,
)
else:
# Fall back to legacy approach for backward compatibility
if weight_update_map is not None:
warnings.warn(
"weight_update_map was provided but collector has no weight_sync_schemes. "
"Ignoring weight_update_map and using legacy policy_weights_getter.",
UserWarning,
stacklevel=2,
)

policy_weights_getter = partial(
TensorDict.from_module, self.loss_module.actor_network
)
update_weights = UpdateWeights(
self.collector, 1, policy_weights_getter=policy_weights_getter
)

self.register_op("post_steps", update_weights)

# Store logging configuration
Expand Down
Loading