This repository contains the code for implementing the methods presented in the paper Training Neural Networks on Data Sources with Unknown Reliability
To instead reproduce the results of the experiments in the paper, please see the repository unreliable-sources.
To install the package, clone the repository and run the following command:
pip install loss_adapted_plasticity
To use this method on new data
from loss_adapted_plasticity import LossAdaptedPlasticity
# define the loss weighting with the desired parameters
loss_weighting = LossAdaptedPlasticity()
# ensure that your loss function
# returns a loss for each sample in the batch.
# In pytorch this is done by setting reduction="none"
criterion = nn.LossFunction(reduction="none")
for epoch in range(epochs):
for data, target, sources in train_loader:
output = model(data)
losses = criterion(output, target)
# compute the mean weighted loss
loss = loss_weighting(losses, sources).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
The loss_weighting
object will keep track of the loss history for each source, and will compute the weighted loss for each batch in
The loss_weighting
object can be used with any loss function that returns a loss for each sample in the batch and will return a weighted loss for each sample, which can then be reduced (usually by the mean or sum) to get the batch loss used in back-propagation.
This assumes that the smaller a loss value it is, the better the prediction is.
This package contains a single class LossAdaptedPlasticity
which can be used to weight the loss of each sample in a batch based on the history of the loss values for each source. The documentation for this is as follows:
This class calculates the loss weighting for each source based on the loss history of each source. It's usage is as follows:
loss_weighting = LossAdaptedPlasticity()
During the training loop:
outputs = model(inputs)
# get loss values for each sample
losses = loss_fn(outputs, targets, reduction="none")
# reweight the loss using LAP
losses = loss_weighting(losses=losses, sources=sources)
# get mean loss and backpropagate
loss = torch.mean(losses)
optimizer.zero_grad()
loss.backward()
optimizer.step()
__init__(self, history_length: int = 50, warmup_iters: int = 100, depression_strength: float = 1, discrete_amount: float = 0.005, leniency: float = 1.0,device="cpu")
-
history_length
: int, optional: The number of previous loss values for each source to be used in the loss adapted plasticity calculations. Defaults to10
. -
warmup_iters
: int, optional: The number of iterations before the loss weighting starts to be applied. Defaults to100
. -
depression_strength
: float, optional: This float determines the strength of the depression applied to the gradients. It is the value ofm
indep = 1-tanh(m*d)**2
. Defaults to1
. -
discrete_amount
: float, optional: The step size used when calculating the depression. Defaults to0.005
. -
leniency
: float, optional: The number of standard deviations away from the mean loss a mean source loss has to be before depression is applied. Defaults to1.0
. -
device
: str, optional: The device to use for the calculations. If this is a different device to the one which the model is on, the loss values will be moved to the device specified here and returned to the original device afterwards. Defaults to"cpu"
.
forward(self, losses: torch.Tensor, sources: torch.Tensor, writer=None, writer_prefix: typing.Optional[str] = None) -> torch.Tensor:
This function calculates the weighted loss for each sample and should be called during the training loop as shown in the example above.
-
losses
: torch.Tensor of shape (batch_size,): The losses for each example in the batch. -
sources
: torch.Tensor of shape (batch_size,): The source for each example in the batch. -
writer
: torch.utils.tensorboard.SummaryWriter, optional: A tensorboard writer can be passed into this function to track metrics. Defaults toNone
. -
writer_prefix
: str, optional: A prefix to add to the writer metrics. Defaults toNone
.
output
: torch.Tensor of shape (batch_size,): The weighted losses for each example in the batch.
The calculated unreliability of each source. The larger the value, the less reliable the source.