Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
302 changes: 302 additions & 0 deletions reference_algorithms/paper_baselines/lion/pytorch/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
from __future__ import annotations

import collections
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
import torch.distributed.nn as dist_nn
from absl import logging
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.optim.optimizer import Optimizer

from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup

USE_PYTORCH_DDP = pytorch_setup()[0]

# default Lion parameters
HPARAMS = {
'dropout_rate': 0.1,
'learning_rate': 2e-4,
'one_minus_beta1': 0.05,
'beta2': 0.98,
'weight_decay': 0.5,
'warmup_factor': 0.02,
}
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)


# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py.
class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.

Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.

Returns:
the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])

grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)

exp_avg = state['exp_avg']
beta1, beta2 = group['betas']

# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)

p.add_(update.sign_(), alpha=-group['lr'])

# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

return loss


def init_optimizer_state(
workload: spec.Workload,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState,
) -> spec.OptimizerState:
"""Creates a Lion optimizer and a learning rate schedule."""
del model_state
del rng
del hyperparameters

hyperparameters = HPARAMS

optimizer_state = {
'optimizer': Lion(
model_params.parameters(),
lr=HPARAMS.learning_rate,
betas=(1.0 - HPARAMS.one_minus_beta1, HPARAMS.beta2),
weight_decay=HPARAMS.weight_decay,
)
}

def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
)
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
)

optimizer_state['scheduler'] = pytorch_cosine_warmup(
workload.step_hint, HPARAMS, optimizer_state['optimizer']
)
optimizer_state['hyperparameters'] = hyperparameters

return optimizer_state


def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None,
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

hyperparameters = HPARAMS

current_model = current_param_container
current_model.train()
optimizer_state['optimizer'].zero_grad()

logits_batch, new_model_state = workload.model_fn(
params=current_model,
augmented_and_preprocessed_input_batch=batch,
model_state=model_state,
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True,
)

label_smoothing = (
hyperparameters.label_smoothing
if hasattr(HPARAMS, 'label_smoothing')
else 0.0
)
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None

loss_dict = workload.loss_fn(
label_batch=batch['targets'],
logits_batch=logits_batch,
mask_batch=batch.get('weights'),
label_smoothing=label_smoothing,
)
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
# Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
summed_loss = dist_nn.all_reduce(summed_loss)
n_valid_examples = dist_nn.all_reduce(n_valid_examples)
loss = summed_loss / n_valid_examples

loss.backward()

if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
current_model.parameters(), max_norm=grad_clip
)
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()

# Log training metrics - loss, grad_norm, batch_size.
if global_step <= 100 or global_step % 500 == 0:
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
)
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
{
'loss': loss.item(),
'grad_norm': grad_norm.item(),
},
global_step,
)
logging.info(
'%d) loss = %0.3f, grad_norm = %0.3f',
global_step,
loss.item(),
grad_norm.item(),
)

return (optimizer_state, current_param_container, new_model_state)


def prepare_for_eval(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if hasattr(HPARAMS, 'batch_size'):
return HPARAMS.batch_size
if workload_name == 'criteo1tb':
return 262_144
elif workload_name == 'fastmri':
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
return 256
elif workload_name == 'librispeech_deepspeech':
return 256
elif workload_name == 'ogbg':
return 512
elif workload_name == 'wmt':
return 128
elif workload_name == 'mnist':
return 16
else:
raise ValueError(f'Unsupported workload name: {workload_name}.')


def data_selection(
workload: spec.Workload,
input_queue: Iterator[Dict[str, spec.Tensor]],
optimizer_state: spec.OptimizerState,
current_param_container: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
global_step: int,
rng: spec.RandomState,
) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
del workload
del optimizer_state
del current_param_container
del model_state
del hyperparameters
del global_step
del rng
batch = next(input_queue)
return batch
Loading