-
Notifications
You must be signed in to change notification settings - Fork 23
Added an experimental Muon optimizer to PET #977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a896d58
82e9641
65723f5
70311db
101ddd4
89c9f76
d40f5f4
0af2b19
96b0ecd
2865824
d6eb0b2
3d608b2
7fa1b5a
0603a26
6cef387
7b1d2b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| import logging | ||
| import math | ||
| from typing import Dict, Tuple, Union | ||
|
|
||
| import torch | ||
| from packaging import version | ||
| from torch.optim.lr_scheduler import LambdaLR | ||
|
|
||
| from ..documentation import TrainerHypers | ||
| from ..model import PET | ||
|
|
||
|
|
||
| def get_optimizer(model: PET, hypers: TrainerHypers) -> torch.optim.Optimizer: | ||
| """ | ||
| Get the optimizer based on the hyperparameters. | ||
|
|
||
| :param model: The model to optimize. | ||
| :param hypers: The training hyperparameters. | ||
| :return: The optimizer. | ||
| """ | ||
| if hypers["weight_decay"] is None: | ||
| weight_decay = 0.0 | ||
| else: | ||
| weight_decay = hypers["weight_decay"] | ||
| lr = hypers.get("learning_rate", 1e-4) | ||
| if hypers["optimizer"].lower() == "adam": | ||
| optimizer = torch.optim.Adam( | ||
| model.parameters(), lr=lr, weight_decay=weight_decay | ||
| ) | ||
| elif hypers["optimizer"].lower() == "adamw": | ||
| optimizer = torch.optim.AdamW( | ||
| model.parameters(), | ||
| lr=lr, | ||
| weight_decay=weight_decay, | ||
| ) | ||
| elif hypers["optimizer"].lower() == "muon": | ||
| if version.parse(torch.__version__) < version.parse("2.9.1"): | ||
| raise ValueError( | ||
| f"The Muon optimizer requires PyTorch >= 2.9.1, but you have " | ||
| f"{torch.__version__}. This feature is experimental and so far " | ||
| "not well tested. Please manually update PyTorch to use the " | ||
| "Muon optimizer." | ||
| ) | ||
| logging.warning( | ||
| "Using the Muon optimizer with auxiliary AdamW for non-matrix " | ||
| "parameters. This feature is experimental and so far not well tested. " | ||
| "Please use it with caution or set the optimizer to Adam or AdamW in the " | ||
| "options.yaml." | ||
| ) | ||
| # Separate parameters into Muon and Adam groups. | ||
| # By design, Muon should only be used for the matrix-type parameters | ||
| # (i. e. those with ndim >= 2), and only for optimizing the hidden | ||
| # layers of the model (in our case, the GNN layers). All other parameters | ||
| # including biases, embeddings, and readout layers (heads) should be | ||
| # optimized with Adam or AdamW. | ||
| muon_params = [] | ||
| adam_params = [] | ||
| for n, p in model.named_parameters(): | ||
| if p.ndim >= 2 and ( | ||
| ("gnn_layers" in n and "neighbor_embedder" not in n) | ||
| or "combination_mlps" in n | ||
| ): | ||
| muon_params.append(p) | ||
| else: | ||
| adam_params.append(p) | ||
| adam_group = dict(params=adam_params, use_muon=False) | ||
| muon_group = dict(params=muon_params, use_muon=True) | ||
| optimizer = MuonWithAuxAdamW( | ||
| [muon_group, adam_group], | ||
| lr=lr, | ||
| weight_decay=weight_decay, | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Unknown optimizer: {hypers['optimizer']}. Please choose Adam, " | ||
| f"AdamW or Muon." | ||
| ) | ||
|
|
||
| return optimizer | ||
|
|
||
|
|
||
| def get_scheduler( | ||
| optimizer: torch.optim.Optimizer, | ||
| train_hypers: TrainerHypers, | ||
| steps_per_epoch: int, | ||
| ) -> LambdaLR: | ||
| """ | ||
| Get a CosineAnnealing learning-rate scheduler with warmup | ||
|
|
||
| :param optimizer: The optimizer for which to create the scheduler. | ||
| :param train_hypers: The training hyperparameters. | ||
| :param steps_per_epoch: The number of steps per epoch. | ||
| :return: The learning rate scheduler. | ||
| """ | ||
| total_steps = train_hypers["num_epochs"] * steps_per_epoch | ||
| warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) | ||
| min_lr_ratio = 1e-4 # hardcode minimum LR ratio | ||
|
|
||
| logging.info( | ||
| f"Using cosine decay from {train_hypers['learning_rate']} to " | ||
| f"{train_hypers['learning_rate'] * min_lr_ratio} after " | ||
| f"{warmup_steps} warmup optimizer steps and {total_steps} " | ||
| "total steps." | ||
| ) | ||
|
|
||
| def lr_lambda(current_step: int) -> float: | ||
| if current_step < warmup_steps: | ||
| # Linear warmup | ||
| return float(current_step) / float(max(1, warmup_steps)) | ||
| else: | ||
| # Cosine decay | ||
| progress = (current_step - warmup_steps) / float( | ||
| max(1, total_steps - warmup_steps) | ||
| ) | ||
| cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
| return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay | ||
|
|
||
| scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) | ||
| return scheduler | ||
|
|
||
|
|
||
| class MuonWithAuxAdamW(torch.optim.Optimizer): | ||
| """ | ||
| Combined optimizer with Muon and AdamW for different parameter groups. | ||
|
|
||
| :param param_groups: Parameter groups for the optimizer. | ||
| :param lr: Learning rate. | ||
| :param weight_decay: Weight decay. | ||
| :param momentum: Momentum for Muon. | ||
| :param eps: Epsilon for AdamW. | ||
| :param betas: Betas for AdamW. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| param_groups: list, | ||
| lr: Union[float, torch.Tensor] = 0.001, | ||
| weight_decay: float = 0.0, | ||
| momentum: float = 0.95, | ||
| eps: float = 1e-10, | ||
| betas: Tuple[float, float] = (0.9, 0.95), | ||
| ): | ||
| # Set defaults that will be merged into param_groups | ||
| defaults = dict( | ||
| lr=lr, | ||
| weight_decay=weight_decay, | ||
| momentum=momentum, | ||
| eps=eps, | ||
| betas=betas, | ||
| ) | ||
|
|
||
| # Initialize base optimizer first (this merges defaults into param_groups) | ||
| super().__init__(param_groups, defaults) | ||
|
|
||
| # Now create the internal optimizers using the fully initialized param_groups | ||
| for group in self.param_groups: | ||
| assert "use_muon" in group | ||
| params = group["params"] | ||
| if group["use_muon"]: | ||
| self.muon_optimizer = torch.optim.Muon( | ||
| params, | ||
| lr=group["lr"], | ||
| momentum=group["momentum"], | ||
| ) | ||
| else: | ||
| self.adamw_optimizer = torch.optim.AdamW( | ||
| params, | ||
| lr=group["lr"], | ||
| betas=group["betas"], | ||
| eps=group["eps"], | ||
| weight_decay=group["weight_decay"], | ||
| ) | ||
|
|
||
| @torch.no_grad() | ||
| def step(self) -> None: | ||
| self.muon_optimizer.step() | ||
| self.adamw_optimizer.step() | ||
|
|
||
| def zero_grad(self, set_to_none: bool = True) -> None: | ||
| self.muon_optimizer.zero_grad(set_to_none=set_to_none) | ||
| self.adamw_optimizer.zero_grad(set_to_none=set_to_none) | ||
|
|
||
| def load_state_dict(self, state_dict: Dict) -> None: | ||
| self.muon_optimizer.load_state_dict(state_dict["muon_optimizer"]) | ||
| self.adamw_optimizer.load_state_dict(state_dict["adamw_optimizer"]) | ||
|
|
||
| def state_dict(self) -> Dict: | ||
| return { | ||
| "muon_optimizer": self.muon_optimizer.state_dict(), | ||
| "adamw_optimizer": self.adamw_optimizer.state_dict(), | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be regenerated with different hyper parameters?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure - just the PET one or the others too?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great if you could do it for all the newly generated checkpoints! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want two separate learning rates for the Adam and Muon parameter groups.
If you look at the example from the README of https://github.com/KellerJordan/Muon:
the Adam LR is more what we'd normally expect but the Muon one can be pushed much higher.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit sceptical with setting the LR values like this. I mean, then should be highly architecture-dependent, right? In the same time @sirmarcel has tested Muon for PET and noticed that it works nice even with a common LR of ~1e-3 for both Adam and Muon parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm ok, I noticed in my tests that I could push the Muon LR to 1e-1 even and it was still stable, but as soon as the Adam LR went above 1e-3 training diverged. But again, an extra hyperparameter is more complexity, so let's keep it simple and have one as you say for now