diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index 5a172781d..577d64c11 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -175,6 +175,8 @@ class OptimizerArgs(BaseArgs): params_group_method: ParamsGroupMethod | None = None # backward hooked optimizer use_optimizer_with_backward_hook: bool = False + # whether to split params for the optimizer using model-defined split functions + split_params_for_optimizer: bool = False # class args for optimizer class_args: dict = { "lr": 1e-5, diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index d753840ff..7c7f0c393 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -32,6 +32,7 @@ from .gradient_checkpointing import apply_gradient_checkpointing from .hf_models import ( _INIT_MARKER, + _OPTIMIZER_SPLIT_FUNCTION, CausalLMOutputWithPast, get_parameter_marker_maps, is_parameter_initialized, diff --git a/lm_engine/finetune.py b/lm_engine/finetune.py index a266fd66b..bb5c8c5ae 100644 --- a/lm_engine/finetune.py +++ b/lm_engine/finetune.py @@ -250,6 +250,7 @@ def main() -> None: model_container=model_container, params_group_method=args.optimizer_args.params_group_method, use_optimizer_with_backward_hook=args.optimizer_args.use_optimizer_with_backward_hook, + split_params_for_optimizer=args.optimizer_args.split_params_for_optimizer, ) lr_scheduler_container = get_scheduler_container( diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index e7cd2e04a..1f876ae80 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -22,6 +22,8 @@ ) from .parameter import ( _INIT_MARKER, + _OPTIMIZER_SPLIT_FUNCTION, + get_optimizer_split_function, get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index 15d38ded9..a8aecd417 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -4,10 +4,12 @@ from __future__ import annotations +from functools import partial + import torch import torch.nn as nn -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..init_utils import _get_std_for_linear @@ -77,6 +79,12 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + if self.is_glu: + set_optimizer_split_function( + self.c_fc.weight, + partial(_split_up_gate_tensor_for_mlp_for_optimizer, is_interleaved=self.use_interleaved_weights), + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = self.act(x, is_interleaved=self.use_interleaved_weights) if self.is_glu else self.act(x) @@ -113,19 +121,30 @@ def interleave_up_gate_tensor_for_mlp( return W -def split_up_gate_tensor_for_mlp( +def _split_up_gate_tensor_for_mlp_for_optimizer( c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0 ) -> tuple[torch.Tensor, torch.Tensor]: if is_interleaved: if dim == 0: - u = c_fc_weight[1::2].contiguous() - g = c_fc_weight[::2].contiguous() + u = c_fc_weight[1::2] + g = c_fc_weight[::2] elif dim == 1: - u = c_fc_weight[:, 1::2].contiguous() - g = c_fc_weight[:, ::2].contiguous() + u = c_fc_weight[:, 1::2] + g = c_fc_weight[:, ::2] else: raise ValueError else: u, g = c_fc_weight.chunk(2, dim=dim) return u, g + + +def split_up_gate_tensor_for_mlp( + c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0 +) -> tuple[torch.Tensor, torch.Tensor]: + u, g = _split_up_gate_tensor_for_mlp_for_optimizer(c_fc_weight=c_fc_weight, is_interleaved=is_interleaved, dim=dim) + if is_interleaved: + u = u.contiguous() + g = g.contiguous() + + return u, g diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 80020f0af..ee4d1caf5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -5,6 +5,7 @@ from __future__ import annotations import math +from functools import partial import torch import torch.nn.functional as F @@ -14,7 +15,7 @@ from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache, GenerationState, LinearCache from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD -from ...parameter import mark_parameter_as_mup_learning_rate +from ...parameter import mark_parameter_as_mup_learning_rate, set_optimizer_split_function from ..activations import sigmoid from ..chunk import contiguous_split from ..dropout import Dropout @@ -53,15 +54,22 @@ def interleave_query_key_value_tensor_for_attention( return torch.cat(interleaved) -def split_query_key_value_tensor_for_attention( +def _split_query_key_value_tensor_for_attention_for_optimizer( query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: query_heads_per_group = num_heads // num_key_value_heads - original_shape = query_key_value_weight.shape + query_key_value_weight = query_key_value_weight.view(num_key_value_heads, query_heads_per_group + 2, -1) + query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1) + return query_weight, key_weight, value_weight - query_key_value_weight = query_key_value_weight.view(num_key_value_heads, (query_heads_per_group + 2), -1) - query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1) +def split_query_key_value_tensor_for_attention( + query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + original_shape = query_key_value_weight.shape + query_weight, key_weight, value_weight = _split_query_key_value_tensor_for_attention_for_optimizer( + query_key_value_weight=query_key_value_weight, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) query_weight = query_weight.reshape(-1, *original_shape[1:]) key_weight = key_weight.reshape(-1, *original_shape[1:]) @@ -198,6 +206,15 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + set_optimizer_split_function( + self.c_attn.weight, + partial( + _split_query_key_value_tensor_for_attention_for_optimizer, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + ), + ) + def forward( self, x: torch.Tensor, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 0a816f0a0..d67b8bc06 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -2,11 +2,14 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from typing import Callable + import torch.nn as nn _INIT_MARKER = "_is_initialized" -_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"] +_OPTIMIZER_SPLIT_FUNCTION = "_optimizer_split_function" +_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", _OPTIMIZER_SPLIT_FUNCTION] _ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER] @@ -53,7 +56,9 @@ def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: l for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} for marker in _METADATA_MARKERS + extra_markers: - marker_maps[-1][param_name][marker] = getattr(param, marker, False) + marker_maps[-1][param_name][marker] = getattr( + param, marker, None if marker == _OPTIMIZER_SPLIT_FUNCTION else False + ) return marker_maps @@ -77,3 +82,14 @@ def set_parameter_marker_maps( for marker, value in _marker_map[param_name].items(): setattr(parameter, marker, value) + + +def set_optimizer_split_function(parameter: nn.Parameter | None, function: Callable) -> nn.Parameter | None: + if parameter is not None: + parameter._optimizer_split_function = function + + return parameter + + +def get_optimizer_split_function(parameter: nn.Parameter) -> Callable | None: + return getattr(parameter, _OPTIMIZER_SPLIT_FUNCTION, None) diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d03fdea39..7842493a4 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -2,24 +2,29 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +import logging + import torch.nn as nn +from torch.optim import ASGD as TorchASGD +from torch.optim import LBFGS as TorchLBFGS +from torch.optim import SGD as TorchSGD +from torch.optim import Adadelta as TorchAdadelta +from torch.optim import Adagrad as TorchAdagrad +from torch.optim import Adam as TorchAdam +from torch.optim import Adamax as TorchAdamax +from torch.optim import AdamW as TorchAdamW +from torch.optim import NAdam as TorchNAdam from torch.optim import Optimizer -from torch.optim.adadelta import Adadelta as TorchAdadelta -from torch.optim.adagrad import Adagrad as TorchAdagrad -from torch.optim.adam import Adam as TorchAdam -from torch.optim.adamax import Adamax as TorchAdamax -from torch.optim.adamw import AdamW as TorchAdamW -from torch.optim.asgd import ASGD as TorchASGD -from torch.optim.lbfgs import LBFGS as TorchLBFGS -from torch.optim.nadam import NAdam as TorchNAdam -from torch.optim.radam import RAdam as TorchRAdam -from torch.optim.rmsprop import RMSprop as TorchRMSprop -from torch.optim.rprop import Rprop as TorchRprop -from torch.optim.sgd import SGD as TorchSGD +from torch.optim import RAdam as TorchRAdam +from torch.optim import RMSprop as TorchRMSprop +from torch.optim import Rprop as TorchRprop from ..containers import BackwardHookOptimizerContainer, ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod -from .params_group import get_param_groups_list +from ..hf_models import get_optimizer_split_function +from ..utils import log_rank_0 +from .params_group import _ParamsGroupsList, get_param_groups_list +from .split_param_optimizer import SplitParamOptimizer # https://pytorch.org/docs/stable/optim.html @@ -39,12 +44,52 @@ } +def _build_optimizer( + optimizer_class, params_groups: _ParamsGroupsList, optimizer_class_args: dict, split_params_for_optimizer: bool +) -> SplitParamOptimizer | Optimizer: + if not split_params_for_optimizer: + return optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) + + proxy_grad_fns: dict[int, tuple] = {} + split_params: set[nn.Parameter] = set() + modified_groups = [] + + for pg in params_groups.params_groups: + group = pg.to_param_group() + names = pg.get_param_names() + group_kwargs = {k: v for k, v in group.items() if k != "params"} + new_params = [] + for param, name in zip(group["params"], names): + split_fn = get_optimizer_split_function(param) + if split_fn is None: + new_params.append(param) + else: + log_rank_0(logging.INFO, f"splitting {name} for optimizer") + pieces = split_fn(param.data) + + for i, piece in enumerate(pieces): + proxy = nn.Parameter(piece) + new_params.append(proxy) + proxy_grad_fns[id(proxy)] = (param, lambda g, fn=split_fn, idx=i: fn(g)[idx]) + split_params.add(param) + + modified_groups.append({"params": new_params, **group_kwargs}) + + inner = optimizer_class(modified_groups, **optimizer_class_args) + + if split_params: + inner = SplitParamOptimizer(inner=inner, proxy_grad_fns=proxy_grad_fns, split_params=split_params) + + return inner + + def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, model_container: ModelContainer, params_group_method: ParamsGroupMethod, use_optimizer_with_backward_hook: bool, + split_params_for_optimizer: bool, ) -> OptimizerContainer: """setup list of optimizers for the model @@ -54,6 +99,7 @@ def get_optimizer_container( model_container (ModelContainer): model container params_group_method (ParamsGroupMethod): the params grouping to use use_optimizer_with_backward_hook (bool): whether to use optimizer as a backward hook + split_params_for_optimizer (bool): whether to split params using model-defined split functions Returns: OptimizerContainer: optimizer container @@ -69,6 +115,8 @@ def get_optimizer_container( params_groups_list = get_param_groups_list(model_container, optimizer_class_args, params_group_method) if use_optimizer_with_backward_hook: + assert not split_params_for_optimizer + for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): for group in params_groups.params_groups: @@ -89,7 +137,12 @@ def _step(p: nn.Parameter) -> None: else: optimizer_list = OptimizerContainer( [ - optimizer_class(params_groups.to_torch_compatible_params_groups(), **optimizer_class_args) + _build_optimizer( + optimizer_class=optimizer_class, + params_groups=params_groups, + optimizer_class_args=optimizer_class_args, + split_params_for_optimizer=split_params_for_optimizer, + ) for params_groups in params_groups_list ] ) diff --git a/lm_engine/optimization/split_param_optimizer.py b/lm_engine/optimization/split_param_optimizer.py new file mode 100644 index 000000000..be407a391 --- /dev/null +++ b/lm_engine/optimization/split_param_optimizer.py @@ -0,0 +1,66 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +from typing import Callable + +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class SplitParamOptimizer(Optimizer): + def __init__( + self, + inner: Optimizer, + proxy_grad_fns: dict[int, tuple[nn.Parameter, Callable]], + split_params: set[nn.Parameter], + ) -> SplitParamOptimizer: + self._inner = inner + self._proxy_grad_fns = proxy_grad_fns + self._split_params = split_params + + @property + def param_groups(self) -> list[dict]: + return self._inner.param_groups + + @property + def state(self) -> dict: + return self._inner.state + + def state_dict(self) -> dict: + return self._inner.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + self._inner.load_state_dict(state_dict) + + def add_param_group(self, param_group: dict) -> None: + self._inner.add_param_group(param_group) + + def step(self, closure: Callable | None = None) -> torch.Tensor | None: + for group in self._inner.param_groups: + for p in group["params"]: + info = self._proxy_grad_fns.get(id(p)) + if info is None: + continue + + orig_param, grad_slice_fn = info + if orig_param.grad is not None: + p.grad = grad_slice_fn(orig_param.grad) + + return self._inner.step(closure) + + def zero_grad(self, set_to_none: bool = True) -> None: + self._inner.zero_grad(set_to_none) + + for param in self._split_params: + if set_to_none: + param.grad = None + elif param.grad is not None: + param.grad.zero_() + + def __repr__(self) -> str: + x = super().__repr__() + return f"{self._inner.__class__.__name__}({x})" diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index 65f499b18..678fafd7d 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -646,6 +646,7 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No model_container=model_container, params_group_method=args.optimizer_args.params_group_method, use_optimizer_with_backward_hook=args.optimizer_args.use_optimizer_with_backward_hook, + split_params_for_optimizer=args.optimizer_args.split_params_for_optimizer, ) lr_scheduler_container = get_scheduler_container(