Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
4fea359
Create Clipper base class for shared functionality of state_dicts and…
HesitantlyHuman Jul 8, 2022
78148cd
Create QuantileClip class implementation with both parameter-wise and…
HesitantlyHuman Jul 8, 2022
4998965
Placeholder class for future StandardClipper
HesitantlyHuman Jul 8, 2022
2619a9d
Initial commit and basic dir structure
HesitantlyHuman Jul 8, 2022
514116d
Move tensorflow implementation into appropriate folder, modify tensor…
HesitantlyHuman Jul 8, 2022
72ded99
Update defaults for QuantileClip
HesitantlyHuman Jul 8, 2022
b5c810e
Implement StandardClip
HesitantlyHuman Jul 8, 2022
c42ed87
Delete old autoclip implementations.
HesitantlyHuman Jul 8, 2022
65b1b7f
Add loading and saving state to clipper base class
HesitantlyHuman Jul 9, 2022
2bfb0ae
Add repr to base clipper class, update add_parameter_group to accept …
HesitantlyHuman Jul 9, 2022
6af7894
Remove unnecessary line from mnist example
HesitantlyHuman Jul 9, 2022
f474e09
Update readme to reflect new API and package information
HesitantlyHuman Jul 9, 2022
862b774
Change parameter names to match torch API, add __init__.py
HesitantlyHuman Jul 9, 2022
baf75fa
Remove unecessary clip prefix to clipping parameter values in Quantil…
HesitantlyHuman Jul 9, 2022
8492775
Update mnist example to match new parameter names
HesitantlyHuman Jul 9, 2022
4b5dbd2
Fix type hints typo
HesitantlyHuman Jul 9, 2022
d00b39d
Update setup.py
HesitantlyHuman Jul 9, 2022
48e6e6f
bump version
HesitantlyHuman Jul 9, 2022
67e11e7
Spelling errors and phrasing updates on readme
HesitantlyHuman Jul 9, 2022
b34068f
Fix broken link
HesitantlyHuman Jul 9, 2022
f56f21e
fix broken link
HesitantlyHuman Jul 9, 2022
a6d016b
shorten global_clipping example view width
HesitantlyHuman Jul 10, 2022
ad55068
Bump patch version
HesitantlyHuman Jul 15, 2022
eac0586
Update mnist example to show new optimizer wrapping pattern
HesitantlyHuman Jul 15, 2022
3dc169e
Create new base class for optimizer wrapping, add class method to cli…
HesitantlyHuman Jul 15, 2022
372dc78
Add as_optimizer for QuantileClip and StandardClip
HesitantlyHuman Jul 15, 2022
15729a6
bump minor version
HesitantlyHuman Jul 15, 2022
1fa9bac
Update README.md
HesitantlyHuman Jul 15, 2022
412eb1d
add optimizer wrapping example code
HesitantlyHuman Jul 15, 2022
7225c7c
Add __repr__ for OptimizerWithClipping wrapper
HesitantlyHuman Jul 16, 2022
9611e34
Remove TODO in readme
HesitantlyHuman Jul 16, 2022
c3c209e
Merge pull request #5 from HesitantlyHuman/optimizer-addons
HesitantlyHuman Jul 16, 2022
0b27bd9
write unit tests
HesitantlyHuman Jul 16, 2022
8893f4b
fix spelling error in test function name
HesitantlyHuman Jul 16, 2022
1ff5c98
Fix string recursion error
HesitantlyHuman Jul 16, 2022
8df4eaa
reformat test for readability
HesitantlyHuman Jul 16, 2022
d9e0835
Add parameter checking base function, call parameter checking functio…
HesitantlyHuman Jul 16, 2022
13ba540
Implement parameter checking for QuantileClip and StandardClip
HesitantlyHuman Jul 16, 2022
9dee079
Merge branch 'master' of github.com:HesitantlyHuman/autoclip
HesitantlyHuman Jul 16, 2022
b53b358
Fix optimizer wrapper docstring, return optimizer step output for ca…
HesitantlyHuman Sep 29, 2022
754cbe3
Fix optimizer wrapper pickling error, add unit test coverage for pick…
HesitantlyHuman Sep 29, 2022
3196f94
Clipper pickling tests
HesitantlyHuman Sep 29, 2022
34badef
Bump patch, update README.md
HesitantlyHuman Sep 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 110 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,39 +1,126 @@
# AutoClip: Adaptive Gradient Clipping

