-
Notifications
You must be signed in to change notification settings - Fork 29
[MUON] add Muon optimizer #408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
585e6eb
bf3cbba
e491a54
d6efb2d
8c88a3c
82990ac
f5be732
6c29ad9
83611ed
87da505
6491f89
927e8c0
9236c07
e0ed61b
cadbaef
0376db9
dfc06cb
89e330b
4ca5c46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,22 +3,24 @@ | |||||||||||||||||||||
| # ************************************************** | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||
| from torch.optim import ASGD as TorchASGD | ||||||||||||||||||||||
| from torch.optim import LBFGS as TorchLBFGS | ||||||||||||||||||||||
| from torch.optim import SGD as TorchSGD | ||||||||||||||||||||||
| from torch.optim import Adadelta as TorchAdadelta | ||||||||||||||||||||||
| from torch.optim import Adagrad as TorchAdagrad | ||||||||||||||||||||||
| from torch.optim import Adam as TorchAdam | ||||||||||||||||||||||
| from torch.optim import Adamax as TorchAdamax | ||||||||||||||||||||||
| from torch.optim import AdamW as TorchAdamW | ||||||||||||||||||||||
| from torch.optim import Muon as TorchMuon | ||||||||||||||||||||||
| from torch.optim import NAdam as TorchNAdam | ||||||||||||||||||||||
| from torch.optim import Optimizer | ||||||||||||||||||||||
| from torch.optim.adadelta import Adadelta as TorchAdadelta | ||||||||||||||||||||||
| from torch.optim.adagrad import Adagrad as TorchAdagrad | ||||||||||||||||||||||
| from torch.optim.adam import Adam as TorchAdam | ||||||||||||||||||||||
| from torch.optim.adamax import Adamax as TorchAdamax | ||||||||||||||||||||||
| from torch.optim.adamw import AdamW as TorchAdamW | ||||||||||||||||||||||
| from torch.optim.asgd import ASGD as TorchASGD | ||||||||||||||||||||||
| from torch.optim.lbfgs import LBFGS as TorchLBFGS | ||||||||||||||||||||||
| from torch.optim.nadam import NAdam as TorchNAdam | ||||||||||||||||||||||
| from torch.optim.radam import RAdam as TorchRAdam | ||||||||||||||||||||||
| from torch.optim.rmsprop import RMSprop as TorchRMSprop | ||||||||||||||||||||||
| from torch.optim.rprop import Rprop as TorchRprop | ||||||||||||||||||||||
| from torch.optim.sgd import SGD as TorchSGD | ||||||||||||||||||||||
| from torch.optim import RAdam as TorchRAdam | ||||||||||||||||||||||
| from torch.optim import RMSprop as TorchRMSprop | ||||||||||||||||||||||
| from torch.optim import Rprop as TorchRprop | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer | ||||||||||||||||||||||
| from ..enums import ParamsGroupMethod | ||||||||||||||||||||||
| from ..hf_models import get_optimizer_split_function | ||||||||||||||||||||||
| from .params_group import get_param_groups_list | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -31,6 +33,7 @@ | |||||||||||||||||||||
| "TorchAdamW": TorchAdamW, | ||||||||||||||||||||||
| "TorchASGD": TorchASGD, | ||||||||||||||||||||||
| "TorchLBFGS": TorchLBFGS, | ||||||||||||||||||||||
| "TorchMuon": TorchMuon, | ||||||||||||||||||||||
| "TorchNAdam": TorchNAdam, | ||||||||||||||||||||||
| "TorchRAdam": TorchRAdam, | ||||||||||||||||||||||
| "TorchRMSprop": TorchRMSprop, | ||||||||||||||||||||||
|
|
@@ -39,6 +42,66 @@ | |||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Parameter name substrings that must use AdamW instead of Muon (embeddings and lm_head) | ||||||||||||||||||||||
| _MUON_ADAMW_PARAM_NAMES = {"wte", "lm_head"} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: | ||||||||||||||||||||||
| """Returns True if this param should use AdamW when the main optimizer is Muon.""" | ||||||||||||||||||||||
| if param.ndim == 1: | ||||||||||||||||||||||
| return True | ||||||||||||||||||||||
| return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class _MuonWithAdamW(Optimizer): | ||||||||||||||||||||||
| """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __init__(self, muon: TorchMuon | None, adamw: TorchAdamW | None) -> None: | ||||||||||||||||||||||
| self.muon = muon | ||||||||||||||||||||||
| self.adamw = adamw | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @property | ||||||||||||||||||||||
| def param_groups(self) -> list[dict]: | ||||||||||||||||||||||
| groups = [] | ||||||||||||||||||||||
| if self.muon is not None: | ||||||||||||||||||||||
| groups.extend(self.muon.param_groups) | ||||||||||||||||||||||
| if self.adamw is not None: | ||||||||||||||||||||||
| groups.extend(self.adamw.param_groups) | ||||||||||||||||||||||
| return groups | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def step(self) -> None: | ||||||||||||||||||||||
| if self.muon is not None: | ||||||||||||||||||||||
| self.muon.step() | ||||||||||||||||||||||
| if self.adamw is not None: | ||||||||||||||||||||||
| self.adamw.step() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def zero_grad(self) -> None: | ||||||||||||||||||||||
| if self.muon is not None: | ||||||||||||||||||||||
| self.muon.zero_grad() | ||||||||||||||||||||||
| if self.adamw is not None: | ||||||||||||||||||||||
| self.adamw.zero_grad() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def state_dict(self) -> dict: | ||||||||||||||||||||||
| return { | ||||||||||||||||||||||
| "muon": self.muon.state_dict() if self.muon is not None else None, | ||||||||||||||||||||||
| "adamw": self.adamw.state_dict() if self.adamw is not None else None, | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def load_state_dict(self, state_dict: dict) -> None: | ||||||||||||||||||||||
| if self.muon is not None and state_dict["muon"] is not None: | ||||||||||||||||||||||
| self.muon.load_state_dict(state_dict["muon"]) | ||||||||||||||||||||||
| if self.adamw is not None and state_dict["adamw"] is not None: | ||||||||||||||||||||||
| self.adamw.load_state_dict(state_dict["adamw"]) | ||||||||||||||||||||||
|
Comment on lines
+95
to
+99
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __repr__(self) -> str: | ||||||||||||||||||||||
| return f"MuonWithAdamW(\n muon={self.muon},\n adamw={self.adamw}\n)" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_optimizer_container( | ||||||||||||||||||||||
| optimizer_class_name: str, | ||||||||||||||||||||||
| optimizer_class_args: dict, | ||||||||||||||||||||||
|
|
@@ -71,6 +134,9 @@ def get_optimizer_container( | |||||||||||||||||||||
| if use_optimizer_with_backward_hook: | ||||||||||||||||||||||
| for model, params_groups in zip(model_container, params_groups_list): | ||||||||||||||||||||||
| for param_name, param in model.named_parameters(): | ||||||||||||||||||||||
| if get_optimizer_split_function(param) is not None: | ||||||||||||||||||||||
| assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion check will always pass because
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mayank31398 - also looks like a bug: in line 45 you have: |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for group in params_groups.params_groups: | ||||||||||||||||||||||
| if param_name in group.parameter_name_map: | ||||||||||||||||||||||
| param._optimizer = optimizer_class( | ||||||||||||||||||||||
|
|
@@ -86,12 +152,48 @@ def _step(p: nn.Parameter) -> None: | |||||||||||||||||||||
| break | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) | ||||||||||||||||||||||
| elif optimizer_class_name == "TorchMuon": | ||||||||||||||||||||||
| adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When using Muon, the AdamW optimizer used for 1D parameters and embeddings currently only inherits the learning rate. Other important hyperparameters like
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make sure we can pass all the adamw configs |
||||||||||||||||||||||
| optimizer_list_entries = [] | ||||||||||||||||||||||
| for params_groups in params_groups_list: | ||||||||||||||||||||||
| muon_groups = [] | ||||||||||||||||||||||
| adamw_groups = [] | ||||||||||||||||||||||
| for group in params_groups.params_groups: | ||||||||||||||||||||||
| muon_params = [] | ||||||||||||||||||||||
| adamw_params = [] | ||||||||||||||||||||||
| for param_name, param in group.parameter_name_map.items(): | ||||||||||||||||||||||
| if _is_muon_adamw_param(param_name, param): | ||||||||||||||||||||||
| adamw_params.append(param) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| split_fn = get_optimizer_split_function(param) | ||||||||||||||||||||||
| muon_params.extend(split_fn(param) if split_fn is not None else [param]) | ||||||||||||||||||||||
| if muon_params: | ||||||||||||||||||||||
| muon_groups.append({"params": muon_params, **group.params_group_kwargs}) | ||||||||||||||||||||||
| if adamw_params: | ||||||||||||||||||||||
| adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| optimizer_list_entries.append( | ||||||||||||||||||||||
| _MuonWithAdamW( | ||||||||||||||||||||||
| muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None, | ||||||||||||||||||||||
| adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| optimizer_list = OptimizerContainer(optimizer_list_entries) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| optimizer_list = OptimizerContainer( | ||||||||||||||||||||||
| [ | ||||||||||||||||||||||
| optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) | ||||||||||||||||||||||
| for params_groups in params_groups_list | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| optimizer_list_entries = [] | ||||||||||||||||||||||
| for params_groups in params_groups_list: | ||||||||||||||||||||||
| torch_params_groups = params_groups.to_torch_compatible_params_groups() | ||||||||||||||||||||||
| for group in torch_params_groups: | ||||||||||||||||||||||
| split_params = [] | ||||||||||||||||||||||
| for param in group["params"]: | ||||||||||||||||||||||
| split_fn = get_optimizer_split_function(param) | ||||||||||||||||||||||
| split_params.extend(split_fn(param) if split_fn is not None else [param]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| group["params"] = split_params | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| optimizer_list = OptimizerContainer(optimizer_list_entries) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return optimizer_list | ||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.