Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
585e6eb
add muon
mayank31398 Apr 7, 2026
bf3cbba
add muon
mayank31398 Apr 8, 2026
e491a54
add muon
mayank31398 Apr 8, 2026
d6efb2d
add muon
mayank31398 Apr 8, 2026
8c88a3c
add muon
mayank31398 Apr 8, 2026
82990ac
add muon
mayank31398 Apr 8, 2026
f5be732
add muon
mayank31398 Apr 8, 2026
6c29ad9
add muon
mayank31398 Apr 8, 2026
83611ed
add muon
mayank31398 Apr 8, 2026
87da505
add muon
mayank31398 Apr 8, 2026
6491f89
add muon
mayank31398 Apr 8, 2026
927e8c0
add muon
mayank31398 Apr 8, 2026
9236c07
add muon
mayank31398 Apr 8, 2026
e0ed61b
add muon
mayank31398 Apr 8, 2026
cadbaef
add muon
mayank31398 Apr 8, 2026
0376db9
add muon
mayank31398 Apr 8, 2026
dfc06cb
Merge branch 'main' into muon
mayank31398 Apr 15, 2026
d50e93a
drop muon
mayank31398 Apr 15, 2026
6d6d6b5
drop muon
mayank31398 Apr 15, 2026
76e9bca
drop muon
mayank31398 Apr 15, 2026
744cdc3
drop muon
mayank31398 Apr 15, 2026
78c4186
drop muon
mayank31398 Apr 15, 2026
8168941
drop muon
mayank31398 Apr 15, 2026
ce4c605
drop muon
mayank31398 Apr 15, 2026
0f468d4
drop muon
mayank31398 Apr 15, 2026
f3e8b4f
drop muon
mayank31398 Apr 15, 2026
acdedb8
drop muon
mayank31398 Apr 15, 2026
82983b9
drop muon
mayank31398 Apr 15, 2026
17c0d19
drop muon
mayank31398 Apr 15, 2026
18da9a6
drop muon
mayank31398 Apr 15, 2026
20881ce
drop muon
mayank31398 Apr 15, 2026
dcd5a55
drop muon
mayank31398 Apr 15, 2026
b387271
drop muon
mayank31398 Apr 15, 2026
b35d924
drop muon
mayank31398 Apr 15, 2026
66329eb
drop muon
mayank31398 Apr 15, 2026
eee310c
Apply suggestion from @gemini-code-assist[bot]
mayank31398 Apr 15, 2026
26348f1
drop muon
mayank31398 Apr 15, 2026
911c2aa
Merge remote-tracking branch 'refs/remotes/origin/split-param' into s…
mayank31398 Apr 15, 2026
7ac8d80
drop muon
mayank31398 Apr 15, 2026
7a2fc35
add w
mayank31398 Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lm_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class OptimizerArgs(BaseArgs):
params_group_method: ParamsGroupMethod | None = None
# backward hooked optimizer
use_optimizer_with_backward_hook: bool = False
# whether to split params for the optimizer using model-defined split functions
split_params_for_optimizer: bool = False
# class args for optimizer
class_args: dict = {
"lr": 1e-5,
Expand Down
1 change: 1 addition & 0 deletions lm_engine/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .gradient_checkpointing import apply_gradient_checkpointing
from .hf_models import (
_INIT_MARKER,
_OPTIMIZER_SPLIT_FUNCTION,
CausalLMOutputWithPast,
get_parameter_marker_maps,
is_parameter_initialized,
Expand Down
1 change: 1 addition & 0 deletions lm_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def main() -> None:
model_container=model_container,
params_group_method=args.optimizer_args.params_group_method,
use_optimizer_with_backward_hook=args.optimizer_args.use_optimizer_with_backward_hook,
split_params_for_optimizer=args.optimizer_args.split_params_for_optimizer,
)

lr_scheduler_container = get_scheduler_container(
Expand Down
2 changes: 2 additions & 0 deletions lm_engine/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
)
from .parameter import (
_INIT_MARKER,
_OPTIMIZER_SPLIT_FUNCTION,
get_optimizer_split_function,
get_parameter_marker_maps,
is_parameter_initialized,
is_parameter_with_mup_learning_rate,
Expand Down
31 changes: 25 additions & 6 deletions lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from __future__ import annotations

from functools import partial

import torch
import torch.nn as nn

from ...parameter import mark_parameter_as_mup_learning_rate
from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function
from ..activations import get_activation_function, is_glu
from ..dropout import Dropout
from ..init_utils import _get_std_for_linear
Expand Down Expand Up @@ -77,6 +79,12 @@ def __init__(
mark_parameter_as_mup_learning_rate(self.c_fc.weight)
mark_parameter_as_mup_learning_rate(self.c_proj.weight)

if self.is_glu:
set_optimizer_split_function(
self.c_fc.weight,
partial(_split_up_gate_tensor_for_mlp_for_optimizer, is_interleaved=self.use_interleaved_weights),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.act(x, is_interleaved=self.use_interleaved_weights) if self.is_glu else self.act(x)
Expand Down Expand Up @@ -113,19 +121,30 @@ def interleave_up_gate_tensor_for_mlp(
return W


def split_up_gate_tensor_for_mlp(
def _split_up_gate_tensor_for_mlp_for_optimizer(
c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
if is_interleaved:
if dim == 0:
u = c_fc_weight[1::2].contiguous()
g = c_fc_weight[::2].contiguous()
u = c_fc_weight[1::2]
g = c_fc_weight[::2]
elif dim == 1:
u = c_fc_weight[:, 1::2].contiguous()
g = c_fc_weight[:, ::2].contiguous()
u = c_fc_weight[:, 1::2]
g = c_fc_weight[:, ::2]
else:
raise ValueError
else:
u, g = c_fc_weight.chunk(2, dim=dim)

return u, g


def split_up_gate_tensor_for_mlp(
c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
u, g = _split_up_gate_tensor_for_mlp_for_optimizer(c_fc_weight=c_fc_weight, is_interleaved=is_interleaved, dim=dim)
if is_interleaved:
u = u.contiguous()
g = g.contiguous()

return u, g
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import math
from functools import partial

import torch
import torch.nn.functional as F
Expand All @@ -14,7 +15,7 @@
from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available
from ...cache import GenerationCache, GenerationState, LinearCache
from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD
from ...parameter import mark_parameter_as_mup_learning_rate
from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function
from ..activations import sigmoid
from ..chunk import contiguous_split
from ..dropout import Dropout
Expand Down Expand Up @@ -53,15 +54,22 @@ def interleave_query_key_value_tensor_for_attention(
return torch.cat(interleaved)


def split_query_key_value_tensor_for_attention(
def _split_query_key_value_tensor_for_attention_for_optimizer(
query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
query_heads_per_group = num_heads // num_key_value_heads
original_shape = query_key_value_weight.shape
query_key_value_weight = query_key_value_weight.view(num_key_value_heads, query_heads_per_group + 2, -1)
query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1)
return query_weight, key_weight, value_weight

query_key_value_weight = query_key_value_weight.view(num_key_value_heads, (query_heads_per_group + 2), -1)

query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1)
def split_query_key_value_tensor_for_attention(
query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
original_shape = query_key_value_weight.shape
query_weight, key_weight, value_weight = _split_query_key_value_tensor_for_attention_for_optimizer(
query_key_value_weight=query_key_value_weight, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)

query_weight = query_weight.reshape(-1, *original_shape[1:])
key_weight = key_weight.reshape(-1, *original_shape[1:])
Expand Down Expand Up @@ -198,6 +206,15 @@ def __init__(
mark_parameter_as_mup_learning_rate(self.c_attn.weight)
mark_parameter_as_mup_learning_rate(self.c_proj.weight)

set_optimizer_split_function(
self.c_attn.weight,
partial(
_split_query_key_value_tensor_for_attention_for_optimizer,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
),
)

def forward(
self,
x: torch.Tensor,
Expand Down
20 changes: 18 additions & 2 deletions lm_engine/hf_models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from typing import Callable

import torch.nn as nn


_INIT_MARKER = "_is_initialized"
_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"]
_OPTIMIZER_SPLIT_FUNCTION = "_optimizer_split_function"
_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", _OPTIMIZER_SPLIT_FUNCTION]
_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER]


Expand Down Expand Up @@ -53,7 +56,9 @@ def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: l
for param_name, param in model.named_parameters():
marker_maps[-1][param_name] = {}
for marker in _METADATA_MARKERS + extra_markers:
marker_maps[-1][param_name][marker] = getattr(param, marker, False)
marker_maps[-1][param_name][marker] = getattr(
param, marker, None if marker == _OPTIMIZER_SPLIT_FUNCTION else False
)

return marker_maps

Expand All @@ -77,3 +82,14 @@ def set_parameter_marker_maps(

for marker, value in _marker_map[param_name].items():
setattr(parameter, marker, value)


def set_optimizer_split_function(parameter: nn.Parameter | None, function: Callable) -> nn.Parameter | None:
if parameter is not None:
parameter._optimizer_split_function = function

return parameter


def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None:
return getattr(parameter, _OPTIMIZER_SPLIT_FUNCTION, None)
81 changes: 67 additions & 14 deletions lm_engine/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import logging

import torch.nn as nn
from torch.optim import ASGD as TorchASGD
from torch.optim import LBFGS as TorchLBFGS
from torch.optim import SGD as TorchSGD
from torch.optim import Adadelta as TorchAdadelta
from torch.optim import Adagrad as TorchAdagrad
from torch.optim import Adam as TorchAdam
from torch.optim import Adamax as TorchAdamax
from torch.optim import AdamW as TorchAdamW
from torch.optim import NAdam as TorchNAdam
from torch.optim import Optimizer
from torch.optim.adadelta import Adadelta as TorchAdadelta
from torch.optim.adagrad import Adagrad as TorchAdagrad
from torch.optim.adam import Adam as TorchAdam
from torch.optim.adamax import Adamax as TorchAdamax
from torch.optim.adamw import AdamW as TorchAdamW
from torch.optim.asgd import ASGD as TorchASGD
from torch.optim.lbfgs import LBFGS as TorchLBFGS
from torch.optim.nadam import NAdam as TorchNAdam
from torch.optim.radam import RAdam as TorchRAdam
from torch.optim.rmsprop import RMSprop as TorchRMSprop
from torch.optim.rprop import Rprop as TorchRprop
from torch.optim.sgd import SGD as TorchSGD
from torch.optim import RAdam as TorchRAdam
from torch.optim import RMSprop as TorchRMSprop
from torch.optim import Rprop as TorchRprop

from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer
from ..enums import ParamsGroupMethod
from .params_group import get_param_groups_list
from ..hf_models import get_optimizer_split_function
from ..utils import log_rank_0
from .params_group import _ParamsGroupsList, get_param_groups_list
from .split_param_optimizer import SplitParamOptimizer


# https://pytorch.org/docs/stable/optim.html
Expand All @@ -39,12 +44,52 @@
}


def _build_optimizer(
optimizer_class, params_groups: _ParamsGroupsList, optimizer_class_args: dict, split_params_for_optimizer: bool
) -> SplitParamOptimizer | Optimizer:
if not split_params_for_optimizer:
return optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args)

proxy_grad_fns: dict[int, tuple] = {}
split_params: set[nn.Parameter] = set()
modified_groups = []

for pg in params_groups.params_groups:
group = pg.to_param_group()
names = pg.get_param_names()
group_kwargs = {k: v for k, v in group.items() if k != "params"}
new_params = []
for param, name in zip(group["params"], names):
split_fn = get_optimizer_split_function(param)
if split_fn is None:
new_params.append(param)
else:
log_rank_0(logging.INFO, f"splitting {name} for optimizer")
pieces = split_fn(param.data)

for i, piece in enumerate(pieces):
proxy = nn.Parameter(piece)
new_params.append(proxy)
proxy_grad_fns[id(proxy)] = (param, lambda g, fn=split_fn, idx=i: fn(g)[idx])
split_params.add(param)

modified_groups.append({"params": new_params, **group_kwargs})

inner = optimizer_class(modified_groups, **optimizer_class_args)

if split_params:
inner = SplitParamOptimizer(inner=inner, proxy_grad_fns=proxy_grad_fns, split_params=split_params)

return inner


def get_optimizer_container(
optimizer_class_name: str,
optimizer_class_args: dict,
model_container: ModelContainer,
params_group_method: ParamsGroupMethod,
use_optimizer_with_backward_hook: bool,
split_params_for_optimizer: bool,
) -> OptimizerContainer:
"""setup list of optimizers for the model

Expand All @@ -54,6 +99,7 @@ def get_optimizer_container(
model_container (ModelContainer): model container
params_group_method (ParamsGroupMethod): the params grouping to use
use_optimizer_with_backward_hook (bool): whether to use optimizer as a backward hook
split_params_for_optimizer (bool): whether to split params using model-defined split functions

Returns:
OptimizerContainer: optimizer container
Expand All @@ -69,6 +115,8 @@ def get_optimizer_container(
params_groups_list = get_param_groups_list(model_container, optimizer_class_args, params_group_method)

if use_optimizer_with_backward_hook:
assert not split_params_for_optimizer

for model, params_groups in zip(model_container, params_groups_list):
for param_name, param in model.named_parameters():
for group in params_groups.params_groups:
Expand All @@ -89,7 +137,12 @@ def _step(p: nn.Parameter) -> None:
else:
optimizer_list = OptimizerContainer(
[
optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args)
_build_optimizer(
optimizer_class=optimizer_class,
params_groups=params_groups,
optimizer_class_args=optimizer_class_args,
split_params_for_optimizer=split_params_for_optimizer,
)
for params_groups in params_groups_list
]
)
Expand Down
66 changes: 66 additions & 0 deletions lm_engine/optimization/split_param_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

from typing import Callable

import torch
import torch.nn as nn
from torch.optim import Optimizer


class SplitParamOptimizer(Optimizer):
def __init__(
self,
inner: Optimizer,
proxy_grad_fns: dict[int, tuple[nn.Parameter, Callable]],
split_params: set[nn.Parameter],
) -> SplitParamOptimizer:
self._inner = inner
self._proxy_grad_fns = proxy_grad_fns
self._split_params = split_params

@property
def param_groups(self) -> list[dict]:
return self._inner.param_groups

@property
def state(self) -> dict:
return self._inner.state

def state_dict(self) -> dict:
return self._inner.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
self._inner.load_state_dict(state_dict)

def add_param_group(self, param_group: dict) -> None:
self._inner.add_param_group(param_group)

def step(self, closure: Callable | None = None) -> torch.Tensor | None:
for group in self._inner.param_groups:
for p in group["params"]:
info = self._proxy_grad_fns.get(id(p))
if info is None:
continue

orig_param, grad_slice_fn = info
if orig_param.grad is not None:
p.grad = grad_slice_fn(orig_param.grad)

return self._inner.step(closure)

def zero_grad(self, set_to_none: bool = True) -> None:
self._inner.zero_grad(set_to_none)

for param in self._split_params:
if set_to_none:
param.grad = None
elif param.grad is not None:
param.grad.zero_()

def __repr__(self) -> str:
x = super().__repr__()
return f"{self._inner.__class__.__name__}({x})"
Loading
Loading