diff --git a/reference_algorithms/paper_baselines/lion/__init__.py b/reference_algorithms/paper_baselines/lion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_algorithms/paper_baselines/lion/pytorch/__init__.py b/reference_algorithms/paper_baselines/lion/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_algorithms/paper_baselines/lion/pytorch/submission.py b/reference_algorithms/paper_baselines/lion/pytorch/submission.py new file mode 100644 index 000000000..2154a820d --- /dev/null +++ b/reference_algorithms/paper_baselines/lion/pytorch/submission.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import collections +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR +from torch.optim.optimizer import Optimizer + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +# default Lion parameters +HPARAMS = { + 'dropout_rate': 0.1, + 'learning_rate': 2e-4, + 'one_minus_beta1': 0.05, + 'beta2': 0.98, + 'weight_decay': 0.5, + 'warmup_factor': 0.02, +} +HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) + + +# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py. +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + ): + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + beta1, beta2 = group['betas'] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + + p.add_(update.sign_(), alpha=-group['lr']) + + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a Lion optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': Lion( + model_params.parameters(), + lr=HPARAMS.learning_rate, + betas=(1.0 - HPARAMS.one_minus_beta1, HPARAMS.beta2), + weight_decay=HPARAMS.weight_decay, + ) + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, HPARAMS, optimizer_state['optimizer'] + ) + optimizer_state['hyperparameters'] = hyperparameters + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(HPARAMS, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if hasattr(HPARAMS, 'batch_size'): + return HPARAMS.batch_size + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch