diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index e7cd2e04a..bf9af9159 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -22,6 +22,7 @@ ) from .parameter import ( _INIT_MARKER, + get_optimizer_split_function, get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 15d38ded9..ab5c96a60 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -4,10 +4,12 @@ from __future__ import annotations +from functools import partial + import torch import torch.nn as nn -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..init_utils import _get_std_for_linear @@ -77,6 +79,10 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + set_optimizer_split_function( + self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.use_interleaved_weights) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = self.act(x, is_interleaved=self.use_interleaved_weights) if self.is_glu else self.act(x) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 80020f0af..a9d4efbaf 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -5,6 +5,7 @@ from __future__ import annotations import math +from functools import partial import torch import torch.nn.functional as F @@ -14,7 +15,7 @@ from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache, GenerationState, LinearCache from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import sigmoid from ..chunk import contiguous_split from ..dropout import Dropout @@ -198,6 +199,15 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + set_optimizer_split_function( + self.c_attn.weight, + partial( + split_query_key_value_tensor_for_attention, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + ), + ) + def forward( self, x: torch.Tensor, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 0a816f0a0..eb35e9c8f 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from typing import Callable + import torch.nn as nn @@ -77,3 +79,14 @@ def set_parameter_marker_maps( for marker, value in _marker_map[param_name].items(): setattr(parameter, marker, value) + + +def set_optimizer_split_function(parameter: nn.Parameter | None, function: Callable) -> nn.Parameter | None: + if parameter is not None: + parameter._optimizer_split_function = function + + return parameter + + +def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None: + return getattr(parameter, "_optimizer_split_function", None) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d03fdea39..c2e25d852 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -3,22 +3,24 @@ # ************************************************** import torch.nn as nn +from torch.optim import ASGD as TorchASGD +from torch.optim import LBFGS as TorchLBFGS +from torch.optim import SGD as TorchSGD +from torch.optim import Adadelta as TorchAdadelta +from torch.optim import Adagrad as TorchAdagrad +from torch.optim import Adam as TorchAdam +from torch.optim import Adamax as TorchAdamax +from torch.optim import AdamW as TorchAdamW +from torch.optim import Muon as TorchMuon +from torch.optim import NAdam as TorchNAdam from torch.optim import Optimizer -from torch.optim.adadelta import Adadelta as TorchAdadelta -from torch.optim.adagrad import Adagrad as TorchAdagrad -from torch.optim.adam import Adam as TorchAdam -from torch.optim.adamax import Adamax as TorchAdamax -from torch.optim.adamw import AdamW as TorchAdamW -from torch.optim.asgd import ASGD as TorchASGD -from torch.optim.lbfgs import LBFGS as TorchLBFGS -from torch.optim.nadam import NAdam as TorchNAdam -from torch.optim.radam import RAdam as TorchRAdam -from torch.optim.rmsprop import RMSprop as TorchRMSprop -from torch.optim.rprop import Rprop as TorchRprop -from torch.optim.sgd import SGD as TorchSGD +from torch.optim import RAdam as TorchRAdam +from torch.optim import RMSprop as TorchRMSprop +from torch.optim import Rprop as TorchRprop from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod +from ..hf_models import get_optimizer_split_function from .params_group import get_param_groups_list @@ -31,6 +33,7 @@ "TorchAdamW": TorchAdamW, "TorchASGD": TorchASGD, "TorchLBFGS": TorchLBFGS, + "TorchMuon": TorchMuon, "TorchNAdam": TorchNAdam, "TorchRAdam": TorchRAdam, "TorchRMSprop": TorchRMSprop, @@ -39,6 +42,66 @@ } +_SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] + +# Parameter name substrings that must use AdamW instead of Muon (embeddings and lm_head) +_MUON_ADAMW_PARAM_NAMES = {"wte", "lm_head"} + + +def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: + """Returns True if this param should use AdamW when the main optimizer is Muon.""" + if param.ndim == 1: + return True + return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) + + +class _MuonWithAdamW(Optimizer): + """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. + + Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. + """ + + def __init__(self, muon: TorchMuon | None, adamw: TorchAdamW | None) -> None: + self.muon = muon + self.adamw = adamw + + @property + def param_groups(self) -> list[dict]: + groups = [] + if self.muon is not None: + groups.extend(self.muon.param_groups) + if self.adamw is not None: + groups.extend(self.adamw.param_groups) + return groups + + def step(self) -> None: + if self.muon is not None: + self.muon.step() + if self.adamw is not None: + self.adamw.step() + + def zero_grad(self) -> None: + if self.muon is not None: + self.muon.zero_grad() + if self.adamw is not None: + self.adamw.zero_grad() + + def state_dict(self) -> dict: + return { + "muon": self.muon.state_dict() if self.muon is not None else None, + "adamw": self.adamw.state_dict() if self.adamw is not None else None, + } + + def load_state_dict(self, state_dict: dict) -> None: + if self.muon is not None and state_dict["muon"] is not None: + self.muon.load_state_dict(state_dict["muon"]) + if self.adamw is not None and state_dict["adamw"] is not None: + self.adamw.load_state_dict(state_dict["adamw"]) + + def __repr__(self) -> str: + return f"MuonWithAdamW(\n muon={self.muon},\n adamw={self.adamw}\n)" + + def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, @@ -71,6 +134,9 @@ def get_optimizer_container( if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): + if get_optimizer_split_function(param) is not None: + assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS + for group in params_groups.params_groups: if param_name in group.parameter_name_map: param._optimizer = optimizer_class( @@ -86,12 +152,48 @@ def _step(p: nn.Parameter) -> None: break optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) + elif optimizer_class_name == "TorchMuon": + adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} + optimizer_list_entries = [] + for params_groups in params_groups_list: + muon_groups = [] + adamw_groups = [] + for group in params_groups.params_groups: + muon_params = [] + adamw_params = [] + for param_name, param in group.parameter_name_map.items(): + if _is_muon_adamw_param(param_name, param): + adamw_params.append(param) + else: + split_fn = get_optimizer_split_function(param) + muon_params.extend(split_fn(param) if split_fn is not None else [param]) + if muon_params: + muon_groups.append({"params": muon_params, **group.params_group_kwargs}) + if adamw_params: + adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) + + optimizer_list_entries.append( + _MuonWithAdamW( + muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None, + adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None, + ) + ) + + optimizer_list = OptimizerContainer(optimizer_list_entries) else: - optimizer_list = OptimizerContainer( - [ - optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) - for params_groups in params_groups_list - ] - ) + optimizer_list_entries = [] + for params_groups in params_groups_list: + torch_params_groups = params_groups.to_torch_compatible_params_groups() + for group in torch_params_groups: + split_params = [] + for param in group["params"]: + split_fn = get_optimizer_split_function(param) + split_params.extend(split_fn(param) if split_fn is not None else [param]) + + group["params"] = split_params + + optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) + + optimizer_list = OptimizerContainer(optimizer_list_entries) return optimizer_list