From 585e6eb2e72306ad2d71511f64c96840e45e6637 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 16:54:39 -0700 Subject: [PATCH 01/38] add muon Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- lm_engine/optimization/optimizer.py | 27 ++++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index f9cf82992..fa12c4a7c 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit f9cf829925e0ebcc1b54cf7ee6200ebc3eeb06c3 +Subproject commit fa12c4a7c07098a5ab0fa928f369a266bdd0db58 diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d03fdea39..c715909db 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -3,19 +3,19 @@ # ************************************************** import torch.nn as nn -from torch.optim import Optimizer -from torch.optim.adadelta import Adadelta as TorchAdadelta -from torch.optim.adagrad import Adagrad as TorchAdagrad -from torch.optim.adam import Adam as TorchAdam -from torch.optim.adamax import Adamax as TorchAdamax -from torch.optim.adamw import AdamW as TorchAdamW -from torch.optim.asgd import ASGD as TorchASGD -from torch.optim.lbfgs import LBFGS as TorchLBFGS -from torch.optim.nadam import NAdam as TorchNAdam -from torch.optim.radam import RAdam as TorchRAdam -from torch.optim.rmsprop import RMSprop as TorchRMSprop -from torch.optim.rprop import Rprop as TorchRprop -from torch.optim.sgd import SGD as TorchSGD +from torch.optim import ASGD as TorchASGD +from torch.optim import LBFGS as TorchLBFGS +from torch.optim import SGD as TorchSGD +from torch.optim import Adadelta as TorchAdadelta +from torch.optim import Adagrad as TorchAdagrad +from torch.optim import Adam as TorchAdam +from torch.optim import Adamax as TorchAdamax +from torch.optim import AdamW as TorchAdamW +from torch.optim import Muon as TorchMuon +from torch.optim import NAdam as TorchNAdam +from torch.optim import RAdam as TorchRAdam +from torch.optim import RMSprop as TorchRMSprop +from torch.optim import Rprop as TorchRprop from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod @@ -31,6 +31,7 @@ "TorchAdamW": TorchAdamW, "TorchASGD": TorchASGD, "TorchLBFGS": TorchLBFGS, + "TorchMuon": TorchMuon, "TorchNAdam": TorchNAdam, "TorchRAdam": TorchRAdam, "TorchRMSprop": TorchRMSprop, From bf3cbba77dc596dce42ab7d397f62754db74effd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:04:35 -0700 Subject: [PATCH 02/38] add muon Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/attention.py | 12 +++++++++++- lm_engine/hf_models/parameter.py | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 069d5df49..3d204911c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -5,6 +5,7 @@ from __future__ import annotations import math +from functools import partial import torch import torch.nn.functional as F @@ -14,7 +15,7 @@ from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function from ..activations import sigmoid from ..chunk import contiguous_split from ..dropout import Dropout @@ -198,6 +199,15 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + register_optimizer_split_function( + self.c_proj.weight, + partial( + split_query_key_value_tensor_for_attention, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + ), + ) + def forward( self, x: torch.Tensor, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 0a816f0a0..93b9b16a7 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from typing import Callable + import torch.nn as nn @@ -77,3 +79,10 @@ def set_parameter_marker_maps( for marker, value in _marker_map[param_name].items(): setattr(parameter, marker, value) + + +def register_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: + if parameter is not None: + parameter._optimizer_split_function = function + + return parameter From e491a542468a0c769b6499d37162a3eab4ce6cec Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:07:17 -0700 Subject: [PATCH 03/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 15d38ded9..be9006bc4 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -4,10 +4,12 @@ from __future__ import annotations +from functools import partial + import torch import torch.nn as nn -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..init_utils import _get_std_for_linear @@ -77,6 +79,10 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + register_optimizer_split_function( + self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.num_heads) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = self.act(x, is_interleaved=self.use_interleaved_weights) if self.is_glu else self.act(x) From d6efb2d0f7e093dbb705b54b31c8dc0835304676 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:23:05 -0700 Subject: [PATCH 04/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/__init__.py | 1 + lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 4 ++-- .../modeling_utils/sequence_mixer_blocks/attention.py | 4 ++-- lm_engine/hf_models/parameter.py | 6 +++++- lm_engine/optimization/optimizer.py | 3 +++ 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index e7cd2e04a..bf9af9159 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -22,6 +22,7 @@ ) from .parameter import ( _INIT_MARKER, + get_optimizer_split_function, get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index be9006bc4..6192dca22 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..init_utils import _get_std_for_linear @@ -79,7 +79,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) - register_optimizer_split_function( + set_optimizer_split_function( self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.num_heads) ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 3d204911c..841f61d75 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -15,7 +15,7 @@ from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD -from ...parameter import mark_parameter_as_mup_learning_rate, register_optimizer_split_function +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import sigmoid from ..chunk import contiguous_split from ..dropout import Dropout @@ -199,7 +199,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) - register_optimizer_split_function( + set_optimizer_split_function( self.c_proj.weight, partial( split_query_key_value_tensor_for_attention, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 93b9b16a7..c10d68013 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -81,8 +81,12 @@ def set_parameter_marker_maps( setattr(parameter, marker, value) -def register_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: +def set_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: if parameter is not None: parameter._optimizer_split_function = function return parameter + + +def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None: + return getattr(parameter, "_optimizer_split_function", None) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index c715909db..2eea66c34 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -19,6 +19,7 @@ from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod +from ..hf_models import get_optimizer_split_function from .params_group import get_param_groups_list @@ -72,6 +73,8 @@ def get_optimizer_container( if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): + assert get_optimizer_split_function(param) is None + for group in params_groups.params_groups: if param_name in group.parameter_name_map: param._optimizer = optimizer_class( From 8c88a3c7ee841de5ed2b0f0690b967f4ca9095e7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:28:34 -0700 Subject: [PATCH 05/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 6192dca22..ab5c96a60 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -80,7 +80,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj.weight) set_optimizer_split_function( - self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.num_heads) + self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.use_interleaved_weights) ) def forward(self, x: torch.Tensor) -> torch.Tensor: From 82990ac72b80a0ef65e79a7e5300343bdac0af13 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:30:56 -0700 Subject: [PATCH 06/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 2eea66c34..e3b0129fb 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -41,6 +41,9 @@ } +_SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] + + def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, @@ -73,7 +76,8 @@ def get_optimizer_container( if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): - assert get_optimizer_split_function(param) is None + if get_optimizer_split_function(param) is not None: + assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS for group in params_groups.params_groups: if param_name in group.parameter_name_map: From f5be732d761ab6211778f32d86340d6a603e1593 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 17:42:11 -0700 Subject: [PATCH 07/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index e3b0129fb..2b256cad9 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -95,11 +95,16 @@ def _step(p: nn.Parameter) -> None: optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) else: - optimizer_list = OptimizerContainer( - [ - optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) - for params_groups in params_groups_list - ] - ) + optimizer_list_entries = [] + for model, params_groups in zip(model_container, params_groups_list): + torch_params_groups = params_groups.to_torch_compatible_params_groups() + for group in torch_params_groups: + split_params = [] + for param in group["params"]: + split_fn = get_optimizer_split_function(param) + split_params.extend(split_fn(param) if split_fn is not None else [param]) + group["params"] = split_params + optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) + optimizer_list = OptimizerContainer(optimizer_list_entries) return optimizer_list From 6c29ad93ae82ac9e3302d57126bd6e9b6ae21219 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:34:49 -0700 Subject: [PATCH 08/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 82 ++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 2b256cad9..d65275a46 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -43,6 +43,63 @@ _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] +# Parameter name substrings that must use AdamW instead of Muon (embeddings and lm_head) +_MUON_ADAMW_PARAM_NAMES = {"wte", "lm_head"} + + +def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: + """Returns True if this param should use AdamW when the main optimizer is Muon.""" + if param.ndim == 1: + return True + return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) + + +class _MuonWithAdamW: + """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. + + Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. + """ + + def __init__(self, muon: TorchMuon | None, adamw: TorchAdamW | None) -> None: + self.muon = muon + self.adamw = adamw + + @property + def param_groups(self) -> list[dict]: + groups = [] + if self.muon is not None: + groups.extend(self.muon.param_groups) + if self.adamw is not None: + groups.extend(self.adamw.param_groups) + return groups + + def step(self) -> None: + if self.muon is not None: + self.muon.step() + if self.adamw is not None: + self.adamw.step() + + def zero_grad(self) -> None: + if self.muon is not None: + self.muon.zero_grad() + if self.adamw is not None: + self.adamw.zero_grad() + + def state_dict(self) -> dict: + return { + "muon": self.muon.state_dict() if self.muon is not None else None, + "adamw": self.adamw.state_dict() if self.adamw is not None else None, + } + + def load_state_dict(self, state_dict: dict) -> None: + if self.muon is not None and state_dict["muon"] is not None: + self.muon.load_state_dict(state_dict["muon"]) + if self.adamw is not None and state_dict["adamw"] is not None: + self.adamw.load_state_dict(state_dict["adamw"]) + + def __repr__(self) -> str: + return f"MuonWithAdamW(\n muon={self.muon},\n adamw={self.adamw}\n)" + def get_optimizer_container( optimizer_class_name: str, @@ -94,9 +151,32 @@ def _step(p: nn.Parameter) -> None: break optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) + elif optimizer_class_name == "TorchMuon": + adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} + optimizer_list_entries = [] + for params_groups in params_groups_list: + muon_groups = [] + adamw_groups = [] + for group in params_groups.params_groups: + muon_params = [] + adamw_params = [] + for param_name, param in group.parameter_name_map.items(): + if _is_muon_adamw_param(param_name, param): + adamw_params.append(param) + else: + split_fn = get_optimizer_split_function(param) + muon_params.extend(split_fn(param) if split_fn is not None else [param]) + if muon_params: + muon_groups.append({"params": muon_params, **group.params_group_kwargs}) + if adamw_params: + adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) + muon = TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None + adamw = TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None + optimizer_list_entries.append(_MuonWithAdamW(muon, adamw)) + optimizer_list = OptimizerContainer(optimizer_list_entries) else: optimizer_list_entries = [] - for model, params_groups in zip(model_container, params_groups_list): + for params_groups in params_groups_list: torch_params_groups = params_groups.to_torch_compatible_params_groups() for group in torch_params_groups: split_params = [] From 83611ed901d2458bf19010b2591a0b796d65964d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:36:33 -0700 Subject: [PATCH 09/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d65275a46..ca98ec0d4 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -170,9 +170,12 @@ def _step(p: nn.Parameter) -> None: muon_groups.append({"params": muon_params, **group.params_group_kwargs}) if adamw_params: adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) + muon = TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None adamw = TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None + optimizer_list_entries.append(_MuonWithAdamW(muon, adamw)) + optimizer_list = OptimizerContainer(optimizer_list_entries) else: optimizer_list_entries = [] @@ -183,8 +186,11 @@ def _step(p: nn.Parameter) -> None: for param in group["params"]: split_fn = get_optimizer_split_function(param) split_params.extend(split_fn(param) if split_fn is not None else [param]) + group["params"] = split_params + optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) + optimizer_list = OptimizerContainer(optimizer_list_entries) return optimizer_list From 87da505aa3feb1154b09eb36bc5336d5b24e13b0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:38:24 -0700 Subject: [PATCH 10/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index ca98ec0d4..15a0fd89b 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -13,6 +13,7 @@ from torch.optim import AdamW as TorchAdamW from torch.optim import Muon as TorchMuon from torch.optim import NAdam as TorchNAdam +from torch.optim import Optimizer from torch.optim import RAdam as TorchRAdam from torch.optim import RMSprop as TorchRMSprop from torch.optim import Rprop as TorchRprop @@ -54,7 +55,7 @@ def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) -class _MuonWithAdamW: +class _MuonWithAdamW(Optimizer): """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. From 6491f89ffb82dc24420f6188fac8a848320fc0a9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 19:48:30 -0700 Subject: [PATCH 11/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 15a0fd89b..c2e25d852 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -172,10 +172,12 @@ def _step(p: nn.Parameter) -> None: if adamw_params: adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) - muon = TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None - adamw = TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None - - optimizer_list_entries.append(_MuonWithAdamW(muon, adamw)) + optimizer_list_entries.append( + _MuonWithAdamW( + muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None, + adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None, + ) + ) optimizer_list = OptimizerContainer(optimizer_list_entries) else: From 927e8c0ca94b9f2ae8468ca8a65279d6f8e39077 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:00:53 -0700 Subject: [PATCH 12/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index c2e25d852..4b4cae413 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -2,6 +2,7 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +import torch import torch.nn as nn from torch.optim import ASGD as TorchASGD from torch.optim import LBFGS as TorchLBFGS @@ -74,6 +75,7 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups + @torch.compile(fullgraph=True, mode="reduce_overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From 9236c074bb20f3e52f3e48fc488499b4446c2c2c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:08:02 -0700 Subject: [PATCH 13/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 4b4cae413..498972756 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -75,7 +75,7 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups - @torch.compile(fullgraph=True, mode="reduce_overhead") + @torch.compile(fullgraph=True, mode="reduce-overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From e0ed61bba50e1164a7ee01ae2123213293bc96b7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:10:39 -0700 Subject: [PATCH 14/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 498972756..d962a2353 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -75,7 +75,7 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups - @torch.compile(fullgraph=True, mode="reduce-overhead") + @torch.compile(mode="reduce-overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From cadbaefe7f762945736ca810b69790c693bd7afb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 20:19:42 -0700 Subject: [PATCH 15/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d962a2353..ed401bdbe 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -75,7 +75,6 @@ def param_groups(self) -> list[dict]: groups.extend(self.adamw.param_groups) return groups - @torch.compile(mode="reduce-overhead") def step(self) -> None: if self.muon is not None: self.muon.step() From 0376db98ce8d2937cf318e03dfa43cab5d9f626e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 7 Apr 2026 21:47:44 -0700 Subject: [PATCH 16/38] add muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index ed401bdbe..c2e25d852 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -2,7 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -import torch import torch.nn as nn from torch.optim import ASGD as TorchASGD from torch.optim import LBFGS as TorchLBFGS From d50e93a94e25727384bacb4b83259391dcad26a5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 11:43:44 -0700 Subject: [PATCH 17/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 87 ----------------------------- 1 file changed, 87 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index c2e25d852..bb95a53f3 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -11,7 +11,6 @@ from torch.optim import Adam as TorchAdam from torch.optim import Adamax as TorchAdamax from torch.optim import AdamW as TorchAdamW -from torch.optim import Muon as TorchMuon from torch.optim import NAdam as TorchNAdam from torch.optim import Optimizer from torch.optim import RAdam as TorchRAdam @@ -33,7 +32,6 @@ "TorchAdamW": TorchAdamW, "TorchASGD": TorchASGD, "TorchLBFGS": TorchLBFGS, - "TorchMuon": TorchMuon, "TorchNAdam": TorchNAdam, "TorchRAdam": TorchRAdam, "TorchRMSprop": TorchRMSprop, @@ -44,63 +42,6 @@ _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] -# Parameter name substrings that must use AdamW instead of Muon (embeddings and lm_head) -_MUON_ADAMW_PARAM_NAMES = {"wte", "lm_head"} - - -def _is_muon_adamw_param(param_name: str, param: nn.Parameter) -> bool: - """Returns True if this param should use AdamW when the main optimizer is Muon.""" - if param.ndim == 1: - return True - return any(name in param_name for name in _MUON_ADAMW_PARAM_NAMES) - - -class _MuonWithAdamW(Optimizer): - """Wraps a Muon optimizer and an AdamW optimizer into a single optimizer-like object. - - Muon handles 2D+ weight matrices; AdamW handles embeddings, lm_head, and 1D params. - """ - - def __init__(self, muon: TorchMuon | None, adamw: TorchAdamW | None) -> None: - self.muon = muon - self.adamw = adamw - - @property - def param_groups(self) -> list[dict]: - groups = [] - if self.muon is not None: - groups.extend(self.muon.param_groups) - if self.adamw is not None: - groups.extend(self.adamw.param_groups) - return groups - - def step(self) -> None: - if self.muon is not None: - self.muon.step() - if self.adamw is not None: - self.adamw.step() - - def zero_grad(self) -> None: - if self.muon is not None: - self.muon.zero_grad() - if self.adamw is not None: - self.adamw.zero_grad() - - def state_dict(self) -> dict: - return { - "muon": self.muon.state_dict() if self.muon is not None else None, - "adamw": self.adamw.state_dict() if self.adamw is not None else None, - } - - def load_state_dict(self, state_dict: dict) -> None: - if self.muon is not None and state_dict["muon"] is not None: - self.muon.load_state_dict(state_dict["muon"]) - if self.adamw is not None and state_dict["adamw"] is not None: - self.adamw.load_state_dict(state_dict["adamw"]) - - def __repr__(self) -> str: - return f"MuonWithAdamW(\n muon={self.muon},\n adamw={self.adamw}\n)" - def get_optimizer_container( optimizer_class_name: str, @@ -152,34 +93,6 @@ def _step(p: nn.Parameter) -> None: break optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) - elif optimizer_class_name == "TorchMuon": - adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} - optimizer_list_entries = [] - for params_groups in params_groups_list: - muon_groups = [] - adamw_groups = [] - for group in params_groups.params_groups: - muon_params = [] - adamw_params = [] - for param_name, param in group.parameter_name_map.items(): - if _is_muon_adamw_param(param_name, param): - adamw_params.append(param) - else: - split_fn = get_optimizer_split_function(param) - muon_params.extend(split_fn(param) if split_fn is not None else [param]) - if muon_params: - muon_groups.append({"params": muon_params, **group.params_group_kwargs}) - if adamw_params: - adamw_groups.append({"params": adamw_params, **group.params_group_kwargs}) - - optimizer_list_entries.append( - _MuonWithAdamW( - muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None, - adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None, - ) - ) - - optimizer_list = OptimizerContainer(optimizer_list_entries) else: optimizer_list_entries = [] for params_groups in params_groups_list: From 6d6d6b54e7956286d37cb5631a4a1ab86b96f537 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 11:45:41 -0700 Subject: [PATCH 18/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index bb95a53f3..d127e03f7 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -12,7 +12,6 @@ from torch.optim import Adamax as TorchAdamax from torch.optim import AdamW as TorchAdamW from torch.optim import NAdam as TorchNAdam -from torch.optim import Optimizer from torch.optim import RAdam as TorchRAdam from torch.optim import RMSprop as TorchRMSprop from torch.optim import Rprop as TorchRprop From 76e9bca49d40ac33883527361c2409b0429893eb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 11:47:26 -0700 Subject: [PATCH 19/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d127e03f7..0b83ed935 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -39,9 +39,6 @@ } -_SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] - - def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, @@ -74,8 +71,7 @@ def get_optimizer_container( if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): - if get_optimizer_split_function(param) is not None: - assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS + assert get_optimizer_split_function(param) is None for group in params_groups.params_groups: if param_name in group.parameter_name_map: From 744cdc323125763acde1171cbb67c20ce0c9cca6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 14:28:26 -0700 Subject: [PATCH 20/38] drop muon Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/mlp.py | 8 +-- .../optimization/split_param_optimizer.py | 62 +++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 lm_engine/optimization/split_param_optimizer.py diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index ab5c96a60..1217090d2 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -124,11 +124,11 @@ def split_up_gate_tensor_for_mlp( ) -> tuple[torch.Tensor, torch.Tensor]: if is_interleaved: if dim == 0: - u = c_fc_weight[1::2].contiguous() - g = c_fc_weight[::2].contiguous() + u = c_fc_weight[1::2] + g = c_fc_weight[::2] elif dim == 1: - u = c_fc_weight[:, 1::2].contiguous() - g = c_fc_weight[:, ::2].contiguous() + u = c_fc_weight[:, 1::2] + g = c_fc_weight[:, ::2] else: raise ValueError else: diff --git a/lm_engine/optimization/split_param_optimizer.py b/lm_engine/optimization/split_param_optimizer.py new file mode 100644 index 000000000..42fc727a9 --- /dev/null +++ b/lm_engine/optimization/split_param_optimizer.py @@ -0,0 +1,62 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +from typing import Callable + +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class SplitParamOptimizer: + def __init__( + self, + inner: Optimizer, + proxy_grad_fns: dict[int, tuple[nn.Parameter, Callable]], + split_params: set[nn.Parameter], + ) -> SplitParamOptimizer: + self._inner = inner + self._proxy_grad_fns = proxy_grad_fns + self._split_params = split_params + + @property + def param_groups(self) -> list[dict]: + return self._inner.param_groups + + @property + def state(self) -> dict: + return self._inner.state + + def state_dict(self) -> dict: + return self._inner.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + self._inner.load_state_dict(state_dict) + + def add_param_group(self, param_group: dict) -> None: + self._inner.add_param_group(param_group) + + def step(self, closure: Callable | None = None) -> torch.Tensor | None: + for group in self._inner.param_groups: + for p in group["params"]: + info = self._proxy_grad_fns.get(id(p)) + if info is None: + continue + + orig_param, grad_slice_fn = info + if orig_param.grad is not None: + p.grad = grad_slice_fn(orig_param.grad) + + return self._inner.step(closure) + + def zero_grad(self, set_to_none: bool = True) -> None: + self._inner.zero_grad(set_to_none) + + for param in self._split_params: + if set_to_none: + param.grad = None + elif param.grad is not None: + param.grad.zero_() From 78c41869c1aa6c98c5abeb1f8d5b1b2a0428a109 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 14:31:43 -0700 Subject: [PATCH 21/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 65 ++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 0b83ed935..99cd26b48 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -12,6 +12,7 @@ from torch.optim import Adamax as TorchAdamax from torch.optim import AdamW as TorchAdamW from torch.optim import NAdam as TorchNAdam +from torch.optim import Optimizer from torch.optim import RAdam as TorchRAdam from torch.optim import RMSprop as TorchRMSprop from torch.optim import Rprop as TorchRprop @@ -20,6 +21,7 @@ from ..enums import ParamsGroupMethod from ..hf_models import get_optimizer_split_function from .params_group import get_param_groups_list +from .split_param_optimizer import SplitParamOptimizer # https://pytorch.org/docs/stable/optim.html @@ -39,6 +41,45 @@ } +def _build_optimizer( + optimizer_class, torch_params_groups: list[dict], optimizer_class_args: dict +) -> SplitParamOptimizer | Optimizer: + proxy_grad_fns: dict[int, tuple] = {} + split_params: set[nn.Parameter] = set() + modified_groups = [] + + for group in torch_params_groups: + group_kwargs = {k: v for k, v in group.items() if k != "params"} + new_params = [] + for param in group["params"]: + split_fn = get_optimizer_split_function(param) + if split_fn is None: + new_params.append(param) + else: + pieces = split_fn(param.data) + assert all( + p.untyped_storage().data_ptr() == param.data.untyped_storage().data_ptr() for p in pieces + ), ( + f"Optimizer split function for {param.shape} must return views " + "(tensors sharing storage with the original). " + "Use the *_for_optimizer variant, which skips .contiguous()/.reshape()." + ) + for i, piece in enumerate(pieces): + proxy = nn.Parameter(piece) + new_params.append(proxy) + proxy_grad_fns[id(proxy)] = (param, lambda g, fn=split_fn, idx=i: fn(g)[idx]) + split_params.add(param) + + modified_groups.append({"params": new_params, **group_kwargs}) + + inner = optimizer_class(modified_groups, **optimizer_class_args) + + if split_params: + inner = SplitParamOptimizer(inner=inner, proxy_grad_fns=proxy_grad_fns, split_params=split_params) + + return inner + + def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, @@ -89,19 +130,15 @@ def _step(p: nn.Parameter) -> None: optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) else: - optimizer_list_entries = [] - for params_groups in params_groups_list: - torch_params_groups = params_groups.to_torch_compatible_params_groups() - for group in torch_params_groups: - split_params = [] - for param in group["params"]: - split_fn = get_optimizer_split_function(param) - split_params.extend(split_fn(param) if split_fn is not None else [param]) - - group["params"] = split_params - - optimizer_list_entries.append(optimizer_class(torch_params_groups, **optimizer_class_args)) - - optimizer_list = OptimizerContainer(optimizer_list_entries) + optimizer_list = OptimizerContainer( + [ + _build_optimizer( + optimizer_class=optimizer_class, + torch_params_groups=params_groups.to_torch_compatible_params_groups(), + optimizer_class_args=optimizer_class_args, + ) + for params_groups in params_groups_list + ] + ) return optimizer_list From 816894103785f7ab3d17edb65f370010f14f495a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 14:34:00 -0700 Subject: [PATCH 22/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/__init__.py | 1 + .../sequence_mixer_blocks/attention.py | 12 ++---------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 3404b87d3..360ecbb22 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -25,5 +25,6 @@ get_sequence_mixer, interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, + split_query_key_value_tensor_for_attention_for_optimizer, ) from .TP import tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 4f957e729..a242c2003 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -58,16 +58,8 @@ def split_query_key_value_tensor_for_attention( query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: query_heads_per_group = num_heads // num_key_value_heads - original_shape = query_key_value_weight.shape - - query_key_value_weight = query_key_value_weight.view(num_key_value_heads, (query_heads_per_group + 2), -1) - + query_key_value_weight = query_key_value_weight.view(num_key_value_heads, query_heads_per_group + 2, -1) query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1) - - query_weight = query_weight.reshape(-1, *original_shape[1:]) - key_weight = key_weight.reshape(-1, *original_shape[1:]) - value_weight = value_weight.reshape(-1, *original_shape[1:]) - return query_weight, key_weight, value_weight @@ -200,7 +192,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj.weight) set_optimizer_split_function( - self.c_proj.weight, + self.c_attn.weight, partial( split_query_key_value_tensor_for_attention, num_heads=self.num_heads, From ce4c605a9e8e490a318f0ab344e3372efdd4293f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 14:36:51 -0700 Subject: [PATCH 23/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/arguments.py | 2 ++ lm_engine/finetune.py | 1 + lm_engine/optimization/optimizer.py | 8 +++++++- lm_engine/pretrain.py | 1 + 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index 5a172781d..577d64c11 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -175,6 +175,8 @@ class OptimizerArgs(BaseArgs): params_group_method: ParamsGroupMethod | None = None # backward hooked optimizer use_optimizer_with_backward_hook: bool = False + # whether to split params for the optimizer using model-defined split functions + split_params_for_optimizer: bool = False # class args for optimizer class_args: dict = { "lr": 1e-5, diff --git a/lm_engine/finetune.py b/lm_engine/finetune.py index a266fd66b..bb5c8c5ae 100644 --- a/lm_engine/finetune.py +++ b/lm_engine/finetune.py @@ -250,6 +250,7 @@ def main() -> None: model_container=model_container, params_group_method=args.optimizer_args.params_group_method, use_optimizer_with_backward_hook=args.optimizer_args.use_optimizer_with_backward_hook, + split_params_for_optimizer=args.optimizer_args.split_params_for_optimizer, ) lr_scheduler_container = get_scheduler_container( diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 99cd26b48..74a5edd8f 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -42,8 +42,11 @@ def _build_optimizer( - optimizer_class, torch_params_groups: list[dict], optimizer_class_args: dict + optimizer_class, torch_params_groups: list[dict], optimizer_class_args: dict, split_params_for_optimizer: bool ) -> SplitParamOptimizer | Optimizer: + if not split_params_for_optimizer: + return optimizer_class(torch_params_groups, **optimizer_class_args) + proxy_grad_fns: dict[int, tuple] = {} split_params: set[nn.Parameter] = set() modified_groups = [] @@ -86,6 +89,7 @@ def get_optimizer_container( model_container: ModelContainer, params_group_method: ParamsGroupMethod, use_optimizer_with_backward_hook: bool, + split_params_for_optimizer: bool, ) -> OptimizerContainer: """setup list of optimizers for the model @@ -95,6 +99,7 @@ def get_optimizer_container( model_container (ModelContainer): model container params_group_method (ParamsGroupMethod): the params grouping to use use_optimizer_with_backward_hook (bool): whether to use optimizer as a backward hook + split_params_for_optimizer (bool): whether to split params using model-defined split functions Returns: OptimizerContainer: optimizer container @@ -136,6 +141,7 @@ def _step(p: nn.Parameter) -> None: optimizer_class=optimizer_class, torch_params_groups=params_groups.to_torch_compatible_params_groups(), optimizer_class_args=optimizer_class_args, + split_params_for_optimizer=split_params_for_optimizer, ) for params_groups in params_groups_list ] diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index 65f499b18..678fafd7d 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -646,6 +646,7 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No model_container=model_container, params_group_method=args.optimizer_args.params_group_method, use_optimizer_with_backward_hook=args.optimizer_args.use_optimizer_with_backward_hook, + split_params_for_optimizer=args.optimizer_args.split_params_for_optimizer, ) lr_scheduler_container = get_scheduler_container( From 0f468d49fa2b9510423c716944db8cb78cc52aa9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 14:40:59 -0700 Subject: [PATCH 24/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 360ecbb22..3404b87d3 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -25,6 +25,5 @@ get_sequence_mixer, interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, - split_query_key_value_tensor_for_attention_for_optimizer, ) from .TP import tensor_parallel_split_safetensor_slice From f3e8b4f17a18dd210db32d60c09c99e0044b8c38 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:09:16 -0700 Subject: [PATCH 25/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 74a5edd8f..f901184f0 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +import logging + import torch.nn as nn from torch.optim import ASGD as TorchASGD from torch.optim import LBFGS as TorchLBFGS @@ -20,7 +22,8 @@ from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod from ..hf_models import get_optimizer_split_function -from .params_group import get_param_groups_list +from ..utils import log_rank_0 +from .params_group import _ParamsGroupsList, get_param_groups_list from .split_param_optimizer import SplitParamOptimizer @@ -42,23 +45,26 @@ def _build_optimizer( - optimizer_class, torch_params_groups: list[dict], optimizer_class_args: dict, split_params_for_optimizer: bool + optimizer_class, params_groups: _ParamsGroupsList, optimizer_class_args: dict, split_params_for_optimizer: bool ) -> SplitParamOptimizer | Optimizer: if not split_params_for_optimizer: - return optimizer_class(torch_params_groups, **optimizer_class_args) + return optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) proxy_grad_fns: dict[int, tuple] = {} split_params: set[nn.Parameter] = set() modified_groups = [] - for group in torch_params_groups: + for pg in params_groups.params_groups: + group = pg.to_param_group() + names = pg.get_param_names() group_kwargs = {k: v for k, v in group.items() if k != "params"} new_params = [] - for param in group["params"]: + for param, name in zip(group["params"], names): split_fn = get_optimizer_split_function(param) if split_fn is None: new_params.append(param) else: + log_rank_0(logging.INFO, f"splitting {name}") pieces = split_fn(param.data) assert all( p.untyped_storage().data_ptr() == param.data.untyped_storage().data_ptr() for p in pieces @@ -139,7 +145,7 @@ def _step(p: nn.Parameter) -> None: [ _build_optimizer( optimizer_class=optimizer_class, - torch_params_groups=params_groups.to_torch_compatible_params_groups(), + params_groups=params_groups, optimizer_class_args=optimizer_class_args, split_params_for_optimizer=split_params_for_optimizer, ) From acdedb84d961b895f50ca32b966ff8d3280e4466 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:11:23 -0700 Subject: [PATCH 26/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index f901184f0..18d5382d2 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -121,10 +121,10 @@ def get_optimizer_container( params_groups_list = get_param_groups_list(model_container, optimizer_class_args, params_group_method) if use_optimizer_with_backward_hook: + assert not split_params_for_optimizer + for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): - assert get_optimizer_split_function(param) is None - for group in params_groups.params_groups: if param_name in group.parameter_name_map: param._optimizer = optimizer_class( From 82983b9677210a4d1db8581351a4799926f2d54d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:15:49 -0700 Subject: [PATCH 27/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index c10d68013..26ad38ded 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -8,7 +8,7 @@ _INIT_MARKER = "_is_initialized" -_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"] +_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_optimizer_split_function"] _ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER] From 17c0d197597e61fa9a398ec3172681446f8bc6ca Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:32:42 -0700 Subject: [PATCH 28/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 1 + lm_engine/hf_models/__init__.py | 1 + lm_engine/hf_models/parameter.py | 9 ++++++--- lm_engine/optimization/optimizer.py | 17 ++++++++++------- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index d753840ff..7c7f0c393 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -32,6 +32,7 @@ from .gradient_checkpointing import apply_gradient_checkpointing from .hf_models import ( _INIT_MARKER, + _OPTIMIZER_SPLIT_FUNCTION, CausalLMOutputWithPast, get_parameter_marker_maps, is_parameter_initialized, diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index bf9af9159..1f876ae80 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -22,6 +22,7 @@ ) from .parameter import ( _INIT_MARKER, + _OPTIMIZER_SPLIT_FUNCTION, get_optimizer_split_function, get_parameter_marker_maps, is_parameter_initialized, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 26ad38ded..bd91db161 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -8,8 +8,9 @@ _INIT_MARKER = "_is_initialized" -_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_optimizer_split_function"] -_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER] +_OPTIMIZER_SPLIT_FUNCTION = "_optimizer_split_function" +_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"] +_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER, _OPTIMIZER_SPLIT_FUNCTION] def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: @@ -55,7 +56,9 @@ def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: l for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} for marker in _METADATA_MARKERS + extra_markers: - marker_maps[-1][param_name][marker] = getattr(param, marker, False) + marker_maps[-1][param_name][marker] = getattr( + param, marker, None if marker == _OPTIMIZER_SPLIT_FUNCTION else False + ) return marker_maps diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 18d5382d2..485b62dbd 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -66,13 +66,16 @@ def _build_optimizer( else: log_rank_0(logging.INFO, f"splitting {name}") pieces = split_fn(param.data) - assert all( - p.untyped_storage().data_ptr() == param.data.untyped_storage().data_ptr() for p in pieces - ), ( - f"Optimizer split function for {param.shape} must return views " - "(tensors sharing storage with the original). " - "Use the *_for_optimizer variant, which skips .contiguous()/.reshape()." - ) + try: + assert all( + p.untyped_storage().data_ptr() == param.data.untyped_storage().data_ptr() for p in pieces + ), ( + f"Optimizer split function for {param.shape} must return views " + "(tensors sharing storage with the original). " + "Use the *_for_optimizer variant, which skips .contiguous()/.reshape()." + ) + except RuntimeError: + pass # storage pointer inaccessible in distributed contexts (e.g. FSDP) for i, piece in enumerate(pieces): proxy = nn.Parameter(piece) new_params.append(proxy) From 18da9a6a8f431ae8e142b5948d20d62f5c2c7987 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:33:15 -0700 Subject: [PATCH 29/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 485b62dbd..465fa7fb2 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -66,16 +66,7 @@ def _build_optimizer( else: log_rank_0(logging.INFO, f"splitting {name}") pieces = split_fn(param.data) - try: - assert all( - p.untyped_storage().data_ptr() == param.data.untyped_storage().data_ptr() for p in pieces - ), ( - f"Optimizer split function for {param.shape} must return views " - "(tensors sharing storage with the original). " - "Use the *_for_optimizer variant, which skips .contiguous()/.reshape()." - ) - except RuntimeError: - pass # storage pointer inaccessible in distributed contexts (e.g. FSDP) + for i, piece in enumerate(pieces): proxy = nn.Parameter(piece) new_params.append(proxy) From 20881ce35a5b1e05a75ec6e0436eb42396c5e339 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:37:55 -0700 Subject: [PATCH 30/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/parameter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index bd91db161..d1a993894 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -9,8 +9,8 @@ _INIT_MARKER = "_is_initialized" _OPTIMIZER_SPLIT_FUNCTION = "_optimizer_split_function" -_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"] -_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER, _OPTIMIZER_SPLIT_FUNCTION] +_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", _OPTIMIZER_SPLIT_FUNCTION] +_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER] def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: @@ -92,4 +92,4 @@ def set_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None: - return getattr(parameter, "_optimizer_split_function", None) + return getattr(parameter, _OPTIMIZER_SPLIT_FUNCTION, None) From dcd5a55f02904c64f619c6353107282cea1bacfa Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:38:42 -0700 Subject: [PATCH 31/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/split_param_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/split_param_optimizer.py b/lm_engine/optimization/split_param_optimizer.py index 42fc727a9..bd91f1021 100644 --- a/lm_engine/optimization/split_param_optimizer.py +++ b/lm_engine/optimization/split_param_optimizer.py @@ -11,7 +11,7 @@ from torch.optim import Optimizer -class SplitParamOptimizer: +class SplitParamOptimizer(Optimizer): def __init__( self, inner: Optimizer, From b3872712b057866d900ff5efec3076165dc7b6dd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:41:07 -0700 Subject: [PATCH 32/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index 465fa7fb2..7842493a4 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -64,7 +64,7 @@ def _build_optimizer( if split_fn is None: new_params.append(param) else: - log_rank_0(logging.INFO, f"splitting {name}") + log_rank_0(logging.INFO, f"splitting {name} for optimizer") pieces = split_fn(param.data) for i, piece in enumerate(pieces): From b35d924937c1e0ee18abdd92e7428aee01086cf7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:43:23 -0700 Subject: [PATCH 33/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/split_param_optimizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lm_engine/optimization/split_param_optimizer.py b/lm_engine/optimization/split_param_optimizer.py index bd91f1021..6d01d2799 100644 --- a/lm_engine/optimization/split_param_optimizer.py +++ b/lm_engine/optimization/split_param_optimizer.py @@ -60,3 +60,7 @@ def zero_grad(self, set_to_none: bool = True) -> None: param.grad = None elif param.grad is not None: param.grad.zero_() + + def __repr__(self) -> str: + x = super().__repr__() + return f"{x}({self._inner})" From 66329eb0098a45b7070eea3caeadad6e4d1b8e33 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:44:13 -0700 Subject: [PATCH 34/38] drop muon Signed-off-by: Mayank Mishra --- lm_engine/optimization/split_param_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/optimization/split_param_optimizer.py b/lm_engine/optimization/split_param_optimizer.py index 6d01d2799..be407a391 100644 --- a/lm_engine/optimization/split_param_optimizer.py +++ b/lm_engine/optimization/split_param_optimizer.py @@ -63,4 +63,4 @@ def zero_grad(self, set_to_none: bool = True) -> None: def __repr__(self) -> str: x = super().__repr__() - return f"{x}({self._inner})" + return f"{self._inner.__class__.__name__}({x})" From eee310ce76446323c8d1f37fd5a0fe04ce61d326 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Wed, 15 Apr 2026 15:50:11 -0700 Subject: [PATCH 35/38] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lm_engine/hf_models/parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index d1a993894..d67b8bc06 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -84,7 +84,7 @@ def set_parameter_marker_maps( setattr(parameter, marker, value) -def set_optimizer_split_function(parameter: nn.Parameter, function: Callable) -> None: +def set_optimizer_split_function(parameter: nn.Parameter | None, function: Callable) -> nn.Parameter | None: if parameter is not None: parameter._optimizer_split_function = function From 26348f1cc38e04a88aeda6a4e47c79cb669998e3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:52:50 -0700 Subject: [PATCH 36/38] drop muon Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/mlp_blocks/mlp.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 1217090d2..cdd444498 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -80,7 +80,8 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj.weight) set_optimizer_split_function( - self.c_fc.weight, partial(split_up_gate_tensor_for_mlp, is_interleaved=self.use_interleaved_weights) + self.c_fc.weight, + partial(_split_up_gate_tensor_for_mlp_for_optimizer, is_interleaved=self.use_interleaved_weights), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -119,7 +120,7 @@ def interleave_up_gate_tensor_for_mlp( return W -def split_up_gate_tensor_for_mlp( +def _split_up_gate_tensor_for_mlp_for_optimizer( c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0 ) -> tuple[torch.Tensor, torch.Tensor]: if is_interleaved: @@ -135,3 +136,14 @@ def split_up_gate_tensor_for_mlp( u, g = c_fc_weight.chunk(2, dim=dim) return u, g + + +def split_up_gate_tensor_for_mlp( + c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0 +) -> tuple[torch.Tensor, torch.Tensor]: + u, g = _split_up_gate_tensor_for_mlp_for_optimizer(c_fc_weight=c_fc_weight, is_interleaved=is_interleaved, dim=dim) + if is_interleaved: + u = u.contiguous() + g = g.contiguous() + + return u, g From 7ac8d80a8731ef472fab2dfe1746fc8965e48a0b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 15:54:47 -0700 Subject: [PATCH 37/38] drop muon Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/attention.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index a242c2003..ee4d1caf5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -54,7 +54,7 @@ def interleave_query_key_value_tensor_for_attention( return torch.cat(interleaved) -def split_query_key_value_tensor_for_attention( +def _split_query_key_value_tensor_for_attention_for_optimizer( query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: query_heads_per_group = num_heads // num_key_value_heads @@ -63,6 +63,21 @@ def split_query_key_value_tensor_for_attention( return query_weight, key_weight, value_weight +def split_query_key_value_tensor_for_attention( + query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + original_shape = query_key_value_weight.shape + query_weight, key_weight, value_weight = _split_query_key_value_tensor_for_attention_for_optimizer( + query_key_value_weight=query_key_value_weight, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + + query_weight = query_weight.reshape(-1, *original_shape[1:]) + key_weight = key_weight.reshape(-1, *original_shape[1:]) + value_weight = value_weight.reshape(-1, *original_shape[1:]) + + return query_weight, key_weight, value_weight + + class Attention(DTensorModule): def __init__( self, @@ -194,7 +209,7 @@ def __init__( set_optimizer_split_function( self.c_attn.weight, partial( - split_query_key_value_tensor_for_attention, + _split_query_key_value_tensor_for_attention_for_optimizer, num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads, ), From 7a2fc35c60a6a603af275fa17cd169911716d0e4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 16 Apr 2026 14:02:29 -0700 Subject: [PATCH 38/38] add w Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index cdd444498..a8aecd417 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -79,10 +79,11 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) - set_optimizer_split_function( - self.c_fc.weight, - partial(_split_up_gate_tensor_for_mlp_for_optimizer, is_interleaved=self.use_interleaved_weights), - ) + if self.is_glu: + set_optimizer_split_function( + self.c_fc.weight, + partial(_split_up_gate_tensor_for_mlp_for_optimizer, is_interleaved=self.use_interleaved_weights), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x)