This repository accompanies the [paper](https://arxiv.org/abs/2007.14469):
# AutoClip
Pytorch and tensorflow implementations (and variations) of the AutoClip gradient smoothing procedure from [Seetharaman et al](https://arxiv.org/abs/2007.14469).

> Prem Seetharaman, Gordon Wichern, Bryan Pardo, Jonathan Le Roux. "AutoClip: Adaptive Gradient Clipping for Source Separation Networks." 2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP). IEEE, 2020.

At the moment it contains a [sample implementation of AutoClip](autoclip.py) that can be integrated into an ML project based on PyTorch easily.
Soon it will come as a Python package that can be installed and attached to a training script more easily.
## About

## Abstract
> Clipping the gradient is a known approach to improving gradient descent, but requires hand selection of a clipping threshold hyperparameter. We present AutoClip, a simple method for automatically and adaptively choosing a gradient clipping threshold, based on the history of gradient norms observed during training. Experimental results show that applying AutoClip results in improved generalization performance for audio source separation networks. Observation of the training dynamics of a separation network trained with and without AutoClip show that AutoClip guides optimization into smoother parts of the loss landscape. AutoClip is very simple to implement and can be integrated readily into a variety of applications across multiple domains.
While training your model, AutoClip keeps a running history of all of your model's gradient magnitudes. Using these, the gradient clipper can adaptively clamp outlier gradient values before they reach the optimizer of your choice.

## Presentation
While AutoClip is great as a preventative measure against exploding gradients, it also speeds up training time, and encourages the optimizer to find more optimal models. At an intuitive level, AutoClip compensates for the stochastic nature of training over batches, regularizing training effects.

This work was presented at MLSP2020 in a special session. If you missed my talk, no worries, there's a pandemic happening so it's recorded! [Here it is](https://share.descript.com/view/18725e02-95fe-4fb0-b32d-26c63617d482).
## Installation

## Citation
AutoClip is listed on pypi. To install AutoClip simply run the following command
```
@inproceedings{seetharaman2020autoclip,
title={AutoClip: Adaptive Gradient Clipping for Source Separation Networks},
author={Seetharaman, Prem, and Wichern, Gordon, and Pardo, Bryan, and Le Roux, Jonathan},
booktitle={2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP)},
year={2020},
organization={IEEE}
}
pip install autoclip
```
and the `autoclip` package will be installed in your currently active environment.

## Torch API

Below are some examples how to use `autoclip`'s torch API.

### Clippers as Optimizer Wrappers
Using the optimizer wrapping pattern is the recommended way to use AutoClip, and `autoclip`'s torch API supports wrapping arbitrary pytorch optimizers. The wrapping pattern allows you to avoid changing your training code when you want to use an AutoClip clipper. This is especially useful if you do not own the training code for whatever reason. (Say for example you are using someone else's Trainer class, as is often the case with frameworks like `huggingface`.)

The following is an example of how to integrate AutoClip into your model training using this pattern:
```python
import torch
from autoclip.torch import QuantileClip

## Training dynamics
model = torch.nn.Sequential(
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 2)
)

### Mask-inference loss
optimizer = torch.optim.AdamW(model.parameters())
optimizer = QuantileClip.as_optimizer(optimizer=optimizer, quantile=0.9, history_length=1000)
```
Now you can use the optimizer just like you would have before adding the clipper, and the clipping will be applied automatically.

![](images/mi.gif)
### Raw AutoClip Clippers
You can still use the clipper manually if you would like. If this is the case, then you would create your clipper like this:
```python
import torch
from autoclip.torch import QuantileClip

### Whitened K-Means loss
model = torch.nn.Sequential(
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 2)
)

![](images/wkm.gif)
clipper = QuantileClip(model.parameters(), quantile=0.9, history_length=1000)
```
Then, to clip the model's gradients, simply run the clipper's `.step()` function during your training loop. Note that you should call the clipper's `step` before you call your optimizer's `step`. Calling it after would mean that your clipping will have no effect, since the model will have already been updated using the unclipped gradients. For example:
```python
for batch_num, batch in enumerate(training_dataset):
model_prediction = model(batch['data'])
loss = loss_function(model_prediction, batch['targets'])
loss.backward()
clipper.step() # clipper comes before optimizer
optimizer.step()
```

Training dynamics of a smaller mask inference network (2 BLSTM layers with 300 hidden units) with mask-inference loss and whitened k-means loss, with and without AutoClip. The top left figure shows the norm of the step size taken on the model parameters. The top right figure shows the training loss over time, showing that AutoClip leads to better optimization. The bottom figures show the relationship between gradient norm and a measure of smoothness along the training trajectory. Statistics were recorded every 20 iterations during training. With AutoClip, we observe a stronger correlation (r-value of .86), compared to without (r-value of .62). All gradients to the right of the dashed black line in the bottom right plot are clipped. We show the location of the AutoClip threshold at the end of training. The threshold changes during training.
### Global vs Local Clipping
`autoclip`'s torch clippers support two clipping modes. The first is `global_clipping`, which is the original AutoClip as described in Seetherman et al. The second is local or parameter-wise clipping. In this mode a history is kept for every parameter, and each is clipped according to its own history. By default, the `autoclip` clippers will use the parameter-wise clipping.
To use the global mode, simply pass the appropriate flag:
```python
clipper = QuantileClip(
model.parameters(),
quantile=0.9,
history_length=1000,
global_clipping=True
)
```

### Checkpointing
The torch clippers also support checkpointing through `state_dict()` and `load_state_dict()`, just like torch models and optimizers. For example, if you want to checkpoint a clipper to `clipper.pth`:
```python
clipper = QuantileClip(model.parameters())
torch.save(clipper.state_dict(), 'clipper.pth')

# Then later
clipper = QuantileClip(model.parameters())
clipper.load_state_dict(torch.load('clipper.pth'))
```
Keep in mind that just like a torch optimizer this will error if you give the clipper differently sized model parameters.

While it is generally recommended to use `state_dict`s instead (see the [pytorch documentation](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model) on this subject for more info), you may also use `torch.save` and `torch.load` directly to pickle the entire clipper object.

## Tensorflow
`autoclip`'s tensorflow API does not currently have feature parity with the torch API (If you want to change this, feel free to [contribute](https://github.com/HesitantlyHuman/autoclip/issues/2)).
As it is, the tensorflow API currently only supports the original AutoClip algorithm, and does not support checkpointing. Below is a short example:
```python
import tensorflow as tf
from autoclip.tf import QuantileClip

model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(50),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(
2,
activation=tf.keras.activations.tanh
),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=0.001,
gradient_transformers=[
QuantileClip(
quantile=0.9,
history_length=1000
)
]
),
loss="mean_absolute_error",
metrics=["accuracy"],
)
model.fit(train_data, train_targets)
```
Binary file removed autoclip.pdf
Binary file not shown.
29 changes: 0 additions & 29 deletions autoclip.py

This file was deleted.

3 changes: 3 additions & 0 deletions autoclip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

__location__ = os.path.dirname(__file__)
1 change: 1 addition & 0 deletions autoclip/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from autoclip.tf.quantile import QuantileClip
31 changes: 31 additions & 0 deletions autoclip/tf/quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import tensorflow as tf
import tensorflow_probability as tfp


class QuantileClip:
def __init__(self, quantile: float = 0.9, history_length: int = 1000):
self.quantile = quantile * 100
self.grad_history = tf.Variable(tf.zeros(history_length), trainable=False)
self.i = tf.Variable(0, trainable=False)
self.history_size = history_length

def __call__(self, grads_and_vars):
grad_norms = [self._get_grad_norm(g) for g, _ in grads_and_vars]
total_norm = tf.norm(grad_norms)
assign_idx = tf.math.mod(self.i, self.history_size)
self.grad_history = self.grad_history[assign_idx].assign(total_norm)
self.i = self.i.assign_add(1)
clip_value = tfp.stats.percentile(self.grad_history[: self.i], q=self.quantile)
return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars]

def _get_grad_norm(self, t, axes=None, name=None):
values = tf.convert_to_tensor(
t.values if isinstance(t, tf.IndexedSlices) else t, name="t"
)

# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
l2sum = tf.math.reduce_sum(values * values, axes, keepdims=True)
pred = l2sum > 0
# Two-tap tf.where trick to bypass NaN gradients
l2sum_safe = tf.where(pred, l2sum, tf.ones_like(l2sum))
return tf.squeeze(tf.where(pred, tf.math.sqrt(l2sum_safe), l2sum))
2 changes: 2 additions & 0 deletions autoclip/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from autoclip.torch.quantile import QuantileClip
from autoclip.torch.std import StandardClip
Loading