Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lm_engine/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from .parameter import (
_INIT_MARKER,
get_optimizer_split_function,
get_parameter_marker_maps,
is_parameter_initialized,
is_parameter_with_mup_learning_rate,
Expand Down
8 changes: 7 additions & 1 deletion lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from __future__ import annotations

from functools import partial

import torch
import torch.nn as nn

from ...parameter import mark_parameter_as_mup_learning_rate
from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function
from ..activations import get_activation_function, is_glu
from ..dropout import Dropout
from ..init_utils import _get_std_for_linear
Expand Down Expand Up @@ -77,6 +79,10 @@ def __init__(
mark_parameter_as_mup_learning_rate(self.c_fc.weight)
mark_parameter_as_mup_learning_rate(self.c_proj.weight)

set_optimizer_split_function(
self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.use_interleaved_weights)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.act(x, is_interleaved=self.use_interleaved_weights) if self.is_glu else self.act(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import math
from functools import partial

import torch
import torch.nn.functional as F
Expand All @@ -14,7 +15,7 @@
from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available
from ...cache import GenerationCache, GenerationState, LinearCache
from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD
from ...parameter import mark_parameter_as_mup_learning_rate
from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function
from ..activations import sigmoid
from ..chunk import contiguous_split
from ..dropout import Dropout
Expand Down Expand Up @@ -198,6 +199,15 @@ def __init__(
mark_parameter_as_mup_learning_rate(self.c_attn.weight)
mark_parameter_as_mup_learning_rate(self.c_proj.weight)

set_optimizer_split_function(
self.c_attn.weight,
partial(
split_query_key_value_tensor_for_attention,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
),
)
Comment thread
mayank31398 marked this conversation as resolved.

def forward(
self,
x: torch.Tensor,
Expand Down
13 changes: 13 additions & 0 deletions lm_engine/hf_models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from typing import Callable

import torch.nn as nn


Expand Down Expand Up @@ -77,3 +79,14 @@ def set_parameter_marker_maps(

for marker, value in _marker_map[param_name].items():
setattr(parameter, marker, value)


def set_optimizer_split_function(parameter: nn.Parameter | None, function: Callable) -> nn.Parameter | None:
if parameter is not None:
parameter._optimizer_split_function = function

return parameter


def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None:
return getattr(parameter, "_optimizer_split_function", None)
138 changes: 120 additions & 18 deletions lm_engine/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
# **************************************************

import torch.nn as nn
from torch.optim import ASGD as TorchASGD
from torch.optim import LBFGS as TorchLBFGS
from torch.optim import SGD as TorchSGD
from torch.optim import Adadelta as TorchAdadelta
from torch.optim import Adagrad as TorchAdagrad
from torch.optim import Adam as TorchAdam
from torch.optim import Adamax as TorchAdamax
from torch.optim import AdamW as TorchAdamW
from torch.optim import Muon as TorchMuon
from torch.optim import NAdam as TorchNAdam
from torch.optim import Optimizer
from torch.optim.adadelta import Adadelta as TorchAdadelta
from torch.optim.adagrad import Adagrad as TorchAdagrad
from torch.optim.adam import Adam as TorchAdam
from torch.optim.adamax import Adamax as TorchAdamax
from torch.optim.adamw import AdamW as TorchAdamW
from torch.optim.asgd import ASGD as TorchASGD
from torch.optim.lbfgs import LBFGS as TorchLBFGS
from torch.optim.nadam import NAdam as TorchNAdam
from torch.optim.radam import RAdam as TorchRAdam
from torch.optim.rmsprop import RMSprop as TorchRMSprop
from torch.optim.rprop import Rprop as TorchRprop
from torch.optim.sgd import SGD as TorchSGD
from torch.optim import RAdam as TorchRAdam
from torch.optim import RMSprop as TorchRMSprop
from torch.optim import Rprop as TorchRprop

from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer
from ..enums import ParamsGroupMethod
from ..hf_models import get_optimizer_split_function
from .params_group import get_param_groups_list


Expand All @@ -31,6 +33,7 @@
"TorchAdamW": TorchAdamW,
"TorchASGD": TorchASGD,
"TorchLBFGS": TorchLBFGS,
"TorchMuon": TorchMuon,
"TorchNAdam": TorchNAdam,
"TorchRAdam": TorchRAdam,
"TorchRMSprop": TorchRMSprop,
Expand All @@ -39,6 +42,66 @@
}


_SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"]

# Parameter name substrings that must use AdamW instead of Muon (embeddings and lm_head)
_MUON_ADAMW_PARAM_NAMES = {"wte", "lm_head"}


def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool:
"""Returns True if this param should use AdamW when the main optimizer is Muon."""
if param.ndim == 1:
return True
return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES)


class _MuonWithAdamW(Optimizer):
"""Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object.

Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params.
"""

def __init__(self, muon: TorchMuon | None, adamw: TorchAdamW | None) -> None:
self.muon = muon
self.adamw = adamw

@property
def param_groups(self) -> list[dict]:
groups = []
if self.muon is not None:
groups.extend(self.muon.param_groups)
if self.adamw is not None:
groups.extend(self.adamw.param_groups)
return groups

def step(self) -> None:
if self.muon is not None:
self.muon.step()
if self.adamw is not None:
self.adamw.step()

def zero_grad(self) -> None:
if self.muon is not None:
self.muon.zero_grad()
if self.adamw is not None:
self.adamw.zero_grad()

def state_dict(self) -> dict:
return {
"muon": self.muon.state_dict() if self.muon is not None else None,
"adamw": self.adamw.state_dict() if self.adamw is not None else None,
}

def load_state_dict(self, state_dict: dict) -> None:
if self.muon is not None and state_dict["muon"] is not None:
self.muon.load_state_dict(state_dict["muon"])
if self.adamw is not None and state_dict["adamw"] is not None:
self.adamw.load_state_dict(state_dict["adamw"])
Comment on lines +95 to +99
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The load_state_dict method in the _MuonWithAdamW wrapper assumes the provided state_dict always contains "muon" and "adamw" keys. If a standard optimizer state dict is passed (e.g., during a transition or from a different checkpoint format), this will raise a KeyError. Using .get() would make this more robust.

Suggested change
def load_state_dict(self, state_dict: dict) -> None:
if self.muon is not None and state_dict["muon"] is not None:
self.muon.load_state_dict(state_dict["muon"])
if self.adamw is not None and state_dict["adamw"] is not None:
self.adamw.load_state_dict(state_dict["adamw"])
def load_state_dict(self, state_dict: dict) -> None:
if self.muon is not None and state_dict.get("muon") is not None:
self.muon.load_state_dict(state_dict["muon"])
if self.adamw is not None and state_dict.get("adamw") is not None:
self.adamw.load_state_dict(state_dict["adamw"])


def __repr__(self) -> str:
return f"MuonWithAdamW(\n muon={self.muon},\n adamw={self.adamw}\n)"


def get_optimizer_container(
optimizer_class_name: str,
optimizer_class_args: dict,
Expand Down Expand Up @@ -71,6 +134,9 @@ def get_optimizer_container(
if use_optimizer_with_backward_hook:
for model, params_groups in zip(model_container, params_groups_list):
for param_name, param in model.named_parameters():
if get_optimizer_split_function(param) is not None:
assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion check will always pass because optimizer_class is a class object, while _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS is a list of strings (e.g., ["TorchMuon"]). You should check against optimizer_class_name instead.

Suggested change
assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS
assert optimizer_class_name not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mayank31398 - also looks like a bug: in line 45 you have: _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] and in 155 you have optimizer_class_name == "TorchMuon"


for group in params_groups.params_groups:
if param_name in group.parameter_name_map:
param._optimizer = optimizer_class(
Expand All @@ -86,12 +152,48 @@ def _step(p: nn.Parameter) -> None:
break

optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container))
elif optimizer_class_name == "TorchMuon":
adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When using Muon, the AdamW optimizer used for 1D parameters and embeddings currently only inherits the learning rate. Other important hyperparameters like weight_decay, betas, and eps provided in optimizer_class_args are ignored for the AdamW part. It is recommended to pass these parameters to ensure consistent optimization behavior for non-Muon parameters.

Suggested change
adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)}
adamw_args = {k: v for k, v in optimizer_class_args.items() if k in ("lr", "betas", "eps", "weight_decay")}
adamw_args.setdefault("lr", 1e-3)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mayank31398 -

optimizer_list_entries.append(
                _MuonWithAdamW(
                    muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None,
                    adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None,
                )

We should make sure we can pass all the adamw configs

optimizer_list_entries = []
for params_groups in params_groups_list:
muon_groups = []
adamw_groups = []
for group in params_groups.params_groups:
muon_params = []
adamw_params = []
for param_name, param in group.parameter_name_map.items():
if _is_muon_adamw_param(param_name, param):
adamw_params.append(param)
else:
split_fn = get_optimizer_split_function(param)
muon_params.extend(split_fn(param) if split_fn is not None else [param])
if muon_params:
muon_groups.append({"params": muon_params, **group.params_group_kwargs})
if adamw_params:
adamw_groups.append({"params": adamw_params, **group.params_group_kwargs})

optimizer_list_entries.append(
_MuonWithAdamW(
muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None,
adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None,
)
)

optimizer_list = OptimizerContainer(optimizer_list_entries)
else:
optimizer_list = OptimizerContainer(
[
optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args)
for params_groups in params_groups_list
]
)
optimizer_list_entries = []
for params_groups in params_groups_list:
torch_params_groups = params_groups.to_torch_compatible_params_groups()
for group in torch_params_groups:
split_params = []
for param in group["params"]:
split_fn = get_optimizer_split_function(param)
split_params.extend(split_fn(param) if split_fn is not None else [param])

group["params"] = split_params

optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args))

optimizer_list = OptimizerContainer(optimizer_list_entries)

return optimizer_list
Loading