From 22b07c82fea83646a91ecaad84390d7887187965 Mon Sep 17 00:00:00 2001 From: tfaod <8447104+tfaod@users.noreply.github.com> Date: Tue, 5 Aug 2025 17:44:07 +0000 Subject: [PATCH 1/4] Add lion algorithm --- .../paper_baselines/lion/__init__.py | 0 .../paper_baselines/lion/pytorch/__init__.py | 0 .../lion/pytorch/submission.py | 284 ++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 reference_algorithms/paper_baselines/lion/__init__.py create mode 100644 reference_algorithms/paper_baselines/lion/pytorch/__init__.py create mode 100644 reference_algorithms/paper_baselines/lion/pytorch/submission.py diff --git a/reference_algorithms/paper_baselines/lion/__init__.py b/reference_algorithms/paper_baselines/lion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_algorithms/paper_baselines/lion/pytorch/__init__.py b/reference_algorithms/paper_baselines/lion/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_algorithms/paper_baselines/lion/pytorch/submission.py b/reference_algorithms/paper_baselines/lion/pytorch/submission.py new file mode 100644 index 000000000..4ff746242 --- /dev/null +++ b/reference_algorithms/paper_baselines/lion/pytorch/submission.py @@ -0,0 +1,284 @@ +from __future__ import annotations +import collections +from typing import Tuple, Callable, Any, Dict, Iterator, List, Optional + +from absl import logging +import torch +from torch.optim.optimizer import Optimizer + +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +# default Lion parameters +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 2e-4, + "one_minus_beta1": 0.05, + "beta2": 0.98, + "weight_decay": 0.5, + "warmup_factor": 0.02 +} +HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) + +# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py. +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + ): + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + beta1, beta2 = group['betas'] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + + p.add_(update.sign_(), alpha=-group['lr']) + + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a Lion optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + Lion( + model_params.parameters(), + lr=HPARAMS.learning_rate, + betas=(1.0 - HPARAMS.one_minus_beta1, + HPARAMS.beta2), + weight_decay=HPARAMS.weight_decay) + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, HPARAMS, optimizer_state['optimizer']) + optimizer_state['hyperparameters'] = hyperparameters + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(HPARAMS, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if hasattr(HPARAMS, "batch_size"): + return HPARAMS.batch_size + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch From f557301ebccca37f8634108cc3d3c40971fa033c Mon Sep 17 00:00:00 2001 From: tfaod <8447104+tfaod@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:11:19 +0000 Subject: [PATCH 2/4] [reformat] ruff reformat lion submission --- .../lion/pytorch/submission.py | 170 ++++++++++-------- 1 file changed, 95 insertions(+), 75 deletions(-) diff --git a/reference_algorithms/paper_baselines/lion/pytorch/submission.py b/reference_algorithms/paper_baselines/lion/pytorch/submission.py index 4ff746242..557ba9c89 100644 --- a/reference_algorithms/paper_baselines/lion/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lion/pytorch/submission.py @@ -18,15 +18,16 @@ # default Lion parameters HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 2e-4, - "one_minus_beta1": 0.05, - "beta2": 0.98, - "weight_decay": 0.5, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 2e-4, + 'one_minus_beta1': 0.05, + 'beta2': 0.98, + 'weight_decay': 0.5, + 'warmup_factor': 0.02, } HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) + # Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py. class Lion(Optimizer): def __init__( @@ -90,11 +91,13 @@ def step(self, closure=None): return loss -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Lion optimizer and a learning rate schedule.""" del model_state del rng @@ -103,44 +106,47 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters = HPARAMS optimizer_state = { - 'optimizer': - Lion( - model_params.parameters(), - lr=HPARAMS.learning_rate, - betas=(1.0 - HPARAMS.one_minus_beta1, - HPARAMS.beta2), - weight_decay=HPARAMS.weight_decay) + 'optimizer': Lion( + model_params.parameters(), + lr=HPARAMS.learning_rate, + betas=(1.0 - HPARAMS.one_minus_beta1, HPARAMS.beta2), + weight_decay=HPARAMS.weight_decay, + ) } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, HPARAMS, optimizer_state['optimizer']) + workload.step_hint, HPARAMS, optimizer_state['optimizer'] + ) optimizer_state['hyperparameters'] = hyperparameters return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -155,26 +161,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(HPARAMS, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(HPARAMS, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -187,7 +197,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -196,31 +207,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -234,8 +252,8 @@ def prepare_for_eval(workload: spec.Workload, def get_batch_size(workload_name): # Return the global batch size. - if hasattr(HPARAMS, "batch_size"): - return HPARAMS.batch_size + if hasattr(HPARAMS, 'batch_size'): + return HPARAMS.batch_size if workload_name == 'criteo1tb': return 262_144 elif workload_name == 'fastmri': @@ -262,14 +280,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ From 44c663ccd3d470e0a6394d5887c7fc998295e655 Mon Sep 17 00:00:00 2001 From: tfaod <8447104+tfaod@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:14:38 +0000 Subject: [PATCH 3/4] [nit] remove unused Callable type --- reference_algorithms/paper_baselines/lion/pytorch/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reference_algorithms/paper_baselines/lion/pytorch/submission.py b/reference_algorithms/paper_baselines/lion/pytorch/submission.py index 557ba9c89..567eeb206 100644 --- a/reference_algorithms/paper_baselines/lion/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lion/pytorch/submission.py @@ -1,6 +1,6 @@ from __future__ import annotations import collections -from typing import Tuple, Callable, Any, Dict, Iterator, List, Optional +from typing import Tuple, Any, Dict, Iterator, List, Optional from absl import logging import torch From a95ed2328d213c920e008a5b71d6838c84729384 Mon Sep 17 00:00:00 2001 From: tfaod <8447104+tfaod@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:17:32 +0000 Subject: [PATCH 4/4] [ruff] reformat imports --- .../paper_baselines/lion/pytorch/submission.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/reference_algorithms/paper_baselines/lion/pytorch/submission.py b/reference_algorithms/paper_baselines/lion/pytorch/submission.py index 567eeb206..2154a820d 100644 --- a/reference_algorithms/paper_baselines/lion/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lion/pytorch/submission.py @@ -1,15 +1,13 @@ from __future__ import annotations + import collections -from typing import Tuple, Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch.optim.optimizer import Optimizer - import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR +from torch.optim.optimizer import Optimizer from algoperf import spec from algoperf.pytorch_utils import pytorch_setup