From 585e6eb2e72306ad2d71511f64c96840e45e6637 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 16:54:39 -0700 Subject: [PATCH 01/18] add muon Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- lm_engine/optimization/optimizer.py | 27 ++++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index f9cf82992..fa12c4a7c 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit f9cf829925e0ebcc1b54cf7ee6200ebc3eeb06c3 +Subproject commit fa12c4a7c07098a5ab0fa928f369a266bdd0db58 diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d03fdea39..c715909db 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -3,19 +3,19 @@ # ************************************************** import torch.nn as nn -from torch.optim import Optimizer -from torch.optim.adadelta import Adadelta as TorchAdadelta -from torch.optim.adagrad import Adagrad as TorchAdagrad -from torch.optim.adam import Adam as TorchAdam -from torch.optim.adamax import Adamax as TorchAdamax -from torch.optim.adamw import AdamW as TorchAdamW -from torch.optim.asgd import ASGD as TorchASGD -from torch.optim.lbfgs import LBFGS as TorchLBFGS -from torch.optim.nadam import NAdam as TorchNAdam -from torch.optim.radam import RAdam as TorchRAdam -from torch.optim.rmsprop import RMSprop as TorchRMSprop -from torch.optim.rprop import Rprop as TorchRprop -from torch.optim.sgd import SGD as TorchSGD +from torch.optim import ASGD as TorchASGD +from torch.optim import LBFGS as TorchLBFGS +from torch.optim import SGD as TorchSGD +from torch.optim import Adadelta as TorchAdadelta +from torch.optim import Adagrad as TorchAdagrad +from torch.optim import Adam as TorchAdam +from torch.optim import Adamax as TorchAdamax +from torch.optim import AdamW as TorchAdamW +from torch.optim import Muon as TorchMuon +from torch.optim import NAdam as TorchNAdam +from torch.optim import RAdam as TorchRAdam +from torch.optim import RMSprop as TorchRMSprop +from torch.optim import Rprop as TorchRprop from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod @@ -31,6 +31,7 @@ "TorchAdamW": TorchAdamW, "TorchASGD": TorchASGD, "TorchLBFGS": TorchLBFGS, + "TorchMuon": TorchMuon, "TorchNAdam": TorchNAdam, "TorchRAdam": TorchRAdam, "TorchRMSprop": TorchRMSprop, From bf3cbba77dc596dce42ab7d397f62754db74effd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:04:35 -0700 Subject: [PATCH 02/18] add muon Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/attention.py | 12 +++++++++++- lm_engine/hf_models/parameter.py | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 069d5df49..3d204911c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -5,6 +5,7 @@ from __future__ import annotations import math +from functools import partial import torch import torch.nn.functional as F @@ -14,7 +15,7 @@ from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function from ..activations import sigmoid from ..chunk import contiguous_split from ..dropout import Dropout @@ -198,6 +199,15 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + register_optimizer_split_function( + self.c_proj.weight, + partial( + split_query_key_value_tensor_for_attention, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + ), + ) + def forward( self, x: torch.Tensor, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 0a816f0a0..93b9b16a7 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from typing import Callable + import torch.nn as nn @@ -77,3 +79,10 @@ def set_parameter_marker_maps( for marker, value in _marker_map[param_name].items(): setattr(parameter, marker, value) + + +def register_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: + if parameter is not None: + parameter._optimizer_split_function = function + + return parameter From e491a542468a0c769b6499d37162a3eab4ce6cec Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:07:17 -0700 Subject: [PATCH 03/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 15d38ded9..be9006bc4 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -4,10 +4,12 @@ from __future__ import annotations +from functools import partial + import torch import torch.nn as nn -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..init_utils import _get_std_for_linear @@ -77,6 +79,10 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + register_optimizer_split_function( + self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.num_heads) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = self.act(x, is_interleaved=self.use_interleaved_weights) if self.is_glu else self.act(x) From d6efb2d0f7e093dbb705b54b31c8dc0835304676 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:23:05 -0700 Subject: [PATCH 04/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/__init__.py | 1 + lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 4 ++-- .../modeling_utils/sequence_mixer_blocks/attention.py | 4 ++-- lm_engine/hf_models/parameter.py | 6 +++++- lm_engine/optimization/optimizer.py | 3 +++ 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index e7cd2e04a..bf9af9159 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -22,6 +22,7 @@ ) from .parameter import ( _INIT_MARKER, + get_optimizer_split_function, get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index be9006bc4..6192dca22 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..init_utils import _get_std_for_linear @@ -79,7 +79,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) - register_optimizer_split_function( + set_optimizer_split_function( self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.num_heads) ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 3d204911c..841f61d75 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -15,7 +15,7 @@ from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD -from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import sigmoid from ..chunk import contiguous_split from ..dropout import Dropout @@ -199,7 +199,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) - register_optimizer_split_function( + set_optimizer_split_function( self.c_proj.weight, partial( split_query_key_value_tensor_for_attention, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 93b9b16a7..c10d68013 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -81,8 +81,12 @@ def set_parameter_marker_maps( setattr(parameter, marker, value) -def register_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: +def set_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: if parameter is not None: parameter._optimizer_split_function = function return parameter + + +def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None: + return getattr(parameter, "_optimizer_split_function", None) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index c715909db..2eea66c34 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -19,6 +19,7 @@ from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod +from ..hf_models import get_optimizer_split_function from .params_group import get_param_groups_list @@ -72,6 +73,8 @@ def get_optimizer_container( if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): + assert get_optimizer_split_function(param) is None + for group in params_groups.params_groups: if param_name in group.parameter_name_map: param._optimizer = optimizer_class( From 8c88a3c7ee841de5ed2b0f0690b967f4ca9095e7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:28:34 -0700 Subject: [PATCH 05/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 6192dca22..ab5c96a60 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -80,7 +80,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj.weight) set_optimizer_split_function( - self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.num_heads) + self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.use_interleaved_weights) ) def forward(self, x: torch.Tensor) -> torch.Tensor: From 82990ac72b80a0ef65e79a7e5300343bdac0af13 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:30:56 -0700 Subject: [PATCH 06/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 2eea66c34..e3b0129fb 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -41,6 +41,9 @@ } +_SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] + + def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, @@ -73,7 +76,8 @@ def get_optimizer_container( if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): - assert get_optimizer_split_function(param) is None + if get_optimizer_split_function(param) is not None: + assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS for group in params_groups.params_groups: if param_name in group.parameter_name_map: From f5be732d761ab6211778f32d86340d6a603e1593 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:42:11 -0700 Subject: [PATCH 07/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index e3b0129fb..2b256cad9 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -95,11 +95,16 @@ def _step(p: nn.Parameter) -> None: optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) else: - optimizer_list = OptimizerContainer( - [ - optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) - for params_groups in params_groups_list - ] - ) + optimizer_list_entries = [] + for model, params_groups in zip(model_container, params_groups_list): + torch_params_groups = params_groups.to_torch_compatible_params_groups() + for group in torch_params_groups: + split_params = [] + for param in group["params"]: + split_fn = get_optimizer_split_function(param) + split_params.extend(split_fn(param) if split_fn is not None else [param]) + group["params"] = split_params + optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) + optimizer_list = OptimizerContainer(optimizer_list_entries) return optimizer_list From 6c29ad93ae82ac9e3302d57126bd6e9b6ae21219 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:34:49 -0700 Subject: [PATCH 08/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 82 ++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 2b256cad9..d65275a46 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -43,6 +43,63 @@ _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] +# Parameter name substrings that must use AdamW instead of Muon (embeddings and lm_head) +_MUON_ADAMW_PARAM_NAMES = {"wte", "lm_head"} + + +def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: + """Returns True if this param should use AdamW when the main optimizer is Muon.""" + if param.ndim == 1: + return True + return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) + + +class _MuonWithAdamW: + """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. + + Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. + """ + + def __init__(self, muon: TorchMuon | None, adamw: TorchAdamW | None) -> None: + self.muon = muon + self.adamw = adamw + + @property + def param_groups(self) -> list[dict]: + groups = [] + if self.muon is not None: + groups.extend(self.muon.param_groups) + if self.adamw is not None: + groups.extend(self.adamw.param_groups) + return groups + + def step(self) -> None: + if self.muon is not None: + self.muon.step() + if self.adamw is not None: + self.adamw.step() + + def zero_grad(self) -> None: + if self.muon is not None: + self.muon.zero_grad() + if self.adamw is not None: + self.adamw.zero_grad() + + def state_dict(self) -> dict: + return { + "muon": self.muon.state_dict() if self.muon is not None else None, + "adamw": self.adamw.state_dict() if self.adamw is not None else None, + } + + def load_state_dict(self, state_dict: dict) -> None: + if self.muon is not None and state_dict["muon"] is not None: + self.muon.load_state_dict(state_dict["muon"]) + if self.adamw is not None and state_dict["adamw"] is not None: + self.adamw.load_state_dict(state_dict["adamw"]) + + def __repr__(self) -> str: + return f"MuonWithAdamW(\n muon={self.muon},\n adamw={self.adamw}\n)" + def get_optimizer_container( optimizer_class_name: str, @@ -94,9 +151,32 @@ def _step(p: nn.Parameter) -> None: break optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) + elif optimizer_class_name == "TorchMuon": + adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} + optimizer_list_entries = [] + for params_groups in params_groups_list: + muon_groups = [] + adamw_groups = [] + for group in params_groups.params_groups: + muon_params = [] + adamw_params = [] + for param_name, param in group.parameter_name_map.items(): + if _is_muon_adamw_param(param_name, param): + adamw_params.append(param) + else: + split_fn = get_optimizer_split_function(param) + muon_params.extend(split_fn(param) if split_fn is not None else [param]) + if muon_params: + muon_groups.append({"params": muon_params, **group.params_group_kwargs}) + if adamw_params: + adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) + muon = TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None + adamw = TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None + optimizer_list_entries.append(_MuonWithAdamW(muon, adamw)) + optimizer_list = OptimizerContainer(optimizer_list_entries) else: optimizer_list_entries = [] - for model, params_groups in zip(model_container, params_groups_list): + for params_groups in params_groups_list: torch_params_groups = params_groups.to_torch_compatible_params_groups() for group in torch_params_groups: split_params = [] From 83611ed901d2458bf19010b2591a0b796d65964d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:36:33 -0700 Subject: [PATCH 09/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d65275a46..ca98ec0d4 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -170,9 +170,12 @@ def _step(p: nn.Parameter) -> None: muon_groups.append({"params": muon_params, **group.params_group_kwargs}) if adamw_params: adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) + muon = TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None adamw = TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None + optimizer_list_entries.append(_MuonWithAdamW(muon, adamw)) + optimizer_list = OptimizerContainer(optimizer_list_entries) else: optimizer_list_entries = [] @@ -183,8 +186,11 @@ def _step(p: nn.Parameter) -> None: for param in group["params"]: split_fn = get_optimizer_split_function(param) split_params.extend(split_fn(param) if split_fn is not None else [param]) + group["params"] = split_params + optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) + optimizer_list = OptimizerContainer(optimizer_list_entries) return optimizer_list From 87da505aa3feb1154b09eb36bc5336d5b24e13b0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:38:24 -0700 Subject: [PATCH 10/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index ca98ec0d4..15a0fd89b 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -13,6 +13,7 @@ from torch.optim import AdamW as TorchAdamW from torch.optim import Muon as TorchMuon from torch.optim import NAdam as TorchNAdam +from torch.optim import Optimizer from torch.optim import RAdam as TorchRAdam from torch.optim import RMSprop as TorchRMSprop from torch.optim import Rprop as TorchRprop @@ -54,7 +55,7 @@ def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) -class _MuonWithAdamW: +class _MuonWithAdamW(Optimizer): """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. From 6491f89ffb82dc24420f6188fac8a848320fc0a9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:48:30 -0700 Subject: [PATCH 11/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 15a0fd89b..c2e25d852 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -172,10 +172,12 @@ def _step(p: nn.Parameter) -> None: if adamw_params: adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) - muon = TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None - adamw = TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None - - optimizer_list_entries.append(_MuonWithAdamW(muon, adamw)) + optimizer_list_entries.append( + _MuonWithAdamW( + muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None, + adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None, + ) + ) optimizer_list = OptimizerContainer(optimizer_list_entries) else: From 927e8c0ca94b9f2ae8468ca8a65279d6f8e39077 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:00:53 -0700 Subject: [PATCH 12/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index c2e25d852..4b4cae413 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -2,6 +2,7 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +import torch import torch.nn as nn from torch.optim import ASGD as TorchASGD from torch.optim import LBFGS as TorchLBFGS @@ -74,6 +75,7 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups + @torch.compile(fullgraph=True, mode="reduce_overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From 9236c074bb20f3e52f3e48fc488499b4446c2c2c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:08:02 -0700 Subject: [PATCH 13/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 4b4cae413..498972756 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -75,7 +75,7 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups - @torch.compile(fullgraph=True, mode="reduce_overhead") + @torch.compile(fullgraph=True, mode="reduce-overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From e0ed61bba50e1164a7ee01ae2123213293bc96b7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:10:39 -0700 Subject: [PATCH 14/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 498972756..d962a2353 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -75,7 +75,7 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups - @torch.compile(fullgraph=True, mode="reduce-overhead") + @torch.compile(mode="reduce-overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From cadbaefe7f762945736ca810b69790c693bd7afb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:19:42 -0700 Subject: [PATCH 15/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d962a2353..ed401bdbe 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -75,7 +75,6 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups - @torch.compile(mode="reduce-overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From 0376db98ce8d2937cf318e03dfa43cab5d9f626e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 21:47:44 -0700 Subject: [PATCH 16/18] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index ed401bdbe..c2e25d852 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -2,7 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -import torch import torch.nn as nn from torch.optim import ASGD as TorchASGD from torch.optim import LBFGS as TorchLBFGS From 89e330b2f368b9ad51ba2499ea701b7af94d746a Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:41:32 -0700 Subject: [PATCH 17/18] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../hf_models/modeling_utils/sequence_mixer_blocks/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 4f957e729..a9d4efbaf 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -200,7 +200,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj.weight) set_optimizer_split_function( - self.c_proj.weight, + self.c_attn.weight, partial( split_query_key_value_tensor_for_attention, num_heads=self.num_heads, From 4ca5c46696bff5a052d755a9b614c8b9e846a414 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:42:07 -0700 Subject: [PATCH 18/18] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lm_engine/hf_models/parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index c10d68013..eb35e9c8f 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -81,7 +81,7 @@ def set_parameter_marker_maps( setattr(parameter, marker, value) -def set_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: +def set_optimizer_split_function(parameter: nn.Parameter | None, function: Callable) -> nn.Parameter | None: if parameter is not None: parameter._optimizer_split_function = function