From c2da93a9a3aedf297242df80e78d0292b4fb3549 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 20 Mar 2025 17:33:52 +0100 Subject: [PATCH 01/69] :bug: Fix #152 --- torch_uncertainty/models/mlp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index cda24d09..43d31757 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -82,10 +82,13 @@ def __init__( ) self.layers = layers + self.fc_dropout = nn.Dropout(p=dropout_rate) + self.last_fc_dropout = nn.Dropout(p=dropout_rate) def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]: - for layer in self.layers: - x = F.dropout(layer(x), p=self.dropout_rate, training=self.training) + for i, layer in enumerate(self.layers): + dropout = self.fc_dropout if i < len(self.layers) - 1 else self.last_fc_dropout + x = dropout(layer(x), p=self.dropout_rate, training=self.training) x = self.activation(x) return self.final_layer(x) From 4cd6743243360070e655bd7192ddccc9990b9bfc Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 20 Mar 2025 17:43:00 +0100 Subject: [PATCH 02/69] :bug: Fix `nn.Dropout.forward()` argument in `_MLP.forward()` --- torch_uncertainty/models/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 43d31757..52c25b1a 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -88,7 +88,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]: for i, layer in enumerate(self.layers): dropout = self.fc_dropout if i < len(self.layers) - 1 else self.last_fc_dropout - x = dropout(layer(x), p=self.dropout_rate, training=self.training) + x = dropout(layer(x)) x = self.activation(x) return self.final_layer(x) From a3f53b3d3a26dd36df1fe4506f28fdef38718388 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 20 Mar 2025 17:46:26 +0100 Subject: [PATCH 03/69] :hammer: Rework OOD criteria --- docs/source/quickstart.rst | 5 +- tests/_dummies/baseline.py | 3 +- tests/routines/test_classification.py | 36 +++++---- .../classification/deep_ensembles.py | 3 +- .../baselines/classification/resnet.py | 10 +-- .../baselines/classification/vgg.py | 10 +-- .../baselines/classification/wideresnet.py | 10 +-- torch_uncertainty/ood_criteria.py | 74 +++++++++++++++++++ torch_uncertainty/routines/classification.py | 65 +++++++--------- 9 files changed, 138 insertions(+), 78 deletions(-) create mode 100644 torch_uncertainty/ood_criteria.py diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 26ebcb89..4e076379 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -46,12 +46,9 @@ and its parameters. # ... eval_ood: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", + ood_criterion: TUOODCriterion | None = None, log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, ) -> None: ... diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 59291f14..5d72d4da 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -3,6 +3,7 @@ from torch import nn from torch_uncertainty.models import EMA, SWA, deep_ensembles +from torch_uncertainty.ood_criteria import TUOODCriterion from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.routines import ( @@ -25,7 +26,7 @@ def __new__( baseline_type: str = "single", optim_recipe=optim_cifar10_resnet18, with_feats: bool = True, - ood_criterion: str = "msp", + ood_criterion: type[TUOODCriterion] | None = None, eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index cd724c68..e959caec 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -10,6 +10,13 @@ ) from torch_uncertainty import TUTrainer from torch_uncertainty.losses import DECLoss, ELBOLoss +from torch_uncertainty.ood_criteria import ( + EnergyCriterion, + EntropyCriterion, + LogitCriterion, + MutualInformationCriterion, + VariationRatioCriterion, +) from torch_uncertainty.routines import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget @@ -31,7 +38,6 @@ def test_one_estimator_binary(self): num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), baseline_type="single", - ood_criterion="msp", ema=True, ) @@ -54,7 +60,7 @@ def test_two_estimators_binary(self): num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), baseline_type="single", - ood_criterion="logit", + ood_criterion=LogitCriterion, swa=True, ) @@ -79,7 +85,7 @@ def test_one_estimator_two_classes(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, eval_shift=True, no_mixup_params=True, @@ -105,7 +111,7 @@ def test_one_estimator_two_classes_timm(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, mixtype="timm", mixup_alpha=1.0, @@ -132,7 +138,7 @@ def test_one_estimator_two_classes_mixup(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, mixtype="mixup", mixup_alpha=1.0, @@ -158,7 +164,7 @@ def test_one_estimator_two_classes_mixup_io(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, mixtype="mixup_io", mixup_alpha=1.0, @@ -184,7 +190,7 @@ def test_one_estimator_two_classes_regmixup(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, mixtype="regmixup", mixup_alpha=1.0, @@ -210,7 +216,7 @@ def test_one_estimator_two_classes_kernel_warping_emb(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, mixtype="kernel_warping", mixup_alpha=0.5, @@ -236,7 +242,7 @@ def test_one_estimator_two_classes_kernel_warping_inp(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="entropy", + ood_criterion=EntropyCriterion, eval_ood=True, mixtype="kernel_warping", dist_sim="inp", @@ -263,7 +269,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion="energy", + ood_criterion=EnergyCriterion, eval_ood=True, eval_grouping_loss=True, calibrate=True, @@ -289,7 +295,7 @@ def test_two_estimators_two_classes_mi(self): in_channels=dm.num_channels, loss=DECLoss(1, 1e-2), baseline_type="ensemble", - ood_criterion="mi", + ood_criterion=MutualInformationCriterion, eval_ood=True, ) @@ -320,7 +326,7 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): in_channels=dm.num_channels, loss=ELBOLoss(None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4), baseline_type="ensemble", - ood_criterion="vr", + ood_criterion=VariationRatioCriterion, eval_ood=True, save_in_csv=True, ) @@ -341,14 +347,16 @@ def test_classification_failures(self): model=nn.Module(), loss=None, is_ensemble=False, - ood_criterion="mi", + ood_criterion=MutualInformationCriterion, ) + with pytest.raises(ValueError): ClassificationRoutine( num_classes=10, model=nn.Module(), loss=None, - ood_criterion="other", + is_ensemble=False, + ood_criterion=32, ) with pytest.raises(ValueError): diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index bba7683c..b795fd6a 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -2,6 +2,7 @@ from typing import Literal from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.utils import get_version @@ -24,7 +25,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", + ood_criterion: type[TUOODCriterion] | None = None, log_plots: bool = False, ) -> None: log_path = Path(log_path) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 376f5815..a86c8977 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -12,6 +12,7 @@ packed_resnet, resnet, ) +from torch_uncertainty.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget @@ -67,7 +68,7 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", + ood_criterion: type[TUOODCriterion] | None = None, log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, @@ -144,11 +145,8 @@ def __init__( ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. - MSP is the maximum softmax probability, logit is the maximum - logit, entropy is the entropy of the mean prediction, mi is the - mutual information of the ensemble and vr is the variation ratio - of the ensemble. + ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. + Defaults to None which amounts to the maximum softmax probability score (MSP). log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 520c6425..2b0d2e07 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -8,6 +8,7 @@ packed_vgg, vgg, ) +from torch_uncertainty.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget @@ -38,7 +39,7 @@ def __init__( groups: int = 1, alpha: int | None = None, gamma: int = 1, - ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", + ood_criterion: type[TUOODCriterion] | None = None, log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, @@ -90,11 +91,8 @@ def __init__( gamma (int, optional): Number of groups within each estimator. Only used if :attr:`version` is ``"packed"`` and scales with :attr:`groups`. Defaults to ``1s``. - ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. - MSP is the maximum softmax probability, logit is the maximum - logit, entropy is the entropy of the mean prediction, mi is the - mutual information of the ensemble and vr is the variation ratio - of the ensemble. + ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. + Defaults to None which amounts to the maximum softmax probability score (MSP). log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 477b83fd..d10b3adb 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -11,6 +11,7 @@ packed_wideresnet28x10, wideresnet28x10, ) +from torch_uncertainty.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ( ClassificationRoutine, ) @@ -47,7 +48,7 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", + ood_criterion: type[TUOODCriterion] | None = None, log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, @@ -102,11 +103,8 @@ def __init__( ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. - MSP is the maximum softmax probability, logit is the maximum - logit, entropy is the entropy of the mean prediction, mi is the - mutual information of the ensemble and vr is the variation ratio - of the ensemble. + ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. + Defaults to None which amounts to the maximum softmax probability score (MSP). log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in diff --git a/torch_uncertainty/ood_criteria.py b/torch_uncertainty/ood_criteria.py new file mode 100644 index 00000000..2d42f64a --- /dev/null +++ b/torch_uncertainty/ood_criteria.py @@ -0,0 +1,74 @@ +from abc import ABC, abstractmethod +from enum import Enum + +import torch +from torch import Tensor, nn + +from torch_uncertainty.metrics import MutualInformation, VariationRatio + + +class OODCriterionInputType(Enum): + LOGIT = 1 + PROB = 2 + ESTIMATOR_PROB = 3 + + +class TUOODCriterion(ABC, nn.Module): + input_type: OODCriterionInputType + ensemble_only = False + + @abstractmethod + def forward(self, inputs: Tensor) -> Tensor: + pass + + +class LogitCriterion(TUOODCriterion): + input_type = OODCriterionInputType.LOGIT + + def forward(self, inputs: Tensor) -> Tensor: + return -inputs.mean(dim=1).max(dim=-1).values + + +class EnergyCriterion(TUOODCriterion): + input_type = OODCriterionInputType.LOGIT + + def forward(self, inputs: Tensor) -> Tensor: + return -inputs.mean(dim=1).logsumexp(dim=-1) + + +class MaxSoftmaxProbabilityCriterion(TUOODCriterion): + input_type = OODCriterionInputType.PROB + + def forward(self, inputs: Tensor) -> Tensor: + return -inputs.max(-1)[0] + + +class EntropyCriterion(TUOODCriterion): + input_type = OODCriterionInputType.ESTIMATOR_PROB + + def forward(self, inputs: Tensor) -> Tensor: + return torch.special.entr(inputs).sum(dim=-1).mean(dim=1) + + +class MutualInformationCriterion(TUOODCriterion): + ensemble_only = True + input_type = OODCriterionInputType.ESTIMATOR_PROB + + def __init__(self) -> None: + super().__init__() + self.mi_metric = MutualInformation(reduction="none") + + def forward(self, inputs: Tensor) -> Tensor: + return self.mi_metric(inputs) + + +class VariationRatioCriterion(TUOODCriterion): + ensemble_only = True + input_type = OODCriterionInputType.ESTIMATOR_PROB + + def __init__(self) -> None: + super().__init__() + self.vr_metric = VariationRatio(reduction="none", probabilistic=False) + + def forward(self, inputs: Tensor) -> Tensor: + return self.vr_metric(inputs.transpose(0, 1)) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 52b8dc46..c97ac664 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,6 +1,5 @@ from collections.abc import Callable from pathlib import Path -from typing import Literal import torch import torch.nn.functional as F @@ -32,12 +31,16 @@ GroupingLoss, MutualInformation, RiskAt80Cov, - VariationRatio, ) from torch_uncertainty.models import ( EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) +from torch_uncertainty.ood_criteria import ( + MaxSoftmaxProbabilityCriterion, + OODCriterionInputType, + TUOODCriterion, +) from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing from torch_uncertainty.transforms import ( Mixup, @@ -72,7 +75,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", + ood_criterion: type[TUOODCriterion] | None = None, post_processing: PostProcessing | None = None, num_bins_cal_err: int = 15, log_plots: bool = False, @@ -100,13 +103,8 @@ def __init__( shift performance. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. - ood_criterion (str, optional): OOD criterion. Available options are - - ``"msp"`` (default): Maximum softmax probability. - - ``"logit"``: Maximum logit. - - ``"energy"``: Logsumexp of the mean logits. - - ``"entropy"``: Entropy of the mean prediction. - - ``"mi"``: Mutual information of the ensemble. - - ``"vr"``: Variation ratio of the ensemble. + ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. + Defaults to None which amounts to the maximum softmax probability score (MSP). post_processing (PostProcessing, optional): Post-processing method to train on the calibration set. No post-processing if None. Defaults to ``None``. @@ -155,12 +153,14 @@ def __init__( if format_batch_fn is None: format_batch_fn = nn.Identity() + if ood_criterion is None: + ood_criterion = MaxSoftmaxProbabilityCriterion self.num_classes = num_classes self.eval_ood = eval_ood self.eval_shift = eval_shift self.eval_grouping_loss = eval_grouping_loss - self.ood_criterion = ood_criterion + self.ood_criterion = ood_criterion() self.log_plots = log_plots self.save_in_csv = save_in_csv self.binary_cls = num_classes == 1 @@ -484,22 +484,13 @@ def test_step( logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - confs = probs.max(-1)[0] - - if self.ood_criterion == "logit": - ood_scores = -logits.mean(dim=1).max(dim=-1).values - elif self.ood_criterion == "energy": - ood_scores = -logits.mean(dim=1).logsumexp(dim=-1) - elif self.ood_criterion == "entropy": - ood_scores = torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) - elif self.ood_criterion == "mi": - mi_metric = MutualInformation(reduction="none") - ood_scores = mi_metric(probs_per_est) - elif self.ood_criterion == "vr": - vr_metric = VariationRatio(reduction="none", probabilistic=False) - ood_scores = vr_metric(probs_per_est.transpose(0, 1)) + + if self.ood_criterion.input_type == OODCriterionInputType.LOGIT: + ood_scores = self.ood_criterion(logits) + elif self.ood_criterion.input_type == OODCriterionInputType.PROB: + ood_scores = self.ood_criterion(probs) else: - ood_scores = -confs + ood_scores = self.ood_criterion(probs_per_est) if dataloader_idx == 0: # squeeze if binary classification only for binary metrics @@ -688,40 +679,34 @@ def _classification_routine_checks( model: nn.Module, num_classes: int, is_ensemble: bool, - ood_criterion: str, + ood_criterion: type[TUOODCriterion] | None, eval_grouping_loss: bool, num_bins_cal_err: int, mixup_params: dict | None, post_processing: PostProcessing | None, format_batch_fn: nn.Module | None, ) -> None: - """Check the domains of the routine's parameters. + """Check the domains of the arguments of the classification routine. Args: model (nn.Module): the model used to make classification predictions. num_classes (int): the number of classes in the dataset. is_ensemble (bool): whether the model is an ensemble or a single model. - ood_criterion (str): the criterion for the binary OOD detection task. + ood_criterion (TUOODCriterion, optional): OOD criterion for the binary OOD detection task. eval_grouping_loss (bool): whether to evaluate the grouping loss. num_bins_cal_err (int): the number of bins for the evaluation of the calibration. mixup_params (dict | None): the dictionary to setup the mixup augmentation. post_processing (PostProcessing | None): the post-processing module. format_batch_fn (nn.Module | None): the function for formatting the batch for ensembles. """ - if ood_criterion not in [ - "msp", - "logit", - "energy", - "entropy", - "mi", - "vr", - ]: + if ood_criterion is not None and ( + not isinstance(ood_criterion, type) or not issubclass(ood_criterion, TUOODCriterion) + ): raise ValueError( - "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," - f" 'mi' or 'vr'. Got {ood_criterion}." + f"Use `ood_criteria.TUOODCriterion` classes as OOD criteria. Got {type(ood_criterion)}." ) - if not is_ensemble and ood_criterion in ["mi", "vr"]: + if not is_ensemble and ood_criterion is not None and ood_criterion.ensemble_only: raise ValueError( "You cannot use mutual information or variation ratio with a single model." ) From 86ffe4b19ee9537263f1045a0016a085aff5804d Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 20 Mar 2025 18:58:46 +0100 Subject: [PATCH 04/69] :hammer: Allow for string OOD criteria Co-authored-by: Adrien Lafage --- tests/_dummies/baseline.py | 2 +- tests/routines/test_classification.py | 26 ++++++++---------- torch_uncertainty/ood_criteria.py | 29 ++++++++++++++++++-- torch_uncertainty/routines/classification.py | 20 ++++---------- 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 5d72d4da..36585d63 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -26,7 +26,7 @@ def __new__( baseline_type: str = "single", optim_recipe=optim_cifar10_resnet18, with_feats: bool = True, - ood_criterion: type[TUOODCriterion] | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index e959caec..205c68c1 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -11,11 +11,7 @@ from torch_uncertainty import TUTrainer from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.ood_criteria import ( - EnergyCriterion, EntropyCriterion, - LogitCriterion, - MutualInformationCriterion, - VariationRatioCriterion, ) from torch_uncertainty.routines import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget @@ -60,7 +56,7 @@ def test_two_estimators_binary(self): num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), baseline_type="single", - ood_criterion=LogitCriterion, + ood_criterion="logit", swa=True, ) @@ -111,7 +107,7 @@ def test_one_estimator_two_classes_timm(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EntropyCriterion, + ood_criterion="entropy", eval_ood=True, mixtype="timm", mixup_alpha=1.0, @@ -138,7 +134,7 @@ def test_one_estimator_two_classes_mixup(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EntropyCriterion, + ood_criterion="entropy", eval_ood=True, mixtype="mixup", mixup_alpha=1.0, @@ -164,7 +160,7 @@ def test_one_estimator_two_classes_mixup_io(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EntropyCriterion, + ood_criterion="entropy", eval_ood=True, mixtype="mixup_io", mixup_alpha=1.0, @@ -190,7 +186,7 @@ def test_one_estimator_two_classes_regmixup(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EntropyCriterion, + ood_criterion="entropy", eval_ood=True, mixtype="regmixup", mixup_alpha=1.0, @@ -216,7 +212,7 @@ def test_one_estimator_two_classes_kernel_warping_emb(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EntropyCriterion, + ood_criterion="entropy", eval_ood=True, mixtype="kernel_warping", mixup_alpha=0.5, @@ -242,7 +238,7 @@ def test_one_estimator_two_classes_kernel_warping_inp(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EntropyCriterion, + ood_criterion="entropy", eval_ood=True, mixtype="kernel_warping", dist_sim="inp", @@ -269,7 +265,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), baseline_type="single", - ood_criterion=EnergyCriterion, + ood_criterion="energy", eval_ood=True, eval_grouping_loss=True, calibrate=True, @@ -295,7 +291,7 @@ def test_two_estimators_two_classes_mi(self): in_channels=dm.num_channels, loss=DECLoss(1, 1e-2), baseline_type="ensemble", - ood_criterion=MutualInformationCriterion, + ood_criterion="mutual_information", eval_ood=True, ) @@ -326,7 +322,7 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): in_channels=dm.num_channels, loss=ELBOLoss(None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4), baseline_type="ensemble", - ood_criterion=VariationRatioCriterion, + ood_criterion="variation_ratio", eval_ood=True, save_in_csv=True, ) @@ -347,7 +343,7 @@ def test_classification_failures(self): model=nn.Module(), loss=None, is_ensemble=False, - ood_criterion=MutualInformationCriterion, + ood_criterion="mutual_information", ) with pytest.raises(ValueError): diff --git a/torch_uncertainty/ood_criteria.py b/torch_uncertainty/ood_criteria.py index 2d42f64a..cbaac9d2 100644 --- a/torch_uncertainty/ood_criteria.py +++ b/torch_uncertainty/ood_criteria.py @@ -18,11 +18,11 @@ class TUOODCriterion(ABC, nn.Module): ensemble_only = False @abstractmethod - def forward(self, inputs: Tensor) -> Tensor: + def forward(self, inputs: Tensor) -> Tensor: # coverage: ignore pass -class LogitCriterion(TUOODCriterion): +class MaxLogitCriterion(TUOODCriterion): input_type = OODCriterionInputType.LOGIT def forward(self, inputs: Tensor) -> Tensor: @@ -72,3 +72,28 @@ def __init__(self) -> None: def forward(self, inputs: Tensor) -> Tensor: return self.vr_metric(inputs.transpose(0, 1)) + + +def get_ood_criterion(ood_criterion): + if isinstance(ood_criterion, str): + if ood_criterion == "logit": + return MaxLogitCriterion() + if ood_criterion == "energy": + return EnergyCriterion() + if ood_criterion == "msp": + return MaxSoftmaxProbabilityCriterion() + if ood_criterion == "entropy": + return EntropyCriterion() + if ood_criterion == "mutual_information": + return MutualInformationCriterion() + if ood_criterion == "variation_ratio": + return VariationRatioCriterion() + raise ValueError( + "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," + f" 'mutual_information' or 'variation_ratio'. Got {ood_criterion}." + ) + if isinstance(ood_criterion, type) and issubclass(ood_criterion, TUOODCriterion): + return ood_criterion() + raise ValueError( + f"The OOD criterion should be a string or a subclass of TUOODCriterion. Got {type(ood_criterion)}." + ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index c97ac664..4f790203 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -37,9 +37,9 @@ STEP_UPDATE_MODEL, ) from torch_uncertainty.ood_criteria import ( - MaxSoftmaxProbabilityCriterion, OODCriterionInputType, TUOODCriterion, + get_ood_criterion, ) from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing from torch_uncertainty.transforms import ( @@ -75,7 +75,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - ood_criterion: type[TUOODCriterion] | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", post_processing: PostProcessing | None = None, num_bins_cal_err: int = 15, log_plots: bool = False, @@ -153,14 +153,12 @@ def __init__( if format_batch_fn is None: format_batch_fn = nn.Identity() - if ood_criterion is None: - ood_criterion = MaxSoftmaxProbabilityCriterion self.num_classes = num_classes self.eval_ood = eval_ood self.eval_shift = eval_shift self.eval_grouping_loss = eval_grouping_loss - self.ood_criterion = ood_criterion() + self.ood_criterion = get_ood_criterion(ood_criterion) self.log_plots = log_plots self.save_in_csv = save_in_csv self.binary_cls = num_classes == 1 @@ -679,7 +677,7 @@ def _classification_routine_checks( model: nn.Module, num_classes: int, is_ensemble: bool, - ood_criterion: type[TUOODCriterion] | None, + ood_criterion: type[TUOODCriterion] | str, eval_grouping_loss: bool, num_bins_cal_err: int, mixup_params: dict | None, @@ -699,14 +697,8 @@ def _classification_routine_checks( post_processing (PostProcessing | None): the post-processing module. format_batch_fn (nn.Module | None): the function for formatting the batch for ensembles. """ - if ood_criterion is not None and ( - not isinstance(ood_criterion, type) or not issubclass(ood_criterion, TUOODCriterion) - ): - raise ValueError( - f"Use `ood_criteria.TUOODCriterion` classes as OOD criteria. Got {type(ood_criterion)}." - ) - - if not is_ensemble and ood_criterion is not None and ood_criterion.ensemble_only: + ood_criterion_cls = get_ood_criterion(ood_criterion) + if not is_ensemble and ood_criterion_cls.ensemble_only: raise ValueError( "You cannot use mutual information or variation ratio with a single model." ) From 53509d7c257a00b024faa6ea900f4e9bf58665b6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 20 Mar 2025 19:12:13 +0100 Subject: [PATCH 05/69] :bug: Add unpushed modifications --- torch_uncertainty/baselines/classification/deep_ensembles.py | 2 +- torch_uncertainty/baselines/classification/resnet.py | 2 +- torch_uncertainty/baselines/classification/vgg.py | 2 +- torch_uncertainty/baselines/classification/wideresnet.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index b795fd6a..c66bc9e3 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -25,7 +25,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - ood_criterion: type[TUOODCriterion] | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", log_plots: bool = False, ) -> None: log_path = Path(log_path) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index a86c8977..1e2b18ab 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -68,7 +68,7 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: type[TUOODCriterion] | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 2b0d2e07..0d0b887e 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -39,7 +39,7 @@ def __init__( groups: int = 1, alpha: int | None = None, gamma: int = 1, - ood_criterion: type[TUOODCriterion] | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index d10b3adb..1be8939c 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -48,7 +48,7 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: type[TUOODCriterion] | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, From 404917de687d519ec2e72428bd4c7f22b02e69a5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 20 Mar 2025 20:09:34 +0100 Subject: [PATCH 06/69] :sparkles: Add first impl. of TTA --- tests/_dummies/datamodule.py | 12 +++--- torch_uncertainty/datamodules/abstract.py | 40 ++++++++++++++---- .../datamodules/classification/cifar10.py | 25 +++++++---- .../datamodules/classification/cifar100.py | 28 ++++++++----- .../datamodules/classification/imagenet.py | 33 +++++++++------ .../datamodules/classification/mnist.py | 35 +++++++++------- .../classification/tiny_imagenet.py | 42 +++++++------------ torch_uncertainty/datamodules/depth/base.py | 2 +- torch_uncertainty/datamodules/depth/muad.py | 2 +- .../datamodules/segmentation/cityscapes.py | 2 +- .../datamodules/segmentation/muad.py | 2 +- torch_uncertainty/routines/classification.py | 7 +++- torch_uncertainty/utils/__init__.py | 3 +- torch_uncertainty/utils/data.py | 31 ++++++++++++++ torch_uncertainty/utils/misc.py | 14 ------- 15 files changed, 171 insertions(+), 107 deletions(-) create mode 100644 torch_uncertainty/utils/data.py diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 8da6b391..a47080a2 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -104,11 +104,11 @@ def setup(self, stage: str | None = None) -> None: self.shift.shift_severity = 1 def test_dataloader(self) -> DataLoader | list[DataLoader]: - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, shuffle=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.get_ood_set(), shuffle=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.get_shift_set(), shuffle=False)) return dataloader def _get_train_data(self) -> ArrayLike: @@ -171,7 +171,7 @@ def setup(self, stage: str | None = None) -> None: ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - return [self._data_loader(self.test)] + return [self._data_loader(self.test, shuffle=False)] class DummySegmentationDataModule(TUDataModule): @@ -256,7 +256,7 @@ def setup(self, stage: str | None = None) -> None: ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - return [self._data_loader(self.test)] + return [self._data_loader(self.test, shuffle=False)] def _get_train_data(self) -> ArrayLike: return self.train.data @@ -345,7 +345,7 @@ def setup(self, stage: str | None = None) -> None: ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - return [self._data_loader(self.test)] + return [self._data_loader(self.test, shuffle=False)] def _get_train_data(self) -> ArrayLike: return self.train.data diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index ed56b8e5..d9db703f 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -6,6 +6,8 @@ from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike +from torch_uncertainty.utils import TTADataset + if util.find_spec("sklearn"): from sklearn.model_selection import StratifiedKFold @@ -24,6 +26,8 @@ class TUDataModule(ABC, LightningDataModule): train: Dataset val: Dataset test: Dataset + ood: Dataset + shift: Dataset shift_severity = 1 @@ -35,6 +39,7 @@ def __init__( num_workers: int, pin_memory: bool, persistent_workers: bool, + num_tta: int = 1, postprocess_set: Literal["val", "test"] = "val", ) -> None: """Abstract DataModule class for TorchUncertainty. @@ -50,8 +55,9 @@ def __init__( num_workers (int): Number of workers to use for data loading. pin_memory (bool): Whether to pin memory. persistent_workers (bool): Whether to use persistent workers. + num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). postprocess_set (str): Which split to use as post-processing set to fit the - post-processing method. + post-processing method. Defaults to ``val``. """ super().__init__() @@ -63,8 +69,13 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers + if not num_tta % batch_size: + raise ValueError( + f"The number of Test-time augmentations num_tta should divide batch_size. Got {num_tta} and {batch_size}." + ) + self.num_tta = num_tta if postprocess_set == "test": - logging.warning("Fitting the calibration method on the test set!") + logging.warning("You might be fitting the calibration method on the test set!") self.postprocess_set = postprocess_set @abstractmethod @@ -77,12 +88,28 @@ def get_train_set(self) -> Dataset: def get_val_set(self) -> Dataset: """Get the validation set.""" + if self.num_tta > 1: + return TTADataset(self.val, self.num_tta) return self.val def get_test_set(self) -> Dataset: """Get the test set.""" + if self.num_tta > 1: + return TTADataset(self.test, self.num_tta) return self.test + def get_ood_set(self) -> Dataset: + """Get the shifted set.""" + if self.num_tta > 1: + return TTADataset(self.ood, self.num_tta) + return self.ood + + def get_shift_set(self) -> Dataset: + """Get the shifted set.""" + if self.num_tta > 1: + return TTADataset(self.shift, self.num_tta) + return self.shift + def train_dataloader(self) -> DataLoader: r"""Get the training dataloader. @@ -97,7 +124,7 @@ def val_dataloader(self) -> DataLoader: Return: DataLoader: validation dataloader. """ - return self._data_loader(self.val) + return self._data_loader(self.get_val_set(), shuffle=False) def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. @@ -106,7 +133,7 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: test set for in distribution data and out-of-distribution data. """ - return [self._data_loader(self.test)] + return [self._data_loader(self.get_test_set(), shuffle=False)] def postprocess_dataloader(self) -> DataLoader: r"""Get the calibration dataloader. @@ -116,13 +143,12 @@ def postprocess_dataloader(self) -> DataLoader: """ return self.val_dataloader() if self.postprocess_set == "val" else self.test_dataloader()[0] - def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + def _data_loader(self, dataset: Dataset, shuffle: bool) -> DataLoader: """Create a dataloader for a given dataset. Args: dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. + shuffle (bool): Whether to shuffle the dataset Return: DataLoader: Dataloader for the given dataset. diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index d52211e2..68d0b283 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -31,6 +31,7 @@ def __init__( batch_size: int, eval_ood: bool = False, eval_shift: bool = False, + num_tta: int = 1, shift_severity: int = 1, val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", @@ -65,6 +66,7 @@ def __init__( ``False``. auto_augment (str): Which auto-augment to apply. Defaults to ``None``. test_alt (str): Which test set to use. Defaults to ``None``. + num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). shift_severity (int): Severity of corruption to apply for CIFAR10-C. Defaults to ``1``. num_dataloaders (int): Number of dataloaders to use. Defaults to ``1``. @@ -76,6 +78,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + num_tta=num_tta, postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, @@ -131,12 +134,16 @@ def __init__( ] ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] + self.test_transform = ( + v2.Compose( + [ + v2.ToImage(), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + if num_tta == 1 + else self.train_transform ) def prepare_data(self) -> None: # coverage: ignore @@ -231,11 +238,11 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.get_test_set(), shuffle=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.get_ood_set(), shuffle=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.get_shift_set(), shuffle=False)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 6334b10c..65496a5b 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -31,6 +31,7 @@ def __init__( batch_size: int, eval_ood: bool = False, eval_shift: bool = False, + num_tta: int = 1, shift_severity: int = 1, val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", @@ -62,8 +63,8 @@ def __init__( randaugment (bool): Whether to apply RandAugment. Defaults to ``False``. auto_augment (str): Which auto-augment to apply. Defaults to ``None``. - shift_severity (int): Severity of corruption to apply to - CIFAR100-C. Defaults to ``1``. + num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). + shift_severity (int): Severity of corruption to apply to CIFAR100-C. Defaults to ``1``. num_dataloaders (int): Number of dataloaders to use. Defaults to ``1``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. @@ -75,6 +76,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + num_tta=num_tta, postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, @@ -125,12 +127,16 @@ def __init__( v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] + self.test_transform = ( + v2.Compose( + [ + v2.ToImage(), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + if num_tta == 1 + else self.train_transform ) def prepare_data(self) -> None: # coverage: ignore @@ -218,11 +224,11 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-100C data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.get_test_set(), shuffle=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.get_ood_set(), shuffle=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.get_shift_set(), shuffle=False)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 23012f66..ad629d7d 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -48,6 +48,7 @@ def __init__( batch_size: int, eval_ood: bool = False, eval_shift: bool = False, + num_tta: int = 1, shift_severity: int = 1, val_split: float | Path | None = None, postprocess_set: Literal["val", "test"] = "val", @@ -66,12 +67,13 @@ def __init__( Args: root (str): Root directory of the datasets. + batch_size (int): Number of samples per batch. eval_ood (bool): Whether to evaluate out-of-distribution performance. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. - shift_severity: int = 1, - batch_size (int): Number of samples per batch. + num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). + shift_severity (int): Severity of corruption to apply to ImageNet-C. Defaults to ``1``. val_split (float or Path): Share of samples to use for validation or path to a yaml file containing a list of validation images ids. Defaults to ``0.0``. @@ -97,6 +99,7 @@ def __init__( root=Path(root), batch_size=batch_size, val_split=val_split, + num_tta=num_tta, postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, @@ -180,14 +183,18 @@ def __init__( ] ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.Resize(256, interpolation=self.interpolation), - v2.CenterCrop(224), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] + self.test_transform = ( + v2.Compose( + [ + v2.ToImage(), + v2.Resize(256, interpolation=self.interpolation), + v2.CenterCrop(224), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + if num_tta == 1 + else self.train_transform ) def _verify_splits(self, split: str) -> None: @@ -299,11 +306,11 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split (out-of-distribution data), and/or ImageNetC data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.get_test_set(), shuffle=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.get_ood_set(), shuffle=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.get_shift_set(), shuffle=False)) return dataloader diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a1d72de4..8424165c 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -29,6 +29,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, ood_ds: Literal["fashion", "notMNIST"] = "fashion", + num_tta: int = 1, val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, @@ -51,6 +52,7 @@ def __init__( notMNIST. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). postprocess_set (str, optional): The post-hoc calibration dataset to use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults @@ -66,6 +68,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + num_tta=num_tta, postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, @@ -100,24 +103,28 @@ def __init__( v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.CenterCrop(28), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - if self.eval_ood: # NotMNIST has 3 channels - self.ood_transform = v2.Compose( + self.test_transform = ( + v2.Compose( [ v2.ToImage(), - v2.Grayscale(num_output_channels=1), v2.CenterCrop(28), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) + if num_tta == 1 + else self.train_transform + ) + + self.ood_transform = v2.Compose( + [ + v2.ToImage(), + v2.Grayscale(num_output_channels=1), + v2.CenterCrop(28), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore """Download the datasets.""" @@ -182,9 +189,9 @@ def test_dataloader(self) -> list[DataLoader]: distribution data), FashionMNIST or NotMNIST test split (out-of-distribution data), and/or MNISTC (shifted data). """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.get_test_set(), shuffle=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.get_ood_set(), shuffle=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.get_shift_set(), shuffle=False)) return dataloader diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index f84b972c..c535d1b7 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -37,6 +37,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + num_tta: int = 1, postprocess_set: Literal["val", "test"] = "val", ood_ds: str = "svhn", interpolation: str = "bilinear", @@ -50,6 +51,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + num_tta=num_tta, postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, @@ -98,13 +100,17 @@ def __init__( ] ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.Resize(64, interpolation=self.interpolation), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] + self.test_transform = ( + v2.Compose( + [ + v2.ToImage(), + v2.Resize(64, interpolation=self.interpolation), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + if num_tta == 1 + else self.train_transform ) def _verify_splits(self, split: str) -> None: # coverage: ignore @@ -222,22 +228,6 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.test_transform, ) - def train_dataloader(self) -> DataLoader: - r"""Get the training dataloader for TinyImageNet. - - Return: - DataLoader: TinyImageNet training dataloader. - """ - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - r"""Get the validation dataloader for TinyImageNet. - - Return: - DataLoader: TinyImageNet validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. @@ -245,11 +235,11 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: test set for in distribution data, OOD data, and/or TinyImageNetC data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.get_test_set(), shuffle=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.get_ood_set(), shuffle=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.get_shift_set(), shuffle=False)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 59804c0a..99f14cfc 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -9,7 +9,7 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.transforms import RandomRescale -from torch_uncertainty.utils.misc import create_train_val_split +from torch_uncertainty.utils import create_train_val_split class DepthDataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index 032a4292..f0272fc1 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -3,7 +3,7 @@ from torch.nn.common_types import _size_2_t from torch_uncertainty.datasets import MUAD -from torch_uncertainty.utils.misc import create_train_val_split +from torch_uncertainty.utils import create_train_val_split from .base import DepthDataModule diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index a6005893..82aadcf5 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -10,7 +10,7 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.segmentation import Cityscapes from torch_uncertainty.transforms import RandomRescale -from torch_uncertainty.utils.misc import create_train_val_split +from torch_uncertainty.utils import create_train_val_split class CityscapesDataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 00f39251..579a7679 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -9,7 +9,7 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale -from torch_uncertainty.utils.misc import create_train_val_split +from torch_uncertainty.utils import create_train_val_split class MUADDataModule(TUDataModule): diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 52b8dc46..3a5d6522 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -66,6 +66,7 @@ def __init__( num_classes: int, loss: nn.Module, is_ensemble: bool = False, + num_tta: int = 1, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, mixup_params: dict | None = None, @@ -86,6 +87,7 @@ def __init__( loss (torch.nn.Module): Loss function to optimize the :attr:`model`. is_ensemble (bool, optional): Indicates whether the model is an ensemble at test time or not. Defaults to ``False``. + num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). format_batch_fn (torch.nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and @@ -161,6 +163,7 @@ def __init__( self.eval_shift = eval_shift self.eval_grouping_loss = eval_grouping_loss self.ood_criterion = ood_criterion + self.num_tta = num_tta self.log_plots = log_plots self.save_in_csv = save_in_csv self.binary_cls = num_classes == 1 @@ -449,7 +452,7 @@ def validation_step(self, batch: tuple[Tensor, Tensor]) -> None: """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0) // self.num_tta) if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) @@ -481,7 +484,7 @@ def test_step( """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0) // self.num_tta) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index 885f2dd0..5ba095bf 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,7 +1,8 @@ # ruff: noqa: F401 from .checkpoints import get_version from .cli import TULightningCLI +from .data import TTADataset, create_train_val_split from .hub import load_hf -from .misc import create_train_val_split, csv_writer, plot_hist +from .misc import csv_writer, plot_hist from .trainer import TUTrainer from .transforms import interpolation_modes_from_str diff --git a/torch_uncertainty/utils/data.py b/torch_uncertainty/utils/data.py new file mode 100644 index 00000000..6339d5d7 --- /dev/null +++ b/torch_uncertainty/utils/data.py @@ -0,0 +1,31 @@ +import copy +from collections.abc import Callable +from typing import Any + +from torch.utils.data import Dataset, random_split + + +def create_train_val_split( + dataset: Dataset, + val_split_rate: float, + val_transforms: Callable | None = None, +) -> tuple[Dataset, Dataset]: + train, val = random_split(dataset, [1 - val_split_rate, val_split_rate]) + val = copy.deepcopy(val) + val.dataset.transform = val_transforms + return train, val + + +class TTADataset(Dataset): + def __init__(self, dataset: Dataset, num_augmentations: int) -> None: + super().__init__() + self.dataset = dataset + self.num_augmentations = num_augmentations + + def __len__(self): + """Get the virtual length of the dataset.""" + return len(self.dataset) * self.num_augmentations + + def __getitem__(self, index) -> Any: + """Get the item corresponding to idx // :attr:`self.num_augmentations`.""" + return self.dataset[index // self.num_augmentations] diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index ab5d697d..328626a6 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -1,13 +1,10 @@ -import copy import csv -from collections.abc import Callable from pathlib import Path import matplotlib.pyplot as plt import torch from matplotlib.axes import Axes from matplotlib.figure import Figure -from torch.utils.data import Dataset, random_split def csv_writer(path: Path, dic: dict) -> None: @@ -69,14 +66,3 @@ def plot_hist( plt.legend() fig.tight_layout() return fig, ax - - -def create_train_val_split( - dataset: Dataset, - val_split_rate: float, - val_transforms: Callable | None = None, -) -> tuple[Dataset, Dataset]: - train, val = random_split(dataset, [1 - val_split_rate, val_split_rate]) - val = copy.deepcopy(val) - val.dataset.transform = val_transforms - return train, val From bebe01f468e32579b83a9425efed489f0d35d6f7 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 20 Mar 2025 20:24:22 +0100 Subject: [PATCH 07/69] :shirt: remove mention of calibration methods --- docs/source/quickstart.rst | 2 +- torch_uncertainty/datamodules/abstract.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 26ebcb89..679a7d8d 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -9,7 +9,7 @@ These routines make it very easy to: - train ensembles-like methods (Deep Ensembles, Packed-Ensembles, MIMO, Masksembles, etc) - compute and monitor uncertainty metrics: calibration, out-of-distribution detection, proper scores, grouping loss, etc. -- leverage calibration methods automatically during evaluation +- leverage post-processing methods automatically during evaluation Yet, we take account that their will be as many different uses of TorchUncertainty as there are of users. This page provides ideas on how to benefit from TorchUncertainty at all levels: from ready-to-train lightning-based models to using only specific diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index ed56b8e5..9880e787 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -64,7 +64,7 @@ def __init__( self.persistent_workers = persistent_workers if postprocess_set == "test": - logging.warning("Fitting the calibration method on the test set!") + logging.warning("You might be fitting the post-processing method on the test set!") self.postprocess_set = postprocess_set @abstractmethod From 88b4d00a7fd1b12c520c21bdc6dc361f661f0e0e Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 21 Mar 2025 11:12:44 +0100 Subject: [PATCH 08/69] :bug: Fix the TIN dataset & add doc --- .../datamodules/classification/imagenet.py | 5 +- .../classification/tiny_imagenet.py | 131 ++++++++---------- 2 files changed, 61 insertions(+), 75 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 23012f66..084c285c 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -62,7 +62,10 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, ) -> None: - """DataModule for ImageNet. + """DataModule for the ImageNet dataset. + + This datamodule uses ImageNet as In-distribution dataset, OpenImage-O, INaturalist, + ImageNet-0, SVHN or DTD as Out-of-distribution dataset and ImageNet-C as shifted dataset. Args: root (str): Root directory of the datasets. diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index f84b972c..195377d3 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -6,13 +6,14 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import DataLoader from torchvision.datasets import DTD, SVHN from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import ( ImageNetO, + OpenImageO, TinyImageNet, TinyImageNetC, ) @@ -46,6 +47,34 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, ) -> None: + """DataModule for the Tiny-ImageNet dataset. + + This datamodule uses Tiny-ImageNet as In-distribution dataset, OpenImage-O, ImageNet-0, + SVHN or DTD as Out-of-distribution dataset and Tiny-ImageNet-C as shifted dataset. + + Args: + root (str): Root directory of the datasets. + batch_size (int): Number of samples per batch. + eval_ood (bool): Whether to evaluate out-of-distribution performance. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. + shift_severity (int): Severity of the shift. Defaults to ``1``. + val_split (float or Path): Share of samples to use for validation + or path to a yaml file containing a list of validation images + ids. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. + ood_ds (str): Which out-of-distribution dataset to use. Defaults to + ``"openimage-o"``. + test_alt (str): Which test set to use. Defaults to ``None``. + procedure (str): Which procedure to use. Defaults to ``None``. + train_size (int): Size of training images. Defaults to ``224``. + interpolation (str): Interpolation method for the Resize Crops. Defaults to ``"bilinear"``. + basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. + rand_augment_opt (str): Which RandAugment to use. Defaults to ``None``. + num_workers (int): Number of workers to use for data loading. Defaults to ``1``. + pin_memory (bool): Whether to pin memory. Defaults to ``True``. + persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. + """ super().__init__( root=root, batch_size=batch_size, @@ -70,6 +99,8 @@ def __init__( self.ood_dataset = SVHN elif ood_ds == "textures": self.ood_dataset = DTD + elif ood_ds == "openimage-o": + self.ood_dataset = OpenImageO else: raise ValueError(f"OOD dataset {ood_ds} not supported for TinyImageNet.") self.shift_dataset = TinyImageNetC @@ -116,43 +147,19 @@ def _verify_splits(self, split: str) -> None: # coverage: ignore def prepare_data(self) -> None: # coverage: ignore if self.eval_ood: - if self.ood_ds != "textures": - self.ood_dataset( - self.root, - split="test", - download=True, - transform=self.test_transform, - ) - else: - ConcatDataset( - [ - self.ood_dataset( - self.root, - split="train", - download=True, - transform=self.test_transform, - ), - self.ood_dataset( - self.root, - split="val", - download=True, - transform=self.test_transform, - ), - self.ood_dataset( - self.root, - split="test", - download=True, - transform=self.test_transform, - ), - ] - ) - if self.eval_shift: - self.shift_dataset( - self.root, - download=True, - transform=self.test_transform, - shift_severity=self.shift_severity, - ) + self.ood_dataset( + self.root, + split="test", + download=True, + transform=self.test_transform, + ) + if self.eval_shift: + self.shift_dataset( + self.root, + download=True, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: @@ -184,43 +191,19 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: - if self.ood_ds == "textures": - self.ood = ConcatDataset( - [ - self.ood_dataset( - self.root, - split="train", - download=True, - transform=self.test_transform, - ), - self.ood_dataset( - self.root, - split="val", - download=True, - transform=self.test_transform, - ), - self.ood_dataset( - self.root, - split="test", - download=True, - transform=self.test_transform, - ), - ] - ) - else: - self.ood = self.ood_dataset( - self.root, - split="test", - transform=self.test_transform, - ) + self.ood = self.ood_dataset( + self.root, + split="test", + transform=self.test_transform, + ) - if self.eval_shift: - self.shift = self.shift_dataset( - self.root, - download=False, - shift_severity=self.shift_severity, - transform=self.test_transform, - ) + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + download=False, + shift_severity=self.shift_severity, + transform=self.test_transform, + ) def train_dataloader(self) -> DataLoader: r"""Get the training dataloader for TinyImageNet. From 9b7fb7df8eaae99893cfdeaf2cfb17c0b0bbe7c5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 21 Mar 2025 13:59:29 +0100 Subject: [PATCH 09/69] :wrench: Don't show an error when deleting inexistent ssh-agent --- .github/workflows/build-docs.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index 131164ed..70868b13 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -57,8 +57,8 @@ jobs: external_repository: torch-uncertainty/torch-uncertainty.github.io publish_branch: main publish_dir: docs/build/html - + + # ||: not to error if there is no running ssh-agent - name: Kill SSH Agent run: | - killall ssh-agent - continue-on-error: true + killall ssh-agent ||: From 9117bb258fc22efdbc4e010a82005b6786574c92 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 21 Mar 2025 14:05:47 +0100 Subject: [PATCH 10/69] :wrench: We probably don't need to run the docs every week anymore for cache purposes --- .github/workflows/build-docs.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index 70868b13..412aba76 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -7,8 +7,6 @@ on: types: [opened, reopened, ready_for_review, synchronize] branches: - main - schedule: - - cron: "00 12 * * 0" # Every Sunday noon (preserve the cache folders) workflow_dispatch: env: From cd94ae24d8be894813db06b12d1f23d40ca8369b Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 23 Mar 2025 15:05:25 +0100 Subject: [PATCH 11/69] :wrench: Improve MNIST configs --- experiments/classification/mnist/configs/bayesian_lenet.yaml | 1 + experiments/classification/mnist/configs/lenet.yaml | 1 + .../classification/mnist/configs/lenet_batch_ensemble.yaml | 2 +- .../classification/mnist/configs/lenet_checkpoint_ensemble.yaml | 1 + .../classification/mnist/configs/lenet_deep_ensemble.yaml | 2 +- experiments/classification/mnist/configs/lenet_ema.yaml | 1 + experiments/classification/mnist/configs/lenet_swa.yaml | 1 + experiments/classification/mnist/configs/lenet_swag.yaml | 1 + 8 files changed, 8 insertions(+), 2 deletions(-) diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 70f5cf8e..70474951 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -53,6 +53,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 8 optimizer: lr: 0.05 momentum: 0.9 diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 3f8b63c2..afd095a0 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -44,6 +44,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 8 optimizer: lr: 0.05 momentum: 0.9 diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml index d385b100..1625044d 100644 --- a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -51,7 +51,7 @@ model: data: root: ./data batch_size: 128 - num_workers: 127 + num_workers: 8 eval_ood: true eval_shift: true optimizer: diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index 354b9bf7..28afacfe 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -60,6 +60,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 8 optimizer: lr: 0.05 momentum: 0.9 diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml index 1d47b782..59f8ea7c 100644 --- a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -61,7 +61,7 @@ model: data: root: ./data batch_size: 128 - num_workers: 127 + num_workers: 8 eval_ood: true eval_shift: true optimizer: diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index d453df0b..d45988dd 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -48,6 +48,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 8 optimizer: lr: 0.05 momentum: 0.9 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 09d7d506..3e58b813 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -50,6 +50,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 8 optimizer: lr: 0.05 momentum: 0.9 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index e33d954f..761a0554 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -50,6 +50,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 8 optimizer: lr: 0.05 momentum: 0.9 From 81b0015608b08719518bbd0d7fb59f88fb869ea6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 23 Mar 2025 15:05:50 +0100 Subject: [PATCH 12/69] :shirt: Fix arg tab --- .../layers/channel_layer_norm.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torch_uncertainty/layers/channel_layer_norm.py b/torch_uncertainty/layers/channel_layer_norm.py index dce247e2..0052dc18 100644 --- a/torch_uncertainty/layers/channel_layer_norm.py +++ b/torch_uncertainty/layers/channel_layer_norm.py @@ -18,23 +18,23 @@ def __init__( r"""Layer normalization over the channel dimension. Args: - normalized_shape (int or list or torch.Size): input shape from an expected input - of size - - .. math:: - [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] - \times \ldots \times \text{normalized\_shape}[-1]] - - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the channel dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability. Default: 1e-5 - elementwise_affine (bool): a boolean value that when set to ``True``, this module - has learnable per-element affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. - bias (bool): If set to ``False``, the layer will not learn an additive bias (only relevant if - :attr:`elementwise_affine` is ``True``). Default: ``True``. - device (torch.device or str or None): the desired device of the module. - dtype (torch.dtype or str or None): the desired floating point type of the module. + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the channel dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine (bool): a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias (bool): If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + device (torch.device or str or None): the desired device of the module. + dtype (torch.dtype or str or None): the desired floating point type of the module. Attributes: weight: the learnable weights of the module of shape From dd8e304a854679de9d6959f8418ea01fa34d3985 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 23 Mar 2025 15:07:36 +0100 Subject: [PATCH 13/69] :shirt: Normalize import of nn.functional --- torch_uncertainty/layers/bayesian/abnn.py | 2 +- torch_uncertainty/layers/bayesian/bayes_conv.py | 2 +- torch_uncertainty/layers/packed.py | 2 +- torch_uncertainty/losses/classification.py | 2 +- torch_uncertainty/models/depth/bts.py | 2 +- torch_uncertainty/models/segmentation/deeplab.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_uncertainty/layers/bayesian/abnn.py b/torch_uncertainty/layers/bayesian/abnn.py index dcf75122..8ec13be2 100644 --- a/torch_uncertainty/layers/bayesian/abnn.py +++ b/torch_uncertainty/layers/bayesian/abnn.py @@ -1,6 +1,6 @@ import torch +import torch.nn.functional as F from torch import Tensor, nn -from torch.nn import functional as F class BatchNormAdapter2d(nn.Module): diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index f1277ed4..2db8f1bf 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -1,7 +1,7 @@ import torch +import torch.nn.functional as F from torch import Tensor from torch.nn import Module, init -from torch.nn import functional as F from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t from torch.nn.modules.utils import ( _pair, diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 66150367..c5ad45f1 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -3,9 +3,9 @@ from typing import Any import torch +import torch.nn.functional as F from einops import rearrange from torch import Tensor, nn -from torch.nn import functional as F from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t from .functional.packed import packed_linear, packed_multi_head_attention_forward diff --git a/torch_uncertainty/losses/classification.py b/torch_uncertainty/losses/classification.py index 96769338..34901cb5 100644 --- a/torch_uncertainty/losses/classification.py +++ b/torch_uncertainty/losses/classification.py @@ -1,6 +1,6 @@ import torch +import torch.nn.functional as F from torch import Tensor, nn -from torch.nn import functional as F class DECLoss(nn.Module): diff --git a/torch_uncertainty/models/depth/bts.py b/torch_uncertainty/models/depth/bts.py index 10e105f5..408b2344 100644 --- a/torch_uncertainty/models/depth/bts.py +++ b/torch_uncertainty/models/depth/bts.py @@ -2,9 +2,9 @@ from typing import Literal import torch +import torch.nn.functional as F import torchvision.models as tv_models from torch import Tensor, nn -from torch.nn import functional as F from torchvision.models.densenet import DenseNet121_Weights, DenseNet161_Weights from torchvision.models.resnet import ( ResNet50_Weights, diff --git a/torch_uncertainty/models/segmentation/deeplab.py b/torch_uncertainty/models/segmentation/deeplab.py index 13bbfe8b..b38bb776 100644 --- a/torch_uncertainty/models/segmentation/deeplab.py +++ b/torch_uncertainty/models/segmentation/deeplab.py @@ -1,9 +1,9 @@ from typing import Literal import torch +import torch.nn.functional as F import torchvision.models as tv_models from torch import Tensor, nn -from torch.nn import functional as F from torch.nn.common_types import _size_2_t from torchvision.models.resnet import ResNet50_Weights, ResNet101_Weights From 629aca9f4ad737eda15028ea72010656529fc52d Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 23 Mar 2025 15:08:01 +0100 Subject: [PATCH 14/69] :shirt: Improve some docstrings --- .../datamodules/classification/mnist.py | 21 +++++++------------ torch_uncertainty/models/wrappers/ema.py | 5 ++++- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 8424165c..5bd49e13 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -42,27 +42,21 @@ def __init__( Args: root (str): Root directory of the datasets. - eval_ood (bool): Whether to evaluate on out-of-distribution data. - Defaults to ``False``. - eval_shift (bool): Whether to evaluate on shifted data. Defaults to + eval_ood (bool): Whether to evaluate on out-of-distribution data. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. batch_size (int): Number of samples per batch. ood_ds (str): Which out-of-distribution dataset to use. Defaults to - ``"fashion"``; `fashion` stands for FashionMNIST and `notMNIST` for - notMNIST. - val_split (float): Share of samples to use for validation. Defaults - to ``0.0``. + ``"fashion"``; `fashion` stands for FashionMNIST and `notMNIST` for notMNIST. + val_split (float): Share of samples to use for validation. Defaults to ``0.0``. num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). postprocess_set (str, optional): The post-hoc calibration dataset to use for the post-processing method. Defaults to ``val``. - num_workers (int): Number of workers to use for data loading. Defaults - to ``1``. - basic_augment (bool): Whether to apply base augmentations. Defaults to - ``True``. + num_workers (int): Number of workers to use for data loading. Defaults to ``1``. + basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. pin_memory (bool): Whether to pin memory. Defaults to ``True``. - persistent_workers (bool): Whether to use persistent workers. Defaults - to ``True``. + persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. """ super().__init__( root=root, @@ -87,6 +81,7 @@ def __init__( self.ood_dataset = NotMNIST else: raise ValueError(f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}.") + self.shift_dataset = MNISTC self.shift_severity = 1 diff --git a/torch_uncertainty/models/wrappers/ema.py b/torch_uncertainty/models/wrappers/ema.py index 1e4c276c..a796408f 100644 --- a/torch_uncertainty/models/wrappers/ema.py +++ b/torch_uncertainty/models/wrappers/ema.py @@ -11,6 +11,9 @@ def __init__( ) -> None: """Exponential Moving Average. + The :attr:`model` given as argument is used to compute the gradient during the training. + The EMA model is regularly updated with the inner-model and used at evaluation time. + Args: model (nn.Module): The model to train and ensemble. momentum (float): The momentum of the moving average. @@ -46,4 +49,4 @@ def forward(self, x: Tensor) -> Tensor: def _ema_checks(momentum: float) -> None: if momentum < 0.0 or momentum >= 1.0: - raise ValueError(f"`momentum` must be in the range [0, 1). Got {momentum}.") + raise ValueError(f"`momentum` must be in [0, 1). Got {momentum}.") From 13d01c79677d3cc3c700c30f0a4862c73f0eca69 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 23 Mar 2025 15:10:07 +0100 Subject: [PATCH 15/69] :sparkles: Add Zero and fix TTA in cls --- .../mnist/configs/lenet_tta_zero.yaml | 64 +++++++++++++++++ torch_uncertainty/models/wrappers/__init__.py | 1 + torch_uncertainty/models/wrappers/zero.py | 68 +++++++++++++++++++ torch_uncertainty/routines/classification.py | 12 +++- torch_uncertainty/utils/cli.py | 1 + 5 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 experiments/classification/mnist/configs/lenet_tta_zero.yaml create mode 100644 torch_uncertainty/models/wrappers/zero.py diff --git a/experiments/classification/mnist/configs/lenet_tta_zero.yaml b/experiments/classification/mnist/configs/lenet_tta_zero.yaml new file mode 100644 index 00000000..1e88526f --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_tta_zero.yaml @@ -0,0 +1,64 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 1 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_tta_zero + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.Zero + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + layer_args: {} + num_tta: 64 + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 + num_tta: 64 + num_workers: 8 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index fb4ff50c..b9959356 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -9,6 +9,7 @@ from .stochastic import StochasticModel from .swa import SWA from .swag import SWAG +from .zero import Zero STEP_UPDATE_MODEL = (EMA,) EPOCH_UPDATE_MODEL = (SWA, SWAG, CheckpointEnsemble) diff --git a/torch_uncertainty/models/wrappers/zero.py b/torch_uncertainty/models/wrappers/zero.py new file mode 100644 index 00000000..851e953c --- /dev/null +++ b/torch_uncertainty/models/wrappers/zero.py @@ -0,0 +1,68 @@ +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.special import entr + + +class Zero(nn.Module): + def __init__(self, model: nn.Module, num_tta: int, filter_views: float = 0.1) -> None: + """Zero for test-time adaptation. + + Zero performs "0-temperature averaging" (i.e. majority voting) at evaluation. It starts + by filtering the :attr:`filter_views` most confident predictions, and returns the majority vote + as a prediction. If used during training, the predictions will be those of the inner-model + passed as argument (:attr:`model`). + + Args: + model (nn.Module): The inner model to train. + num_tta (int): The number of views at evaluation time. + filter_views (float): Filter out 1-:attr:`filter_views` of the predictions of the augmented views. + """ + super().__init__() + _zero_checks(num_tta, filter_views) + self.core_model = model + self.filter = filter_views + self.kept_views = int(filter_views * num_tta) + self.num_tta = num_tta + + def eval_forward(self, x: Tensor) -> Tensor: + # predict and separate the views from the batch + all_predictions = rearrange(self.core_model(x), "(b v) c -> b v c", v=self.num_tta) + batch_size, _, num_classes = all_predictions.shape + entropies = entr(all_predictions).sum(1) + + # Get the index of the most confident predictions on the views + conf_idx = torch.argsort(entropies, dim=-1) + votes = all_predictions.argmax(-1) + + # Count the votes + predictions = torch.zeros((batch_size, num_classes), device=all_predictions.device) + for img_id, img_votes in enumerate(votes): + predictions[img_id, :] += torch.bincount( + img_votes[conf_idx[img_id, : self.kept_views]], minlength=all_predictions.shape[-1] + ) + maximum = predictions[img_id, :].max() + i = 0 + # If the maximum corresponds to two predictions, look at an additional + while i < self.num_tta and torch.sum(1 * (predictions[img_id, :] == maximum)) > 1: + predictions[img_id, img_votes[conf_idx[img_id, self.kept_views + i]]] += 1 + maximum = predictions[img_id, :].max() + i += 1 + + predictions /= self.num_tta + # We will apply the softmax in the routine, so let's apply its inverse here + return predictions.log() + + def forward(self, x: Tensor) -> Tensor: + if self.training: + return self.core_model.forward(x) + return self.eval_forward(x) + + +def _zero_checks(num_tta: int, filter_views: float) -> None: + if filter_views <= 0.0 or filter_views >= 1.0: + raise ValueError(f"`filter_views` must be in the range (0, 1). Got {filter_views}.") + if num_tta < 1 / filter_views: + raise ValueError( + f"`num_tta` should be greater than 1/filter_views to use Zero. Got {num_tta} < {1 / filter_views}." + ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 3a5d6522..6f5f88e9 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -451,8 +451,11 @@ def validation_step(self, batch: tuple[Tensor, Tensor]) -> None: batch (tuple[Tensor, Tensor]): the validation data and their corresponding targets """ inputs, targets = batch + # remove duplicates when doing TTA + targets = targets[:: self.num_tta] + logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0) // self.num_tta) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) @@ -483,8 +486,11 @@ def test_step( distribution-shifted. """ inputs, targets = batch + # remove duplicates when doing TTA + targets = targets[:: self.num_tta] + logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0) // self.num_tta) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] @@ -731,7 +737,7 @@ def _classification_routine_checks( if is_ensemble and eval_grouping_loss: raise NotImplementedError( - "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." + "Grouping loss for ensembles is not yet implemented. Raise an issue if needed." ) if num_classes < 1: diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 7ec97f86..a0758006 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -127,3 +127,4 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: super().add_arguments_to_parser(parser) parser.link_arguments("data.eval_ood", "model.eval_ood") parser.link_arguments("data.eval_shift", "model.eval_shift") + parser.link_arguments("data.num_tta", "model.num_tta") From f8ff1b9eafedd26e63355ed848f2c998b42ed028 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 23 Mar 2025 23:14:52 +0100 Subject: [PATCH 16/69] :sparkles: Add eval_batch_size --- tests/_dummies/datamodule.py | 20 +++++++--- tests/datamodules/test_abstract_datamodule.py | 8 ++-- torch_uncertainty/datamodules/abstract.py | 39 ++++++++++++------- .../datamodules/classification/cifar10.py | 23 ++++++----- .../datamodules/classification/cifar100.py | 17 ++++---- .../datamodules/classification/imagenet.py | 20 +++++----- .../datamodules/classification/mnist.py | 18 +++++---- .../classification/tiny_imagenet.py | 16 +++++--- .../classification/uci/bank_marketing.py | 3 ++ .../classification/uci/dota2_games.py | 6 ++- .../datamodules/classification/uci/htru2.py | 6 ++- .../classification/uci/online_shoppers.py | 3 ++ .../classification/uci/spam_base.py | 3 ++ .../classification/uci/uci_classification.py | 4 ++ torch_uncertainty/datamodules/depth/base.py | 6 ++- torch_uncertainty/datamodules/depth/kitti.py | 6 ++- torch_uncertainty/datamodules/depth/muad.py | 6 ++- torch_uncertainty/datamodules/depth/nyu.py | 6 ++- .../datamodules/segmentation/camvid.py | 6 ++- .../datamodules/segmentation/cityscapes.py | 6 ++- .../datamodules/segmentation/muad.py | 6 ++- .../datamodules/uci_regression.py | 14 ++++--- 22 files changed, 163 insertions(+), 79 deletions(-) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 8da6b391..c86de1bc 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -26,6 +26,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, num_classes: int = 2, num_workers: int = 1, eval_ood: bool = False, @@ -38,6 +39,7 @@ def __init__( root=root, val_split=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -104,11 +106,11 @@ def setup(self, stage: str | None = None) -> None: self.shift.shift_severity = 1 def test_dataloader(self) -> DataLoader | list[DataLoader]: - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.ood, training=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.shift, training=False)) return dataloader def _get_train_data(self) -> ArrayLike: @@ -126,6 +128,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, out_features: int = 2, num_workers: int = 1, pin_memory: bool = True, @@ -134,6 +137,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -171,7 +175,7 @@ def setup(self, stage: str | None = None) -> None: ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - return [self._data_loader(self.test)] + return [self._data_loader(self.test, training=False)] class DummySegmentationDataModule(TUDataModule): @@ -184,6 +188,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, num_classes: int = 2, num_workers: int = 1, image_size: int = 4, @@ -195,6 +200,7 @@ def __init__( root=root, val_split=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -256,7 +262,7 @@ def setup(self, stage: str | None = None) -> None: ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - return [self._data_loader(self.test)] + return [self._data_loader(self.test, training=False)] def _get_train_data(self) -> ArrayLike: return self.train.data @@ -273,6 +279,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, output_dim: int = 2, num_workers: int = 1, image_size: int = 4, @@ -284,6 +291,7 @@ def __init__( root=root, val_split=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -345,7 +353,7 @@ def setup(self, stage: str | None = None) -> None: ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - return [self._data_loader(self.test)] + return [self._data_loader(self.test, training=False)] def _get_train_data(self) -> ArrayLike: return self.train.data diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 1983a028..a757f485 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -14,7 +14,7 @@ class TestTUDataModule: def test_errors(self): TUDataModule.__abstractmethods__ = set() - dm = TUDataModule("root", 128, 0.0, 4, True, True) + dm = TUDataModule("root", 128, 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() dm._get_train_data() @@ -26,12 +26,12 @@ class TestCrossValDataModule: def test_cv_main(self): TUDataModule.__abstractmethods__ = set() - dm = TUDataModule("root", 128, 0.0, 4, True, True) + dm = TUDataModule("root", 128, 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 0.0, 4, True, True) + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 128, 0.0, 4, True, True) cv_dm.setup() cv_dm.setup("test") @@ -50,6 +50,7 @@ def test_errors(self): dm = TUDataModule( root="root", batch_size=128, + eval_batch_size=None, val_split=0.0, num_workers=4, pin_memory=True, @@ -65,6 +66,7 @@ def test_errors(self): val_idx=[1], datamodule=dm, batch_size=128, + eval_batch_size=128, val_split=0.0, num_workers=4, pin_memory=True, diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 9880e787..fa04df4d 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -31,6 +31,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None, val_split: float | None, num_workers: int, pin_memory: bool, @@ -45,18 +46,24 @@ def __init__( Args: root (str): Root directory of the datasets. - batch_size (int): Number of samples per batch. - val_split (float): Share of samples to use for validation. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. + val_split (float | None): Share of samples to use for validation. num_workers (int): Number of workers to use for data loading. pin_memory (bool): Whether to pin memory. persistent_workers (bool): Whether to use persistent workers. postprocess_set (str): Which split to use as post-processing set to fit the - post-processing method. + post-processing method. Defaults to ``val``. """ super().__init__() self.root = Path(root) self.batch_size = batch_size + if eval_batch_size is None: + self.eval_batch_size = batch_size + else: + self.eval_batch_size = eval_batch_size self.val_split = val_split self.num_workers = num_workers @@ -89,7 +96,7 @@ def train_dataloader(self) -> DataLoader: Return: DataLoader: training dataloader. """ - return self._data_loader(self.train, shuffle=True) + return self._data_loader(self.train, training=True, shuffle=True) def val_dataloader(self) -> DataLoader: r"""Get the validation dataloader. @@ -97,7 +104,7 @@ def val_dataloader(self) -> DataLoader: Return: DataLoader: validation dataloader. """ - return self._data_loader(self.val) + return self._data_loader(self.val, training=False) def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. @@ -106,7 +113,7 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: test set for in distribution data and out-of-distribution data. """ - return [self._data_loader(self.test)] + return [self._data_loader(self.test, training=False)] def postprocess_dataloader(self) -> DataLoader: r"""Get the calibration dataloader. @@ -116,11 +123,12 @@ def postprocess_dataloader(self) -> DataLoader: """ return self.val_dataloader() if self.postprocess_set == "val" else self.test_dataloader()[0] - def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + def _data_loader(self, dataset: Dataset, training: bool, shuffle: bool = False) -> DataLoader: """Create a dataloader for a given dataset. Args: dataset (Dataset): Dataset to create a dataloader for. + training (bool): Whether it is a training or evaluation dataloader. shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. @@ -129,7 +137,7 @@ def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: """ return DataLoader( dataset, - batch_size=self.batch_size, + batch_size=self.batch_size if training else self.eval_batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory, @@ -171,6 +179,7 @@ def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list val_idx=val_idx, datamodule=self, batch_size=self.batch_size, + eval_batch_size=self.eval_batch_size, val_split=self.val_split, postprocess_set=self.postprocess_set, num_workers=self.num_workers, @@ -190,6 +199,7 @@ def __init__( val_idx: ArrayLike, datamodule: TUDataModule, batch_size: int, + eval_batch_size: int | None, val_split: float, num_workers: int, pin_memory: bool, @@ -199,6 +209,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, @@ -219,12 +230,12 @@ def setup(self, stage: str | None = None) -> None: else: raise ValueError(f"Stage {stage} not supported.") - def _data_loader(self, dataset: Dataset, idx: ArrayLike) -> DataLoader: + def _data_loader(self, dataset: Dataset, idx: ArrayLike, training: bool) -> DataLoader: return DataLoader( dataset=dataset, sampler=SubsetRandomSampler(idx), shuffle=False, - batch_size=self.batch_size, + batch_size=self.batch_size if training else self.eval_batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -244,12 +255,12 @@ def get_test_set(self) -> Dataset: def train_dataloader(self) -> DataLoader: """Get the training dataloader for the current fold.""" - return self._data_loader(self.dm.get_train_set(), self.train_idx) + return self._data_loader(self.dm.get_train_set(), training=True, idx=self.train_idx) def val_dataloader(self) -> DataLoader: """Get the validation dataloader for the current fold.""" - return self._data_loader(self.dm.get_train_set(), self.val_idx) + return self._data_loader(self.dm.get_train_set(), training=False, idx=self.val_idx) - def test_dataloader(self) -> DataLoader: + def test_dataloader(self) -> list[DataLoader]: """Get the test dataloader for the current fold.""" - return self._data_loader(self.dm.get_train_set(), self.val_idx) + return [self._data_loader(self.dm.get_train_set(), training=False, idx=self.val_idx)] diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index d52211e2..c92ba2cf 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -29,6 +29,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, eval_ood: bool = False, eval_shift: bool = False, shift_severity: int = 1, @@ -46,12 +47,12 @@ def __init__( """DataModule for CIFAR10. Args: - root (str): Root directory of the datasets. - eval_ood (bool): Whether to evaluate on out-of-distribution data. - Defaults to ``False``. - eval_shift (bool): Whether to evaluate on shifted data. Defaults to - ``False``. - batch_size (int): Number of samples per batch. + root (str | Path): Root directory of the datasets. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. + eval_ood (bool): Whether to evaluate on out-of-distribution data. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. postprocess_set (str, optional): The post-hoc calibration dataset to @@ -75,6 +76,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, postprocess_set=postprocess_set, num_workers=num_workers, @@ -222,8 +224,9 @@ def train_dataloader(self) -> DataLoader: return self._data_loader( AggregatedDataset(self.train, self.num_dataloaders), shuffle=True, + training=True, ) - return self._data_loader(self.train, shuffle=True) + return self._data_loader(self.train, training=True, shuffle=True) def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. @@ -231,11 +234,11 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.ood, training=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.shift, training=False)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 6334b10c..884ef9d4 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -29,6 +29,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, eval_ood: bool = False, eval_shift: bool = False, shift_severity: int = 1, @@ -51,7 +52,9 @@ def __init__( performance. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. postprocess_set (str, optional): The post-hoc calibration dataset to @@ -74,6 +77,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, postprocess_set=postprocess_set, num_workers=num_workers, @@ -206,10 +210,9 @@ def train_dataloader(self) -> DataLoader: """ if self.num_dataloaders > 1: return self._data_loader( - AggregatedDataset(self.train, self.num_dataloaders), - shuffle=True, + AggregatedDataset(self.train, self.num_dataloaders), shuffle=True, training=True ) - return self._data_loader(self.train, shuffle=True) + return self._data_loader(self.train, training=True, shuffle=True) def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. @@ -218,11 +221,11 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-100C data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.ood, training=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.shift, training=False)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 084c285c..30e78602 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -46,6 +46,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, eval_ood: bool = False, eval_shift: bool = False, shift_severity: int = 1, @@ -69,12 +70,12 @@ def __init__( Args: root (str): Root directory of the datasets. - eval_ood (bool): Whether to evaluate out-of-distribution - performance. Defaults to ``False``. - eval_shift (bool): Whether to evaluate on shifted data. Defaults to - ``False``. - shift_severity: int = 1, - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. + eval_ood (bool): Whether to evaluate out-of-distribution performance. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. + shift_severity (int): Severity of the shift. Defaults to ``1``. val_split (float or Path): Share of samples to use for validation or path to a yaml file containing a list of validation images ids. Defaults to ``0.0``. @@ -99,6 +100,7 @@ def __init__( super().__init__( root=Path(root), batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, postprocess_set=postprocess_set, num_workers=num_workers, @@ -302,11 +304,11 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split (out-of-distribution data), and/or ImageNetC data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.ood, training=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.shift, training=False)) return dataloader diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a1d72de4..06963bf9 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -26,6 +26,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, eval_ood: bool = False, eval_shift: bool = False, ood_ds: Literal["fashion", "notMNIST"] = "fashion", @@ -41,11 +42,11 @@ def __init__( Args: root (str): Root directory of the datasets. - eval_ood (bool): Whether to evaluate on out-of-distribution data. - Defaults to ``False``. - eval_shift (bool): Whether to evaluate on shifted data. Defaults to - ``False``. - batch_size (int): Number of samples per batch. + eval_ood (bool): Whether to evaluate on out-of-distribution data. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. ood_ds (str): Which out-of-distribution dataset to use. Defaults to ``"fashion"``; `fashion` stands for FashionMNIST and `notMNIST` for notMNIST. @@ -65,6 +66,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, postprocess_set=postprocess_set, num_workers=num_workers, @@ -182,9 +184,9 @@ def test_dataloader(self) -> list[DataLoader]: distribution data), FashionMNIST or NotMNIST test split (out-of-distribution data), and/or MNISTC (shifted data). """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.ood, training=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.shift, training=False)) return dataloader diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 195377d3..80ee6e1b 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -34,6 +34,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, eval_ood: bool = False, eval_shift: bool = False, shift_severity: int = 1, @@ -54,7 +55,9 @@ def __init__( Args: root (str): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. eval_ood (bool): Whether to evaluate out-of-distribution performance. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. shift_severity (int): Severity of the shift. Defaults to ``1``. @@ -78,6 +81,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, postprocess_set=postprocess_set, num_workers=num_workers, @@ -211,7 +215,7 @@ def train_dataloader(self) -> DataLoader: Return: DataLoader: TinyImageNet training dataloader. """ - return self._data_loader(self.train, shuffle=True) + return self._data_loader(self.train, training=True, shuffle=True) def val_dataloader(self) -> DataLoader: r"""Get the validation dataloader for TinyImageNet. @@ -219,7 +223,7 @@ def val_dataloader(self) -> DataLoader: Return: DataLoader: TinyImageNet validation dataloader. """ - return self._data_loader(self.val) + return self._data_loader(self.test, training=False) def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. @@ -228,11 +232,11 @@ def test_dataloader(self) -> list[DataLoader]: list[DataLoader]: test set for in distribution data, OOD data, and/or TinyImageNetC data. """ - dataloader = [self._data_loader(self.test)] + dataloader = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) + dataloader.append(self._data_loader(self.ood, training=False)) if self.eval_shift: - dataloader.append(self._data_loader(self.shift)) + dataloader.append(self._data_loader(self.shift, training=False)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/uci/bank_marketing.py b/torch_uncertainty/datamodules/classification/uci/bank_marketing.py index 9f07ce69..49e3f410 100644 --- a/torch_uncertainty/datamodules/classification/uci/bank_marketing.py +++ b/torch_uncertainty/datamodules/classification/uci/bank_marketing.py @@ -10,6 +10,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, @@ -22,6 +23,8 @@ def __init__( Args: root (string): Root directory of the datasets. batch_size (int): The batch size for training and testing. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. val_split (float, optional): Share of validation samples among the non-test samples. Defaults to ``0``. test_split (float, optional): Share of test samples. Defaults to ``0.2``. diff --git a/torch_uncertainty/datamodules/classification/uci/dota2_games.py b/torch_uncertainty/datamodules/classification/uci/dota2_games.py index 8269a6c9..a5f5bc67 100644 --- a/torch_uncertainty/datamodules/classification/uci/dota2_games.py +++ b/torch_uncertainty/datamodules/classification/uci/dota2_games.py @@ -10,6 +10,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, @@ -22,8 +23,9 @@ def __init__( Args: root (string): Root directory of the datasets. batch_size (int): The batch size for training and testing. - val_split (float, optional): Share of validation samples among the - non-test samples. Defaults to ``0``. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. + val_split (float, optional): Share of validation samples among the non-test samples. Defaults to ``0``. test_split (float, optional): Share of test samples. Defaults to ``0.2``. num_workers (int, optional): How many subprocesses to use for data loading. Defaults to ``1``. diff --git a/torch_uncertainty/datamodules/classification/uci/htru2.py b/torch_uncertainty/datamodules/classification/uci/htru2.py index 3dfdaf45..a102ddf9 100644 --- a/torch_uncertainty/datamodules/classification/uci/htru2.py +++ b/torch_uncertainty/datamodules/classification/uci/htru2.py @@ -10,6 +10,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, @@ -22,8 +23,9 @@ def __init__( Args: root (string): Root directory of the datasets. batch_size (int): The batch size for training and testing. - val_split (float, optional): Share of validation samples among the - non-test samples. Defaults to ``0``. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. + val_split (float, optional): Share of validation samples among the non-test samples. Defaults to ``0``. test_split (float, optional): Share of test samples. Defaults to ``0.2``. num_workers (int, optional): How many subprocesses to use for data loading. Defaults to ``1``. diff --git a/torch_uncertainty/datamodules/classification/uci/online_shoppers.py b/torch_uncertainty/datamodules/classification/uci/online_shoppers.py index c5d24e11..28189cdc 100644 --- a/torch_uncertainty/datamodules/classification/uci/online_shoppers.py +++ b/torch_uncertainty/datamodules/classification/uci/online_shoppers.py @@ -10,6 +10,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, @@ -22,6 +23,8 @@ def __init__( Args: root (string): Root directory of the datasets. batch_size (int): The batch size for training and testing. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. val_split (float, optional): Share of validation samples among the non-test samples. Defaults to ``0``. test_split (float, optional): Share of test samples. Defaults to ``0.2``. diff --git a/torch_uncertainty/datamodules/classification/uci/spam_base.py b/torch_uncertainty/datamodules/classification/uci/spam_base.py index 868ab738..66988398 100644 --- a/torch_uncertainty/datamodules/classification/uci/spam_base.py +++ b/torch_uncertainty/datamodules/classification/uci/spam_base.py @@ -10,6 +10,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, @@ -22,6 +23,8 @@ def __init__( Args: root (string): Root directory of the datasets. batch_size (int): The batch size for training and testing. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. val_split (float, optional): Share of validation samples among the non-test samples. Defaults to ``0``. test_split (float, optional): Share of test samples. Defaults to ``0.2``. diff --git a/torch_uncertainty/datamodules/classification/uci/uci_classification.py b/torch_uncertainty/datamodules/classification/uci/uci_classification.py index 40ed06a0..1f9160bb 100644 --- a/torch_uncertainty/datamodules/classification/uci/uci_classification.py +++ b/torch_uncertainty/datamodules/classification/uci/uci_classification.py @@ -14,6 +14,7 @@ def __init__( root: str | Path, dataset: type[Dataset], batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, @@ -27,6 +28,8 @@ def __init__( root (string): Root directory of the datasets. dataset (type[Dataset]): The UCI classification dataset class. batch_size (int): The batch size for training and testing. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. val_split (float, optional): Share of validation samples among the non-test samples. Defaults to ``0``. test_split (float, optional): Share of test samples. Defaults to ``0.2``. @@ -43,6 +46,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 59804c0a..2abebf72 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -22,6 +22,7 @@ def __init__( max_depth: float, crop_size: _size_2_t, eval_size: _size_2_t, + eval_batch_size: int | None = None, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -32,7 +33,9 @@ def __init__( Args: dataset (type[VisionDataset]): Dataset class to use. root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. min_depth (float, optional): Minimum depth value for evaluation. max_depth (float, optional): Maximum depth value for training and evaluation. @@ -56,6 +59,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/kitti.py b/torch_uncertainty/datamodules/depth/kitti.py index 69227769..de50a51d 100644 --- a/torch_uncertainty/datamodules/depth/kitti.py +++ b/torch_uncertainty/datamodules/depth/kitti.py @@ -12,6 +12,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, min_depth: float = 1e-3, max_depth: float = 80.0, crop_size: _size_2_t = (352, 704), @@ -25,7 +26,9 @@ def __init__( Args: root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. min_depth (float, optional): Minimum depth value for evaluation. Defaults to ``1e-3``. max_depth (float, optional): Maximum depth value for training and @@ -55,6 +58,7 @@ def __init__( dataset=KITTIDepth, root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index 032a4292..5e0c2b06 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -17,6 +17,7 @@ def __init__( max_depth: float, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), + eval_batch_size: int | None = None, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -26,7 +27,9 @@ def __init__( Args: root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. min_depth (float, optional): Minimum depth value for evaluation max_depth (float, optional): Maximum depth value for training and evaluation. @@ -55,6 +58,7 @@ def __init__( dataset=MUAD, root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, diff --git a/torch_uncertainty/datamodules/depth/nyu.py b/torch_uncertainty/datamodules/depth/nyu.py index 077badff..cbfd6dd1 100644 --- a/torch_uncertainty/datamodules/depth/nyu.py +++ b/torch_uncertainty/datamodules/depth/nyu.py @@ -12,6 +12,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, min_depth: float = 1e-3, max_depth: float = 10.0, crop_size: _size_2_t = (416, 544), @@ -25,7 +26,9 @@ def __init__( Args: root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. min_depth (float, optional): Minimum depth value for evaluation. Defaults to ``1e-3``. max_depth (float, optional): Maximum depth value for training and @@ -55,6 +58,7 @@ def __init__( dataset=NYUv2, root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 66d774b8..b028a48e 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -23,6 +23,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, crop_size: _size_2_t = 640, eval_size: _size_2_t = (720, 960), group_classes: bool = True, @@ -36,7 +37,9 @@ def __init__( Args: root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. crop_size (sequence or int, optional): Desired input image and segmentation mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop @@ -93,6 +96,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index a6005893..2328df64 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -24,6 +24,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), basic_augment: bool = True, @@ -36,7 +37,9 @@ def __init__( Args: root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. crop_size (sequence or int, optional): Desired input image and segmentation mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop @@ -110,6 +113,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 00f39251..82c31555 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -21,6 +21,7 @@ def __init__( self, root: str | Path, batch_size: int, + eval_batch_size: int | None = None, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, @@ -32,7 +33,9 @@ def __init__( Args: root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. crop_size (sequence or int, optional): Desired input image and segmentation mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop @@ -104,6 +107,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 6dae8899..55eac3fe 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -15,8 +15,9 @@ class UCIRegressionDataModule(TUDataModule): def __init__( self, root: str | Path, - batch_size: int, dataset_name: str, + batch_size: int, + eval_batch_size: int | None = None, val_split: float = 0.0, num_workers: int = 1, pin_memory: bool = True, @@ -28,11 +29,13 @@ def __init__( Args: root (string): Root directory of the datasets. - batch_size (int): The batch size for training and testing. dataset_name (string, optional): The name of the dataset. One of - "boston-housing", "concrete", "energy", "kin8nm", - "naval-propulsion-plant", "power-plant", "protein", - "wine-quality-red", and "yacht". + ``boston-housing``, ``concrete``, ``energy``, ``kin8nm``, + ``naval-propulsion-plant``, ``power-plant``, ``protein``, + ``wine-quality-red``, and ``yacht``. + batch_size (int): The batch size for training and testing. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. val_split (float, optional): Share of validation samples. Defaults to ``0``. num_workers (int, optional): How many subprocesses to use for data @@ -49,6 +52,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, From d79ef8e1350cdd63079482a836878bc0cc3a3af8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 24 Mar 2025 10:23:53 +0100 Subject: [PATCH 17/69] :bug: Add forgotten eval_batch_size in init --- .../datamodules/classification/uci/bank_marketing.py | 1 + torch_uncertainty/datamodules/classification/uci/dota2_games.py | 1 + torch_uncertainty/datamodules/classification/uci/htru2.py | 1 + .../datamodules/classification/uci/online_shoppers.py | 1 + torch_uncertainty/datamodules/classification/uci/spam_base.py | 1 + 5 files changed, 5 insertions(+) diff --git a/torch_uncertainty/datamodules/classification/uci/bank_marketing.py b/torch_uncertainty/datamodules/classification/uci/bank_marketing.py index 49e3f410..c2a620b9 100644 --- a/torch_uncertainty/datamodules/classification/uci/bank_marketing.py +++ b/torch_uncertainty/datamodules/classification/uci/bank_marketing.py @@ -42,6 +42,7 @@ def __init__( root=root, dataset=BankMarketing, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, diff --git a/torch_uncertainty/datamodules/classification/uci/dota2_games.py b/torch_uncertainty/datamodules/classification/uci/dota2_games.py index a5f5bc67..a746f0cf 100644 --- a/torch_uncertainty/datamodules/classification/uci/dota2_games.py +++ b/torch_uncertainty/datamodules/classification/uci/dota2_games.py @@ -41,6 +41,7 @@ def __init__( root=root, dataset=DOTA2Games, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, diff --git a/torch_uncertainty/datamodules/classification/uci/htru2.py b/torch_uncertainty/datamodules/classification/uci/htru2.py index a102ddf9..781eb43f 100644 --- a/torch_uncertainty/datamodules/classification/uci/htru2.py +++ b/torch_uncertainty/datamodules/classification/uci/htru2.py @@ -41,6 +41,7 @@ def __init__( root=root, dataset=HTRU2, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, diff --git a/torch_uncertainty/datamodules/classification/uci/online_shoppers.py b/torch_uncertainty/datamodules/classification/uci/online_shoppers.py index 28189cdc..117f837a 100644 --- a/torch_uncertainty/datamodules/classification/uci/online_shoppers.py +++ b/torch_uncertainty/datamodules/classification/uci/online_shoppers.py @@ -42,6 +42,7 @@ def __init__( root=root, dataset=OnlineShoppers, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, diff --git a/torch_uncertainty/datamodules/classification/uci/spam_base.py b/torch_uncertainty/datamodules/classification/uci/spam_base.py index 66988398..5839130d 100644 --- a/torch_uncertainty/datamodules/classification/uci/spam_base.py +++ b/torch_uncertainty/datamodules/classification/uci/spam_base.py @@ -42,6 +42,7 @@ def __init__( root=root, dataset=SpamBase, batch_size=batch_size, + eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, From 1cec6aef4fe8d95f53de93d7607fdc6effebba96 Mon Sep 17 00:00:00 2001 From: Adrien Lafage Date: Mon, 24 Mar 2025 19:04:22 +0100 Subject: [PATCH 18/69] :hammer: Set default value of `reset_model_parameters` to `True` --- torch_uncertainty/models/wrappers/deep_ensembles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 4c49a3d5..e7b4b014 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -66,7 +66,7 @@ def deep_ensembles( "classification", "regression", "segmentation", "pixel_regression" ] = "classification", probabilistic: bool | None = None, - reset_model_parameters: bool = False, + reset_model_parameters: bool = True, ) -> _DeepEnsembles: """Build a Deep Ensembles out of the original models. From 2f1d77ec20ec92112882e348248bb0924da7cbe6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 25 Mar 2025 13:34:33 +0100 Subject: [PATCH 19/69] :zap: Update License time span --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 14e1280a..d6b36815 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2023-2024 Adrien Lafage and Olivier Laurent + Copyright 2023-2025 Adrien Lafage and Olivier Laurent Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From fd21db23f7c210b2a5c292f935e15983b1ba07fe Mon Sep 17 00:00:00 2001 From: fira7s Date: Tue, 25 Mar 2025 17:43:52 +0100 Subject: [PATCH 20/69] :books: Added doc for ood scores and fixed covergae --- docs/source/api.rst | 21 +++ docs/source/quickstart.rst | 2 +- tests/routines/test_classification.py | 9 ++ torch_uncertainty/ood_criteria.py | 182 +++++++++++++++++++++++++- 4 files changed, 211 insertions(+), 3 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 0165ad49..b96240e5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -393,6 +393,27 @@ Scaling Methods VectorScaler MatrixScaler + + +OOD Scores +----------------------- + +.. currentmodule:: torch_uncertainty.ood_criteria + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class_inherited.rst + + TUOODCriterion + MaxLogitCriterion + EnergyCriterion + MaxSoftmaxProbabilityCriterion + EntropyCriterion + MutualInformationCriterion + VariationRatioCriterion + + Datamodules ----------- diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 4e076379..3cc95aa1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -46,7 +46,7 @@ and its parameters. # ... eval_ood: bool = False, eval_grouping_loss: bool = False, - ood_criterion: TUOODCriterion | None = None, + ood_criterion: type[TUOODCriterion] | str = "msp", log_plots: bool = False, save_in_csv: bool = False, ) -> None: diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 205c68c1..a7410cc6 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -416,3 +416,12 @@ def test_classification_failures(self): is_ensemble=True, post_processing=nn.Module(), ) + + def test_invalid_ood_criterion_random_string(self): + with pytest.raises(ValueError): + DummyClassificationBaseline( + num_classes=2, + in_channels=3, + loss=nn.CrossEntropyLoss(), + ood_criterion="gsgsds", + ) diff --git a/torch_uncertainty/ood_criteria.py b/torch_uncertainty/ood_criteria.py index cbaac9d2..081345b0 100644 --- a/torch_uncertainty/ood_criteria.py +++ b/torch_uncertainty/ood_criteria.py @@ -8,6 +8,15 @@ class OODCriterionInputType(Enum): + """Enum representing the type of input expected by the OOD (Out-of-Distribution) criteria. + + Attributes: + LOGIT (int): Represents that the input is in the form of logits (pre-softmax values). + PROB (int): Represents that the input is in the form of probabilities (post-softmax values). + ESTIMATOR_PROB (int): Represents that the input is in the form of estimated probabilities + from an ensemble or other probabilistic model. + """ + LOGIT = 1 PROB = 2 ESTIMATOR_PROB = 3 @@ -17,36 +26,146 @@ class TUOODCriterion(ABC, nn.Module): input_type: OODCriterionInputType ensemble_only = False + def __init__(self) -> None: + """Abstract base class for Out-of-Distribution (OOD) criteria. + + This class defines a common interface for implementing various OOD detection + criteria. Subclasses must implement the `forward` method. + + Attributes: + input_type (OODCriterionInputType): Type of input expected by the criterion. + ensemble_only (bool): Whether the criterion requires ensemble outputs. + """ + super().__init__() + @abstractmethod - def forward(self, inputs: Tensor) -> Tensor: # coverage: ignore - pass + def forward(self, inputs: Tensor) -> Tensor: + """Forward pass for the OOD criterion. + + Args: + inputs (Tensor): The input tensor representing model outputs. + + Returns: + Tensor: OOD score computed according to the criterion. + """ class MaxLogitCriterion(TUOODCriterion): input_type = OODCriterionInputType.LOGIT + def __init__(self) -> None: + """OOD criterion based on the maximum logit value. + + This criterion computes the negative of the highest logit value across + the output dimensions. Lower maximum logits indicate greater uncertainty. + + Attributes: + input_type (OODCriterionInputType): Expected input type is logits. + """ + super().__init__() + def forward(self, inputs: Tensor) -> Tensor: + """Compute the negative of the maximum logit value. + + Args: + inputs (Tensor): Tensor of logits with shape (batch_size, num_classes). + + Returns: + Tensor: Negative of the maximum logit value for each sample. + """ return -inputs.mean(dim=1).max(dim=-1).values class EnergyCriterion(TUOODCriterion): input_type = OODCriterionInputType.LOGIT + def __init__(self) -> None: + r"""OOD criterion based on the energy function. + + This criterion computes the negative log-sum-exp of the logits. + Higher energy values indicate greater uncertainty. + + .. math:: + E(\mathbf{z}) = -\log\left(\sum_{i=1}^{C} \exp(z_i)\right) + + where :math:`\mathbf{z} = [z_1, z_2, \dots, z_C]` is the logit vector. + + Attributes: + input_type (OODCriterionInputType): Expected input type is logits. + """ + super().__init__() + def forward(self, inputs: Tensor) -> Tensor: + """Compute the negative energy score. + + Args: + inputs (Tensor): Tensor of logits with shape (batch_size, num_classes). + + Returns: + Tensor: Negative energy score for each sample. + """ return -inputs.mean(dim=1).logsumexp(dim=-1) class MaxSoftmaxProbabilityCriterion(TUOODCriterion): input_type = OODCriterionInputType.PROB + def __init__(self) -> None: + r"""OOD criterion based on maximum softmax probability. + + This criterion computes the negative of the highest softmax probability. + Lower maximum probabilities indicate greater uncertainty. + + .. math:: + \text{score} = -\max_{i}(p_i) + + where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector. + + Attributes: + input_type (OODCriterionInputType): Expected input type is probabilities. + """ + super().__init__() + def forward(self, inputs: Tensor) -> Tensor: + """Compute the negative of the maximum softmax probability. + + Args: + inputs (Tensor): Tensor of probabilities with shape (batch_size, num_classes). + + Returns: + Tensor: Negative of the highest softmax probability for each sample. + """ return -inputs.max(-1)[0] class EntropyCriterion(TUOODCriterion): input_type = OODCriterionInputType.ESTIMATOR_PROB + def __init__(self) -> None: + r"""OOD criterion based on entropy. + + This criterion computes the mean entropy of the predicted probability distribution. + Higher entropy values indicate greater uncertainty. + + .. math:: + H(\mathbf{p}) = -\sum_{i=1}^{C} p_i \log(p_i) + + where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector. + + Attributes: + input_type (OODCriterionInputType): Expected input type is estimated probabilities. + """ + super().__init__() + def forward(self, inputs: Tensor) -> Tensor: + """Compute the entropy of the predicted probability distribution. + + Args: + inputs (Tensor): Tensor of estimated probabilities with shape (batch_size, num_classes). + + Returns: + Tensor: Mean entropy value for each sample. + """ return torch.special.entr(inputs).sum(dim=-1).mean(dim=1) @@ -55,10 +174,33 @@ class MutualInformationCriterion(TUOODCriterion): input_type = OODCriterionInputType.ESTIMATOR_PROB def __init__(self) -> None: + r"""OOD criterion based on mutual information. + + This criterion computes the mutual information between ensemble predictions. + Higher mutual information values indicate lower uncertainty. + + Given ensemble predictions :math:`\{\mathbf{p}^{(k)}\}_{k=1}^{K}`, the mutual information is computed as: + + .. math:: + I(y, \theta) = H\Big(\frac{1}{K}\sum_{k=1}^{K} \mathbf{p}^{(k)}\Big) - \frac{1}{K}\sum_{k=1}^{K} H(\mathbf{p}^{(k)}) + + Attributes: + ensemble_only (bool): Requires ensemble predictions. + input_type (OODCriterionInputType): Expected input type is estimated probabilities. + """ super().__init__() self.mi_metric = MutualInformation(reduction="none") def forward(self, inputs: Tensor) -> Tensor: + """Compute mutual information from ensemble predictions. + + Args: + inputs (Tensor): Tensor of ensemble probabilities with shape + (ensemble_size, batch_size, num_classes). + + Returns: + Tensor: Mutual information for each sample. + """ return self.mi_metric(inputs) @@ -67,14 +209,50 @@ class VariationRatioCriterion(TUOODCriterion): input_type = OODCriterionInputType.ESTIMATOR_PROB def __init__(self) -> None: + r"""OOD criterion based on variation ratio. + + This criterion computes the variation ratio from ensemble predictions. + Higher variation ratio values indicate greater uncertainty. + + Given ensemble predictions where :math:`n_{\text{mode}}` is the count of the most frequently + predicted class among :math:`K` predictions, the variation ratio is computed as: + + .. math:: + \text{VR} = 1 - \frac{n_{\text{mode}}}{K} + + Attributes: + ensemble_only (bool): Requires ensemble predictions. + input_type (OODCriterionInputType): Expected input type is estimated probabilities. + """ super().__init__() self.vr_metric = VariationRatio(reduction="none", probabilistic=False) def forward(self, inputs: Tensor) -> Tensor: + """Compute variation ratio from ensemble predictions. + + Args: + inputs (Tensor): Tensor of ensemble probabilities with shape + (ensemble_size, batch_size, num_classes). + + Returns: + Tensor: Variation ratio for each sample. + """ return self.vr_metric(inputs.transpose(0, 1)) def get_ood_criterion(ood_criterion): + """Get an OOD criterion instance based on a string identifier or class type. + + Args: + ood_criterion (str or type): A string identifier for a predefined OOD criterion + or a subclass of `TUOODCriterion`. + + Returns: + TUOODCriterion: An instance of the requested OOD criterion. + + Raises: + ValueError: If the input string or class type is invalid. + """ if isinstance(ood_criterion, str): if ood_criterion == "logit": return MaxLogitCriterion() From d91658c9d4ce47bf11eb696442dcbb720e304344 Mon Sep 17 00:00:00 2001 From: alafage Date: Tue, 25 Mar 2025 17:18:04 +0100 Subject: [PATCH 21/69] :art: Move failure test + rename MaxSoftmaxProbabilityCriterion -> MaxSoftmaxCriterion --- docs/source/api.rst | 2 +- tests/routines/test_classification.py | 18 +++++++++--------- torch_uncertainty/ood_criteria.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index b96240e5..83a9cdc5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -408,7 +408,7 @@ OOD Scores TUOODCriterion MaxLogitCriterion EnergyCriterion - MaxSoftmaxProbabilityCriterion + MaxSoftmaxCriterion EntropyCriterion MutualInformationCriterion VariationRatioCriterion diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index a7410cc6..52ef05d7 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -355,6 +355,15 @@ def test_classification_failures(self): ood_criterion=32, ) + with pytest.raises(ValueError): + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + is_ensemble=False, + ood_criterion="other", + ) + with pytest.raises(ValueError): mixup_params = {"cutmix_alpha": -1} ClassificationRoutine( @@ -416,12 +425,3 @@ def test_classification_failures(self): is_ensemble=True, post_processing=nn.Module(), ) - - def test_invalid_ood_criterion_random_string(self): - with pytest.raises(ValueError): - DummyClassificationBaseline( - num_classes=2, - in_channels=3, - loss=nn.CrossEntropyLoss(), - ood_criterion="gsgsds", - ) diff --git a/torch_uncertainty/ood_criteria.py b/torch_uncertainty/ood_criteria.py index 081345b0..2302240f 100644 --- a/torch_uncertainty/ood_criteria.py +++ b/torch_uncertainty/ood_criteria.py @@ -107,7 +107,7 @@ def forward(self, inputs: Tensor) -> Tensor: return -inputs.mean(dim=1).logsumexp(dim=-1) -class MaxSoftmaxProbabilityCriterion(TUOODCriterion): +class MaxSoftmaxCriterion(TUOODCriterion): input_type = OODCriterionInputType.PROB def __init__(self) -> None: @@ -259,7 +259,7 @@ def get_ood_criterion(ood_criterion): if ood_criterion == "energy": return EnergyCriterion() if ood_criterion == "msp": - return MaxSoftmaxProbabilityCriterion() + return MaxSoftmaxCriterion() if ood_criterion == "entropy": return EntropyCriterion() if ood_criterion == "mutual_information": From 2233294c1838df087cfea523c7368c6120549d9a Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 26 Mar 2025 11:11:51 +0100 Subject: [PATCH 22/69] :shirt: Add new rules, Lint & format --- pyproject.toml | 7 ++++ .../classification/test_cifar10.py | 10 ++--- .../classification/test_imagenet.py | 2 +- tests/datamodules/test_abstract_datamodule.py | 6 +-- tests/layers/test_batch.py | 8 ++-- tests/layers/test_bayesian.py | 30 ++++++-------- tests/layers/test_distributions.py | 6 +-- tests/layers/test_mask.py | 8 ++-- tests/layers/test_packed.py | 40 +++++++++---------- .../classification/test_brier_score.py | 24 +++++------ .../classification/test_disagreement.py | 6 +-- tests/metrics/classification/test_entropy.py | 6 +-- .../classification/test_mutual_information.py | 4 +- .../classification/test_variation_ratio.py | 6 +-- tests/models/wrappers/test_batch_ensemble.py | 2 +- tests/models/wrappers/test_stochastic.py | 2 +- tests/models/wrappers/test_swa.py | 4 +- tests/routines/test_classification.py | 2 +- tests/routines/test_regression.py | 12 +++--- tests/transforms/test_corruption.py | 2 +- tests/transforms/test_image.py | 6 +-- tests/transforms/test_mixup.py | 2 +- .../datasets/classification/cifar/cifar_h.py | 3 +- .../routines/pixel_regression.py | 4 +- torch_uncertainty/routines/segmentation.py | 4 +- 25 files changed, 104 insertions(+), 102 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9d8de32a..901ab845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,18 +89,22 @@ lint.extend-select = [ "A", "ARG", "B", + "BLE", "C4", "D", "ERA", "F", + "FURB", "G", "I", + "INT", "ISC", "ICN", "N", "NPY", "PERF", "PIE", + "PT", "PTH", "PYI", "Q", @@ -109,6 +113,8 @@ lint.extend-select = [ "RSE", "S", "SIM", + "T20", + "TC", "TCH", "TID", "TRY", @@ -128,6 +134,7 @@ lint.ignore = [ "ISC001", "N818", "N812", + "PT011", "RUF012", "S101", "TRY003", diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index d583219a..0a962c30 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -72,12 +72,12 @@ def test_cifar10_main(self): auto_augment="rand-m9-n2-mstd0.5", ) + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + test_alt="h", + ) with pytest.raises(ValueError, match="CIFAR-H can only be used in testing."): - dm = CIFAR10DataModule( - root="./data/", - batch_size=128, - test_alt="h", - ) dm.setup("fit") with pytest.raises(ValueError, match="Test set "): diff --git a/tests/datamodules/classification/test_imagenet.py b/tests/datamodules/classification/test_imagenet.py index dd949cce..9ee5ad7c 100644 --- a/tests/datamodules/classification/test_imagenet.py +++ b/tests/datamodules/classification/test_imagenet.py @@ -88,6 +88,6 @@ def test_imagenet(self): with pytest.raises(FileNotFoundError): dm._verify_splits(split="test") + dm.root = Path("./tests/testlog") with pytest.raises(FileNotFoundError): - dm.root = Path("./tests/testlog") dm._verify_splits(split="test") diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index a757f485..a7e8d863 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -15,10 +15,9 @@ class TestTUDataModule: def test_errors(self): TUDataModule.__abstractmethods__ = set() dm = TUDataModule("root", 128, 128, 0.0, 4, True, True) + dm.setup() with pytest.raises(NotImplementedError): - dm.setup() dm._get_train_data() - dm._get_train_targets() class TestCrossValDataModule: @@ -72,10 +71,9 @@ def test_errors(self): pin_memory=True, persistent_workers=True, ) + cv_dm.setup() with pytest.raises(NotImplementedError): - cv_dm.setup() cv_dm._get_train_data() - cv_dm._get_train_targets() with pytest.raises(ValueError): cv_dm.setup("other") diff --git a/tests/layers/test_batch.py b/tests/layers/test_batch.py index 75e54485..e11befdc 100644 --- a/tests/layers/test_batch.py +++ b/tests/layers/test_batch.py @@ -4,12 +4,12 @@ from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear -@pytest.fixture() +@pytest.fixture def feat_input() -> torch.Tensor: return torch.rand((4, 6)) -@pytest.fixture() +@pytest.fixture def img_input() -> torch.Tensor: return torch.rand((5, 6, 3, 3)) @@ -19,7 +19,7 @@ class TestBatchLinear: def test_linear_one_estimator(self, feat_input: torch.Tensor): layer = BatchLinear(6, 2, num_estimators=1) - print(layer) + print(layer) # noqa: T201 out = layer(feat_input) assert out.shape == torch.Size([4, 2]) @@ -50,7 +50,7 @@ class TestBatchConv2d: def test_conv_one_estimator(self, img_input: torch.Tensor): layer = BatchConv2d(6, 2, num_estimators=1, kernel_size=1) - print(layer) + print(layer) # noqa: T201 out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) diff --git a/tests/layers/test_bayesian.py b/tests/layers/test_bayesian.py index 66b09523..db90a6e9 100644 --- a/tests/layers/test_bayesian.py +++ b/tests/layers/test_bayesian.py @@ -12,32 +12,32 @@ from torch_uncertainty.layers.bayesian.sampler import TrainableDistribution -@pytest.fixture() +@pytest.fixture def feat_input_odd() -> torch.Tensor: return torch.rand((5, 10)) -@pytest.fixture() +@pytest.fixture def feat_input_even() -> torch.Tensor: return torch.rand((8, 10)) -@pytest.fixture() +@pytest.fixture def img_input_odd() -> torch.Tensor: return torch.rand((5, 10, 3, 3)) -@pytest.fixture() +@pytest.fixture def img_input_even() -> torch.Tensor: return torch.rand((8, 10, 3, 3)) -@pytest.fixture() +@pytest.fixture def cube_input_odd() -> torch.Tensor: return torch.rand((1, 10, 3, 3, 3)) -@pytest.fixture() +@pytest.fixture def cube_input_even() -> torch.Tensor: return torch.rand((2, 10, 3, 3, 3)) @@ -47,7 +47,7 @@ class TestBayesLinear: def test_linear(self, feat_input_odd: torch.Tensor) -> None: layer = BayesLinear(10, 2, sigma_init=0) - print(layer) + print(layer) # noqa: T201 out = layer(feat_input_odd) assert out.shape == torch.Size([5, 2]) layer.sample() @@ -71,7 +71,7 @@ class TestBayesConv1d: def test_conv1(self, feat_input_odd: torch.Tensor) -> None: layer = BayesConv1d(5, 2, kernel_size=1, sigma_init=0) - print(layer) + print(layer) # noqa: T201 out = layer(feat_input_odd) assert out.shape == torch.Size([2, 10]) @@ -81,7 +81,7 @@ def test_conv1(self, feat_input_odd: torch.Tensor) -> None: def test_conv1_even(self, feat_input_even: torch.Tensor) -> None: layer = BayesConv1d(8, 2, kernel_size=1, sigma_init=0, padding_mode="reflect") - print(layer) + print(layer) # noqa: T201 out = layer(feat_input_even) assert out.shape == torch.Size([2, 10]) @@ -100,7 +100,7 @@ class TestBayesConv2d: def test_conv2(self, img_input_odd: torch.Tensor) -> None: layer = BayesConv2d(10, 2, kernel_size=1, sigma_init=0) - print(layer) + print(layer) # noqa: T201 out = layer(img_input_odd) assert out.shape == torch.Size([5, 2, 3, 3]) layer.sample() @@ -112,7 +112,6 @@ def test_conv2(self, img_input_odd: torch.Tensor) -> None: def test_conv2_even(self, img_input_even: torch.Tensor) -> None: layer = BayesConv2d(10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect") - print(layer) out = layer(img_input_even) assert out.shape == torch.Size([8, 2, 3, 3]) @@ -125,18 +124,16 @@ class TestBayesConv3d: def test_conv3(self, cube_input_odd: torch.Tensor) -> None: layer = BayesConv3d(10, 2, kernel_size=1, sigma_init=0) - print(layer) + print(layer) # noqa: T201 out = layer(cube_input_odd) assert out.shape == torch.Size([1, 2, 3, 3, 3]) layer = BayesConv3d(10, 2, kernel_size=1, sigma_init=0, bias=False) - print(layer) out = layer(cube_input_odd) assert out.shape == torch.Size([1, 2, 3, 3, 3]) def test_conv3_even(self, cube_input_even: torch.Tensor) -> None: layer = BayesConv3d(10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect") - print(layer) out = layer(cube_input_even) assert out.shape == torch.Size([2, 2, 3, 3, 3]) @@ -158,7 +155,7 @@ class TestLPBNNLinear: def test_linear(self, feat_input_odd: torch.Tensor) -> None: layer = LPBNNLinear(10, 2, num_estimators=4) - print(layer) + print(layer) # noqa: T201 out = layer(feat_input_odd.repeat(4, 1)) assert out.shape == torch.Size([5 * 4, 2]) @@ -180,7 +177,7 @@ class TestLPBNNConv2d: def test_conv2(self, img_input_odd: torch.Tensor) -> None: layer = LPBNNConv2d(10, 2, kernel_size=1, num_estimators=4) - print(layer) + print(layer) # noqa: T201 out = layer(img_input_odd.repeat(4, 1, 1, 1)) assert out.shape == torch.Size([5 * 4, 2, 3, 3]) @@ -191,7 +188,6 @@ def test_conv2(self, img_input_odd: torch.Tensor) -> None: def test_conv2_even(self, img_input_even: torch.Tensor) -> None: layer = LPBNNConv2d(10, 2, kernel_size=1, num_estimators=4, padding_mode="reflect") - print(layer) out = layer(img_input_even.repeat(4, 1, 1, 1)) assert out.shape == torch.Size([8 * 4, 2, 3, 3]) diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py index cee3ebe5..98250d8b 100644 --- a/tests/layers/test_distributions.py +++ b/tests/layers/test_distributions.py @@ -7,7 +7,7 @@ ) -@pytest.fixture() +@pytest.fixture def feat_input() -> torch.Tensor: return torch.rand((3, 8)) # (B, Hin) @@ -99,8 +99,8 @@ def test_failures(self): with pytest.raises(NotImplementedError): get_dist_linear_layer("unknown") + layer_class = get_dist_linear_layer("normal") with pytest.raises(ValueError): - layer_class = get_dist_linear_layer("normal") layer_class( base_layer=torch.nn.Conv2d, event_dim=2, @@ -197,8 +197,8 @@ def test_failures(self): with pytest.raises(NotImplementedError): get_dist_conv_layer("unknown") + layer_class = get_dist_conv_layer("normal") with pytest.raises(ValueError): - layer_class = get_dist_conv_layer("normal") layer_class( base_layer=torch.nn.Linear, event_dim=2, diff --git a/tests/layers/test_mask.py b/tests/layers/test_mask.py index bf8e2c2d..f0da2820 100644 --- a/tests/layers/test_mask.py +++ b/tests/layers/test_mask.py @@ -4,22 +4,22 @@ from torch_uncertainty.layers.masksembles import MaskedConv2d, MaskedLinear -@pytest.fixture() +@pytest.fixture def feat_input_odd() -> torch.Tensor: return torch.rand((5, 10)) -@pytest.fixture() +@pytest.fixture def feat_input_even() -> torch.Tensor: return torch.rand((8, 10)) -@pytest.fixture() +@pytest.fixture def img_input_odd() -> torch.Tensor: return torch.rand((5, 10, 3, 3)) -@pytest.fixture() +@pytest.fixture def img_input_even() -> torch.Tensor: return torch.rand((8, 10, 3, 3)) diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index e30e7cd3..0efdfab0 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -18,68 +18,68 @@ ) -@pytest.fixture() +@pytest.fixture def feat_input() -> torch.Tensor: return torch.rand((6, 1)) # (Cin, Lin) -@pytest.fixture() +@pytest.fixture def feat_input_one_rearrange() -> torch.Tensor: return torch.rand((1 * 3, 5)) -@pytest.fixture() +@pytest.fixture def feat_multi_dim() -> torch.Tensor: return torch.rand((1, 2, 3, 4, 6)) -@pytest.fixture() +@pytest.fixture def feat_input_16_features() -> torch.Tensor: return torch.rand((3, 16)) -@pytest.fixture() +@pytest.fixture def seq_input() -> torch.Tensor: return torch.rand((5, 6, 3)) -@pytest.fixture() +@pytest.fixture def img_input() -> torch.Tensor: return torch.rand((5, 6, 3, 3)) -@pytest.fixture() +@pytest.fixture def voxels_input() -> torch.Tensor: return torch.rand((5, 6, 3, 3, 3)) -@pytest.fixture() +@pytest.fixture def unbatched_qkv() -> torch.Tensor: return torch.rand((3, 6)) -@pytest.fixture() +@pytest.fixture def unbatched_q_kv() -> tuple[torch.Tensor, torch.Tensor]: return torch.rand((3, 6)), torch.rand((4, 2)) -@pytest.fixture() +@pytest.fixture def unbatched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return torch.rand((3, 6)), torch.rand((4, 2)), torch.rand((4, 4)) -@pytest.fixture() +@pytest.fixture def batched_qkv() -> torch.Tensor: return torch.rand((2, 3, 6)) -@pytest.fixture() +@pytest.fixture def extended_batched_qkv() -> torch.Tensor: expansion = 2 return torch.rand((2, 3, 6 * expansion)) -@pytest.fixture() +@pytest.fixture def batched_q_kv() -> tuple[torch.Tensor, torch.Tensor]: return ( torch.rand((2, 3, 6)), @@ -87,7 +87,7 @@ def batched_q_kv() -> tuple[torch.Tensor, torch.Tensor]: ) -@pytest.fixture() +@pytest.fixture def batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return ( torch.rand((2, 3, 6)), @@ -96,7 +96,7 @@ def batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -@pytest.fixture() +@pytest.fixture def extended_batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: expansion = 2 return ( @@ -106,12 +106,12 @@ def extended_batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -@pytest.fixture() +@pytest.fixture def unbatched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: return torch.rand((3, 6)), torch.rand((4, 6)) -@pytest.fixture() +@pytest.fixture def batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: return ( torch.rand((2, 3, 6)), @@ -119,7 +119,7 @@ def batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: ) -@pytest.fixture() +@pytest.fixture def extended_batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: expansion = 2 return ( @@ -246,9 +246,9 @@ def test_linear_failures(self): implementation="invalid", ) + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full") + layer.implementation = "invalid" with pytest.raises(ValueError): - layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full") - layer.implementation = "invalid" _ = layer(torch.rand((2, 16))) diff --git a/tests/metrics/classification/test_brier_score.py b/tests/metrics/classification/test_brier_score.py index 518a6d5e..2642dcc4 100644 --- a/tests/metrics/classification/test_brier_score.py +++ b/tests/metrics/classification/test_brier_score.py @@ -4,51 +4,51 @@ from torch_uncertainty.metrics import BrierScore -@pytest.fixture() +@pytest.fixture def vec2d_max() -> torch.Tensor: vec = torch.as_tensor([0.5, 0.5]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec2d_max_target() -> torch.Tensor: vec = torch.as_tensor([0, 1]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec2d_max_target1d() -> torch.Tensor: return torch.as_tensor([1]) -@pytest.fixture() +@pytest.fixture def vec2d_min() -> torch.Tensor: vec = torch.as_tensor([0.0, 1.0]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec2d_min_target() -> torch.Tensor: vec = torch.as_tensor([0, 1]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec2d_5classes() -> torch.Tensor: return torch.as_tensor([[0.2, 0.6, 0.1, 0.05, 0.05], [0.05, 0.25, 0.1, 0.3, 0.3]]) -@pytest.fixture() +@pytest.fixture def vec2d_5classes_target() -> torch.Tensor: return torch.as_tensor([[0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]) -@pytest.fixture() +@pytest.fixture def vec2d_5classes_target1d() -> torch.Tensor: return torch.as_tensor([3, 4]) -@pytest.fixture() +@pytest.fixture def vec3d() -> torch.Tensor: """Return a torch tensor with a mean BrierScore of 0 and a BrierScore of the mean of 0.5 to test the `ensemble` parameter of `BrierScore`. @@ -57,13 +57,13 @@ def vec3d() -> torch.Tensor: return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec3d_target() -> torch.Tensor: vec = torch.as_tensor([0, 1]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec3d_target1d() -> torch.Tensor: vec = torch.as_tensor([1]) return vec.unsqueeze(0) @@ -170,8 +170,8 @@ def test_compute_3d_to_2d(self, vec3d: torch.Tensor, vec3d_target: torch.Tensor) assert metric.compute() == 0.5 def test_bad_input(self) -> None: + metric = BrierScore(num_classes=2, reduction="none") with pytest.raises(ValueError): - metric = BrierScore(num_classes=2, reduction="none") metric.update(torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2)) def test_bad_argument(self): diff --git a/tests/metrics/classification/test_disagreement.py b/tests/metrics/classification/test_disagreement.py index 1fa7d45c..3814a987 100644 --- a/tests/metrics/classification/test_disagreement.py +++ b/tests/metrics/classification/test_disagreement.py @@ -4,17 +4,17 @@ from torch_uncertainty.metrics import Disagreement -@pytest.fixture() +@pytest.fixture def disagreement_probas() -> torch.Tensor: return torch.as_tensor([[[0.0, 1.0], [1.0, 0.0]]]) -@pytest.fixture() +@pytest.fixture def agreement_probas() -> torch.Tensor: return torch.as_tensor([[[1.0, 0.0], [1.0, 0.0]]]) -@pytest.fixture() +@pytest.fixture def disagreement_probas_3() -> torch.Tensor: return torch.as_tensor([[[0.0, 1.0], [0.0, 1.0], [1.0, 0.0]]]) diff --git a/tests/metrics/classification/test_entropy.py b/tests/metrics/classification/test_entropy.py index 0a119c30..be1a1f46 100644 --- a/tests/metrics/classification/test_entropy.py +++ b/tests/metrics/classification/test_entropy.py @@ -6,19 +6,19 @@ from torch_uncertainty.metrics import Entropy -@pytest.fixture() +@pytest.fixture def vec2d_max() -> torch.Tensor: vec = torch.as_tensor([0.5, 0.5]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec2d_min() -> torch.Tensor: vec = torch.as_tensor([0.0, 1.0]) return vec.unsqueeze(0) -@pytest.fixture() +@pytest.fixture def vec3d() -> torch.Tensor: """Return a torch tensor with a mean entropy of 0 and an entropy of the mean of ln(2) to test the `ensemble` parameter of `Entropy`. diff --git a/tests/metrics/classification/test_mutual_information.py b/tests/metrics/classification/test_mutual_information.py index 22597f5e..f2730bc7 100644 --- a/tests/metrics/classification/test_mutual_information.py +++ b/tests/metrics/classification/test_mutual_information.py @@ -6,13 +6,13 @@ from torch_uncertainty.metrics import MutualInformation -@pytest.fixture() +@pytest.fixture def disagreement_probas() -> torch.Tensor: """Return a vector with mean entropy ~ln(2) and entropy of mean =0.""" return torch.as_tensor([[[1e-8, 1 - 1e-8], [1 - 1e-8, 1e-8]]]) -@pytest.fixture() +@pytest.fixture def agreement_probas() -> torch.Tensor: return torch.as_tensor([[[0.9, 0.1], [0.9, 0.1]]]) diff --git a/tests/metrics/classification/test_variation_ratio.py b/tests/metrics/classification/test_variation_ratio.py index 10936f2d..24be5027 100644 --- a/tests/metrics/classification/test_variation_ratio.py +++ b/tests/metrics/classification/test_variation_ratio.py @@ -4,18 +4,18 @@ from torch_uncertainty.metrics import VariationRatio -@pytest.fixture() +@pytest.fixture def disagreement_probas_3est() -> torch.Tensor: """Return a vector with mean entropy ~ln(2) and entropy of mean =0.""" return torch.as_tensor([[[0.2, 0.8]], [[0.7, 0.3]], [[0.6, 0.4]]]) -@pytest.fixture() +@pytest.fixture def agreement_probas() -> torch.Tensor: return torch.as_tensor([[[0.9, 0.1]], [[0.9, 0.1]]]) -@pytest.fixture() +@pytest.fixture def agreement_probas_3est() -> torch.Tensor: """Return a vector with mean entropy ~ln(2) and entropy of mean =0.""" return torch.as_tensor([[[0.2, 0.8]], [[0.3, 0.7]], [[0.4, 0.6]]]) diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py index 3ec42082..f51cde15 100644 --- a/tests/models/wrappers/test_batch_ensemble.py +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -6,7 +6,7 @@ from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble, batch_ensemble -@pytest.fixture() +@pytest.fixture def img_input() -> torch.Tensor: return torch.rand((5, 6, 3, 3)) diff --git a/tests/models/wrappers/test_stochastic.py b/tests/models/wrappers/test_stochastic.py index dc8a814e..be31e5b0 100644 --- a/tests/models/wrappers/test_stochastic.py +++ b/tests/models/wrappers/test_stochastic.py @@ -69,7 +69,7 @@ def test_mix(self): state = model.sample()[0] keys = state.keys() - print(list(keys)) + print(list(keys)) # noqa: T201 assert list(keys) == [ "layer.weight", "layer2.weight", diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py index 1bc0da0e..4a33b2ec 100644 --- a/tests/models/wrappers/test_swa.py +++ b/tests/models/wrappers/test_swa.py @@ -90,12 +90,12 @@ def test_training(self): def test_state_dict(self): mod = dummy_model(1, 10) swag = SWAG(mod, cycle_start=1, cycle_length=1, num_estimators=3) - print(swag.state_dict()) + print(swag.state_dict()) # noqa: T201 swag.load_state_dict(swag.state_dict()) def test_failures(self): + swag = SWAG(nn.Module(), scale=1, cycle_start=1, cycle_length=1) with pytest.raises(NotImplementedError, match="Raise an issue if you need this feature"): - swag = SWAG(nn.Module(), scale=1, cycle_start=1, cycle_length=1) swag.sample(scale=1, block=True) with pytest.raises(ValueError, match="`scale` must be non-negative."): SWAG(nn.Module(), scale=-1, cycle_start=1, cycle_length=1) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 52ef05d7..368f4ec7 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -364,8 +364,8 @@ def test_classification_failures(self): ood_criterion="other", ) + mixup_params = {"cutmix_alpha": -1} with pytest.raises(ValueError): - mixup_params = {"cutmix_alpha": -1} ClassificationRoutine( num_classes=10, model=nn.Module(), diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index e491209e..2eabc1be 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -159,11 +159,11 @@ def test_regression_failures(self): loss=nn.MSELoss(), ) + routine = RegressionRoutine( + dist_family="normal", + output_dim=1, + model=nn.Identity(), + loss=nn.MSELoss(), + ) with pytest.raises(TypeError): - routine = RegressionRoutine( - dist_family="normal", - output_dim=1, - model=nn.Identity(), - loss=nn.MSELoss(), - ) routine(torch.randn(1, 1)) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index bf1184a5..bc33c6d9 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -42,7 +42,7 @@ def test_gaussian_noise(self): inputs = torch.rand(3, 3, 32, 32) assert transform(inputs).ndim == 4 - print(transform) + print(transform) # noqa: T201 def test_shot_noise(self): inputs = torch.rand(3, 32, 32) diff --git a/tests/transforms/test_image.py b/tests/transforms/test_image.py index c79e5210..277db309 100644 --- a/tests/transforms/test_image.py +++ b/tests/transforms/test_image.py @@ -22,14 +22,14 @@ ) -@pytest.fixture() +@pytest.fixture def img_input() -> torch.Tensor: rng = np.random.default_rng() imarray = rng.uniform(low=0, high=255, size=(28, 28, 3)) return Image.fromarray(imarray.astype("uint8")).convert("RGB") -@pytest.fixture() +@pytest.fixture def tv_tensors_input() -> tuple[torch.Tensor, torch.Tensor]: rng = np.random.default_rng() imarray1 = rng.uniform(low=0, high=255, size=(3, 28, 28)) @@ -40,7 +40,7 @@ def tv_tensors_input() -> tuple[torch.Tensor, torch.Tensor]: ) -@pytest.fixture() +@pytest.fixture def batch_input() -> tuple[torch.Tensor, torch.Tensor]: imgs = torch.rand(2, 3, 28, 28) return imgs, torch.tensor([0, 1]) diff --git a/tests/transforms/test_mixup.py b/tests/transforms/test_mixup.py index 01900fb4..697cd7f1 100644 --- a/tests/transforms/test_mixup.py +++ b/tests/transforms/test_mixup.py @@ -5,7 +5,7 @@ from torch_uncertainty.transforms.mixup import AbstractMixup -@pytest.fixture() +@pytest.fixture def batch_input() -> tuple[torch.Tensor, torch.Tensor]: imgs = torch.rand(2, 3, 28, 28) return imgs, torch.tensor([0, 1]) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index 4cbbecb1..67ed3ca4 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path from typing import Any @@ -39,7 +40,7 @@ def __init__( ) -> None: if train: raise ValueError("CIFAR10H does not support training data.") - print("WARNING: CIFAR10H cannot be used within Classification routines for now.") + logging.warning("WARNING: CIFAR10H cannot be used within Classification routines for now.") super().__init__( Path(root), train=False, diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 92b9b9da..1db3009a 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -303,9 +303,7 @@ def test_step( preds, dist = self.evaluation_forward(inputs) if batch_idx == 0 and self.log_plots: - num_images = ( - self.num_image_plot if self.num_image_plot < inputs.size(0) else inputs.size(0) - ) + num_images = min(inputs.size(0), self.num_image_plot) self._plot_depth( inputs[:num_images, ...], preds[:num_images, ...], diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 662d5581..a68a97ca 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -1,3 +1,5 @@ +import logging + import torch from einops import rearrange from lightning.pytorch import LightningModule @@ -286,7 +288,7 @@ def on_test_epoch_end(self) -> None: if self.trainer.datamodule is not None: self.log_segmentation_plots() else: - print("No datamodule found, skipping segmentation plots.") + logging.info("No datamodule found, skipping segmentation plots.") def log_segmentation_plots(self) -> None: """Build and log examples of segmentation plots from the test set.""" From 020ab8937a2d48e3635b87c436a826a5e50c727e Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Mon, 31 Mar 2025 14:46:35 +0200 Subject: [PATCH 23/69] :books: Update Dockerfile documentation --- docker/DOCKER.md | 159 ++++++++++++++++++++++++++--------------------- 1 file changed, 88 insertions(+), 71 deletions(-) diff --git a/docker/DOCKER.md b/docker/DOCKER.md index 35adbb81..4e2de285 100644 --- a/docker/DOCKER.md +++ b/docker/DOCKER.md @@ -1,71 +1,88 @@ -# :whale: Docker image for contributors - -### Pre-built Docker image -1. To pull the pre-built image from Docker Hub, simply run: - ```bash - docker pull docker.io/tonyzamyatin/torch-uncertainty:latest - ``` - - This image includes: - - PyTorch with CUDA support - - OpenGL (for visualization tasks) - - Git, OpenSSH, and all Python dependencies - - Checkout the [registry on Docker Hub](https://hub.docker.com/repository/docker/tonyzamyatin/torch-uncertainty/general) for all available images. - -2. To start a container using this image, set up the necessary environment variables and run: - ```bash - docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \ - -e VM_SSH_PUBLIC_KEY="your-public-key" \ - -e GITHUB_SSH_PRIVATE_KEY="your-github-key" \ - -e GITHUB_USER="your-github-username" \ - -e GIT_USER_EMAIL="your-git-email" \ - -e GIT_USER_NAME="your-git-name" \ - docker.io/tonyzamyatin/torch-uncertainty - ``` - - Optionally, you can also set `-e USER_COMPACT_SHELL_PROMPT="true"` - to make the VM's shell prompts compact and colorized. - - **Note:** Some cloud providers offer templates, in which you can preconfigure - in advance which Docker image to pull and which environment variables to set. - In this case, the provider will pull the image, set all environment variables, - and start the container for you. - -3. Once your cloud provider has deployed the VM, it will display the host address and SSH port. - You can connect to the container via SSH using: - ```bash - ssh -i /path/to/private_key root@ -p - ``` - - Replace `` and `` with the values provided by your cloud provider, - and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`. - -4. The container exposes port `8888` in case you want to run Jupyter Notebooks or TensorBoard. - - **Note:** The `/workspace` directory is mounted from your local machine or cloud storage, - so changes persist across container restarts. - If using a cloud provider, ensure your network volume is correctly attached to avoid losing data. - -### Modifying and publishing custom Docker image - -If you want to make changes to the Dockerfile, follow these steps: -1. Edit the Dockerfile to fit your needs. - -2. Build the modified image: - ``` - docker build -t my-custom-image . - ``` - -3. Push to a Docker registry (if you want to use it on another VM): - ``` - docker tag my-custom-image mydockerhubuser/my-custom-image:tag - docker push mydockerhubuser/my-custom-image:tag - ``` - -4. Pull the custom image onto your VM: - ``` - docker pull mydockerhubuser/my-custom-image - ``` - -5. Run the container using the same docker run command with the new image name. +# 🐋 Docker image for contributors + +This Docker image is designed for users and contributors who want to run experiments with `torch-uncertainty` on remote virtual machines with GPU support. It is particularly useful for those who do not have access to a local GPU and need a pre-configured environment for development and experimentation. + +--- +## How to Use The Docker Image +### Step 1: Fork the Repository + +Before proceeding, ensure you have forked the `torch-uncertainty` repository to your own GitHub account. You can do this by visiting the [torch-uncertainty GitHub repository](https://github.com/ENSTA-U2IS-AI/torch-uncertainty) and clicking the **Fork** button in the top-right corner. + +Once forked, clone your forked repository to your local machine: +```bash +git clone git@github.com:/torch-uncertainty.git +cd torch-uncertainty +``` + +> ### ⚠️ IMPORTANT NOTE: Keep Your Fork Synced +> +> **To ensure that you are working with the latest stable version and bug fixes, you must manually sync your fork with the upstream repository before building the Docker image. Failure to sync your fork may result in outdated dependencies or missing bug fixes in the Docker image.** + +### Step 2: Build the Docker image locally +Build the modified image locally and push it to a Docker registry: +``` +docker build -t my-torch-uncertainty-docker:version . +docker push my-dockerhub-user/my-torch-uncertainty-image:version +``` +### Step 3: Set environment variables on your VM +Connect to you VM and set the following environment variables: +```bash +export VM_SSH_PUBLIC_KEY="$(cat ~/.ssh/id_rsa.pub)" +export GITHUB_SSH_PRIVATE_KEY="$(cat ~/.ssh/id_rsa)" +export GITHUB_USER="your-github-username" +export GIT_USER_EMAIL="your-email@example.com" +export GIT_USER_NAME="Your Name" +export USE_COMPACT_SHELL_PROMPT=true +``` + +Here is a brief explanation of the environment variables used in the Docker setup: +- **`VM_SSH_PUBLIC_KEY`**: The public SSH key used to authenticate with the container via SSH. +- **`GITHUB_SSH_PRIVATE_KEY`**: The private SSH key used to authenticate with GitHub for cloning and pushing repositories. +- **`GITHUB_USER`**: The GitHub username used to clone the repository during the first-time setup. +- **`GIT_USER_EMAIL`**: The email address associated with the Git configuration for commits. +- **`GIT_USER_NAME`**: The name associated with the Git configuration for commits. +- **`USE_COMPACT_SHELL_PROMPT`** (optional): Enables a compact and colorized shell prompt inside the container if set to `"true"`. + +### Step 4: Run the Docker container +First, authenticate with your Docker registry if you use a private registry. +Then run the following command to run the Docker image from your docker registriy +```bash +docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \ + -e VM_SSH_PUBLIC_KEY \ + -e GITHUB_SSH_PRIVATE_KEY \ + -e GITHUB_USER \ + -e GIT_USER_EMAIL \ + -e GIT_USER_NAME \ + -e USE_COMPACT_SHELL_PROMPT \ + docker.io/my-dockerhub-user/my-torch-uncertainty-image:version +``` + +### Step 5: Connect to your container +Once the container is up and running, you can connect to it via SSH: +```bash +ssh -i /path/to/private_key root@ -p +``` +Replace `` and `` with the host and port of your VM, +and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`. + +The container exposes port `8888` in case you want to run Jupyter Notebooks or TensorBoard. + +**Note:** The `/workspace` directory is mounted from your local machine or cloud storage, +so changes persist across container restarts. +If using a cloud provider, ensure your network volume is correctly attached to avoid losing data. + +## Remote Development + +This Docker setup also allows for remote development on the VM, since GitHub SSH access is set up and the whole repo is cloned to the VM from your GitHub fork. +For example, you can seamlessly connect your VS Code editor to your remote VM and run experiments, as if on your local machine but with the GPU acceleration of your VM. +See [VS Code Remote Development](https://code.visualstudio.com/docs/remote/remote-overview) for further details. + +## Streamline setup with your Cloud provider of choice + +Many cloud providers offer "templates" where you can specify a Docker image to use as a base. This means you can: + +1. Specify the Docker image from your Docker registry as the base image. +2. Preconfigure the necessary environment variables in the template. +3. Reuse the template any time you need to spin up a virtual machine for experiments. + +The cloud provider will handle setting the environment variables, pulling the Docker image, and spinning up the container for you. This approach simplifies the process and ensures consistency across experiments. From afa984ccc70cae8e5a889a847d3ce032a54b7bfb Mon Sep 17 00:00:00 2001 From: alafage Date: Tue, 1 Apr 2025 12:30:42 +0200 Subject: [PATCH 24/69] :hammer: Slight update of the `TULightningCLI` to support lightning<2.5.1 --- torch_uncertainty/utils/cli.py | 63 +++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 7ec97f86..14943cde 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -69,32 +69,61 @@ def __init__( trainer_defaults: dict[str, Any] | None = None, seed_everything_default: bool | int = True, parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, - parser_class: type[LightningArgumentParser] = LightningArgumentParser, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, run: bool = True, auto_configure_optimizers: bool = True, eval_after_fit_default: bool = False, + **kwargs: Any, ) -> None: """Custom LightningCLI for torch-uncertainty. Args: - model_class (type[LightningModule] | Callable[..., LightningModule] | None, optional): _description_. Defaults to None. - datamodule_class (type[LightningDataModule] | Callable[..., LightningDataModule] | None, optional): _description_. Defaults to None. - save_config_callback (type[SaveConfigCallback] | None, optional): _description_. Defaults to TUSaveConfigCallback. - save_config_kwargs (dict[str, Any] | None, optional): _description_. Defaults to None. - trainer_class (type[Trainer] | Callable[..., Trainer], optional): _description_. Defaults to Trainer. - trainer_defaults (dict[str, Any] | None, optional): _description_. Defaults to None. - seed_everything_default (bool | int, optional): _description_. Defaults to True. - parser_kwargs (dict[str, Any] | dict[str, dict[str, Any]] | None, optional): _description_. Defaults to None. - parser_class (type[LightningArgumentParser], optional): _description_. Defaults to ``LightningArgumentParser``. - subclass_mode_model (bool, optional): _description_. Defaults to False. - subclass_mode_data (bool, optional): _description_. Defaults to False. - args (ArgsType, optional): _description_. Defaults to None. - run (bool, optional): _description_. Defaults to True. - auto_configure_optimizers (bool, optional): _description_. Defaults to True. - eval_after_fit_default (bool, optional): _description_. Defaults to False. + model_class (type[LightningModule] | Callable[..., LightningModule] | None, optional): + An optional `LightningModule` class to train or a callable which returns a + ``LightningModule`` instance when called. If ``None``, you can pass a registered model + with ``--model=MyModel``. Defaults to ``None``. + datamodule_class (type[LightningDataModule] | Callable[..., LightningDataModule] | None, optional): + An optional ``LightningDataModule`` class or a callable which returns a ``LightningDataModule`` + instance when called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. + Defaults to ``None``. + save_config_callback (type[SaveConfigCallback] | None, optional): A callback class to + save the config. Defaults to ``TUSaveConfigCallback``. + save_config_kwargs (dict[str, Any] | None, optional): Parameters that will be used to + instantiate the save_config_callback. Defaults to ``None``. + trainer_class (type[Trainer] | Callable[..., Trainer], optional): An optional subclass + of the Trainer class or a callable which returns a ``Trainer`` instance when called. + Defaults to ``TUTrainer``. + trainer_defaults (dict[str, Any] | None, optional): Set to override Trainer defaults + or add persistent callbacks. The callbacks added through this argument will not + be configurable from a configuration file and will always be present for this + particular CLI. Alternatively, configurable callbacks can be added as explained + in the CLI docs. Defaults to ``None``. + seed_everything_default (bool | int, optional): Number for the ``seed_everything()`` + seed value. Set to ``True`` to automatically choose a seed value. Setting it to ``False`` + will avoid calling seed_everything. Defaults to ``True``. + parser_kwargs (dict[str, Any] | dict[str, dict[str, Any]] | None, optional): Additional + arguments to instantiate each ``LightningArgumentParser``. Defaults to + ``LightningArgumentParser``. Defaults to ``None``. + subclass_mode_model (bool, optional): Whether model can be any subclass of the given + class. Defaults to ``False``. + subclass_mode_data (bool, optional): Whether datamodule can be any subclass of the + given class. Defaults to ``False``. + args (ArgsType, optional): Arguments to parse. If `None` the arguments are taken from + ``sys.argv``. Command line style arguments can be given in a ``list``. Alternatively, + structured config options can be given in a ``dict`` or ``jsonargparse.Namespace``. + Defaults to `None`. + run (bool, optional): Whether subcommands should be added to run a ``Trainer`` method. If + set to `False`, the trainer and model classes will be instantiated only. Defaults + to ``True``. + auto_configure_optimizers (bool, optional): Defaults to ``True``. + eval_after_fit_default (bool, optional): Indicate. Defaults to False. + **kwargs: Additional keyword arguments to pass to the parent class added for + ``lightning>2.5.0``: + + - parser_class: The parser class to use. Defaults to `LightningArgumentParser`. + Available in ``lightning>=2.5.1`` """ self.eval_after_fit_default = eval_after_fit_default super().__init__( @@ -106,12 +135,12 @@ def __init__( trainer_defaults=trainer_defaults, seed_everything_default=seed_everything_default, parser_kwargs=parser_kwargs, - parser_class=parser_class, subclass_mode_model=subclass_mode_model, subclass_mode_data=subclass_mode_data, args=args, run=run, auto_configure_optimizers=auto_configure_optimizers, + **kwargs, ) def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: From 1dcaa96098594a51901b71f539fd0691ea5ebf8e Mon Sep 17 00:00:00 2001 From: alafage Date: Tue, 1 Apr 2025 17:26:46 +0200 Subject: [PATCH 25/69] :book: Fix TULightningCLI docstring --- torch_uncertainty/utils/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 14943cde..105176d1 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -118,7 +118,8 @@ def __init__( set to `False`, the trainer and model classes will be instantiated only. Defaults to ``True``. auto_configure_optimizers (bool, optional): Defaults to ``True``. - eval_after_fit_default (bool, optional): Indicate. Defaults to False. + eval_after_fit_default (bool, optional): Store whether an evaluation should be performed + after the training. Defaults to ``False``. **kwargs: Additional keyword arguments to pass to the parent class added for ``lightning>2.5.0``: From aacfcf884f9b7ac939ac61aae510909cefbeb512 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 10 Apr 2025 10:58:39 +0200 Subject: [PATCH 26/69] :hammer: Enable setting train and test transforms for Depth datamodules --- tests/datamodules/test_depth.py | 38 ++++++- torch_uncertainty/datamodules/depth/base.py | 109 ++++++++++++------- torch_uncertainty/datamodules/depth/kitti.py | 16 ++- torch_uncertainty/datamodules/depth/muad.py | 16 ++- torch_uncertainty/datamodules/depth/nyu.py | 16 ++- 5 files changed, 150 insertions(+), 45 deletions(-) diff --git a/tests/datamodules/test_depth.py b/tests/datamodules/test_depth.py index 733e2adf..3608e732 100644 --- a/tests/datamodules/test_depth.py +++ b/tests/datamodules/test_depth.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from tests._dummies.dataset import DummPixelRegressionDataset from torch_uncertainty.datamodules.depth import ( @@ -6,12 +7,47 @@ MUADDataModule, NYUv2DataModule, ) +from torch_uncertainty.datamodules.depth.base import DepthDataModule from torch_uncertainty.datasets import MUAD, KITTIDepth, NYUv2 class TestMUADDataModule: """Testing the MUADDataModule datamodule.""" + def test_depth_dm(self): + dm = DepthDataModule( + dataset=DummPixelRegressionDataset, + root="./data/", + batch_size=128, + min_depth=0, + max_depth=100, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) + + def test_depth_dm_failures(self): + with pytest.raises(ValueError): + DepthDataModule( + dataset=DummPixelRegressionDataset, + root="./data/", + batch_size=128, + min_depth=0, + max_depth=100, + eval_size=(224, 224), + ) + + with pytest.raises(ValueError): + DepthDataModule( + dataset=DummPixelRegressionDataset, + root="./data/", + batch_size=128, + min_depth=0, + max_depth=100, + crop_size=(224, 224), + ) + def test_muad_main(self): dm = MUADDataModule(root="./data/", min_depth=0, max_depth=100, batch_size=128) @@ -42,7 +78,7 @@ def test_muad_main(self): class TestNYUDataModule: - """Testing the MUADDataModule datamodule.""" + """Testing the NYUv2DataModule datamodule.""" def test_nyu_main(self): dm = NYUv2DataModule(root="./data/", max_depth=100, batch_size=128) diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 2abebf72..3677023b 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -1,6 +1,7 @@ from pathlib import Path import torch +from torch import nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair from torchvision import tv_tensors @@ -20,8 +21,10 @@ def __init__( batch_size: int, min_depth: float, max_depth: float, - crop_size: _size_2_t, - eval_size: _size_2_t, + crop_size: _size_2_t | None = None, + eval_size: _size_2_t | None = None, + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, eval_batch_size: int | None = None, val_split: float | None = None, num_workers: int = 1, @@ -44,12 +47,20 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``None``. eval_size (sequence or int, optional): Desired input image and depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Has to be provided if :attr:`test_transform` is not provided. Otherwise + has no effect. Defaults to ``None``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. val_split (float or None, optional): Share of training samples to use for validation. num_workers (int, optional): Number of dataloaders to use. @@ -69,41 +80,63 @@ def __init__( self.dataset = dataset self.min_depth = min_depth self.max_depth = max_depth - self.crop_size = _pair(crop_size) - self.eval_size = _pair(eval_size) - - self.train_transform = v2.Compose( - [ - RandomRescale(min_scale=0.5, max_scale=2.0), - v2.RandomCrop( - size=self.crop_size, - pad_if_needed=True, - fill={tv_tensors.Image: 0, tv_tensors.Mask: float("nan")}, - ), - v2.RandomHorizontalFlip(), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - self.test_transform = v2.Compose( - [ - v2.Resize(size=self.eval_size), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) + + if train_transform is not None: + self.crop_size = None + self.train_transform = train_transform + else: + if crop_size is None: + raise ValueError( + "crop_size must be provided if train_transform is not provided." + " Please provide a valid crop_size." + ) + + self.crop_size = _pair(crop_size) + + self.train_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: float("nan")}, + ), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + "others": None, + }, + scale=True, + ), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + if test_transform is not None: + self.eval_size = None + self.test_transform = test_transform + else: + if eval_size is None: + raise ValueError( + "eval_size must be provided if test_transform is not provided." + " Please provide a valid eval_size." + ) + + self.eval_size = _pair(eval_size) + self.test_transform = v2.Compose( + [ + v2.Resize(size=self.eval_size), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + "others": None, + }, + scale=True, + ), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) def prepare_data(self) -> None: # coverage: ignore self.dataset( diff --git a/torch_uncertainty/datamodules/depth/kitti.py b/torch_uncertainty/datamodules/depth/kitti.py index de50a51d..e15e184e 100644 --- a/torch_uncertainty/datamodules/depth/kitti.py +++ b/torch_uncertainty/datamodules/depth/kitti.py @@ -1,5 +1,6 @@ from pathlib import Path +from torch import nn from torch.nn.common_types import _size_2_t from torch_uncertainty.datasets import KITTIDepth @@ -17,6 +18,8 @@ def __init__( max_depth: float = 80.0, crop_size: _size_2_t = (352, 704), eval_size: _size_2_t = (375, 1242), + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -38,13 +41,20 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(375, 1242)``. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``(375, 1242)``. eval_size (sequence or int, optional): Desired input image and depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(375, 1242)``. + Has to be provided if :attr:`test_transform` is not provided. + Otherwise has no effect. Defaults to ``(375, 1242)``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -63,6 +73,8 @@ def __init__( max_depth=max_depth, crop_size=crop_size, eval_size=eval_size, + train_transform=train_transform, + test_transform=test_transform, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index 5e0c2b06..36d5a74d 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -1,5 +1,6 @@ from pathlib import Path +from torch import nn from torch.nn.common_types import _size_2_t from torch_uncertainty.datasets import MUAD @@ -17,6 +18,8 @@ def __init__( max_depth: float, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, eval_batch_size: int | None = None, val_split: float | None = None, num_workers: int = 1, @@ -38,13 +41,20 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``1024``. eval_size (sequence or int, optional): Desired input image and depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(1024,2048)``. + Has to be provided if :attr:`test_transform` is not provided. + Otherwise has no effect. Defaults to ``(1024,2048)``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -63,6 +73,8 @@ def __init__( max_depth=max_depth, crop_size=crop_size, eval_size=eval_size, + train_transform=train_transform, + test_transform=test_transform, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/nyu.py b/torch_uncertainty/datamodules/depth/nyu.py index cbfd6dd1..ee5d368c 100644 --- a/torch_uncertainty/datamodules/depth/nyu.py +++ b/torch_uncertainty/datamodules/depth/nyu.py @@ -1,5 +1,6 @@ from pathlib import Path +from torch import nn from torch.nn.common_types import _size_2_t from torch_uncertainty.datasets import NYUv2 @@ -17,6 +18,8 @@ def __init__( max_depth: float = 10.0, crop_size: _size_2_t = (416, 544), eval_size: _size_2_t = (480, 640), + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -38,13 +41,20 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(416, 544)``. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``(416, 544)``. eval_size (sequence or int, optional): Desired input image and depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(480, 640)``. + Has to be provided if :attr:`test_transform` is not provided. + Otherwise has no effect. Defaults to ``(480, 640)``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -63,6 +73,8 @@ def __init__( max_depth=max_depth, crop_size=crop_size, eval_size=eval_size, + train_transform=train_transform, + test_transform=test_transform, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, From 2586d9abb78d9d1e30fa45a3da4adaa7ab43ecb7 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 10 Apr 2025 11:12:44 +0200 Subject: [PATCH 27/69] :hammer: Enable setting train and test transforms in Segmentation datamodules #163 --- tests/datamodules/segmentation/test_camvid.py | 11 +- .../segmentation/test_cityscapes.py | 10 +- tests/datamodules/segmentation/test_muad.py | 11 +- .../datamodules/segmentation/camvid.py | 106 ++++++++++------- .../datamodules/segmentation/cityscapes.py | 110 ++++++++++-------- .../datamodules/segmentation/muad.py | 97 ++++++++------- 6 files changed, 210 insertions(+), 135 deletions(-) diff --git a/tests/datamodules/segmentation/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py index f9017228..7779e664 100644 --- a/tests/datamodules/segmentation/test_camvid.py +++ b/tests/datamodules/segmentation/test_camvid.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from tests._dummies.dataset import DummySegmentationDataset from torch_uncertainty.datamodules.segmentation import CamVidDataModule @@ -9,7 +10,15 @@ class TestCamVidDataModule: """Testing the CamVidDataModule datamodule.""" def test_camvid_main(self): - dm = CamVidDataModule(root="./data/", batch_size=128, group_classes=False) + dm = CamVidDataModule( + root="./data/", + batch_size=128, + group_classes=False, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) dm = CamVidDataModule(root="./data/", batch_size=128, basic_augment=False) assert dm.dataset == CamVid diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py index 4cb46709..aa247c59 100644 --- a/tests/datamodules/segmentation/test_cityscapes.py +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from tests._dummies.dataset import DummySegmentationDataset from torch_uncertainty.datamodules.segmentation import CityscapesDataModule @@ -9,7 +10,14 @@ class TestCityscapesDataModule: """Testing the CityscapesDataModule datamodule.""" def test_camvid_main(self): - dm = CityscapesDataModule(root="./data/", batch_size=128) + dm = CityscapesDataModule( + root="./data/", + batch_size=128, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) dm = CityscapesDataModule(root="./data/", batch_size=128, basic_augment=False) assert dm.dataset == Cityscapes diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py index 862206f0..5edcb5e6 100644 --- a/tests/datamodules/segmentation/test_muad.py +++ b/tests/datamodules/segmentation/test_muad.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from tests._dummies.dataset import DummySegmentationDataset from torch_uncertainty.datamodules.segmentation import MUADDataModule @@ -8,7 +9,15 @@ class TestMUADDataModule: """Testing the MUADDataModule datamodule.""" - def test_camvid_main(self): + def test_muad_main(self): + dm = MUADDataModule( + root="./data/", + batch_size=128, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) dm = MUADDataModule(root="./data/", batch_size=128) assert dm.dataset == MUAD diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index b028a48e..64ae0578 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -26,6 +26,8 @@ def __init__( eval_batch_size: int | None = None, crop_size: _size_2_t = 640, eval_size: _size_2_t = (720, 960), + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, group_classes: bool = True, basic_augment: bool = True, val_split: float | None = None, @@ -45,17 +47,24 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``640``. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``640``. eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(720,960)``. + Has to be provided if :attr:`test_transform` is not provided. + Otherwise has no effect. Defaults to ``(720,960)``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. group_classes (bool, optional): Whether to group the 32 classes into 11 superclasses. Default: ``True``. basic_augment (bool): Whether to apply base augmentations. Defaults to - ``True``. + ``True``. Only used if ``train_transform`` is not provided. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -66,7 +75,7 @@ def __init__( Defaults to ``True``. Note: - This datamodule injects the following transforms into the training and + By default this datamodule injects the following transforms into the training and validation/test datasets: .. code-block:: python @@ -87,8 +96,8 @@ def __init__( ] ) - This behavior can be modified by overriding ``self.train_transform`` - and ``self.test_transform`` after initialization. + This behavior can be modified by setting up ``train_transform`` + and ``test_transform`` at initialization. """ if val_split is not None: # coverage: ignore logging.warning("val_split is not used for CamVidDataModule.") @@ -111,50 +120,57 @@ def __init__( self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) - if basic_augment: - basic_transform = v2.Compose( + if train_transform is not None: + self.train_transform = train_transform + else: + if basic_augment: + basic_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + + self.train_transform = v2.Compose( [ - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop( - size=self.crop_size, - pad_if_needed=True, - fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + basic_transform, + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, ), - v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), - v2.RandomHorizontalFlip(), + v2.Normalize(mean=self.mean, std=self.std), ] ) + + if test_transform is not None: + self.test_transform = test_transform else: - basic_transform = nn.Identity() - - self.train_transform = v2.Compose( - [ - basic_transform, - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - self.test_transform = v2.Compose( - [ - v2.Resize(size=self.eval_size, antialias=True), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + self.test_transform = v2.Compose( + [ + v2.Resize(size=self.eval_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore self.dataset(root=self.root, download=True) diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 2328df64..c9963e6d 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -27,6 +27,8 @@ def __init__( eval_batch_size: int | None = None, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, basic_augment: bool = True, val_split: float | None = None, num_workers: int = 1, @@ -45,15 +47,22 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``1024``. eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(1024,2048)``. + Has to be provided if :attr:`test_transform` is not provided. + Otherwise has no effect. Defaults to ``(1024,2048)``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. basic_augment (bool): Whether to apply base augmentations. Defaults to - ``True``. + ``True``. Only used if ``train_transform`` is not provided. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -65,7 +74,7 @@ def __init__( Note: - This datamodule injects the following transforms into the training and + By default this datamodule injects the following transforms into the training and validation/test datasets: Training transforms: @@ -107,8 +116,8 @@ def __init__( std=[0.229, 0.224, 0.225]) ]) - This behavior can be modified by overriding ``self.train_transform`` - and ``self.test_transform`` after initialization. + This behavior can be modified by setting ``train_transform`` + and ``test_transform`` at initialization. """ super().__init__( root=root, @@ -124,52 +133,59 @@ def __init__( self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) - if basic_augment: - basic_transform = v2.Compose( + if train_transform is not None: + self.train_transform = train_transform + else: + if basic_augment: + basic_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + + self.train_transform = v2.Compose( [ - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop( - size=self.crop_size, - pad_if_needed=True, - fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + v2.ToImage(), + basic_transform, + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, ), - v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), - v2.RandomHorizontalFlip(), + v2.Normalize(mean=self.mean, std=self.std), ] ) + + if test_transform is not None: + self.test_transform = test_transform else: - basic_transform = nn.Identity() - - self.train_transform = v2.Compose( - [ - v2.ToImage(), - basic_transform, - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.Resize(size=self.eval_size, antialias=True), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(size=self.eval_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore self.dataset(root=self.root, split="train", mode=self.mode) diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 82c31555..61f9f31a 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -1,6 +1,7 @@ from pathlib import Path import torch +from torch import nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair from torchvision import tv_tensors @@ -24,6 +25,8 @@ def __init__( eval_batch_size: int | None = None, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -41,13 +44,20 @@ def __init__( int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if + :attr:`train_transform` is not provided. Otherwise has no effect. + Defaults to ``1024``. eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(1024,2048)``. + Has to be provided if :attr:`test_transform` is not provided. + Otherwise has no effect. Defaults to ``(1024,2048)``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -59,7 +69,7 @@ def __init__( Note: - This datamodule injects the following transforms into the training and + By default this datamodule injects the following transforms into the training and validation/test datasets: Training transforms: @@ -101,8 +111,8 @@ def __init__( std=[0.229, 0.224, 0.225]) ]) - This behavior can be modified by overriding ``self.train_transform`` - and ``self.test_transform`` after initialization. + This behavior can be modified by setting up ``train_transform`` + and ``test_transform`` at initialization. """ super().__init__( root=root, @@ -118,41 +128,48 @@ def __init__( self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) - self.train_transform = v2.Compose( - [ - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop( - size=self.crop_size, - pad_if_needed=True, - fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, - ), - v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), - v2.RandomHorizontalFlip(), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - self.test_transform = v2.Compose( - [ - v2.Resize(size=self.eval_size, antialias=True), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + if train_transform is not None: + self.train_transform = train_transform + else: + self.train_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + + if test_transform is not None: + self.test_transform = test_transform + else: + self.test_transform = v2.Compose( + [ + v2.Resize(size=self.eval_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore self.dataset(root=self.root, split="train", target_type="semantic", download=True) From db09be65a7fd9ee2a44a97b6c9b8ae10032a7c21 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 10 Apr 2025 11:30:40 +0200 Subject: [PATCH 28/69] :hammer: Enable setting train and test transforms in Classification datamodules Fix #163 --- .../classification/test_cifar10.py | 10 ++ .../classification/test_cifar100.py | 10 ++ .../classification/test_imagenet.py | 11 +- .../datamodules/classification/test_mnist.py | 12 +++ .../classification/test_tiny_imagenet.py | 10 +- .../datamodules/classification/cifar10.py | 81 ++++++++------ .../datamodules/classification/cifar100.py | 91 +++++++++------- .../datamodules/classification/imagenet.py | 101 ++++++++++-------- .../datamodules/classification/mnist.py | 65 +++++++---- .../classification/tiny_imagenet.py | 69 +++++++----- .../datamodules/segmentation/cityscapes.py | 2 +- 11 files changed, 296 insertions(+), 166 deletions(-) diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index 0a962c30..b928a8a0 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from torchvision.datasets import CIFAR10 from tests._dummies.dataset import DummyClassificationDataset @@ -10,6 +11,15 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" def test_cifar10_main(self): + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) + dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16, postprocess_set="test") assert dm.dataset == CIFAR10 diff --git a/tests/datamodules/classification/test_cifar100.py b/tests/datamodules/classification/test_cifar100.py index 47394cfd..0f87f722 100644 --- a/tests/datamodules/classification/test_cifar100.py +++ b/tests/datamodules/classification/test_cifar100.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from torchvision.datasets import CIFAR100 from tests._dummies.dataset import DummyClassificationDataset @@ -10,6 +11,15 @@ class TestCIFAR100DataModule: """Testing the CIFAR100DataModule datamodule class.""" def test_cifar100(self): + dm = CIFAR100DataModule( + root="./data/", + batch_size=128, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) + dm = CIFAR100DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR100 diff --git a/tests/datamodules/classification/test_imagenet.py b/tests/datamodules/classification/test_imagenet.py index 9ee5ad7c..00110b89 100644 --- a/tests/datamodules/classification/test_imagenet.py +++ b/tests/datamodules/classification/test_imagenet.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +from torch import nn from torchvision.datasets import ImageNet from tests._dummies.dataset import DummyClassificationDataset @@ -11,7 +12,15 @@ class TestImageNetDataModule: """Testing the ImageNetDataModule datamodule class.""" def test_imagenet(self): - dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=0.1) + dm = ImageNetDataModule( + root="./data/", + batch_size=128, + val_split=0.1, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) assert dm.dataset == ImageNet dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index f6ab4f8e..9088bf98 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -11,6 +11,18 @@ class TestMNISTDataModule: """Testing the MNISTDataModule datamodule class.""" def test_mnist_cutout(self): + dm = MNISTDataModule( + root="./data/", + batch_size=128, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + eval_ood=True, + ood_transform=nn.Identity(), + ) + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) + assert isinstance(dm.ood_transform, nn.Identity) + dm = MNISTDataModule( root="./data/", batch_size=128, diff --git a/tests/datamodules/classification/test_tiny_imagenet.py b/tests/datamodules/classification/test_tiny_imagenet.py index 8826d849..5b0f3045 100644 --- a/tests/datamodules/classification/test_tiny_imagenet.py +++ b/tests/datamodules/classification/test_tiny_imagenet.py @@ -1,4 +1,5 @@ import pytest +from torch import nn from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.datamodules import TinyImageNetDataModule @@ -9,9 +10,16 @@ class TestTinyImageNetDataModule: """Testing the TinyImageNetDataModule datamodule class.""" def test_tiny_imagenet(self): - dm = TinyImageNetDataModule(root="./data/", batch_size=128) + dm = TinyImageNetDataModule( + root="./data/", + batch_size=128, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) assert dm.dataset == TinyImageNet + assert isinstance(dm.train_transform, nn.Identity) + assert isinstance(dm.test_transform, nn.Identity) dm = TinyImageNetDataModule( root="./data/", diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index c92ba2cf..5dec3fab 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -36,8 +36,11 @@ def __init__( val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, basic_augment: bool = True, cutout: int | None = None, + randaugment: bool = False, auto_augment: str | None = None, test_alt: Literal["h"] | None = None, num_dataloaders: int = 1, @@ -59,12 +62,18 @@ def __init__( use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. basic_augment (bool): Whether to apply base augmentations. Defaults to - ``True``. + ``True``. Only used if ``train_transform`` is not provided. cutout (int): Size of cutout to apply to images. Defaults to ``None``. + Only used if ``train_transform`` is not provided. randaugment (bool): Whether to apply RandAugment. Defaults to - ``False``. + ``False``. Only used if ``train_transform`` is not provided. auto_augment (str): Which auto-augment to apply. Defaults to ``None``. + Only used if ``train_transform`` is not provided. test_alt (str): Which test set to use. Defaults to ``None``. shift_severity (int): Severity of corruption to apply for CIFAR10-C. Defaults to ``1``. @@ -100,46 +109,54 @@ def __init__( self.ood_dataset = SVHN self.shift_dataset = CIFAR10C - if (cutout is not None) + int(auto_augment is not None) > 1: + if (cutout is not None) + randaugment + int(auto_augment is not None) > 1: raise ValueError( "Only one data augmentation can be chosen at a time. Raise a " "GitHub issue if needed." ) - if basic_augment: - basic_transform = v2.Compose( + if train_transform is not None: + self.train_transform = train_transform + else: + if basic_augment: + basic_transform = v2.Compose( + [ + v2.RandomCrop(32, padding=4), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + + if cutout: + main_transform = Cutout(cutout) + elif randaugment: + main_transform = v2.RandAugment(num_ops=2, magnitude=20) + elif auto_augment: + main_transform = rand_augment_transform(auto_augment, {}) + else: + main_transform = nn.Identity() + + self.train_transform = v2.Compose( [ - v2.RandomCrop(32, padding=4), - v2.RandomHorizontalFlip(), + v2.ToImage(), + basic_transform, + main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - else: - basic_transform = nn.Identity() - if cutout: - main_transform = Cutout(cutout) - elif auto_augment: - main_transform = rand_augment_transform(auto_augment, {}) + if test_transform is not None: + self.test_transform = test_transform else: - main_transform = nn.Identity() - - self.train_transform = v2.Compose( - [ - v2.ToImage(), - basic_transform, - main_transform, - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore if self.test_alt is None: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 884ef9d4..1e2a0b6e 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -35,6 +35,8 @@ def __init__( shift_severity: int = 1, val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, basic_augment: bool = True, cutout: int | None = None, randaugment: bool = False, @@ -59,12 +61,18 @@ def __init__( to ``0.0``. postprocess_set (str, optional): The post-hoc calibration dataset to use for the post-processing method. Defaults to ``val``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. basic_augment (bool): Whether to apply base augmentations. Defaults to - ``True``. + ``True``. Only used if train_transform is not provided. cutout (int): Size of cutout to apply to images. Defaults to ``None``. + Only used if train_transform is not provided. randaugment (bool): Whether to apply RandAugment. Defaults to - ``False``. + ``False``. Only used if train_transform is not provided. auto_augment (str): Which auto-augment to apply. Defaults to ``None``. + Only used if train_transform is not provided. shift_severity (int): Severity of corruption to apply to CIFAR100-C. Defaults to ``1``. num_dataloaders (int): Number of dataloaders to use. Defaults to ``1``. @@ -95,47 +103,54 @@ def __init__( self.shift_severity = shift_severity - if (cutout is not None) + randaugment + int(auto_augment is not None) > 1: - raise ValueError( - "Only one data augmentation can be chosen at a time. Raise a " - "GitHub issue if needed." - ) + if train_transform is not None: + self.train_transform = train_transform + else: + if (cutout is not None) + randaugment + int(auto_augment is not None) > 1: + raise ValueError( + "Only one data augmentation can be chosen at a time. Raise a " + "GitHub issue if needed." + ) + + if basic_augment: + basic_transform = v2.Compose( + [ + v2.RandomCrop(32, padding=4), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + + if cutout: + main_transform = Cutout(cutout) + elif randaugment: + main_transform = v2.RandAugment(num_ops=2, magnitude=20) + elif auto_augment: + main_transform = rand_augment_transform(auto_augment, {}) + else: + main_transform = nn.Identity() - if basic_augment: - basic_transform = v2.Compose( + self.train_transform = v2.Compose( [ - v2.RandomCrop(32, padding=4), - v2.RandomHorizontalFlip(), + v2.ToImage(), + basic_transform, + main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) + + if test_transform is not None: + self.test_transform = test_transform else: - basic_transform = nn.Identity() - - if cutout: - main_transform = Cutout(cutout) - elif randaugment: - main_transform = v2.RandAugment(num_ops=2, magnitude=20) - elif auto_augment: - main_transform = rand_augment_transform(auto_augment, {}) - else: - main_transform = nn.Identity() - - self.train_transform = v2.Compose( - [ - v2.ToImage(), - basic_transform, - main_transform, - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore self.dataset(self.root, train=True, download=True) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 30e78602..d1b228f0 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -52,6 +52,8 @@ def __init__( shift_severity: int = 1, val_split: float | Path | None = None, postprocess_set: Literal["val", "test"] = "val", + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, ood_ds: str = "openimage-o", test_alt: str | None = None, procedure: str | None = None, @@ -81,16 +83,23 @@ def __init__( ids. Defaults to ``0.0``. postprocess_set (str, optional): The post-hoc calibration dataset to use for the post-processing method. Defaults to ``val``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. ood_ds (str): Which out-of-distribution dataset to use. Defaults to ``"openimage-o"``. test_alt (str): Which test set to use. Defaults to ``None``. procedure (str): Which procedure to use. Defaults to ``None``. + Only used if ``train_transform`` is not provided. train_size (int): Size of training images. Defaults to ``224``. interpolation (str): Interpolation method for the Resize Crops. - Defaults to ``"bilinear"``. + Defaults to ``"bilinear"``. Only used if ``train_transform`` is not + provided. basic_augment (bool): Whether to apply base augmentations. Defaults to - ``True``. + ``True``. Only used if ``train_transform`` is not provided. rand_augment_opt (str): Which RandAugment to use. Defaults to ``None``. + Only used if ``train_transform`` is not provided. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. pin_memory (bool): Whether to pin memory. Defaults to ``True``. @@ -146,54 +155,60 @@ def __init__( self.procedure = procedure - if basic_augment: - basic_transform = v2.Compose( - [ - v2.RandomResizedCrop(train_size, interpolation=self.interpolation), - v2.RandomHorizontalFlip(), - ] - ) + if train_transform is not None: + self.train_transform = train_transform else: - basic_transform = nn.Identity() + if basic_augment: + basic_transform = v2.Compose( + [ + v2.RandomResizedCrop(train_size, interpolation=self.interpolation), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() - if self.procedure is None: - if rand_augment_opt is not None: - main_transform = rand_augment_transform(rand_augment_opt, {}) + if self.procedure is None: + if rand_augment_opt is not None: + main_transform = rand_augment_transform(rand_augment_opt, {}) + else: + main_transform = nn.Identity() + elif self.procedure == "ViT": + train_size = 224 + main_transform = v2.Compose( + [ + Mixup(mixup_alpha=0.2, cutmix_alpha=1.0), + rand_augment_transform("rand-m9-n2-mstd0.5", {}), + ] + ) + elif self.procedure == "A3": + train_size = 160 + main_transform = rand_augment_transform("rand-m6-mstd0.5-inc1", {}) else: - main_transform = nn.Identity() - elif self.procedure == "ViT": - train_size = 224 - main_transform = v2.Compose( + raise ValueError("The procedure is unknown") + + self.train_transform = v2.Compose( [ - Mixup(mixup_alpha=0.2, cutmix_alpha=1.0), - rand_augment_transform("rand-m9-n2-mstd0.5", {}), + v2.ToImage(), + basic_transform, + main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - elif self.procedure == "A3": - train_size = 160 - main_transform = rand_augment_transform("rand-m6-mstd0.5-inc1", {}) - else: - raise ValueError("The procedure is unknown") - self.train_transform = v2.Compose( - [ - v2.ToImage(), - basic_transform, - main_transform, - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.Resize(256, interpolation=self.interpolation), - v2.CenterCrop(224), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + if test_transform is not None: + self.test_transform = test_transform + else: + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(256, interpolation=self.interpolation), + v2.CenterCrop(224), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def _verify_splits(self, split: str) -> None: if split not in list(self.root.iterdir()): diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 06963bf9..8d2b06a9 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -33,6 +33,9 @@ def __init__( val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, + ood_transform: nn.Module | None = None, basic_augment: bool = True, cutout: int | None = None, pin_memory: bool = True, @@ -56,6 +59,13 @@ def __init__( use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. + ood_transform (nn.Module | None): Custom transform for out-of-distribution + datasets. Defaults to ``None``. If not provided, a default transform + is used. basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. @@ -89,37 +99,48 @@ def __init__( self.shift_dataset = MNISTC self.shift_severity = 1 - basic_transform = v2.RandomCrop(28, padding=4) if basic_augment else nn.Identity() + if train_transform is not None: + self.train_transform = train_transform + else: + basic_transform = v2.RandomCrop(28, padding=4) if basic_augment else nn.Identity() - main_transform = Cutout(cutout) if cutout else nn.Identity() + main_transform = Cutout(cutout) if cutout else nn.Identity() - self.train_transform = v2.Compose( - [ - v2.ToImage(), - basic_transform, - main_transform, - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.CenterCrop(28), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - if self.eval_ood: # NotMNIST has 3 channels - self.ood_transform = v2.Compose( + self.train_transform = v2.Compose( + [ + v2.ToImage(), + basic_transform, + main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + + if test_transform is not None: + self.test_transform = test_transform + else: + self.test_transform = v2.Compose( [ v2.ToImage(), - v2.Grayscale(num_output_channels=1), v2.CenterCrop(28), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) + if self.eval_ood: + if ood_transform is not None: + self.ood_transform = ood_transform + else: + # NotMNIST has 3 channels + self.ood_transform = v2.Compose( + [ + v2.ToImage(), + v2.Grayscale(num_output_channels=1), + v2.CenterCrop(28), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def prepare_data(self) -> None: # coverage: ignore """Download the datasets.""" diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 80ee6e1b..52598553 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -40,6 +40,8 @@ def __init__( shift_severity: int = 1, val_split: float | None = None, postprocess_set: Literal["val", "test"] = "val", + train_transform: nn.Module | None = None, + test_transform: nn.Module | None = None, ood_ds: str = "svhn", interpolation: str = "bilinear", basic_augment: bool = True, @@ -66,6 +68,10 @@ def __init__( ids. Defaults to ``0.0``. postprocess_set (str, optional): The post-hoc calibration dataset to use for the post-processing method. Defaults to ``val``. + train_transform (nn.Module | None): Custom training transform. Defaults + to ``None``. If not provided, a default transform is used. + test_transform (nn.Module | None): Custom test transform. Defaults to + ``None``. If not provided, a default transform is used. ood_ds (str): Which out-of-distribution dataset to use. Defaults to ``"openimage-o"``. test_alt (str): Which test set to use. Defaults to ``None``. @@ -108,39 +114,46 @@ def __init__( else: raise ValueError(f"OOD dataset {ood_ds} not supported for TinyImageNet.") self.shift_dataset = TinyImageNetC - if basic_augment: - basic_transform = v2.Compose( + + if train_transform is not None: + self.train_transform = train_transform + else: + if basic_augment: + basic_transform = v2.Compose( + [ + v2.RandomCrop(64, padding=4), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + + if rand_augment_opt is not None: + main_transform = rand_augment_transform(rand_augment_opt, {}) + else: + main_transform = nn.Identity() + + self.train_transform = v2.Compose( [ - v2.RandomCrop(64, padding=4), - v2.RandomHorizontalFlip(), + v2.ToImage(), + basic_transform, + main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - else: - basic_transform = nn.Identity() - if rand_augment_opt is not None: - main_transform = rand_augment_transform(rand_augment_opt, {}) + if test_transform is not None: + self.test_transform = test_transform else: - main_transform = nn.Identity() - - self.train_transform = v2.Compose( - [ - v2.ToImage(), - basic_transform, - main_transform, - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) - - self.test_transform = v2.Compose( - [ - v2.ToImage(), - v2.Resize(64, interpolation=self.interpolation), - v2.ToDtype(dtype=torch.float32, scale=True), - v2.Normalize(mean=self.mean, std=self.std), - ] - ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(64, interpolation=self.interpolation), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) def _verify_splits(self, split: str) -> None: # coverage: ignore if split not in list(self.root.iterdir()): diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index c9963e6d..f5b8e01c 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -116,7 +116,7 @@ def __init__( std=[0.229, 0.224, 0.225]) ]) - This behavior can be modified by setting ``train_transform`` + This behavior can be modified by setting up ``train_transform`` and ``test_transform`` at initialization. """ super().__init__( From 0d5ba10a39c6950c059cb763f4af6cbe75ce0b72 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 10 Apr 2025 11:42:30 +0200 Subject: [PATCH 29/69] :white_check_mark: Improve coverage --- tests/datamodules/classification/test_cifar10.py | 2 +- tests/datamodules/classification/test_tiny_imagenet.py | 2 ++ tests/datamodules/segmentation/test_camvid.py | 1 + tests/datamodules/segmentation/test_cityscapes.py | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index b928a8a0..70dae652 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -63,9 +63,9 @@ def test_cifar10_main(self): dm = CIFAR10DataModule( root="./data/", batch_size=128, - cutout=16, num_dataloaders=2, val_split=0.1, + randaugment=True, ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/classification/test_tiny_imagenet.py b/tests/datamodules/classification/test_tiny_imagenet.py index 5b0f3045..f5800bc5 100644 --- a/tests/datamodules/classification/test_tiny_imagenet.py +++ b/tests/datamodules/classification/test_tiny_imagenet.py @@ -35,6 +35,8 @@ def test_tiny_imagenet(self): basic_augment=False, ) + dm = TinyImageNetDataModule(root="./data/", batch_size=128, ood_ds="openimage-o") + with pytest.raises(ValueError): TinyImageNetDataModule(root="./data/", batch_size=128, ood_ds="other") diff --git a/tests/datamodules/segmentation/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py index 7779e664..a9b3add8 100644 --- a/tests/datamodules/segmentation/test_camvid.py +++ b/tests/datamodules/segmentation/test_camvid.py @@ -19,6 +19,7 @@ def test_camvid_main(self): ) assert isinstance(dm.train_transform, nn.Identity) assert isinstance(dm.test_transform, nn.Identity) + dm = CamVidDataModule(root="./data/", batch_size=128, basic_augment=True) dm = CamVidDataModule(root="./data/", batch_size=128, basic_augment=False) assert dm.dataset == CamVid diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py index aa247c59..60b37014 100644 --- a/tests/datamodules/segmentation/test_cityscapes.py +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -18,6 +18,7 @@ def test_camvid_main(self): ) assert isinstance(dm.train_transform, nn.Identity) assert isinstance(dm.test_transform, nn.Identity) + dm = CityscapesDataModule(root="./data/", batch_size=128, basic_augment=True) dm = CityscapesDataModule(root="./data/", batch_size=128, basic_augment=False) assert dm.dataset == Cityscapes From 68376cb45d945431a28c5164b8f92cabfd7c834a Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 11:45:18 +0200 Subject: [PATCH 30/69] :construction: Implement specific ModelCheckpoint for Classification --- torch_uncertainty/callbacks/__init__.py | 0 .../callbacks/model_checkpoint.py | 26 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 torch_uncertainty/callbacks/__init__.py create mode 100644 torch_uncertainty/callbacks/model_checkpoint.py diff --git a/torch_uncertainty/callbacks/__init__.py b/torch_uncertainty/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/callbacks/model_checkpoint.py b/torch_uncertainty/callbacks/model_checkpoint.py new file mode 100644 index 00000000..3bdbee4d --- /dev/null +++ b/torch_uncertainty/callbacks/model_checkpoint.py @@ -0,0 +1,26 @@ +from lightning.pytorch.callbacks import Checkpoint, ModelCheckpoint + + +# FIXME: this is incomplete +class TUClsCheckpoint(Checkpoint): + """Custom ModelCheckpoint class for saving the best model based on validation loss.""" + + def __init__(self): + super().__init__() + self.callbacks = { + "acc": ModelCheckpoint( + filename="{epoch}-{step}-val_acc={val/cls/Acc:.2f}", + monitor="val/cls/Acc", + mode="max", + ), + "ece": ModelCheckpoint( + filename="{epoch}-{step}-val_ece={val/cal/ECE:.2f}", + monitor="val/cal/ECE", + mode="min", + ), + "brier": ModelCheckpoint( + filename="{epoch}-{step}-val_brier={val/cls/Brier:.2f}", + monitor="val/cls/Brier", + mode="min", + ), + } From 7de882e8afc4a064b4a69297bb829c0f62f27eaa Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 16:34:22 +0200 Subject: [PATCH 31/69] :hammer: Enable storing models on cpu in `deep_ensembles` --- tests/_dummies/model.py | 3 + tests/models/wrappers/test_deep_ensembles.py | 16 ++++ .../models/wrappers/deep_ensembles.py | 82 +++++++++++++++++-- 3 files changed, 96 insertions(+), 5 deletions(-) diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 02517617..45bcae38 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -39,6 +39,7 @@ def forward(self, x: Tensor) -> Tensor: torch.ones( (x.shape[0], 1), dtype=torch.float32, + device=x.device, ) ) ) @@ -50,6 +51,7 @@ def feats_forward(self, x: Tensor) -> Tensor: return torch.ones( (x.shape[0], 1), dtype=torch.float32, + device=x.device, ) @@ -92,6 +94,7 @@ def forward(self, x: Tensor) -> Tensor: self.image_size, ), dtype=torch.float32, + device=x.device, ) ) ) diff --git a/tests/models/wrappers/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py index 5c49f029..a2289c44 100644 --- a/tests/models/wrappers/test_deep_ensembles.py +++ b/tests/models/wrappers/test_deep_ensembles.py @@ -31,6 +31,22 @@ def test_list_singleton(self): with pytest.raises(ValueError): deep_ensembles([model_1], num_estimators=1) + def test_store_on_cpu(self): + model_1 = dummy_model(1, 10) + model_2 = dummy_model(1, 10) + + de = deep_ensembles([model_1, model_2], store_on_cpu=True) + de.to("cuda") + assert de.store_on_cpu + assert de.core_models[0].linear.weight.device == torch.device("cpu") + assert de.core_models[1].linear.weight.device == torch.device("cpu") + + inputs = torch.randn(3, 4, 1).cuda() + out = de(inputs) + assert out.device == inputs.device + assert de.core_models[0].linear.weight.device == torch.device("cpu") + assert de.core_models[1].linear.weight.device == torch.device("cpu") + def test_error_prob_regression(self): # The output dicts will have different keys model_1 = dummy_model(1, 2, dist_family="normal") diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index e7b4b014..2c0c528a 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -1,4 +1,5 @@ import copy +import warnings from typing import Literal import torch @@ -9,11 +10,14 @@ class _DeepEnsembles(nn.Module): def __init__( self, models: list[nn.Module], + store_on_cpu: bool = False, ) -> None: """Create a classification deep ensembles from a list of models.""" super().__init__() self.core_models = nn.ModuleList(models) self.num_estimators = len(models) + self.store_on_cpu = store_on_cpu + self.device = None def forward(self, x: torch.Tensor) -> torch.Tensor: r"""Return the logits of the ensemble. @@ -26,17 +30,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: where :math:`B` is the batch size, :math:`N` is the number of estimators, and :math:`C` is the number of classes. """ + if self.store_on_cpu: + preds = torch.tensor([], device=self.device) + for model in self.core_models: + model.to(self.device) + preds = torch.cat([preds, model.forward(x)], dim=0) + model.to("cpu") + return preds return torch.cat([model.forward(x) for model in self.core_models], dim=0) + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + self.device = device + + if self.store_on_cpu: + device = torch.device("cpu") + + if dtype is not None: + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError( + "nn.Module.to only accepts floating point or complex " + f"dtypes, but got desired dtype={dtype}" + ) + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "if a complex module does not work as expected.", + stacklevel=2, + ) + + def convert(t): + try: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + ) + except NotImplementedError as e: + if str(e) == "Cannot copy out of meta tensor; no data!": + raise NotImplementedError( + f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " + f"when moving module from meta to a different device." + ) from None + raise + + return self._apply(convert) + class _RegDeepEnsembles(_DeepEnsembles): def __init__( self, probabilistic: bool, models: list[nn.Module], + store_on_cpu: bool = False, ) -> None: """Create a regression deep ensembles from a list of models.""" - super().__init__(models) + super().__init__(models=models, store_on_cpu=store_on_cpu) self.probabilistic = probabilistic def forward(self, x: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: @@ -51,7 +110,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: :math:`*` is any other dimension. """ if self.probabilistic: - out = [model.forward(x) for model in self.core_models] + if self.store_on_cpu: + out = [] + for model in self.core_models: + model.to(self.device) + out.append(model.forward(x)) + model.to("cpu") + else: + out = [model.forward(x) for model in self.core_models] key_set = {tuple(o.keys()) for o in out} if len(key_set) != 1: raise ValueError("The output of the models must have the same keys.") @@ -67,6 +133,7 @@ def deep_ensembles( ] = "classification", probabilistic: bool | None = None, reset_model_parameters: bool = True, + store_on_cpu: bool = False, ) -> _DeepEnsembles: """Build a Deep Ensembles out of the original models. @@ -77,7 +144,10 @@ def deep_ensembles( Defaults to "classification". probabilistic (bool): Whether the regression model is probabilistic. reset_model_parameters (bool): Whether to reset the model parameters - when :attr:models is a module or a list of length 1. + when :attr:models is a module or a list of length 1. Defaults to ``True``. + store_on_cpu (bool): Whether to store the models on CPU. Defaults to ``False``. + This is useful for large models that do not fit in GPU memory. Only one + model will be stored on GPU at a time during forward. The rest will be stored on CPU. Returns: _DeepEnsembles: The ensembled model. @@ -118,9 +188,11 @@ def deep_ensembles( raise ValueError("num_estimators must be None if you provided a non-singleton list.") if task in ("classification", "segmentation"): - return _DeepEnsembles(models=models) + return _DeepEnsembles(models=models, store_on_cpu=store_on_cpu) if task in ("regression", "pixel_regression"): if probabilistic is None: raise ValueError("probabilistic must be specified for regression models.") - return _RegDeepEnsembles(probabilistic=probabilistic, models=models) + return _RegDeepEnsembles( + probabilistic=probabilistic, models=models, store_on_cpu=store_on_cpu + ) raise ValueError(f"Unknown task: {task}.") From 5bb320c5e928f5c8c1c71af7b4b7d928323a4c97 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 18:13:36 +0200 Subject: [PATCH 32/69] :bug: Fix ``_DeepEnsembles.forward()`` --- torch_uncertainty/models/wrappers/deep_ensembles.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 2c0c528a..d92ba7a7 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -17,7 +17,6 @@ def __init__( self.core_models = nn.ModuleList(models) self.num_estimators = len(models) self.store_on_cpu = store_on_cpu - self.device = None def forward(self, x: torch.Tensor) -> torch.Tensor: r"""Return the logits of the ensemble. @@ -31,9 +30,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: estimators, and :math:`C` is the number of classes. """ if self.store_on_cpu: - preds = torch.tensor([], device=self.device) + preds = torch.tensor([], device=x.device) for model in self.core_models: - model.to(self.device) + model.to(x.device) preds = torch.cat([preds, model.forward(x)], dim=0) model.to("cpu") return preds @@ -42,8 +41,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - self.device = device - if self.store_on_cpu: device = torch.device("cpu") @@ -113,7 +110,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: if self.store_on_cpu: out = [] for model in self.core_models: - model.to(self.device) + model.to(x.device) out.append(model.forward(x)) model.to("cpu") else: @@ -160,6 +157,10 @@ def deep_ensembles( ValueError: If :attr:num_estimators is defined while :attr:models is a (non-singleton) list. + Warning: + The :attr:`store_on_cpu` option is not supported for training. It is + only supported for inference. + References: Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep From fedcb811673904b3b4a3d46a5313e2ab185eeaf5 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 20:47:36 +0200 Subject: [PATCH 33/69] :hammer: Change metric logging names --- .../cifar10/configs/resnet.yaml | 4 +- .../cifar10/configs/resnet18/batched.yaml | 4 +- .../cifar10/configs/resnet18/masked.yaml | 4 +- .../cifar10/configs/resnet18/mimo.yaml | 4 +- .../cifar10/configs/resnet18/packed.yaml | 4 +- .../cifar10/configs/resnet18/standard.yaml | 4 +- .../cifar10/configs/resnet50/batched.yaml | 4 +- .../cifar10/configs/resnet50/masked.yaml | 4 +- .../cifar10/configs/resnet50/mimo.yaml | 4 +- .../cifar10/configs/resnet50/packed.yaml | 4 +- .../cifar10/configs/resnet50/standard.yaml | 4 +- .../cifar10/configs/wideresnet28x10.yaml | 4 +- .../configs/wideresnet28x10/batched.yaml | 4 +- .../configs/wideresnet28x10/masked.yaml | 4 +- .../cifar10/configs/wideresnet28x10/mimo.yaml | 4 +- .../configs/wideresnet28x10/packed.yaml | 4 +- .../configs/wideresnet28x10/standard.yaml | 4 +- .../cifar100/configs/resnet.yaml | 4 +- .../cifar100/configs/resnet18/batched.yaml | 4 +- .../cifar100/configs/resnet18/masked.yaml | 4 +- .../cifar100/configs/resnet18/mimo.yaml | 4 +- .../cifar100/configs/resnet18/packed.yaml | 4 +- .../cifar100/configs/resnet18/standard.yaml | 4 +- .../cifar100/configs/resnet50/batched.yaml | 4 +- .../cifar100/configs/resnet50/masked.yaml | 4 +- .../cifar100/configs/resnet50/mimo.yaml | 4 +- .../cifar100/configs/resnet50/packed.yaml | 4 +- .../cifar100/configs/resnet50/standard.yaml | 4 +- .../configs/wideresnet28x10/standard.yaml | 4 +- .../mnist/configs/bayesian_lenet.yaml | 4 +- .../classification/mnist/configs/lenet.yaml | 5 +- .../mnist/configs/lenet_batch_ensemble.yaml | 4 +- .../configs/lenet_checkpoint_ensemble.yaml | 4 +- .../mnist/configs/lenet_deep_ensemble.yaml | 6 +- .../mnist/configs/lenet_ema.yaml | 4 +- .../mnist/configs/lenet_swa.yaml | 4 +- .../mnist/configs/lenet_swag.yaml | 4 +- .../configs/resnet18/standard.yaml | 4 +- experiments/depth/kitti/configs/bts.yaml | 2 +- experiments/depth/nyu/configs/bts.yaml | 2 +- .../configs/boston/mlp/laplace.yaml | 4 +- .../configs/boston/mlp/normal.yaml | 4 +- .../configs/boston/mlp/point_wise.yaml | 4 +- .../configs/concrete/mlp/laplace.yaml | 4 +- .../configs/concrete/mlp/normal.yaml | 4 +- .../configs/concrete/mlp/point_wise.yaml | 4 +- .../energy-efficiency/mlp/laplace.yaml | 4 +- .../configs/energy-efficiency/mlp/normal.yaml | 4 +- .../energy-efficiency/mlp/point_wise.yaml | 4 +- .../configs/kin8nm/mlp/laplace.yaml | 4 +- .../configs/kin8nm/mlp/normal.yaml | 4 +- .../configs/kin8nm/mlp/point_wise.yaml | 4 +- .../naval-propulsion-plant/mlp/laplace.yaml | 4 +- .../naval-propulsion-plant/mlp/normal.yaml | 4 +- .../mlp/point_wise.yaml | 4 +- .../configs/power-plant/mlp/laplace.yaml | 4 +- .../configs/power-plant/mlp/normal.yaml | 4 +- .../configs/power-plant/mlp/point_wise.yaml | 4 +- .../configs/protein/mlp/laplace.yaml | 4 +- .../configs/protein/mlp/normal.yaml | 4 +- .../configs/protein/mlp/point_wise.yaml | 4 +- .../configs/wine-quality-red/mlp/laplace.yaml | 4 +- .../configs/wine-quality-red/mlp/normal.yaml | 4 +- .../wine-quality-red/mlp/point_wise.yaml | 4 +- .../configs/yacht/mlp/laplace.yaml | 4 +- .../configs/yacht/mlp/normal.yaml | 4 +- .../configs/yacht/mlp/point_wise.yaml | 4 +- .../segmentation/camvid/configs/deeplab.yaml | 2 +- .../camvid/configs/segformer.yaml | 2 +- .../cityscapes/configs/deeplab.yaml | 2 +- .../cityscapes/configs/segformer.yaml | 2 +- tests/test_cli.py | 2 +- torch_uncertainty/routines/classification.py | 78 +++++++++---------- .../routines/pixel_regression.py | 34 ++++---- torch_uncertainty/routines/regression.py | 18 ++--- torch_uncertainty/routines/segmentation.py | 46 +++++------ torch_uncertainty/utils/evaluation_loop.py | 38 ++++----- 77 files changed, 246 insertions(+), 245 deletions(-) diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index feb656c8..bc9efb11 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index 69f1fea2..92c5785b 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index a989dc2d..895989b7 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index 187ec011..da25452b 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index 3e1e1dbe..9c716194 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index 2eb2586b..e45e1212 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index fc0cfeae..715729d0 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index 41ea41a3..57233649 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index 766b7371..3215d5a1 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index 9ffd0a90..b07af86c 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index 39b076e1..c17ebe48 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index 3cb97464..53c81230 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -14,7 +14,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -22,7 +22,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index 6ad00b9a..4a8ad472 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 3fecaf27..5d1caa1c 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index b71c670f..dd949e2c 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index cd45736c..637ae3e1 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index 65616694..1fb982e7 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index f61f467b..b5362462 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index ce2057dd..9bf5cef0 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index 36048d65..182bc4a8 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index ddd474c9..f0bbaa0c 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 6cf74dc5..33800cba 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index e62de94f..182a5815 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index 1884c845..62acdb3c 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index a58f4453..35a476df 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index 9acb534a..f85d5c61 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index 0e1f9185..b25231ca 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index a1f10fab..aa6b5760 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml index 44ccba6d..14fe4f84 100644 --- a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 70f5cf8e..44602262 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 3f8b63c2..28384c16 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: @@ -44,6 +44,7 @@ model: data: root: ./data batch_size: 128 + num_workers: 10 optimizer: lr: 0.05 momentum: 0.9 diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml index d385b100..03bf424a 100644 --- a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -16,7 +16,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -24,7 +24,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index 354b9bf7..50602726 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml index 1d47b782..b64e1f1b 100644 --- a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -16,7 +16,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -24,7 +24,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: @@ -61,7 +61,7 @@ model: data: root: ./data batch_size: 128 - num_workers: 127 + num_workers: 10 eval_ood: true eval_shift: true optimizer: diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index d453df0b..2e2332bd 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 09d7d506..e61bee01 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index e33d954f..773da587 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml index 330181df..d6bceaf6 100644 --- a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml +++ b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/cls/Acc + monitor: val_cls_Acc patience: 1000 check_finite: true model: diff --git a/experiments/depth/kitti/configs/bts.yaml b/experiments/depth/kitti/configs/bts.yaml index 3c20e048..692da6df 100644 --- a/experiments/depth/kitti/configs/bts.yaml +++ b/experiments/depth/kitti/configs/bts.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/SILog + monitor: val_reg_SILog mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/depth/nyu/configs/bts.yaml b/experiments/depth/nyu/configs/bts.yaml index 48f9d7db..0d7e99b2 100644 --- a/experiments/depth/nyu/configs/bts.yaml +++ b/experiments/depth/nyu/configs/bts.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/SILog + monitor: val_reg_SILog mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml index f8adbf90..abfc97ef 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml index 95eaface..ee7371d1 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml index 90daf59b..01fa9b1c 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml index b6ff80c6..7e869fb6 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml index 683333d6..070a9a73 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml index cff6fd10..9676663b 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml index 1837894f..ff02764c 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml index da02570b..a300499b 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml index cff6fd10..9676663b 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml index be42d710..b4f781e5 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml index b5553356..13a6dcad 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml index fdc8fc44..f58ec8d0 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml index b5b66dfc..9673c31d 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml index 92169a3c..6bf3603c 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml index d8e6eb8f..43a4bdef 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml index 4c2ffd85..f184dd90 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml index 6173c120..319377aa 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml index d0ec4670..469ee71f 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml index 8b794d78..ec9e61b9 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml index 82dc62ea..34424213 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml index b984e681..0fcec544 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml index 32275d18..366a057e 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml index 188b8fb2..85ff8971 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml index 522e280e..9b87ffe0 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml index 2d77e0cd..cf9d02ca 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml index c3641593..b92567b0 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/NLL + monitor: val_reg_NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml index e4002049..e4c2a788 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/reg/MSE + monitor: val_reg_MSE patience: 1000 check_finite: true model: diff --git a/experiments/segmentation/camvid/configs/deeplab.yaml b/experiments/segmentation/camvid/configs/deeplab.yaml index d8bcdc4c..195636bd 100644 --- a/experiments/segmentation/camvid/configs/deeplab.yaml +++ b/experiments/segmentation/camvid/configs/deeplab.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/seg/mIoU + monitor: val_seg_mIoU mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 96f200b3..d3119655 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/seg/mIoU + monitor: val_seg_mIoU mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml index 51cc2a1e..55858770 100644 --- a/experiments/segmentation/cityscapes/configs/deeplab.yaml +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/seg/mIoU + monitor: val_seg_mIoU mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index 145a96eb..f606686b 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -14,7 +14,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/seg/mIoU + monitor: val_seg_mIoU mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/tests/test_cli.py b/tests/test_cli.py index 8683a523..f215bcb3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,7 +27,7 @@ def test_cli_init(self): "--data.batch_size", "4", "--trainer.callbacks+=ModelCheckpoint", - "--trainer.callbacks.monitor=val/cls/Acc", + "--trainer.callbacks.monitor=val_cls_Acc", "--trainer.callbacks.mode=max", ] cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 4f790203..b189f958 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -191,47 +191,47 @@ def _init_metrics(self) -> None: task = "binary" if self.binary_cls else "multiclass" metrics_dict = { - "cls/Acc": Accuracy(task=task, num_classes=self.num_classes), - "cls/Brier": BrierScore(num_classes=self.num_classes), - "cls/NLL": CategoricalNLL(), - "cal/ECE": CalibrationError( + "cls_Acc": Accuracy(task=task, num_classes=self.num_classes), + "cls_Brier": BrierScore(num_classes=self.num_classes), + "cls_NLL": CategoricalNLL(), + "cal_ECE": CalibrationError( task=task, num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), - "cal/aECE": CalibrationError( + "cal_aECE": CalibrationError( task=task, adaptive=True, num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), - "sc/AURC": AURC(), - "sc/AUGRC": AUGRC(), - "sc/Cov@5Risk": CovAt5Risk(), - "sc/Risk@80Cov": RiskAt80Cov(), + "sc_AURC": AURC(), + "sc_AUGRC": AUGRC(), + "sc_Cov@5Risk": CovAt5Risk(), + "sc_Risk@80Cov": RiskAt80Cov(), } groups = [ - ["cls/Acc"], - ["cls/Brier"], - ["cls/NLL"], - ["cal/ECE", "cal/aECE"], - ["sc/AURC", "sc/AUGRC", "sc/Cov@5Risk", "sc/Risk@80Cov"], + ["cls_Acc"], + ["cls_Brier"], + ["cls_NLL"], + ["cal_ECE", "cal_aECE"], + ["sc_AURC", "sc_AUGRC", "sc_Cov@5Risk", "sc_Risk@80Cov"], ] if self.binary_cls: metrics_dict |= { - "cls/AUROC": BinaryAUROC(), - "cls/AUPR": BinaryAveragePrecision(), - "cls/FRP95": FPR95(pos_label=1), + "cls_AUROC": BinaryAUROC(), + "cls_AUPR": BinaryAveragePrecision(), + "cls_FRP95": FPR95(pos_label=1), } - groups.extend([["cls/AUROC", "cls/AUPR"], ["cls/FRP95"]]) + groups.extend([["cls_AUROC", "cls_AUPR"], ["cls_FRP95"]]) cls_metrics = MetricCollection(metrics_dict, compute_groups=groups) - self.val_cls_metrics = cls_metrics.clone(prefix="val/") - self.test_cls_metrics = cls_metrics.clone(prefix="test/") + self.val_cls_metrics = cls_metrics.clone(prefix="val_") + self.test_cls_metrics = cls_metrics.clone(prefix="test_") if self.post_processing is not None: - self.post_cls_metrics = cls_metrics.clone(prefix="test/post/") + self.post_cls_metrics = cls_metrics.clone(prefix="test_post_") self.test_id_entropy = Entropy() @@ -244,11 +244,11 @@ def _init_metrics(self) -> None: }, compute_groups=[["AUROC", "AUPR"], ["FPR95"]], ) - self.test_ood_metrics = ood_metrics.clone(prefix="ood/") + self.test_ood_metrics = ood_metrics.clone(prefix="ood_") self.test_ood_entropy = Entropy() if self.eval_shift: - self.test_shift_metrics = cls_metrics.clone(prefix="shift/") + self.test_shift_metrics = cls_metrics.clone(prefix="shift_") # metrics for ensembles only if self.is_ensemble: @@ -260,18 +260,18 @@ def _init_metrics(self) -> None: } ) - self.test_id_ens_metrics = ens_metrics.clone(prefix="test/ens_") + self.test_id_ens_metrics = ens_metrics.clone(prefix="test_ens_") if self.eval_ood: - self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") + self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood_ens_") if self.eval_shift: - self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens_") + self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift_ens_") if self.eval_grouping_loss: - grouping_loss = MetricCollection({"cls/grouping_loss": GroupingLoss()}) - self.val_grouping_loss = grouping_loss.clone(prefix="val/") - self.test_grouping_loss = grouping_loss.clone(prefix="test/") + grouping_loss = MetricCollection({"cls_grouping_loss": GroupingLoss()}) + self.val_grouping_loss = grouping_loss.clone(prefix="val_") + self.test_grouping_loss = grouping_loss.clone(prefix="test_") def _init_mixup(self, mixup_params: dict | None) -> Callable: """Setup the optional mixup augmentation based on the :attr:`mixup_params` dict. @@ -502,7 +502,7 @@ def test_step( self.log_dict(self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False) self.test_id_entropy(probs) self.log( - "test/cls/Entropy", + "test_cls_Entropy", self.test_id_entropy, on_epoch=True, add_dataloader_idx=False, @@ -529,7 +529,7 @@ def test_step( self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_ood_entropy(probs) self.log( - "ood/Entropy", + "ood_Entropy", self.test_ood_entropy, on_epoch=True, add_dataloader_idx=False, @@ -551,7 +551,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "Acc%", - res_dict["val/cls/Acc"] * 100, + res_dict["val_cls_Acc"] * 100, prog_bar=True, logger=False, sync_dist=True, @@ -568,7 +568,7 @@ def on_test_epoch_end(self) -> None: result_dict = self.test_cls_metrics.compute() # already logged - result_dict.update({"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True) + result_dict.update({"test_Entropy": self.test_id_entropy.compute()}, sync_dist=True) if self.post_processing is not None: tmp_metrics = self.post_cls_metrics.compute() @@ -592,7 +592,7 @@ def on_test_epoch_end(self) -> None: result_dict.update(tmp_metrics) # already logged - result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) + result_dict.update({"ood_Entropy": self.test_ood_entropy.compute()}) if self.is_ensemble: tmp_metrics = self.test_ood_ens_metrics.compute() @@ -602,7 +602,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() shift_severity = self.trainer.datamodule.shift_severity - tmp_metrics["shift/shift_severity"] = shift_severity + tmp_metrics["shift_shift_severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -613,21 +613,21 @@ def on_test_epoch_end(self) -> None: if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( - "Reliabity diagram", self.test_cls_metrics["cal/ECE"].plot()[0] + "Reliabity diagram", self.test_cls_metrics["cal_ECE"].plot()[0] ) self.logger.experiment.add_figure( "Risk-Coverage curve", - self.test_cls_metrics["sc/AURC"].plot()[0], + self.test_cls_metrics["sc_AURC"].plot()[0], ) self.logger.experiment.add_figure( "Generalized Risk-Coverage curve", - self.test_cls_metrics["sc/AUGRC"].plot()[0], + self.test_cls_metrics["sc_AUGRC"].plot()[0], ) if self.post_processing is not None: self.logger.experiment.add_figure( "Reliabity diagram after calibration", - self.post_cls_metrics["cal/ECE"].plot()[0], + self.post_cls_metrics["cal_ECE"].plot()[0], ) # plot histograms of logits and likelihoods diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 1db3009a..682759d6 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -114,28 +114,28 @@ def _init_metrics(self) -> None: """Initialize the metrics depending on the exact task.""" depth_metrics = MetricCollection( { - "reg/SILog": SILog(), - "reg/log10": Log10(), - "reg/ARE": MeanGTRelativeAbsoluteError(), - "reg/RSRE": MeanGTRelativeSquaredError(squared=False), - "reg/RMSE": MeanSquaredError(squared=False), - "reg/RMSELog": MeanSquaredLogError(squared=False), - "reg/iMAE": MeanAbsoluteErrorInverse(), - "reg/iRMSE": MeanSquaredErrorInverse(squared=False), - "reg/d1": ThresholdAccuracy(power=1), - "reg/d2": ThresholdAccuracy(power=2), - "reg/d3": ThresholdAccuracy(power=3), + "reg_SILog": SILog(), + "reg_log10": Log10(), + "reg_ARE": MeanGTRelativeAbsoluteError(), + "reg_RSRE": MeanGTRelativeSquaredError(squared=False), + "reg_RMSE": MeanSquaredError(squared=False), + "reg_RMSELog": MeanSquaredLogError(squared=False), + "reg_iMAE": MeanAbsoluteErrorInverse(), + "reg_iRMSE": MeanSquaredErrorInverse(squared=False), + "reg_d1": ThresholdAccuracy(power=1), + "reg_d2": ThresholdAccuracy(power=2), + "reg_d3": ThresholdAccuracy(power=3), }, compute_groups=False, ) - self.val_metrics = depth_metrics.clone(prefix="val/") - self.test_metrics = depth_metrics.clone(prefix="test/") + self.val_metrics = depth_metrics.clone(prefix="val_") + self.test_metrics = depth_metrics.clone(prefix="test_") if self.probabilistic: - depth_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")}) - self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/") - self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/") + depth_prob_metrics = MetricCollection({"reg_NLL": DistributionNLL(reduction="mean")}) + self.val_prob_metrics = depth_prob_metrics.clone(prefix="val_") + self.test_prob_metrics = depth_prob_metrics.clone(prefix="test_") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -322,7 +322,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "RMSE", - res_dict["val/reg/RMSE"], + res_dict["val_reg_RMSE"], prog_bar=True, logger=False, sync_dist=True, diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 88514d39..260e79bf 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -101,20 +101,20 @@ def _init_metrics(self) -> None: """Initialize the metrics depending on the exact task.""" reg_metrics = MetricCollection( { - "reg/MAE": MeanAbsoluteError(), - "reg/MSE": MeanSquaredError(squared=True), - "reg/RMSE": MeanSquaredError(squared=False), + "reg_MAE": MeanAbsoluteError(), + "reg_MSE": MeanSquaredError(squared=True), + "reg_RMSE": MeanSquaredError(squared=False), }, compute_groups=True, ) - self.val_metrics = reg_metrics.clone(prefix="val/") - self.test_metrics = reg_metrics.clone(prefix="test/") + self.val_metrics = reg_metrics.clone(prefix="val_") + self.test_metrics = reg_metrics.clone(prefix="test_") if self.probabilistic: - reg_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")}) - self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/") - self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/") + reg_prob_metrics = MetricCollection({"reg_NLL": DistributionNLL(reduction="mean")}) + self.val_prob_metrics = reg_prob_metrics.clone(prefix="val_") + self.test_prob_metrics = reg_prob_metrics.clone(prefix="test_") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -289,7 +289,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "RMSE", - res_dict["val/reg/RMSE"], + res_dict["valreg_RMSE"], prog_bar=True, logger=False, sync_dist=True, diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index a68a97ca..ad8ee47a 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -105,46 +105,46 @@ def _init_metrics(self) -> None: """Initialize the metrics depending on the exact task.""" seg_metrics = MetricCollection( { - "seg/mIoU": MeanIntersectionOverUnion(num_classes=self.num_classes), + "seg_mIoU": MeanIntersectionOverUnion(num_classes=self.num_classes), }, compute_groups=False, ) sbsmpl_seg_metrics = MetricCollection( { - "seg/mAcc": Accuracy( + "seg_mAcc": Accuracy( task="multiclass", average="macro", num_classes=self.num_classes ), - "seg/Brier": BrierScore(num_classes=self.num_classes), - "seg/NLL": CategoricalNLL(), - "seg/pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), - "cal/ECE": CalibrationError( + "seg_Brier": BrierScore(num_classes=self.num_classes), + "seg_NLL": CategoricalNLL(), + "seg_pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), + "cal_ECE": CalibrationError( task="multiclass", num_classes=self.num_classes, num_bins=self.num_bins_cal_err, ), - "cal/aECE": CalibrationError( + "cal_aECE": CalibrationError( task="multiclass", adaptive=True, num_classes=self.num_classes, num_bins=self.num_bins_cal_err, ), - "sc/AURC": AURC(), - "sc/AUGRC": AUGRC(), + "sc_AURC": AURC(), + "sc_AUGRC": AUGRC(), }, compute_groups=[ - ["seg/mAcc"], - ["seg/Brier"], - ["seg/NLL"], - ["seg/pixAcc"], - ["cal/ECE", "cal/aECE"], - ["sc/AURC", "sc/AUGRC"], + ["seg_mAcc"], + ["seg_Brier"], + ["seg_NLL"], + ["seg_pixAcc"], + ["cal_ECE", "cal_aECE"], + ["sc_AURC", "sc_AUGRC"], ], ) - self.val_seg_metrics = seg_metrics.clone(prefix="val/") - self.val_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="val/") - self.test_seg_metrics = seg_metrics.clone(prefix="test/") - self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test/") + self.val_seg_metrics = seg_metrics.clone(prefix="val_") + self.val_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="val_") + self.test_seg_metrics = seg_metrics.clone(prefix="test_") + self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test_") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -260,7 +260,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "mIoU%", - res_dict["val/seg/mIoU"] * 100, + res_dict["val_seg_mIoU"] * 100, prog_bar=True, sync_dist=True, ) @@ -275,15 +275,15 @@ def on_test_epoch_end(self) -> None: if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( "Calibration/Reliabity diagram", - self.test_sbsmpl_seg_metrics["cal/ECE"].plot()[0], + self.test_sbsmpl_seg_metrics["cal_ECE"].plot()[0], ) self.logger.experiment.add_figure( "Selective Classification/Risk-Coverage curve", - self.test_sbsmpl_seg_metrics["sc/AURC"].plot()[0], + self.test_sbsmpl_seg_metrics["sc_AURC"].plot()[0], ) self.logger.experiment.add_figure( "Selective Classification/Generalized Risk-Coverage curve", - self.test_sbsmpl_seg_metrics["sc/AUGRC"].plot()[0], + self.test_sbsmpl_seg_metrics["sc_AUGRC"].plot()[0], ) if self.trainer.datamodule is not None: self.log_segmentation_plots() diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index f3088103..5f4e8d88 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -35,56 +35,56 @@ def _add_row(table: Table, metric_name: str, value: Tensor) -> None: class TUEvaluationLoop(_EvaluationLoop): @staticmethod def _print_results(results: list[_OUT_DICT], stage: str) -> None: - # test/cls: Classification Metrics - # test/cal: Calibration Metrics + # test_cls: Classification Metrics + # test_cal: Calibration Metrics # ood: OOD Detection Metrics # shift: Distribution shift Metrics - # test/sc: Selective Classification Metrics - # test/post: Post-Processing Metrics - # test/seg: Segmentation Metrics + # test_sc: Selective Classification Metrics + # test_post: Post-Processing Metrics + # test_seg: Segmentation Metrics metrics = {} for result in results: for key, value in result.items(): - if key.startswith("test/cls"): + if key.startswith("test_cls"): if "cls" not in metrics: metrics["cls"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["cls"].update({metric_name: value}) - elif key.startswith("test/cal"): + elif key.startswith("test_cal"): if "cal" not in metrics: metrics["cal"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["cal"].update({metric_name: value}) elif key.startswith("ood"): if "ood" not in metrics: metrics["ood"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["ood"].update({metric_name: value}) elif key.startswith("shift"): if "shift" not in metrics: metrics["shift"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["shift"].update({metric_name: value}) - elif key.startswith("test/sc"): + elif key.startswith("test_sc"): if "sc" not in metrics: metrics["sc"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["sc"].update({metric_name: value}) - elif key.startswith("test/post"): + elif key.startswith("test_post"): if "post" not in metrics: metrics["post"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["post"].update({metric_name: value}) - elif key.startswith("test/seg"): + elif key.startswith("test_seg"): if "seg" not in metrics: metrics["seg"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["seg"].update({metric_name: value}) - elif key.startswith("test/reg"): + elif key.startswith("test_reg"): if "reg" not in metrics: metrics["reg"] = {} - metric_name = key.split("/")[-1] + metric_name = key.split("_")[-1] metrics["reg"].update({metric_name: value}) tables = [] From 85971105397fa09fde7cb4ebf356e46492414b7b Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 20:49:50 +0200 Subject: [PATCH 34/69] :sparkles: Add ``TUClsCheckpoint`` --- torch_uncertainty/callbacks/__init__.py | 2 + torch_uncertainty/callbacks/checkpoint.py | 83 +++++++++++++++++++ .../callbacks/model_checkpoint.py | 26 ------ 3 files changed, 85 insertions(+), 26 deletions(-) create mode 100644 torch_uncertainty/callbacks/checkpoint.py delete mode 100644 torch_uncertainty/callbacks/model_checkpoint.py diff --git a/torch_uncertainty/callbacks/__init__.py b/torch_uncertainty/callbacks/__init__.py index e69de29b..5561d42f 100644 --- a/torch_uncertainty/callbacks/__init__.py +++ b/torch_uncertainty/callbacks/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .checkpoint import TUClsCheckpoint diff --git a/torch_uncertainty/callbacks/checkpoint.py b/torch_uncertainty/callbacks/checkpoint.py new file mode 100644 index 00000000..1ec95656 --- /dev/null +++ b/torch_uncertainty/callbacks/checkpoint.py @@ -0,0 +1,83 @@ +from typing import Any + +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import Checkpoint, ModelCheckpoint +from typing_extensions import override + + +class TUCheckpoint(Checkpoint): + callbacks: dict[str, Checkpoint] + + @override + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + for callback in self.callbacks.values(): + callback.setup(trainer=trainer, pl_module=pl_module, stage=stage) + + @override + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + for callback in self.callbacks.values(): + callback.on_train_start(trainer=trainer, pl_module=pl_module) + + @override + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: dict, + batch: dict, + batch_idx: int, + ) -> None: + for callback in self.callbacks.values(): + callback.on_train_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_idx=batch_idx, + ) + + @override + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + for callback in self.callbacks.values(): + callback.on_train_epoch_end(trainer=trainer, pl_module=pl_module) + + @override + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + for callback in self.callbacks.values(): + callback.on_validation_epoch_end(trainer=trainer, pl_module=pl_module) + + @override + def load_state_dict(self, state_dict: dict[str, dict[str, Any]]) -> None: + for key, callback in self.callbacks.items(): + callback.load_state_dict(state_dict=state_dict[key]) + + +class TUClsCheckpoint(TUCheckpoint): + def __init__(self): + super().__init__() + self.callbacks = { + "acc": ModelCheckpoint( + filename="{epoch}-{step}-{val_cls_Acc:.2f}", + monitor="val_cls_Acc", + mode="max", + ), + "ece": ModelCheckpoint( + filename="{epoch}-{step}-{val_cal_ECE:.2f}", + monitor="val_cal_ECE", + mode="min", + ), + "brier": ModelCheckpoint( + filename="{epoch}-{step}-{val_cls_Brier:.2f}", + monitor="val_cls_Brier", + mode="min", + ), + "nll": ModelCheckpoint( + filename="{epoch}-{step}-{val_cls_NLL:.2f}", + monitor="val_cls_NLL", + mode="min", + ), + } + + @override + def state_dict(self) -> dict[str, dict[str, Any]]: + return {key: callback.state_dict() for key, callback in self.callbacks.items()} diff --git a/torch_uncertainty/callbacks/model_checkpoint.py b/torch_uncertainty/callbacks/model_checkpoint.py deleted file mode 100644 index 3bdbee4d..00000000 --- a/torch_uncertainty/callbacks/model_checkpoint.py +++ /dev/null @@ -1,26 +0,0 @@ -from lightning.pytorch.callbacks import Checkpoint, ModelCheckpoint - - -# FIXME: this is incomplete -class TUClsCheckpoint(Checkpoint): - """Custom ModelCheckpoint class for saving the best model based on validation loss.""" - - def __init__(self): - super().__init__() - self.callbacks = { - "acc": ModelCheckpoint( - filename="{epoch}-{step}-val_acc={val/cls/Acc:.2f}", - monitor="val/cls/Acc", - mode="max", - ), - "ece": ModelCheckpoint( - filename="{epoch}-{step}-val_ece={val/cal/ECE:.2f}", - monitor="val/cal/ECE", - mode="min", - ), - "brier": ModelCheckpoint( - filename="{epoch}-{step}-val_brier={val/cls/Brier:.2f}", - monitor="val/cls/Brier", - mode="min", - ), - } From 028161d90f1ed7d06ce7b58fedfd04ef3c2e1196 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 21:04:47 +0200 Subject: [PATCH 35/69] :wrench: Use GPU pytorch version --- .github/workflows/run-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index ce43aa90..5e24c16d 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -63,7 +63,7 @@ jobs: - name: Install dependencies if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision python3 -m pip install .[all] - name: Check style & format From 8db167a3ee8c83a8967318194fe18ca2285ec7e3 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 16 Apr 2025 22:37:17 +0200 Subject: [PATCH 36/69] :bug: Check whether cuda is available in ``test_deep_ensembles.py`` --- tests/models/wrappers/test_deep_ensembles.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/wrappers/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py index a2289c44..9bc02857 100644 --- a/tests/models/wrappers/test_deep_ensembles.py +++ b/tests/models/wrappers/test_deep_ensembles.py @@ -41,7 +41,9 @@ def test_store_on_cpu(self): assert de.core_models[0].linear.weight.device == torch.device("cpu") assert de.core_models[1].linear.weight.device == torch.device("cpu") - inputs = torch.randn(3, 4, 1).cuda() + device = "cuda" if torch.cuda.is_available() else "cpu" + + inputs = torch.randn(3, 4, 1).to(device) out = de(inputs) assert out.device == inputs.device assert de.core_models[0].linear.weight.device == torch.device("cpu") From 00ff02df7a94a8877a62d508cf5f7ce5c81cc67e Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 00:18:56 +0200 Subject: [PATCH 37/69] :hammer: Simplify ``_DeepEnsembles.to()`` method --- .../models/wrappers/deep_ensembles.py | 40 +------------------ 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index d92ba7a7..00ea959e 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -1,5 +1,4 @@ import copy -import warnings from typing import Literal import torch @@ -44,44 +43,7 @@ def to(self, *args, **kwargs): if self.store_on_cpu: device = torch.device("cpu") - if dtype is not None: - if not (dtype.is_floating_point or dtype.is_complex): - raise TypeError( - "nn.Module.to only accepts floating point or complex " - f"dtypes, but got desired dtype={dtype}" - ) - if dtype.is_complex: - warnings.warn( - "Complex modules are a new feature under active development whose design may change, " - "and some modules might not work as expected when using complex tensors as parameters or buffers. " - "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " - "if a complex module does not work as expected.", - stacklevel=2, - ) - - def convert(t): - try: - if convert_to_format is not None and t.dim() in (4, 5): - return t.to( - device, - dtype if t.is_floating_point() or t.is_complex() else None, - non_blocking, - memory_format=convert_to_format, - ) - return t.to( - device, - dtype if t.is_floating_point() or t.is_complex() else None, - non_blocking, - ) - except NotImplementedError as e: - if str(e) == "Cannot copy out of meta tensor; no data!": - raise NotImplementedError( - f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " - f"when moving module from meta to a different device." - ) from None - raise - - return self._apply(convert) + return super().to(device=device, dtype=dtype, non_blocking=non_blocking) class _RegDeepEnsembles(_DeepEnsembles): From e4af17bf19e2d2347dfbde371d4915bf77668ffa Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 00:21:27 +0200 Subject: [PATCH 38/69] :shirt: Improve coverage --- tests/models/wrappers/test_deep_ensembles.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/wrappers/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py index 9bc02857..4dc04bdd 100644 --- a/tests/models/wrappers/test_deep_ensembles.py +++ b/tests/models/wrappers/test_deep_ensembles.py @@ -36,19 +36,26 @@ def test_store_on_cpu(self): model_2 = dummy_model(1, 10) de = deep_ensembles([model_1, model_2], store_on_cpu=True) - de.to("cuda") + + device = "cuda" if torch.cuda.is_available() else "cpu" + + de.to(device) assert de.store_on_cpu assert de.core_models[0].linear.weight.device == torch.device("cpu") assert de.core_models[1].linear.weight.device == torch.device("cpu") - device = "cuda" if torch.cuda.is_available() else "cpu" - inputs = torch.randn(3, 4, 1).to(device) out = de(inputs) assert out.device == inputs.device assert de.core_models[0].linear.weight.device == torch.device("cpu") assert de.core_models[1].linear.weight.device == torch.device("cpu") + de = deep_ensembles([model_1, model_2], store_on_cpu=False) + de.to(device) + assert not de.store_on_cpu + assert de.core_models[0].linear.weight.device == inputs.device + assert de.core_models[1].linear.weight.device == inputs.device + def test_error_prob_regression(self): # The output dicts will have different keys model_1 = dummy_model(1, 2, dist_family="normal") From 22f80834d7bef6cadde5c6313883a380cc5a96ca Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 00:52:07 +0200 Subject: [PATCH 39/69] :bug: Fix metric logging names --- torch_uncertainty/routines/classification.py | 2 +- torch_uncertainty/routines/regression.py | 2 +- torch_uncertainty/utils/evaluation_loop.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index b189f958..1bcb6e4b 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -602,7 +602,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() shift_severity = self.trainer.datamodule.shift_severity - tmp_metrics["shift_shift_severity"] = shift_severity + tmp_metrics["shift_severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 260e79bf..1c7edc68 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -289,7 +289,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "RMSE", - res_dict["valreg_RMSE"], + res_dict["val_reg_RMSE"], prog_bar=True, logger=False, sync_dist=True, diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index 5f4e8d88..f5bf84b1 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -162,7 +162,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "shift" in metrics: table = Table() table.add_column(first_col_name, justify="center", style="cyan", width=12) - shift_severity = int(metrics["shift"]["shift_severity"]) + shift_severity = int(metrics["shift"]["severity"]) table.add_column( f"Distribution Shift lvl{shift_severity}", justify="center", @@ -171,7 +171,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: ) shift_metrics = OrderedDict(sorted(metrics["shift"].items())) for metric_name, value in shift_metrics.items(): - if metric_name == "shift_severity": + if metric_name == "severity": continue _add_row(table, metric_name, value) tables.append(table) From 9d63666435d818afc5a85b37633358b0db2e5398 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 00:52:43 +0200 Subject: [PATCH 40/69] :sparkles: Add ``TUSegCheckpoint`` callback --- torch_uncertainty/callbacks/checkpoint.py | 33 ++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/callbacks/checkpoint.py b/torch_uncertainty/callbacks/checkpoint.py index 1ec95656..8199e73f 100644 --- a/torch_uncertainty/callbacks/checkpoint.py +++ b/torch_uncertainty/callbacks/checkpoint.py @@ -46,6 +46,10 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) for callback in self.callbacks.values(): callback.on_validation_epoch_end(trainer=trainer, pl_module=pl_module) + @override + def state_dict(self) -> dict[str, dict[str, Any]]: + return {key: callback.state_dict() for key, callback in self.callbacks.items()} + @override def load_state_dict(self, state_dict: dict[str, dict[str, Any]]) -> None: for key, callback in self.callbacks.items(): @@ -78,6 +82,29 @@ def __init__(self): ), } - @override - def state_dict(self) -> dict[str, dict[str, Any]]: - return {key: callback.state_dict() for key, callback in self.callbacks.items()} + +class TUSegCheckpoint(TUCheckpoint): + def __init__(self): + super().__init__() + self.callbacks = { + "acc": ModelCheckpoint( + filename="{epoch}-{step}-{val_seg_mIoU:.2f}", + monitor="val_seg_mIoU", + mode="max", + ), + "ece": ModelCheckpoint( + filename="{epoch}-{step}-{val_cal_ECE:.2f}", + monitor="val_cal_ECE", + mode="min", + ), + "brier": ModelCheckpoint( + filename="{epoch}-{step}-{val_seg_Brier:.2f}", + monitor="val_seg_Brier", + mode="min", + ), + "nll": ModelCheckpoint( + filename="{epoch}-{step}-{val_seg_NLL:.2f}", + monitor="val_seg_NLL", + mode="min", + ), + } From 3945aec0c7a13179fb0351e9a52228a251cd6f5d Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 01:04:02 +0200 Subject: [PATCH 41/69] :shirt: Improve coverage --- tests/models/wrappers/test_deep_ensembles.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/models/wrappers/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py index 4dc04bdd..a854ad46 100644 --- a/tests/models/wrappers/test_deep_ensembles.py +++ b/tests/models/wrappers/test_deep_ensembles.py @@ -66,6 +66,25 @@ def test_error_prob_regression(self): with pytest.raises(ValueError): de(torch.randn(5, 1)) + def test_store_on_cpu_prob_regression(self): + # The output dicts will have different keys + model_1 = dummy_model(1, 2, dist_family="normal") + model_2 = dummy_model(1, 2, dist_family="normal") + + de = deep_ensembles( + [model_1, model_2], task="regression", probabilistic=True, store_on_cpu=True + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + de.to(device) + assert de.store_on_cpu + assert de.core_models[0].linear.weight.device == torch.device("cpu") + assert de.core_models[1].linear.weight.device == torch.device("cpu") + inputs = torch.randn(3, 4, 1).to(device) + out = de(inputs) + assert out["loc"].device == inputs.device + assert de.core_models[0].linear.weight.device == torch.device("cpu") + assert de.core_models[1].linear.weight.device == torch.device("cpu") + def test_errors(self): model_1 = dummy_model(1, 10) with pytest.raises(ValueError): From 8e0d43103f5e9bf8b2077e198498dc810115c10d Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 14:13:12 +0200 Subject: [PATCH 42/69] :sparkles: Add `TURegCheckpoint`callback --- torch_uncertainty/callbacks/checkpoint.py | 36 ++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/callbacks/checkpoint.py b/torch_uncertainty/callbacks/checkpoint.py index 8199e73f..25b298de 100644 --- a/torch_uncertainty/callbacks/checkpoint.py +++ b/torch_uncertainty/callbacks/checkpoint.py @@ -55,6 +55,9 @@ def load_state_dict(self, state_dict: dict[str, dict[str, Any]]) -> None: for key, callback in self.callbacks.items(): callback.load_state_dict(state_dict=state_dict[key]) + @property + def best_model_path(self) -> str: ... + class TUClsCheckpoint(TUCheckpoint): def __init__(self): @@ -82,12 +85,16 @@ def __init__(self): ), } + @property + def best_model_path(self) -> str: + return self.callbacks["acc"].best_model_path + class TUSegCheckpoint(TUCheckpoint): def __init__(self): super().__init__() self.callbacks = { - "acc": ModelCheckpoint( + "miou": ModelCheckpoint( filename="{epoch}-{step}-{val_seg_mIoU:.2f}", monitor="val_seg_mIoU", mode="max", @@ -108,3 +115,30 @@ def __init__(self): mode="min", ), } + + @property + def best_model_path(self) -> str: + return self.callbacks["miou"].best_model_path + + +class TURegCheckpoint(TUCheckpoint): + def __init__(self, probabilistic: bool = False): + super().__init__() + self.callbacks = { + "mse": ModelCheckpoint( + filename="{epoch}-{step}-{val_reg_MSE:.2f}", + monitor="val_reg_MSE", + mode="min", + ), + } + + if probabilistic: + self.callbacks["nll"] = ModelCheckpoint( + filename="{epoch}-{step}-{val_reg_NLL:.2f}", + monitor="val_reg_NLL", + mode="min", + ) + + @property + def best_model_path(self) -> str: + return self.callbacks["mse"].best_model_path From c4e823208fcf4fc0eccdb631a14bacb757678e23 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 17 Apr 2025 14:13:53 +0200 Subject: [PATCH 43/69] :wrench: Update all ModelCheckpoint callbacks in configs --- experiments/classification/cifar10/configs/resnet.yaml | 6 +----- .../classification/cifar10/configs/resnet18/batched.yaml | 6 +----- .../classification/cifar10/configs/resnet18/masked.yaml | 6 +----- .../classification/cifar10/configs/resnet18/mimo.yaml | 6 +----- .../classification/cifar10/configs/resnet18/packed.yaml | 6 +----- .../classification/cifar10/configs/resnet18/standard.yaml | 6 +----- .../classification/cifar10/configs/resnet50/batched.yaml | 6 +----- .../classification/cifar10/configs/resnet50/masked.yaml | 6 +----- .../classification/cifar10/configs/resnet50/mimo.yaml | 6 +----- .../classification/cifar10/configs/resnet50/packed.yaml | 6 +----- .../classification/cifar10/configs/resnet50/standard.yaml | 6 +----- .../classification/cifar10/configs/wideresnet28x10.yaml | 6 +----- .../cifar10/configs/wideresnet28x10/batched.yaml | 6 +----- .../cifar10/configs/wideresnet28x10/masked.yaml | 6 +----- .../cifar10/configs/wideresnet28x10/mimo.yaml | 6 +----- .../cifar10/configs/wideresnet28x10/packed.yaml | 6 +----- .../cifar10/configs/wideresnet28x10/standard.yaml | 6 +----- experiments/classification/cifar100/configs/resnet.yaml | 6 +----- .../classification/cifar100/configs/resnet18/batched.yaml | 6 +----- .../classification/cifar100/configs/resnet18/masked.yaml | 6 +----- .../classification/cifar100/configs/resnet18/mimo.yaml | 6 +----- .../classification/cifar100/configs/resnet18/packed.yaml | 6 +----- .../classification/cifar100/configs/resnet18/standard.yaml | 6 +----- .../classification/cifar100/configs/resnet50/batched.yaml | 6 +----- .../classification/cifar100/configs/resnet50/masked.yaml | 6 +----- .../classification/cifar100/configs/resnet50/mimo.yaml | 6 +----- .../classification/cifar100/configs/resnet50/packed.yaml | 6 +----- .../classification/cifar100/configs/resnet50/standard.yaml | 6 +----- .../cifar100/configs/wideresnet28x10/standard.yaml | 6 +----- .../classification/mnist/configs/bayesian_lenet.yaml | 6 +----- experiments/classification/mnist/configs/lenet.yaml | 6 +----- .../classification/mnist/configs/lenet_batch_ensemble.yaml | 6 +----- .../mnist/configs/lenet_checkpoint_ensemble.yaml | 6 +----- .../classification/mnist/configs/lenet_deep_ensemble.yaml | 6 +----- experiments/classification/mnist/configs/lenet_ema.yaml | 6 +----- experiments/classification/mnist/configs/lenet_swa.yaml | 6 +----- experiments/classification/mnist/configs/lenet_swag.yaml | 6 +----- .../tiny-imagenet/configs/resnet18/standard.yaml | 6 +----- .../regression/uci_datasets/configs/boston/mlp/laplace.yaml | 6 ++---- .../regression/uci_datasets/configs/boston/mlp/normal.yaml | 6 ++---- .../uci_datasets/configs/boston/mlp/point_wise.yaml | 6 +----- .../uci_datasets/configs/concrete/mlp/laplace.yaml | 6 ++---- .../uci_datasets/configs/concrete/mlp/normal.yaml | 6 ++---- .../uci_datasets/configs/concrete/mlp/point_wise.yaml | 6 +----- .../uci_datasets/configs/energy-efficiency/mlp/laplace.yaml | 6 ++---- .../uci_datasets/configs/energy-efficiency/mlp/normal.yaml | 6 ++---- .../configs/energy-efficiency/mlp/point_wise.yaml | 6 +----- .../regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml | 6 ++---- .../regression/uci_datasets/configs/kin8nm/mlp/normal.yaml | 6 ++---- .../uci_datasets/configs/kin8nm/mlp/point_wise.yaml | 6 +----- .../configs/naval-propulsion-plant/mlp/laplace.yaml | 6 ++---- .../configs/naval-propulsion-plant/mlp/normal.yaml | 6 ++---- .../configs/naval-propulsion-plant/mlp/point_wise.yaml | 6 +----- .../uci_datasets/configs/power-plant/mlp/laplace.yaml | 6 ++---- .../uci_datasets/configs/power-plant/mlp/normal.yaml | 6 ++---- .../uci_datasets/configs/power-plant/mlp/point_wise.yaml | 6 +----- .../uci_datasets/configs/protein/mlp/laplace.yaml | 6 ++---- .../regression/uci_datasets/configs/protein/mlp/normal.yaml | 6 ++---- .../uci_datasets/configs/protein/mlp/point_wise.yaml | 6 +----- .../uci_datasets/configs/wine-quality-red/mlp/laplace.yaml | 6 ++---- .../uci_datasets/configs/wine-quality-red/mlp/normal.yaml | 6 ++---- .../configs/wine-quality-red/mlp/point_wise.yaml | 6 +----- .../regression/uci_datasets/configs/yacht/mlp/laplace.yaml | 6 ++---- .../regression/uci_datasets/configs/yacht/mlp/normal.yaml | 6 ++---- .../uci_datasets/configs/yacht/mlp/point_wise.yaml | 6 +----- experiments/segmentation/camvid/configs/deeplab.yaml | 6 +----- experiments/segmentation/camvid/configs/segformer.yaml | 6 +----- experiments/segmentation/cityscapes/configs/deeplab.yaml | 6 +----- experiments/segmentation/cityscapes/configs/segformer.yaml | 6 +----- 69 files changed, 87 insertions(+), 327 deletions(-) diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index bc9efb11..d0cf4bff 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -11,11 +11,7 @@ trainer: save_dir: logs/resnet default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index 92c5785b..04046e62 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -13,11 +13,7 @@ trainer: name: batched default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 895989b7..88f06e25 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -13,11 +13,7 @@ trainer: name: masked default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index da25452b..7698cb22 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -13,11 +13,7 @@ trainer: name: mimo default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index 9c716194..e7e9449f 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -13,11 +13,7 @@ trainer: name: packed default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index e45e1212..0e4f9123 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index 715729d0..c5d5843b 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -13,11 +13,7 @@ trainer: name: batched default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index 57233649..a9f65652 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -13,11 +13,7 @@ trainer: name: masked default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index 3215d5a1..30e0f94a 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -13,11 +13,7 @@ trainer: name: mimo default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index b07af86c..f2d3f0b0 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -13,11 +13,7 @@ trainer: name: packed default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index c17ebe48..76708d83 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index 53c81230..42c0e8d2 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -12,11 +12,7 @@ trainer: save_dir: logs/wideresnet28x10 default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index 4a8ad472..814d2977 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -13,11 +13,7 @@ trainer: name: batched default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 5d1caa1c..458b7aa6 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -13,11 +13,7 @@ trainer: name: masked default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index dd949e2c..b304a564 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -13,11 +13,7 @@ trainer: name: mimo default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index 637ae3e1..c25c4ef7 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -13,11 +13,7 @@ trainer: name: packed default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index 1fb982e7..b89c354b 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index b5362462..8e61ed9e 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -11,11 +11,7 @@ trainer: save_dir: logs/ default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 9bf5cef0..eff53b5b 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -13,11 +13,7 @@ trainer: name: batched default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index 182bc4a8..ee08873c 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -13,11 +13,7 @@ trainer: name: masked default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index f0bbaa0c..e9151d4c 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -13,11 +13,7 @@ trainer: name: mimo default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 33800cba..65eda482 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -13,11 +13,7 @@ trainer: name: packed default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index 182a5815..d718677f 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index 62acdb3c..e7d8ac31 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -13,11 +13,7 @@ trainer: name: batched default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index 35a476df..89e97446 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -13,11 +13,7 @@ trainer: name: masked default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index f85d5c61..fda93bee 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -13,11 +13,7 @@ trainer: name: mimo default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index b25231ca..a8470e61 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -13,11 +13,7 @@ trainer: name: packed default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index aa6b5760..499eec28 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml index 14fe4f84..018f15cb 100644 --- a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 44602262..60f7a916 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 28384c16..59cca79d 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml index 03bf424a..494c2db4 100644 --- a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -14,11 +14,7 @@ trainer: name: batch_ensemble default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index 50602726..1aef7b5f 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml index b64e1f1b..cec1a47d 100644 --- a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -14,11 +14,7 @@ trainer: name: deep_ensemble default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 2e2332bd..100efdb4 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index e61bee01..b6dedfe8 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index 773da587..923ceb26 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml index d6bceaf6..a1c4e1bf 100644 --- a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml +++ b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_cls_Acc - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUClsCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml index abfc97ef..7f01833a 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml index ee7371d1..a3992c52 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml index 01fa9b1c..ebd5fae1 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml index 7e869fb6..d8283ca1 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml index 070a9a73..daab4000 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml index 9676663b..9c55090b 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml index ff02764c..b96dbb87 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml index a300499b..73c8172f 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml index 9676663b..9c55090b 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml index b4f781e5..5c744112 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml index 13a6dcad..d1343b25 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml index f58ec8d0..311d93d0 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml index 9673c31d..99401f2e 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml index 6bf3603c..ee90e040 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml index 43a4bdef..599e4a91 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml index f184dd90..35e4f6ed 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml index 319377aa..6e166b3e 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml index 469ee71f..db139899 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml index ec9e61b9..c17faf76 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml index 34424213..48c22f71 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml index 0fcec544..e62b4580 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml index 366a057e..0c18edd9 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml index 85ff8971..c767c49d 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml index 9b87ffe0..0caab837 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml index cf9d02ca..3478b589 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml index b92567b0..c46be6a5 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml @@ -13,11 +13,9 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint + - class_path: torch_uncertainty.callbacks.TURegCheckpoint init_args: - monitor: val_reg_NLL - mode: min - save_last: true + probabilistic: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml index e4c2a788..3ea5ac2e 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_reg_MSE - mode: min - save_last: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/segmentation/camvid/configs/deeplab.yaml b/experiments/segmentation/camvid/configs/deeplab.yaml index 195636bd..aad8633c 100644 --- a/experiments/segmentation/camvid/configs/deeplab.yaml +++ b/experiments/segmentation/camvid/configs/deeplab.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_seg_mIoU - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index d3119655..b3b2bcbf 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -11,11 +11,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_seg_mIoU - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml index 55858770..0a3c99df 100644 --- a/experiments/segmentation/cityscapes/configs/deeplab.yaml +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -13,11 +13,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_seg_mIoU - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index f606686b..6169c9e1 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -12,11 +12,7 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_seg_mIoU - mode: max - save_last: true + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: step From 798e612e1300f138a64e5eee12c8d0933279aaf7 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 23 Apr 2025 15:26:27 +0200 Subject: [PATCH 44/69] :hammer: Roll back metric name changes and use `auto_insert_metric_name=False` in TU checkpoint callbacks --- .../cifar10/configs/resnet.yaml | 2 +- .../cifar10/configs/resnet18/batched.yaml | 2 +- .../cifar10/configs/resnet18/masked.yaml | 2 +- .../cifar10/configs/resnet18/mimo.yaml | 2 +- .../cifar10/configs/resnet18/packed.yaml | 2 +- .../cifar10/configs/resnet18/standard.yaml | 2 +- .../cifar10/configs/resnet50/batched.yaml | 2 +- .../cifar10/configs/resnet50/masked.yaml | 2 +- .../cifar10/configs/resnet50/mimo.yaml | 2 +- .../cifar10/configs/resnet50/packed.yaml | 2 +- .../cifar10/configs/resnet50/standard.yaml | 2 +- .../cifar10/configs/wideresnet28x10.yaml | 2 +- .../configs/wideresnet28x10/batched.yaml | 2 +- .../configs/wideresnet28x10/masked.yaml | 2 +- .../cifar10/configs/wideresnet28x10/mimo.yaml | 2 +- .../configs/wideresnet28x10/packed.yaml | 2 +- .../configs/wideresnet28x10/standard.yaml | 2 +- .../cifar100/configs/resnet.yaml | 2 +- .../cifar100/configs/resnet18/batched.yaml | 2 +- .../cifar100/configs/resnet18/masked.yaml | 2 +- .../cifar100/configs/resnet18/mimo.yaml | 2 +- .../cifar100/configs/resnet18/packed.yaml | 2 +- .../cifar100/configs/resnet18/standard.yaml | 2 +- .../cifar100/configs/resnet50/batched.yaml | 2 +- .../cifar100/configs/resnet50/masked.yaml | 2 +- .../cifar100/configs/resnet50/mimo.yaml | 2 +- .../cifar100/configs/resnet50/packed.yaml | 2 +- .../cifar100/configs/resnet50/standard.yaml | 2 +- .../configs/wideresnet28x10/standard.yaml | 2 +- .../mnist/configs/bayesian_lenet.yaml | 2 +- .../classification/mnist/configs/lenet.yaml | 2 +- .../mnist/configs/lenet_batch_ensemble.yaml | 2 +- .../configs/lenet_checkpoint_ensemble.yaml | 2 +- .../mnist/configs/lenet_deep_ensemble.yaml | 2 +- .../mnist/configs/lenet_ema.yaml | 2 +- .../mnist/configs/lenet_swa.yaml | 2 +- .../mnist/configs/lenet_swag.yaml | 2 +- .../configs/resnet18/standard.yaml | 2 +- experiments/depth/kitti/configs/bts.yaml | 2 +- experiments/depth/nyu/configs/bts.yaml | 2 +- .../configs/boston/mlp/laplace.yaml | 2 +- .../configs/boston/mlp/normal.yaml | 2 +- .../configs/boston/mlp/point_wise.yaml | 2 +- .../configs/concrete/mlp/laplace.yaml | 2 +- .../configs/concrete/mlp/normal.yaml | 2 +- .../configs/concrete/mlp/point_wise.yaml | 2 +- .../energy-efficiency/mlp/laplace.yaml | 2 +- .../configs/energy-efficiency/mlp/normal.yaml | 2 +- .../energy-efficiency/mlp/point_wise.yaml | 2 +- .../configs/kin8nm/mlp/laplace.yaml | 2 +- .../configs/kin8nm/mlp/normal.yaml | 2 +- .../configs/kin8nm/mlp/point_wise.yaml | 2 +- .../naval-propulsion-plant/mlp/laplace.yaml | 2 +- .../naval-propulsion-plant/mlp/normal.yaml | 2 +- .../mlp/point_wise.yaml | 2 +- .../configs/power-plant/mlp/laplace.yaml | 2 +- .../configs/power-plant/mlp/normal.yaml | 2 +- .../configs/power-plant/mlp/point_wise.yaml | 2 +- .../configs/protein/mlp/laplace.yaml | 2 +- .../configs/protein/mlp/normal.yaml | 2 +- .../configs/protein/mlp/point_wise.yaml | 2 +- .../configs/wine-quality-red/mlp/laplace.yaml | 2 +- .../configs/wine-quality-red/mlp/normal.yaml | 2 +- .../wine-quality-red/mlp/point_wise.yaml | 2 +- .../configs/yacht/mlp/laplace.yaml | 2 +- .../configs/yacht/mlp/normal.yaml | 2 +- .../configs/yacht/mlp/point_wise.yaml | 2 +- tests/test_cli.py | 2 +- torch_uncertainty/callbacks/checkpoint.py | 50 +++++++----- torch_uncertainty/routines/classification.py | 78 +++++++++---------- .../routines/pixel_regression.py | 34 ++++---- torch_uncertainty/routines/regression.py | 18 ++--- torch_uncertainty/routines/segmentation.py | 46 +++++------ torch_uncertainty/utils/evaluation_loop.py | 38 ++++----- 74 files changed, 205 insertions(+), 195 deletions(-) diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index d0cf4bff..69f26c26 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -17,7 +17,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index 04046e62..6ebd6fbc 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 88f06e25..bd975993 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index 7698cb22..89663567 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index e7e9449f..2ba73ebc 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index 0e4f9123..74936350 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index c5d5843b..dfdc6da8 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index a9f65652..6f0d6c8d 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index 30e0f94a..f80d9a1f 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index f2d3f0b0..ce558338 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index 76708d83..1c19d948 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index 42c0e8d2..03f862a8 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -18,7 +18,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index 814d2977..16eee6a1 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 458b7aa6..cd245579 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index b304a564..f143e501 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index c25c4ef7..a6c38c49 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index b89c354b..2f76759c 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index 8e61ed9e..5592b890 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -17,7 +17,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index eff53b5b..ea1f0e54 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index ee08873c..743b54b0 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index e9151d4c..3cde0783 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 65eda482..15ab50e0 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index d718677f..892bedaa 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index e7d8ac31..45af3ca0 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index 89e97446..24613868 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index fda93bee..d86a7877 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index a8470e61..1b084007 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index 499eec28..6f31e824 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml index 018f15cb..4812465a 100644 --- a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 60f7a916..dab14676 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 59cca79d..f2078eac 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml index 494c2db4..f6fc4b9b 100644 --- a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -20,7 +20,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index 1aef7b5f..b33797d9 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml index cec1a47d..c5803e2d 100644 --- a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -20,7 +20,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 100efdb4..09873d61 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index b6dedfe8..8c374ff3 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index 923ceb26..9acc27d1 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml index a1c4e1bf..a4104021 100644 --- a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml +++ b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_cls_Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/depth/kitti/configs/bts.yaml b/experiments/depth/kitti/configs/bts.yaml index 692da6df..ee27a0b8 100644 --- a/experiments/depth/kitti/configs/bts.yaml +++ b/experiments/depth/kitti/configs/bts.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val_reg_SILog + monitor: val/reg/SILog mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/depth/nyu/configs/bts.yaml b/experiments/depth/nyu/configs/bts.yaml index 0d7e99b2..023869de 100644 --- a/experiments/depth/nyu/configs/bts.yaml +++ b/experiments/depth/nyu/configs/bts.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val_reg_SILog + monitor: val/reg/SILog mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml index 7f01833a..658623eb 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml index a3992c52..a8d3fed9 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml index ebd5fae1..291ae2af 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml index d8283ca1..4d441e2e 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml index daab4000..6bdabac0 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml index 9c55090b..e078d3ec 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml index b96dbb87..5af41cae 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml index 73c8172f..abc6fd38 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml index 9c55090b..e078d3ec 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml index 5c744112..c274098b 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml index d1343b25..2ecdee6d 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml index 311d93d0..94a25fe0 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml index 99401f2e..ac8b58b3 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml index ee90e040..1286d72a 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml index 599e4a91..efcdb7a9 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml index 35e4f6ed..95b87412 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml index 6e166b3e..ab0fd5d1 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml index db139899..11f1242e 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml index c17faf76..26c98034 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml index 48c22f71..1c03c5a9 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml index e62b4580..2fc24032 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml index 0c18edd9..c40efc24 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml index c767c49d..a246d27f 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml index 0caab837..8dbc05d8 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml index 3478b589..ff3b66cf 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml index c46be6a5..4d5e154b 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml index 3ea5ac2e..15fca9b0 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml @@ -19,7 +19,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val_reg_MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: diff --git a/tests/test_cli.py b/tests/test_cli.py index f215bcb3..8683a523 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,7 +27,7 @@ def test_cli_init(self): "--data.batch_size", "4", "--trainer.callbacks+=ModelCheckpoint", - "--trainer.callbacks.monitor=val_cls_Acc", + "--trainer.callbacks.monitor=val/cls/Acc", "--trainer.callbacks.mode=max", ] cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) diff --git a/torch_uncertainty/callbacks/checkpoint.py b/torch_uncertainty/callbacks/checkpoint.py index 25b298de..91abfe12 100644 --- a/torch_uncertainty/callbacks/checkpoint.py +++ b/torch_uncertainty/callbacks/checkpoint.py @@ -64,24 +64,28 @@ def __init__(self): super().__init__() self.callbacks = { "acc": ModelCheckpoint( - filename="{epoch}-{step}-{val_cls_Acc:.2f}", - monitor="val_cls_Acc", + filename="{epoch}-{step}-val_acc={val/cls/Acc:.2f}", + monitor="val/cls/Acc", mode="max", + auto_insert_metric_name=False, ), "ece": ModelCheckpoint( - filename="{epoch}-{step}-{val_cal_ECE:.2f}", - monitor="val_cal_ECE", + filename="{epoch}-{step}-val_ece={val/cal/ECE:.2f}", + monitor="val/cal/ECE", mode="min", + auto_insert_metric_name=False, ), "brier": ModelCheckpoint( - filename="{epoch}-{step}-{val_cls_Brier:.2f}", - monitor="val_cls_Brier", + filename="{epoch}-{step}-val_brier={val/cls/Brier:.2f}", + monitor="val/cls/Brier", mode="min", + auto_insert_metric_name=False, ), "nll": ModelCheckpoint( - filename="{epoch}-{step}-{val_cls_NLL:.2f}", - monitor="val_cls_NLL", + filename="{epoch}-{step}-val_nll={val/cls/NLL:.2f}", + monitor="val/cls/NLL", mode="min", + auto_insert_metric_name=False, ), } @@ -95,24 +99,28 @@ def __init__(self): super().__init__() self.callbacks = { "miou": ModelCheckpoint( - filename="{epoch}-{step}-{val_seg_mIoU:.2f}", - monitor="val_seg_mIoU", + filename="{epoch}-{step}-val_miou={val/seg/mIoU:.2f}", + monitor="val/seg/mIoU", mode="max", + auto_insert_metric_name=False, ), "ece": ModelCheckpoint( - filename="{epoch}-{step}-{val_cal_ECE:.2f}", - monitor="val_cal_ECE", + filename="{epoch}-{step}-val_ece={val/cal/ECE:.2f}", + monitor="val/cal/ECE", mode="min", + auto_insert_metric_name=False, ), "brier": ModelCheckpoint( - filename="{epoch}-{step}-{val_seg_Brier:.2f}", - monitor="val_seg_Brier", + filename="{epoch}-{step}-val_brier={val/seg/Brier:.2f}", + monitor="val/seg/Brier", mode="min", + auto_insert_metric_name=False, ), "nll": ModelCheckpoint( - filename="{epoch}-{step}-{val_seg_NLL:.2f}", - monitor="val_seg_NLL", + filename="{epoch}-{step}-val_nll={val/seg/NLL:.2f}", + monitor="val/seg/NLL", mode="min", + auto_insert_metric_name=False, ), } @@ -126,17 +134,19 @@ def __init__(self, probabilistic: bool = False): super().__init__() self.callbacks = { "mse": ModelCheckpoint( - filename="{epoch}-{step}-{val_reg_MSE:.2f}", - monitor="val_reg_MSE", + filename="{epoch}-{step}-val_mse{val/reg/MSE:.2f}", + monitor="val/reg/MSE", mode="min", + auto_insert_metric_name=False, ), } if probabilistic: self.callbacks["nll"] = ModelCheckpoint( - filename="{epoch}-{step}-{val_reg_NLL:.2f}", - monitor="val_reg_NLL", + filename="{epoch}-{step}-val_nll{val/reg/NLL:.2f}", + monitor="val/reg/NLL", mode="min", + auto_insert_metric_name=False, ) @property diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 1bcb6e4b..5119c316 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -191,47 +191,47 @@ def _init_metrics(self) -> None: task = "binary" if self.binary_cls else "multiclass" metrics_dict = { - "cls_Acc": Accuracy(task=task, num_classes=self.num_classes), - "cls_Brier": BrierScore(num_classes=self.num_classes), - "cls_NLL": CategoricalNLL(), - "cal_ECE": CalibrationError( + "cls/Acc": Accuracy(task=task, num_classes=self.num_classes), + "cls/Brier": BrierScore(num_classes=self.num_classes), + "cls/NLL": CategoricalNLL(), + "cal/ECE": CalibrationError( task=task, num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), - "cal_aECE": CalibrationError( + "cal/aECE": CalibrationError( task=task, adaptive=True, num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), - "sc_AURC": AURC(), - "sc_AUGRC": AUGRC(), - "sc_Cov@5Risk": CovAt5Risk(), - "sc_Risk@80Cov": RiskAt80Cov(), + "sc/AURC": AURC(), + "sc/AUGRC": AUGRC(), + "sc/Cov@5Risk": CovAt5Risk(), + "sc/Risk@80Cov": RiskAt80Cov(), } groups = [ - ["cls_Acc"], - ["cls_Brier"], - ["cls_NLL"], - ["cal_ECE", "cal_aECE"], - ["sc_AURC", "sc_AUGRC", "sc_Cov@5Risk", "sc_Risk@80Cov"], + ["cls/Acc"], + ["cls/Brier"], + ["cls/NLL"], + ["cal/ECE", "cal/aECE"], + ["sc/AURC", "sc/AUGRC", "sc/Cov@5Risk", "sc/Risk@80Cov"], ] if self.binary_cls: metrics_dict |= { - "cls_AUROC": BinaryAUROC(), - "cls_AUPR": BinaryAveragePrecision(), - "cls_FRP95": FPR95(pos_label=1), + "cls/AUROC": BinaryAUROC(), + "cls/AUPR": BinaryAveragePrecision(), + "cls/FRP95": FPR95(pos_label=1), } - groups.extend([["cls_AUROC", "cls_AUPR"], ["cls_FRP95"]]) + groups.extend([["cls/AUROC", "cls/AUPR"], ["cls/FRP95"]]) cls_metrics = MetricCollection(metrics_dict, compute_groups=groups) - self.val_cls_metrics = cls_metrics.clone(prefix="val_") - self.test_cls_metrics = cls_metrics.clone(prefix="test_") + self.val_cls_metrics = cls_metrics.clone(prefix="val/") + self.test_cls_metrics = cls_metrics.clone(prefix="test/") if self.post_processing is not None: - self.post_cls_metrics = cls_metrics.clone(prefix="test_post_") + self.post_cls_metrics = cls_metrics.clone(prefix="test/post/") self.test_id_entropy = Entropy() @@ -244,11 +244,11 @@ def _init_metrics(self) -> None: }, compute_groups=[["AUROC", "AUPR"], ["FPR95"]], ) - self.test_ood_metrics = ood_metrics.clone(prefix="ood_") + self.test_ood_metrics = ood_metrics.clone(prefix="ood/") self.test_ood_entropy = Entropy() if self.eval_shift: - self.test_shift_metrics = cls_metrics.clone(prefix="shift_") + self.test_shift_metrics = cls_metrics.clone(prefix="shift/") # metrics for ensembles only if self.is_ensemble: @@ -260,18 +260,18 @@ def _init_metrics(self) -> None: } ) - self.test_id_ens_metrics = ens_metrics.clone(prefix="test_ens_") + self.test_id_ens_metrics = ens_metrics.clone(prefix="test/ens/") if self.eval_ood: - self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood_ens_") + self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens/") if self.eval_shift: - self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift_ens_") + self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens/") if self.eval_grouping_loss: - grouping_loss = MetricCollection({"cls_grouping_loss": GroupingLoss()}) - self.val_grouping_loss = grouping_loss.clone(prefix="val_") - self.test_grouping_loss = grouping_loss.clone(prefix="test_") + grouping_loss = MetricCollection({"cls/grouping/loss": GroupingLoss()}) + self.val_grouping_loss = grouping_loss.clone(prefix="val/") + self.test_grouping_loss = grouping_loss.clone(prefix="test/") def _init_mixup(self, mixup_params: dict | None) -> Callable: """Setup the optional mixup augmentation based on the :attr:`mixup_params` dict. @@ -502,7 +502,7 @@ def test_step( self.log_dict(self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False) self.test_id_entropy(probs) self.log( - "test_cls_Entropy", + "test/cls/Entropy", self.test_id_entropy, on_epoch=True, add_dataloader_idx=False, @@ -529,7 +529,7 @@ def test_step( self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_ood_entropy(probs) self.log( - "ood_Entropy", + "ood/Entropy", self.test_ood_entropy, on_epoch=True, add_dataloader_idx=False, @@ -551,7 +551,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "Acc%", - res_dict["val_cls_Acc"] * 100, + res_dict["val/cls/Acc"] * 100, prog_bar=True, logger=False, sync_dist=True, @@ -568,7 +568,7 @@ def on_test_epoch_end(self) -> None: result_dict = self.test_cls_metrics.compute() # already logged - result_dict.update({"test_Entropy": self.test_id_entropy.compute()}, sync_dist=True) + result_dict.update({"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True) if self.post_processing is not None: tmp_metrics = self.post_cls_metrics.compute() @@ -592,7 +592,7 @@ def on_test_epoch_end(self) -> None: result_dict.update(tmp_metrics) # already logged - result_dict.update({"ood_Entropy": self.test_ood_entropy.compute()}) + result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) if self.is_ensemble: tmp_metrics = self.test_ood_ens_metrics.compute() @@ -602,7 +602,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() shift_severity = self.trainer.datamodule.shift_severity - tmp_metrics["shift_severity"] = shift_severity + tmp_metrics["shift/severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -613,21 +613,21 @@ def on_test_epoch_end(self) -> None: if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( - "Reliabity diagram", self.test_cls_metrics["cal_ECE"].plot()[0] + "Reliabity diagram", self.test_cls_metrics["cal/ECE"].plot()[0] ) self.logger.experiment.add_figure( "Risk-Coverage curve", - self.test_cls_metrics["sc_AURC"].plot()[0], + self.test_cls_metrics["sc/AURC"].plot()[0], ) self.logger.experiment.add_figure( "Generalized Risk-Coverage curve", - self.test_cls_metrics["sc_AUGRC"].plot()[0], + self.test_cls_metrics["sc/AUGRC"].plot()[0], ) if self.post_processing is not None: self.logger.experiment.add_figure( "Reliabity diagram after calibration", - self.post_cls_metrics["cal_ECE"].plot()[0], + self.post_cls_metrics["cal/ECE"].plot()[0], ) # plot histograms of logits and likelihoods diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 682759d6..1db3009a 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -114,28 +114,28 @@ def _init_metrics(self) -> None: """Initialize the metrics depending on the exact task.""" depth_metrics = MetricCollection( { - "reg_SILog": SILog(), - "reg_log10": Log10(), - "reg_ARE": MeanGTRelativeAbsoluteError(), - "reg_RSRE": MeanGTRelativeSquaredError(squared=False), - "reg_RMSE": MeanSquaredError(squared=False), - "reg_RMSELog": MeanSquaredLogError(squared=False), - "reg_iMAE": MeanAbsoluteErrorInverse(), - "reg_iRMSE": MeanSquaredErrorInverse(squared=False), - "reg_d1": ThresholdAccuracy(power=1), - "reg_d2": ThresholdAccuracy(power=2), - "reg_d3": ThresholdAccuracy(power=3), + "reg/SILog": SILog(), + "reg/log10": Log10(), + "reg/ARE": MeanGTRelativeAbsoluteError(), + "reg/RSRE": MeanGTRelativeSquaredError(squared=False), + "reg/RMSE": MeanSquaredError(squared=False), + "reg/RMSELog": MeanSquaredLogError(squared=False), + "reg/iMAE": MeanAbsoluteErrorInverse(), + "reg/iRMSE": MeanSquaredErrorInverse(squared=False), + "reg/d1": ThresholdAccuracy(power=1), + "reg/d2": ThresholdAccuracy(power=2), + "reg/d3": ThresholdAccuracy(power=3), }, compute_groups=False, ) - self.val_metrics = depth_metrics.clone(prefix="val_") - self.test_metrics = depth_metrics.clone(prefix="test_") + self.val_metrics = depth_metrics.clone(prefix="val/") + self.test_metrics = depth_metrics.clone(prefix="test/") if self.probabilistic: - depth_prob_metrics = MetricCollection({"reg_NLL": DistributionNLL(reduction="mean")}) - self.val_prob_metrics = depth_prob_metrics.clone(prefix="val_") - self.test_prob_metrics = depth_prob_metrics.clone(prefix="test_") + depth_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")}) + self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/") + self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -322,7 +322,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "RMSE", - res_dict["val_reg_RMSE"], + res_dict["val/reg/RMSE"], prog_bar=True, logger=False, sync_dist=True, diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 1c7edc68..88514d39 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -101,20 +101,20 @@ def _init_metrics(self) -> None: """Initialize the metrics depending on the exact task.""" reg_metrics = MetricCollection( { - "reg_MAE": MeanAbsoluteError(), - "reg_MSE": MeanSquaredError(squared=True), - "reg_RMSE": MeanSquaredError(squared=False), + "reg/MAE": MeanAbsoluteError(), + "reg/MSE": MeanSquaredError(squared=True), + "reg/RMSE": MeanSquaredError(squared=False), }, compute_groups=True, ) - self.val_metrics = reg_metrics.clone(prefix="val_") - self.test_metrics = reg_metrics.clone(prefix="test_") + self.val_metrics = reg_metrics.clone(prefix="val/") + self.test_metrics = reg_metrics.clone(prefix="test/") if self.probabilistic: - reg_prob_metrics = MetricCollection({"reg_NLL": DistributionNLL(reduction="mean")}) - self.val_prob_metrics = reg_prob_metrics.clone(prefix="val_") - self.test_prob_metrics = reg_prob_metrics.clone(prefix="test_") + reg_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")}) + self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/") + self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -289,7 +289,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "RMSE", - res_dict["val_reg_RMSE"], + res_dict["val/reg/RMSE"], prog_bar=True, logger=False, sync_dist=True, diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index ad8ee47a..a68a97ca 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -105,46 +105,46 @@ def _init_metrics(self) -> None: """Initialize the metrics depending on the exact task.""" seg_metrics = MetricCollection( { - "seg_mIoU": MeanIntersectionOverUnion(num_classes=self.num_classes), + "seg/mIoU": MeanIntersectionOverUnion(num_classes=self.num_classes), }, compute_groups=False, ) sbsmpl_seg_metrics = MetricCollection( { - "seg_mAcc": Accuracy( + "seg/mAcc": Accuracy( task="multiclass", average="macro", num_classes=self.num_classes ), - "seg_Brier": BrierScore(num_classes=self.num_classes), - "seg_NLL": CategoricalNLL(), - "seg_pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), - "cal_ECE": CalibrationError( + "seg/Brier": BrierScore(num_classes=self.num_classes), + "seg/NLL": CategoricalNLL(), + "seg/pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), + "cal/ECE": CalibrationError( task="multiclass", num_classes=self.num_classes, num_bins=self.num_bins_cal_err, ), - "cal_aECE": CalibrationError( + "cal/aECE": CalibrationError( task="multiclass", adaptive=True, num_classes=self.num_classes, num_bins=self.num_bins_cal_err, ), - "sc_AURC": AURC(), - "sc_AUGRC": AUGRC(), + "sc/AURC": AURC(), + "sc/AUGRC": AUGRC(), }, compute_groups=[ - ["seg_mAcc"], - ["seg_Brier"], - ["seg_NLL"], - ["seg_pixAcc"], - ["cal_ECE", "cal_aECE"], - ["sc_AURC", "sc_AUGRC"], + ["seg/mAcc"], + ["seg/Brier"], + ["seg/NLL"], + ["seg/pixAcc"], + ["cal/ECE", "cal/aECE"], + ["sc/AURC", "sc/AUGRC"], ], ) - self.val_seg_metrics = seg_metrics.clone(prefix="val_") - self.val_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="val_") - self.test_seg_metrics = seg_metrics.clone(prefix="test_") - self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test_") + self.val_seg_metrics = seg_metrics.clone(prefix="val/") + self.val_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="val/") + self.test_seg_metrics = seg_metrics.clone(prefix="test/") + self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test/") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -260,7 +260,7 @@ def on_validation_epoch_end(self) -> None: self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "mIoU%", - res_dict["val_seg_mIoU"] * 100, + res_dict["val/seg/mIoU"] * 100, prog_bar=True, sync_dist=True, ) @@ -275,15 +275,15 @@ def on_test_epoch_end(self) -> None: if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( "Calibration/Reliabity diagram", - self.test_sbsmpl_seg_metrics["cal_ECE"].plot()[0], + self.test_sbsmpl_seg_metrics["cal/ECE"].plot()[0], ) self.logger.experiment.add_figure( "Selective Classification/Risk-Coverage curve", - self.test_sbsmpl_seg_metrics["sc_AURC"].plot()[0], + self.test_sbsmpl_seg_metrics["sc/AURC"].plot()[0], ) self.logger.experiment.add_figure( "Selective Classification/Generalized Risk-Coverage curve", - self.test_sbsmpl_seg_metrics["sc_AUGRC"].plot()[0], + self.test_sbsmpl_seg_metrics["sc/AUGRC"].plot()[0], ) if self.trainer.datamodule is not None: self.log_segmentation_plots() diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index f5bf84b1..aebfd1ee 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -35,56 +35,56 @@ def _add_row(table: Table, metric_name: str, value: Tensor) -> None: class TUEvaluationLoop(_EvaluationLoop): @staticmethod def _print_results(results: list[_OUT_DICT], stage: str) -> None: - # test_cls: Classification Metrics - # test_cal: Calibration Metrics + # test/cls: Classification Metrics + # test/cal: Calibration Metrics # ood: OOD Detection Metrics # shift: Distribution shift Metrics - # test_sc: Selective Classification Metrics - # test_post: Post-Processing Metrics - # test_seg: Segmentation Metrics + # test/sc: Selective Classification Metrics + # test/post: Post-Processing Metrics + # test/seg: Segmentation Metrics metrics = {} for result in results: for key, value in result.items(): - if key.startswith("test_cls"): + if key.startswith("test/cls"): if "cls" not in metrics: metrics["cls"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["cls"].update({metric_name: value}) - elif key.startswith("test_cal"): + elif key.startswith("test/cal"): if "cal" not in metrics: metrics["cal"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["cal"].update({metric_name: value}) elif key.startswith("ood"): if "ood" not in metrics: metrics["ood"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["ood"].update({metric_name: value}) elif key.startswith("shift"): if "shift" not in metrics: metrics["shift"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["shift"].update({metric_name: value}) - elif key.startswith("test_sc"): + elif key.startswith("test/sc"): if "sc" not in metrics: metrics["sc"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["sc"].update({metric_name: value}) - elif key.startswith("test_post"): + elif key.startswith("test/post"): if "post" not in metrics: metrics["post"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["post"].update({metric_name: value}) - elif key.startswith("test_seg"): + elif key.startswith("test/seg"): if "seg" not in metrics: metrics["seg"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["seg"].update({metric_name: value}) - elif key.startswith("test_reg"): + elif key.startswith("test/reg"): if "reg" not in metrics: metrics["reg"] = {} - metric_name = key.split("_")[-1] + metric_name = key.split("/")[-1] metrics["reg"].update({metric_name: value}) tables = [] From c0fb2f5acec2863f1c1b8ac1e4d26a84e9444324 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 23 Apr 2025 15:30:33 +0200 Subject: [PATCH 45/69] :bug: Fix typo in Grouping Loss log name --- torch_uncertainty/routines/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 5119c316..dbcf2272 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -269,7 +269,7 @@ def _init_metrics(self) -> None: self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens/") if self.eval_grouping_loss: - grouping_loss = MetricCollection({"cls/grouping/loss": GroupingLoss()}) + grouping_loss = MetricCollection({"cls/grouping_loss": GroupingLoss()}) self.val_grouping_loss = grouping_loss.clone(prefix="val/") self.test_grouping_loss = grouping_loss.clone(prefix="test/") From 7045c5ab46eba0a7481b4cf9d8224568fa69bcfc Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 23 Apr 2025 15:47:03 +0200 Subject: [PATCH 46/69] :shirt: Improve coverage of `ClassificationRoutine` --- tests/routines/test_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 368f4ec7..e6c02495 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -285,6 +285,7 @@ def test_two_estimators_two_classes_mi(self): num_classes=2, num_images=100, eval_ood=True, + eval_shift=True, ) model = DummyClassificationBaseline( num_classes=dm.num_classes, From 68859154611a9ae6179e9c2a1327254159c22f96 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 23 Apr 2025 16:50:08 +0200 Subject: [PATCH 47/69] :hammer: Update metric log names --- torch_uncertainty/routines/classification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index dbcf2272..b715346b 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -260,13 +260,13 @@ def _init_metrics(self) -> None: } ) - self.test_id_ens_metrics = ens_metrics.clone(prefix="test/ens/") + self.test_id_ens_metrics = ens_metrics.clone(prefix="test/ens_") if self.eval_ood: - self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens/") + self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") if self.eval_shift: - self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens/") + self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens_") if self.eval_grouping_loss: grouping_loss = MetricCollection({"cls/grouping_loss": GroupingLoss()}) From a1c7b01873134cbb15f4f1ee85874c2938a06881 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 23 Apr 2025 17:45:04 +0200 Subject: [PATCH 48/69] :hammer: Update checkpoint names --- torch_uncertainty/callbacks/checkpoint.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torch_uncertainty/callbacks/checkpoint.py b/torch_uncertainty/callbacks/checkpoint.py index 91abfe12..b01db94e 100644 --- a/torch_uncertainty/callbacks/checkpoint.py +++ b/torch_uncertainty/callbacks/checkpoint.py @@ -64,25 +64,25 @@ def __init__(self): super().__init__() self.callbacks = { "acc": ModelCheckpoint( - filename="{epoch}-{step}-val_acc={val/cls/Acc:.2f}", + filename="epoch={epoch}-step={step}-val_acc={val/cls/Acc:.2f}", monitor="val/cls/Acc", mode="max", auto_insert_metric_name=False, ), "ece": ModelCheckpoint( - filename="{epoch}-{step}-val_ece={val/cal/ECE:.2f}", + filename="epoch={epoch}-step={step}-val_ece={val/cal/ECE:.2f}", monitor="val/cal/ECE", mode="min", auto_insert_metric_name=False, ), "brier": ModelCheckpoint( - filename="{epoch}-{step}-val_brier={val/cls/Brier:.2f}", + filename="epoch={epoch}-step={step}-val_brier={val/cls/Brier:.2f}", monitor="val/cls/Brier", mode="min", auto_insert_metric_name=False, ), "nll": ModelCheckpoint( - filename="{epoch}-{step}-val_nll={val/cls/NLL:.2f}", + filename="epoch={epoch}-step={step}-val_nll={val/cls/NLL:.2f}", monitor="val/cls/NLL", mode="min", auto_insert_metric_name=False, @@ -99,25 +99,25 @@ def __init__(self): super().__init__() self.callbacks = { "miou": ModelCheckpoint( - filename="{epoch}-{step}-val_miou={val/seg/mIoU:.2f}", + filename="epoch={epoch}-step={step}-val_miou={val/seg/mIoU:.2f}", monitor="val/seg/mIoU", mode="max", auto_insert_metric_name=False, ), "ece": ModelCheckpoint( - filename="{epoch}-{step}-val_ece={val/cal/ECE:.2f}", + filename="epoch={epoch}-step={step}-val_ece={val/cal/ECE:.2f}", monitor="val/cal/ECE", mode="min", auto_insert_metric_name=False, ), "brier": ModelCheckpoint( - filename="{epoch}-{step}-val_brier={val/seg/Brier:.2f}", + filename="epoch={epoch}-step={step}-val_brier={val/seg/Brier:.2f}", monitor="val/seg/Brier", mode="min", auto_insert_metric_name=False, ), "nll": ModelCheckpoint( - filename="{epoch}-{step}-val_nll={val/seg/NLL:.2f}", + filename="epoch={epoch}-step={step}-val_nll={val/seg/NLL:.2f}", monitor="val/seg/NLL", mode="min", auto_insert_metric_name=False, @@ -134,7 +134,7 @@ def __init__(self, probabilistic: bool = False): super().__init__() self.callbacks = { "mse": ModelCheckpoint( - filename="{epoch}-{step}-val_mse{val/reg/MSE:.2f}", + filename="epoch={epoch}-step={step}-val_mse{val/reg/MSE:.2f}", monitor="val/reg/MSE", mode="min", auto_insert_metric_name=False, @@ -143,7 +143,7 @@ def __init__(self, probabilistic: bool = False): if probabilistic: self.callbacks["nll"] = ModelCheckpoint( - filename="{epoch}-{step}-val_nll{val/reg/NLL:.2f}", + filename="epoch={epoch}-step={step}-val_nll{val/reg/NLL:.2f}", monitor="val/reg/NLL", mode="min", auto_insert_metric_name=False, From 0979fdb5c621d9ed2f66c7f48b29d270a5e00f02 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 25 Apr 2025 15:19:51 +0200 Subject: [PATCH 49/69] :sparkles: Add `CoverageRate` metric --- torch_uncertainty/metrics/__init__.py | 1 + .../metrics/classification/__init__.py | 1 + .../metrics/classification/coverage_rate.py | 127 ++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 torch_uncertainty/metrics/classification/coverage_rate.py diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 2f292f56..840b32c6 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -9,6 +9,7 @@ CategoricalNLL, CovAt5Risk, CovAtxRisk, + CoverageRate, Disagreement, Entropy, GroupingLoss, diff --git a/torch_uncertainty/metrics/classification/__init__.py b/torch_uncertainty/metrics/classification/__init__.py index 840d543b..328a475f 100644 --- a/torch_uncertainty/metrics/classification/__init__.py +++ b/torch_uncertainty/metrics/classification/__init__.py @@ -3,6 +3,7 @@ from .brier_score import BrierScore from .calibration_error import CalibrationError from .categorical_nll import CategoricalNLL +from .coverage_rate import CoverageRate from .disagreement import Disagreement from .entropy import Entropy from .fpr import FPR95, FPRx diff --git a/torch_uncertainty/metrics/classification/coverage_rate.py b/torch_uncertainty/metrics/classification/coverage_rate.py new file mode 100644 index 00000000..8d2ad74c --- /dev/null +++ b/torch_uncertainty/metrics/classification/coverage_rate.py @@ -0,0 +1,127 @@ +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.imports import _XLA_AVAILABLE + + +class CoverageRate(Metric): + is_differentiable = False + higher_is_better = True + full_state_update = False + + def __init__( + self, + num_classes: int | None = None, + average: str = "micro", + validate_args: bool = True, + **kwargs, + ): + """Empirical coverage rate metric. + + Args: + num_classes (int | None, optional): Number of classes. Defaults to ``None``. + average (str, optional): Defines the reduction that is applied over labels. Should be + one of the following: + + - ``'macro'`` (default): Compute the metric for each class separately and find their + unweighted mean. This does not take label imbalance into account. + - ``'micro'``: Sum statistics across over all labels. + + validate_args (bool, optional): Whether to validate the arguments. Defaults to ``True``. + kwargs: Additional keyword arguments, see `Advanced metric settings + `_. + + + Raises: + ValueError: If `num_classes` is `None` and `average` is not `micro`. + ValueError: If `num_classes` is not an integer larger than 1. + ValueError: If `average` is not one of `macro` or `micro`. + """ + super().__init__(**kwargs) + + if validate_args: + if num_classes is None and average != "micro": + raise ValueError( + f"Argument `num_classes` can only be `None` for `average='micro'`, but got `average={average}`." + ) + if num_classes is not None and (not isinstance(num_classes, int) or num_classes < 2): + raise ValueError( + f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}" + ) + if average not in ["macro", "micro"]: + raise ValueError("average must be either 'macro' or 'micro'.") + + self.num_classes = num_classes + self.average = average + self.validate_args = validate_args + + size = 1 if (average == "micro" or num_classes is None) else num_classes + + self.add_state("correct", torch.zeros(size, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("total", torch.zeros(size, dtype=torch.float), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update the metric state with predictions and targets. + + Args: + preds (torch.Tensor): predicted sets tensor of shape (B, C), where B is the batch size + and C is the number of classes. + target (torch.Tensor): target sets tensor of shape (B,). + """ + batch_size = preds.size(0) + target = target.long() + + covered = preds[torch.arange(batch_size), target] # (B,) + + if self.average == "micro": + self.correct += covered.sum() + self.total += batch_size + + else: + self.correct += _bincount(target[covered.bool()], self.num_classes) + self.total += _bincount(target, self.num_classes) + + def compute(self) -> Tensor: + """Compute the coverage rate. + + Returns: + Tensor: The coverage rate. + """ + if self.average == "micro": + return _safe_divide(self.correct, self.total) + return _safe_divide(self.correct, self.total).mean() + + +def _bincount(x: Tensor, minlength: int | None = None) -> Tensor: + """Implement custom bincount. + + PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running + MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of + `torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption + as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``. + + Args: + x: tensor to count + minlength: minimum length to count + + Returns: + Number of occurrences for each unique element in x + + Example: + >>> x = torch.tensor([0,0,0,1,1,2,2,2,2]) + >>> _bincount(x, minlength=3) + tensor([3, 2, 4]) + + Source: + https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/data.py#L178 + + """ + if minlength is None: + minlength = len(torch.unique(x)) + + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps: + mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) + return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0) + + return torch.bincount(x, minlength=minlength) From 36d56cf233da8db2fcf51a3316d6ac77e225ccf9 Mon Sep 17 00:00:00 2001 From: alafage Date: Tue, 29 Apr 2025 15:18:11 +0200 Subject: [PATCH 50/69] :sparkles: `deep_ensembles()` wrapper can take ckpt paths to init the ensemble --- .../models/wrappers/deep_ensembles.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 00ea959e..84e5dcae 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -1,4 +1,5 @@ import copy +from pathlib import Path from typing import Literal import torch @@ -93,6 +94,8 @@ def deep_ensembles( probabilistic: bool | None = None, reset_model_parameters: bool = True, store_on_cpu: bool = False, + ckpt_paths: list[str | Path] | Path | None = None, + use_tu_ckpt_format: bool = False, ) -> _DeepEnsembles: """Build a Deep Ensembles out of the original models. @@ -107,9 +110,18 @@ def deep_ensembles( store_on_cpu (bool): Whether to store the models on CPU. Defaults to ``False``. This is useful for large models that do not fit in GPU memory. Only one model will be stored on GPU at a time during forward. The rest will be stored on CPU. + ckpt_paths (list[str | Path] | None): The paths to the checkpoints of the models. + If provided, the models will be loaded from the checkpoints. The number of + models and the number of checkpoint paths must be the same. If not provided, + the models will be used as is. Defaults to ``None``. + use_tu_ckpt_format (bool): Whether the checkpoint is from torch-uncertainty. If ``True``, + the checkpoint will be loaded using the torch-uncertainty loading function. If + ``False``, the checkpoint will be loaded using the default PyTorch loading function. + Note that this option is only used if :attr:ckpt_paths is provided. Defaults to + ``False``. Returns: - _DeepEnsembles: The ensembled model. + _DeepEnsembles | _RegDeepEnsembles: The ensembled model. Raises: ValueError: If :attr:num_estimators is not specified and :attr:models @@ -150,6 +162,40 @@ def deep_ensembles( elif isinstance(models, list) and len(models) > 1 and num_estimators is not None: raise ValueError("num_estimators must be None if you provided a non-singleton list.") + if ckpt_paths is not None: # coverage: ignore + if isinstance(ckpt_paths, str | Path): + ckpt_dir = Path(ckpt_paths) + if not ckpt_dir.is_dir(): + raise ValueError("ckpt_paths must be a directory or a list of paths.") + ckpt_paths = sorted( + elt + for elt in ckpt_dir.iterdir() + if elt.is_file() and elt.suffix in [".pt", ".pth", ".ckpt"] + ) + if len(ckpt_paths) == 0: + raise ValueError("No checkpoint files found in the directory.") + + if len(models) != len(ckpt_paths): + raise ValueError( + "The number of models and the number of checkpoint paths must be the same." + ) + for model, ckpt_path in zip(models, ckpt_paths, strict=True): + if isinstance(ckpt_path, str | Path): + loaded_data = torch.load(ckpt_path, map_location="cpu") + if "state_dict" in loaded_data: + state_dict = loaded_data["state_dict"] + else: + state_dict = loaded_data + + if use_tu_ckpt_format: + model.load_state_dict( + {k[6:]: v for k, v in state_dict.items() if k.startswith("model.")} + ) + else: + model.load_state_dict(state_dict) + else: + raise TypeError("Checkpoint paths must be strings or Path objects.") + if task in ("classification", "segmentation"): return _DeepEnsembles(models=models, store_on_cpu=store_on_cpu) if task in ("regression", "pixel_regression"): From bfd86bd6e97815d74a2ce0bfa434520888356fa8 Mon Sep 17 00:00:00 2001 From: giannifranchi Date: Tue, 29 Apr 2025 18:27:57 +0200 Subject: [PATCH 51/69] Added ConformalClassificationRAPS, ConformalClassificationAPS, and ConformalClassificationTHR + Updated Classification Routine --- torch_uncertainty/post_processing/__init__.py | 3 + .../post_processing/conformal_APS.py | 133 +++++++++++++++++ .../post_processing/conformal_RAPS.py | 138 ++++++++++++++++++ .../post_processing/conformal_THR.py | 126 ++++++++++++++++ torch_uncertainty/routines/classification.py | 33 ++++- 5 files changed, 430 insertions(+), 3 deletions(-) create mode 100644 torch_uncertainty/post_processing/conformal_APS.py create mode 100644 torch_uncertainty/post_processing/conformal_RAPS.py create mode 100644 torch_uncertainty/post_processing/conformal_THR.py diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index bc5a59cf..cca3fb28 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -3,3 +3,6 @@ from .calibration import MatrixScaler, TemperatureScaler, VectorScaler from .laplace import LaplaceApprox from .mc_batch_norm import MCBatchNorm +from .conformal_THR import ConformalclassificationTHR +from .conformal_APS import ConformalclassificationAPS +from .conformal_RAPS import ConformalClassificationRAPS \ No newline at end of file diff --git a/torch_uncertainty/post_processing/conformal_APS.py b/torch_uncertainty/post_processing/conformal_APS.py new file mode 100644 index 00000000..86181c97 --- /dev/null +++ b/torch_uncertainty/post_processing/conformal_APS.py @@ -0,0 +1,133 @@ +from typing import Literal, Optional +import torch +from torch import nn, optim, Tensor +from .abstract import PostProcessing +from torch.utils.data import DataLoader +from .calibration import TemperatureScaler + + + +class ConformalclassificationAPS(PostProcessing): + def __init__( + self, + model: nn.Module, + score_type: str = "softmax", + randomized: bool = True, + device: Optional[Literal["cpu", "cuda"]] = None, + alpha: float = 0.1, + ) -> None: + """Conformal prediction with APS scores. + + Args: + model (nn.Module): Trained classification model. + score_type (str): Type of score transformation. Only 'softmax' is supported for now. + randomized (bool): Whether to use randomized smoothing in APS. + device (str, optional): 'cpu' or 'cuda'. + alpha (float): Allowed miscoverage level. + """ + super().__init__(model=model) + self.model = model.to(device=device) + self.randomized = randomized + self.alpha = alpha + self.device = device or "cpu" + self.q_hat = None + + if score_type == "softmax": + self.transform = lambda x: torch.softmax(x, dim=-1) + else: + raise NotImplementedError("Only softmax is supported for now.") + + def forward(self, inputs: Tensor) -> Tensor: + """Apply the model and return transformed scores (softmax).""" + logits = self.model(inputs) + probs = self.transform(logits) + return probs + + def _sort_sum(self, probs: Tensor): + """Sort probabilities and compute cumulative sums.""" + ordered, indices = torch.sort(probs, dim=-1, descending=True) + cumsum = torch.cumsum(ordered, dim=-1) + return indices, ordered, cumsum + + def _calculate_single_label(self, probs: Tensor, labels: Tensor): + """Compute APS score for the true label.""" + indices, ordered, cumsum = self._sort_sum(probs) + if self.randomized: + U = torch.rand(indices.shape[0], device=probs.device) + else: + U = torch.zeros(indices.shape[0], device=probs.device) + + scores = torch.zeros(probs.shape[0], device=probs.device) + for i in range(probs.shape[0]): + pos = (indices[i] == labels[i]).nonzero(as_tuple=False) + if pos.numel() == 0: + raise ValueError("True label not found.") + pos = pos[0].item() + scores[i] = cumsum[i, pos] - U[i] * ordered[i, pos] + return scores + + def _calculate_all_labels(self, probs: Tensor): + """Compute APS scores for all labels.""" + indices, ordered, cumsum = self._sort_sum(probs) + if self.randomized: + U = torch.rand(probs.shape, device=probs.device) + else: + U = torch.zeros_like(probs) + + ordered_scores = cumsum - ordered * U + _, sorted_indices = torch.sort(indices, descending=False, dim=-1) + scores = ordered_scores.gather(dim=-1, index=sorted_indices) + return scores + + def calibrate(self, dataloader: DataLoader) -> None: + """Calibrate the APS threshold q_hat on a calibration set.""" + self.model.eval() + aps_scores = [] + + with torch.no_grad(): + for images, labels in dataloader: + images, labels = images.to(self.device), labels.to(self.device) + probs = self.forward(images) + scores = self._calculate_single_label(probs, labels) + aps_scores.append(scores) + + aps_scores = torch.cat(aps_scores) + self.q_hat = torch.quantile(aps_scores, 1 - self.alpha) + print(f"APS calibration threshold (q_hat): {self.q_hat:.4f}") + + def fit(self, dataloader: DataLoader) -> None: + """Alias for calibrate to match other API style.""" + self.calibrate(dataloader) + + def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: + """Compute the prediction set for each input.""" + if self.q_hat is None: + raise ValueError("You must calibrate (fit) before calling conformal.") + + self.model.eval() + with torch.no_grad(): + probs = self.forward(inputs) + all_scores = self._calculate_all_labels(probs) + + pred_set = all_scores <= self.q_hat + set_size = pred_set.sum(dim=1).float() + + return pred_set, set_size + + @property + def quantile(self) -> Tensor: + if self.q_hat is None: + raise ValueError("Quantile q_hat is not set. Run `.fit()` first.") + return self.q_hat.detach() + + + + + + + + + + + + diff --git a/torch_uncertainty/post_processing/conformal_RAPS.py b/torch_uncertainty/post_processing/conformal_RAPS.py new file mode 100644 index 00000000..edf63c61 --- /dev/null +++ b/torch_uncertainty/post_processing/conformal_RAPS.py @@ -0,0 +1,138 @@ +from typing import Literal, Optional +import torch +from torch import nn, Tensor +from torch.utils.data import DataLoader +from .abstract import PostProcessing + +class ConformalClassificationRAPS(PostProcessing): + def __init__( + self, + model: nn.Module, + score_type: str = "softmax", + randomized: bool = True, + penalty: float = 0.1, + kreg: int = 1, + device: Optional[Literal["cpu", "cuda"]] = None, + alpha: float = 0.1, + ) -> None: + """Conformal prediction with RAPS scores. + + Args: + model (nn.Module): Trained classification model. + score_type (str): Type of score transformation. Only 'softmax' is supported for now. + randomized (bool): Whether to use randomized smoothing in RAPS. + penalty (float): Regularization weight. + kreg (int): Rank threshold for regularization. + device (str, optional): 'cpu' or 'cuda'. + alpha (float): Allowed miscoverage level. + """ + super().__init__(model=model) + self.model = model.to(device=device) + self.score_type = score_type + self.randomized = randomized + self.penalty = penalty + self.kreg = kreg + self.device = device or "cpu" + self.alpha = alpha + self.q_hat = None + + if self.score_type == "softmax": + self.transform = lambda x: torch.softmax(x, dim=-1) + else: + raise NotImplementedError("Only softmax is supported for now.") + + def forward(self, inputs: Tensor) -> Tensor: + """Apply the model and return transformed scores (softmax).""" + logits = self.model(inputs) + probs = self.transform(logits) + return probs + + def _sort_sum(self, probs: Tensor): + """Sort probabilities and compute cumulative sums.""" + ordered, indices = torch.sort(probs, dim=-1, descending=True) + cumsum = torch.cumsum(ordered, dim=-1) + return indices, ordered, cumsum + + def _calculate_single_label(self, probs: Tensor, labels: Tensor) -> Tensor: + """Compute RAPS score for the true label.""" + indices, ordered, cumsum = self._sort_sum(probs) + batch_size = probs.shape[0] + + if self.randomized: + noise = torch.rand(batch_size, device=probs.device) + else: + noise = torch.zeros(batch_size, device=probs.device) + + scores = torch.zeros(batch_size, device=probs.device) + for i in range(batch_size): + pos_tensor = (indices[i] == labels[i]).nonzero(as_tuple=False) + if pos_tensor.numel() == 0: + raise ValueError("True label not found.") + pos = pos_tensor[0].item() + + reg = max(self.penalty * ((pos + 1) - self.kreg), 0) + scores[i] = cumsum[i, pos] - ordered[i, pos] * noise[i] + reg + return scores + + def _calculate_all_labels(self, probs: Tensor) -> Tensor: + """Compute RAPS scores for all labels.""" + indices, ordered, cumsum = self._sort_sum(probs) + batch_size, num_classes = probs.shape + + if self.randomized: + noise = torch.rand_like(probs) + else: + noise = torch.zeros_like(probs) + + ranks = torch.arange(1, num_classes + 1, device=probs.device, dtype=torch.float) + penalty_vector = self.penalty * (ranks - self.kreg) + penalty_vector = torch.clamp(penalty_vector, min=0) + penalty_matrix = penalty_vector.unsqueeze(0).expand_as(ordered) + + modified_scores = cumsum - ordered * noise + penalty_matrix + + # Reorder scores back to original label order + reordered_scores = torch.empty_like(modified_scores) + reordered_scores.scatter_(dim=-1, index=indices, src=modified_scores) + return reordered_scores + + def calibrate(self, dataloader: DataLoader) -> None: + """Calibrate the RAPS threshold q_hat on a calibration set.""" + self.model.eval() + raps_scores = [] + + with torch.no_grad(): + for images, labels in dataloader: + images, labels = images.to(self.device), labels.to(self.device) + probs = self.forward(images) + scores = self._calculate_single_label(probs, labels) + raps_scores.append(scores) + + raps_scores = torch.cat(raps_scores) + self.q_hat = torch.quantile(raps_scores, 1 - self.alpha) + print(f"RAPS calibration threshold (q_hat): {self.q_hat:.4f}") + + def fit(self, dataloader: DataLoader) -> None: + """Alias for calibrate to match other API style.""" + self.calibrate(dataloader) + + def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: + """Compute the prediction set for each input.""" + if self.q_hat is None: + raise ValueError("You must calibrate (fit) before calling conformal.") + + self.model.eval() + with torch.no_grad(): + probs = self.forward(inputs) + all_scores = self._calculate_all_labels(probs) + + pred_set = all_scores <= self.q_hat + set_size = pred_set.sum(dim=1).float() + + return pred_set, set_size + + @property + def quantile(self) -> Tensor: + if self.q_hat is None: + raise ValueError("Quantile q_hat is not set. Run `.fit()` first.") + return self.q_hat.detach() diff --git a/torch_uncertainty/post_processing/conformal_THR.py b/torch_uncertainty/post_processing/conformal_THR.py new file mode 100644 index 00000000..630ddf1a --- /dev/null +++ b/torch_uncertainty/post_processing/conformal_THR.py @@ -0,0 +1,126 @@ +from typing import Literal, Optional +import torch +from torch import nn, optim, Tensor +from .abstract import PostProcessing +from torch.utils.data import DataLoader +from .calibration import TemperatureScaler + +class ConformalclassificationTHR(PostProcessing): + def __init__( + self, + model: nn.Module, + init_val: float = 1, + lr: float = 0.1, + max_iter: int = 100, + device: Optional[Literal["cpu", "cuda"]] = None, + alpha: float = 1.0, + ) -> None: + """Conformal prediction post-processing for calibrated models. + + Args: + model (nn.Module): Model to calibrate. + temp (float, optional): the temperature value after the calibration + Defaults to 1.5. + device (Optional[Literal["cpu", "cuda"]], optional): Device to use + for optimization. Defaults to None. + alpha (Optional[Literal["cpu", "cuda"]], optional): the confidence level meaning we allow 1-alpha error + References: + Sadinle, M. et al., (2016). Least ambiguous set-valued classifiers with bounded error levels. Journal of the American Statistical Association, 111(515), 1648-1658. + + Link : https://arxiv.org/abs/1609.00451 + """ + super().__init__(model=model) + self.model = model.to(device=device) + self.init_val =init_val + self.lr =lr + self.max_iter = max_iter + self.device = device or "cpu" + self.temp = None # Will be set after calibration + self.q_hat = None # Will be set after calibration + self.alpha= alpha + + def forward(self,inputs: Tensor) -> Tensor: + """Apply temperature scaling.""" + logits = self.model(inputs) + return logits / self.temp + + def calibrate(self, dataloader: DataLoader) -> None: + # Fit the scaler on the calibration dataset + scaled_model = TemperatureScaler(model=self.model, lr = self.lr, max_iter= self.max_iter, device = self.device ) + scaled_model.fit(dataloader=dataloader) + temp = scaled_model.temperature[0].item() + print('temperature AFTER CALIBRATION == ', temp) + + self.temp = scaled_model.temperature[0].item() + + + + + + + def fit(self, dataloader: DataLoader) -> None: + logits_list = [] + labels_list = [] + print('temperature BEFORE CONFORMAL == ', self.temp) + with torch.no_grad(): + for images, labels in dataloader: + images, labels = images.to(self.device), labels.to(self.device) + scaled_logits = self.model(images)/ self.temp + logits_list.append(scaled_logits) + labels_list.append(labels) + # Conformal scores + + scaled_logits = torch.cat(logits_list) + labels = torch.cat(labels_list) + probs = torch.softmax(scaled_logits, dim=1) + true_class_probs = probs.gather(1, labels.unsqueeze(1)).squeeze(1) + scores = 1.0 - true_class_probs # scores are (1 - true prob) + # Quantile + self.q_hat = torch.quantile(scores, 1.0 - self.alpha) + + + def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: + """Perform conformal prediction on the test set.""" + if self.q_hat is None: + raise ValueError("You must calibrate and estimate the qhat first by calling `.fit()`.") + + self.model.eval() + #prediction_sets = [] + #confidence_scores = [] + #all_labels = [] + with torch.no_grad(): + scaled_logits = self.model(inputs)/ self.temp + probs = torch.softmax(scaled_logits, dim=1) + pred_set = probs >= (1.0 - self.q_hat) + top1 = torch.argmax(probs, dim=1, keepdim=True) + pred_set.scatter_(1, top1, True) # Always include top-1 class + + confidence_score = pred_set.sum(dim=1).float() + + + return ( + pred_set, + confidence_score, + ) + + + @property + def quantile(self) -> Tensor: + if self.q_hat is None: + raise ValueError("Quantile q_hat is not set. Run `.fit()` first.") + return self.q_hat.detach() + + + + + + + + + + + + + + + diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index b715346b..3797fff5 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -25,6 +25,7 @@ BrierScore, CalibrationError, CategoricalNLL, + CoverageRate, CovAt5Risk, Disagreement, Entropy, @@ -73,6 +74,7 @@ def __init__( optim_recipe: dict | Optimizer | None = None, mixup_params: dict | None = None, eval_ood: bool = False, + isconformal: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, ood_criterion: type[TUOODCriterion] | str = "msp", @@ -101,6 +103,8 @@ def __init__( detection performance. Defaults to ``False``. eval_shift (bool, optional): Indicates whether to evaluate the Distribution shift performance. Defaults to ``False``. + isconformal (bool, optional): Indicates whether to use conformal prediction + as uncertainty criterion. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. @@ -156,6 +160,7 @@ def __init__( self.num_classes = num_classes self.eval_ood = eval_ood + self.isconformal = isconformal self.eval_shift = eval_shift self.eval_grouping_loss = eval_grouping_loss self.ood_criterion = get_ood_criterion(ood_criterion) @@ -228,6 +233,7 @@ def _init_metrics(self) -> None: cls_metrics = MetricCollection(metrics_dict, compute_groups=groups) self.val_cls_metrics = cls_metrics.clone(prefix="val/") + self.test_cls_metrics = cls_metrics.clone(prefix="test/") if self.post_processing is not None: @@ -246,6 +252,13 @@ def _init_metrics(self) -> None: ) self.test_ood_metrics = ood_metrics.clone(prefix="ood/") self.test_ood_entropy = Entropy() + if self.isconformal: + cfm_metrics = MetricCollection( + { + "CovAcc": CoverageRate(), + }, + ) + self.test_cfm_metrics = cfm_metrics.clone(prefix="test/cls/") if self.eval_shift: self.test_shift_metrics = cls_metrics.clone(prefix="shift/") @@ -482,13 +495,19 @@ def test_step( logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) + if self.isconformal: + pred_conformal, confs_conformal = self.model.conformal(inputs) - if self.ood_criterion.input_type == OODCriterionInputType.LOGIT: + if self.ood_criterion.input_type == OODCriterionInputType.LOGIT and not self.isconformal: ood_scores = self.ood_criterion(logits) - elif self.ood_criterion.input_type == OODCriterionInputType.PROB: + elif self.ood_criterion.input_type == OODCriterionInputType.PROB and not self.isconformal: ood_scores = self.ood_criterion(probs) else: - ood_scores = self.ood_criterion(probs_per_est) + if not self.isconformal: + ood_scores = self.ood_criterion(probs_per_est) + else: + ood_scores = confs_conformal + if dataloader_idx == 0: # squeeze if binary classification only for binary metrics @@ -514,6 +533,9 @@ def test_step( if self.eval_ood: self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) + if self.isconformal: + self.test_cfm_metrics.update(pred_conformal, targets) + if self.id_logit_storage is not None: self.id_logit_storage.append(logits.detach().cpu()) @@ -591,6 +613,11 @@ def on_test_epoch_end(self) -> None: self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) + if self.isconformal: + tmp_metrics = self.test_cfm_metrics.compute() + self.log_dict(tmp_metrics, sync_dist=True) + result_dict.update(tmp_metrics) + # already logged result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) From 803579b0954ac07d1b249a6a8aaf336b28f1e0e2 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 30 Apr 2025 16:08:59 +0200 Subject: [PATCH 52/69] :art: Comply with ruff rules --- torch_uncertainty/post_processing/__init__.py | 6 +- .../{conformal_APS.py => conformal_aps.py} | 61 ++++++-------- .../{conformal_RAPS.py => conformal_raps.py} | 26 +++--- .../{conformal_THR.py => conformal_thr.py} | 84 +++++++------------ torch_uncertainty/routines/classification.py | 23 +++-- 5 files changed, 80 insertions(+), 120 deletions(-) rename torch_uncertainty/post_processing/{conformal_APS.py => conformal_aps.py} (76%) rename torch_uncertainty/post_processing/{conformal_RAPS.py => conformal_raps.py} (89%) rename torch_uncertainty/post_processing/{conformal_THR.py => conformal_thr.py} (61%) diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index cca3fb28..6c99fc12 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,8 +1,8 @@ # ruff: noqa: F401 from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler +from .conformal_aps import ConformalclassificationAPS +from .conformal_raps import ConformalClassificationRAPS +from .conformal_thr import ConformalclassificationTHR from .laplace import LaplaceApprox from .mc_batch_norm import MCBatchNorm -from .conformal_THR import ConformalclassificationTHR -from .conformal_APS import ConformalclassificationAPS -from .conformal_RAPS import ConformalClassificationRAPS \ No newline at end of file diff --git a/torch_uncertainty/post_processing/conformal_APS.py b/torch_uncertainty/post_processing/conformal_aps.py similarity index 76% rename from torch_uncertainty/post_processing/conformal_APS.py rename to torch_uncertainty/post_processing/conformal_aps.py index 86181c97..1dacfb23 100644 --- a/torch_uncertainty/post_processing/conformal_APS.py +++ b/torch_uncertainty/post_processing/conformal_aps.py @@ -1,29 +1,31 @@ -from typing import Literal, Optional +from typing import Literal + import torch -from torch import nn, optim, Tensor -from .abstract import PostProcessing +from torch import Tensor, nn from torch.utils.data import DataLoader -from .calibration import TemperatureScaler +from .abstract import PostProcessing class ConformalclassificationAPS(PostProcessing): def __init__( - self, - model: nn.Module, - score_type: str = "softmax", - randomized: bool = True, - device: Optional[Literal["cpu", "cuda"]] = None, - alpha: float = 0.1, + self, + model: nn.Module, + score_type: str = "softmax", + randomized: bool = True, + device: Literal["cpu", "cuda"] | torch.device | None = None, + alpha: float = 0.1, ) -> None: - """Conformal prediction with APS scores. + r"""Conformal prediction with APS scores. Args: model (nn.Module): Trained classification model. - score_type (str): Type of score transformation. Only 'softmax' is supported for now. + score_type (str): Type of score transformation. Only ``"softmax"`` is supported for now. randomized (bool): Whether to use randomized smoothing in APS. - device (str, optional): 'cpu' or 'cuda'. - alpha (float): Allowed miscoverage level. + device (Literal["cpu", "cuda"] | torch.device | None, optional): device. + Defaults to ``None``. + alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. Defaults + to ``0.1``. """ super().__init__(model=model) self.model = model.to(device=device) @@ -40,8 +42,7 @@ def __init__( def forward(self, inputs: Tensor) -> Tensor: """Apply the model and return transformed scores (softmax).""" logits = self.model(inputs) - probs = self.transform(logits) - return probs + return self.transform(logits) def _sort_sum(self, probs: Tensor): """Sort probabilities and compute cumulative sums.""" @@ -53,9 +54,9 @@ def _calculate_single_label(self, probs: Tensor, labels: Tensor): """Compute APS score for the true label.""" indices, ordered, cumsum = self._sort_sum(probs) if self.randomized: - U = torch.rand(indices.shape[0], device=probs.device) + u = torch.rand(indices.shape[0], device=probs.device) else: - U = torch.zeros(indices.shape[0], device=probs.device) + u = torch.zeros(indices.shape[0], device=probs.device) scores = torch.zeros(probs.shape[0], device=probs.device) for i in range(probs.shape[0]): @@ -63,21 +64,20 @@ def _calculate_single_label(self, probs: Tensor, labels: Tensor): if pos.numel() == 0: raise ValueError("True label not found.") pos = pos[0].item() - scores[i] = cumsum[i, pos] - U[i] * ordered[i, pos] + scores[i] = cumsum[i, pos] - u[i] * ordered[i, pos] return scores def _calculate_all_labels(self, probs: Tensor): """Compute APS scores for all labels.""" indices, ordered, cumsum = self._sort_sum(probs) if self.randomized: - U = torch.rand(probs.shape, device=probs.device) + u = torch.rand(probs.shape, device=probs.device) else: - U = torch.zeros_like(probs) + u = torch.zeros_like(probs) - ordered_scores = cumsum - ordered * U + ordered_scores = cumsum - ordered * u _, sorted_indices = torch.sort(indices, descending=False, dim=-1) - scores = ordered_scores.gather(dim=-1, index=sorted_indices) - return scores + return ordered_scores.gather(dim=-1, index=sorted_indices) def calibrate(self, dataloader: DataLoader) -> None: """Calibrate the APS threshold q_hat on a calibration set.""" @@ -93,7 +93,6 @@ def calibrate(self, dataloader: DataLoader) -> None: aps_scores = torch.cat(aps_scores) self.q_hat = torch.quantile(aps_scores, 1 - self.alpha) - print(f"APS calibration threshold (q_hat): {self.q_hat:.4f}") def fit(self, dataloader: DataLoader) -> None: """Alias for calibrate to match other API style.""" @@ -119,15 +118,3 @@ def quantile(self) -> Tensor: if self.q_hat is None: raise ValueError("Quantile q_hat is not set. Run `.fit()` first.") return self.q_hat.detach() - - - - - - - - - - - - diff --git a/torch_uncertainty/post_processing/conformal_RAPS.py b/torch_uncertainty/post_processing/conformal_raps.py similarity index 89% rename from torch_uncertainty/post_processing/conformal_RAPS.py rename to torch_uncertainty/post_processing/conformal_raps.py index edf63c61..43f6b556 100644 --- a/torch_uncertainty/post_processing/conformal_RAPS.py +++ b/torch_uncertainty/post_processing/conformal_raps.py @@ -1,9 +1,12 @@ -from typing import Literal, Optional +from typing import Literal + import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader + from .abstract import PostProcessing + class ConformalClassificationRAPS(PostProcessing): def __init__( self, @@ -12,10 +15,10 @@ def __init__( randomized: bool = True, penalty: float = 0.1, kreg: int = 1, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Literal["cpu", "cuda"] | torch.device | None = None, alpha: float = 0.1, ) -> None: - """Conformal prediction with RAPS scores. + r"""Conformal prediction with RAPS scores. Args: model (nn.Module): Trained classification model. @@ -23,8 +26,10 @@ def __init__( randomized (bool): Whether to use randomized smoothing in RAPS. penalty (float): Regularization weight. kreg (int): Rank threshold for regularization. - device (str, optional): 'cpu' or 'cuda'. - alpha (float): Allowed miscoverage level. + device (Literal["cpu", "cuda"] | torch.device | None, optional): device. + Defaults to ``None``. + alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. Defaults + to ``0.1``. """ super().__init__(model=model) self.model = model.to(device=device) @@ -44,8 +49,7 @@ def __init__( def forward(self, inputs: Tensor) -> Tensor: """Apply the model and return transformed scores (softmax).""" logits = self.model(inputs) - probs = self.transform(logits) - return probs + return self.transform(logits) def _sort_sum(self, probs: Tensor): """Sort probabilities and compute cumulative sums.""" @@ -79,10 +83,7 @@ def _calculate_all_labels(self, probs: Tensor) -> Tensor: indices, ordered, cumsum = self._sort_sum(probs) batch_size, num_classes = probs.shape - if self.randomized: - noise = torch.rand_like(probs) - else: - noise = torch.zeros_like(probs) + noise = torch.rand_like(probs) if self.randomized else torch.zeros_like(probs) ranks = torch.arange(1, num_classes + 1, device=probs.device, dtype=torch.float) penalty_vector = self.penalty * (ranks - self.kreg) @@ -110,7 +111,6 @@ def calibrate(self, dataloader: DataLoader) -> None: raps_scores = torch.cat(raps_scores) self.q_hat = torch.quantile(raps_scores, 1 - self.alpha) - print(f"RAPS calibration threshold (q_hat): {self.q_hat:.4f}") def fit(self, dataloader: DataLoader) -> None: """Alias for calibrate to match other API style.""" diff --git a/torch_uncertainty/post_processing/conformal_THR.py b/torch_uncertainty/post_processing/conformal_thr.py similarity index 61% rename from torch_uncertainty/post_processing/conformal_THR.py rename to torch_uncertainty/post_processing/conformal_thr.py index 630ddf1a..bb333f8b 100644 --- a/torch_uncertainty/post_processing/conformal_THR.py +++ b/torch_uncertainty/post_processing/conformal_thr.py @@ -1,10 +1,13 @@ -from typing import Literal, Optional +from typing import Literal + import torch -from torch import nn, optim, Tensor -from .abstract import PostProcessing +from torch import Tensor, nn from torch.utils.data import DataLoader + +from .abstract import PostProcessing from .calibration import TemperatureScaler + class ConformalclassificationTHR(PostProcessing): def __init__( self, @@ -12,63 +15,55 @@ def __init__( init_val: float = 1, lr: float = 0.1, max_iter: int = 100, - device: Optional[Literal["cpu", "cuda"]] = None, - alpha: float = 1.0, + device: Literal["cpu", "cuda"] | torch.device | None = None, + alpha: float = 0.1, ) -> None: - """Conformal prediction post-processing for calibrated models. + r"""Conformal prediction post-processing for calibrated models. Args: - model (nn.Module): Model to calibrate. - temp (float, optional): the temperature value after the calibration - Defaults to 1.5. - device (Optional[Literal["cpu", "cuda"]], optional): Device to use - for optimization. Defaults to None. - alpha (Optional[Literal["cpu", "cuda"]], optional): the confidence level meaning we allow 1-alpha error - References: - Sadinle, M. et al., (2016). Least ambiguous set-valued classifiers with bounded error levels. Journal of the American Statistical Association, 111(515), 1648-1658. - - Link : https://arxiv.org/abs/1609.00451 + model (nn.Module): Model to be calibrated. + init_val (float, optional): Initial value for the temperature. + Defaults to ``1``. + lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``. + max_iter (int, optional): Maximum number of iterations for the + optimizer. Defaults to ``100``. + device (Literal["cpu", "cuda"] | torch.device | None, optional): device. + Defaults to ``None``. + alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. Defaults + to ``0.1``. """ super().__init__(model=model) self.model = model.to(device=device) - self.init_val =init_val - self.lr =lr + self.init_val = init_val + self.lr = lr self.max_iter = max_iter self.device = device or "cpu" self.temp = None # Will be set after calibration self.q_hat = None # Will be set after calibration - self.alpha= alpha + self.alpha = alpha - def forward(self,inputs: Tensor) -> Tensor: + def forward(self, inputs: Tensor) -> Tensor: """Apply temperature scaling.""" logits = self.model(inputs) return logits / self.temp def calibrate(self, dataloader: DataLoader) -> None: # Fit the scaler on the calibration dataset - scaled_model = TemperatureScaler(model=self.model, lr = self.lr, max_iter= self.max_iter, device = self.device ) + scaled_model = TemperatureScaler( + model=self.model, lr=self.lr, max_iter=self.max_iter, device=self.device + ) scaled_model.fit(dataloader=dataloader) - temp = scaled_model.temperature[0].item() - print('temperature AFTER CALIBRATION == ', temp) - - self.temp = scaled_model.temperature[0].item() - - - - - + self.temp = scaled_model.temperature[0].item() def fit(self, dataloader: DataLoader) -> None: logits_list = [] labels_list = [] - print('temperature BEFORE CONFORMAL == ', self.temp) with torch.no_grad(): for images, labels in dataloader: images, labels = images.to(self.device), labels.to(self.device) - scaled_logits = self.model(images)/ self.temp + scaled_logits = self.model(images) / self.temp logits_list.append(scaled_logits) labels_list.append(labels) - # Conformal scores scaled_logits = torch.cat(logits_list) labels = torch.cat(labels_list) @@ -78,18 +73,14 @@ def fit(self, dataloader: DataLoader) -> None: # Quantile self.q_hat = torch.quantile(scores, 1.0 - self.alpha) - def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Perform conformal prediction on the test set.""" if self.q_hat is None: raise ValueError("You must calibrate and estimate the qhat first by calling `.fit()`.") self.model.eval() - #prediction_sets = [] - #confidence_scores = [] - #all_labels = [] with torch.no_grad(): - scaled_logits = self.model(inputs)/ self.temp + scaled_logits = self.model(inputs) / self.temp probs = torch.softmax(scaled_logits, dim=1) pred_set = probs >= (1.0 - self.q_hat) top1 = torch.argmax(probs, dim=1, keepdim=True) @@ -97,30 +88,13 @@ def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: confidence_score = pred_set.sum(dim=1).float() - return ( pred_set, confidence_score, ) - @property def quantile(self) -> Tensor: if self.q_hat is None: raise ValueError("Quantile q_hat is not set. Run `.fit()` first.") return self.q_hat.detach() - - - - - - - - - - - - - - - diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 3797fff5..0679e1b8 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -25,8 +25,8 @@ BrierScore, CalibrationError, CategoricalNLL, - CoverageRate, CovAt5Risk, + CoverageRate, Disagreement, Entropy, GroupingLoss, @@ -74,7 +74,7 @@ def __init__( optim_recipe: dict | Optimizer | None = None, mixup_params: dict | None = None, eval_ood: bool = False, - isconformal: bool = False, + is_conformal: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, ood_criterion: type[TUOODCriterion] | str = "msp", @@ -103,7 +103,7 @@ def __init__( detection performance. Defaults to ``False``. eval_shift (bool, optional): Indicates whether to evaluate the Distribution shift performance. Defaults to ``False``. - isconformal (bool, optional): Indicates whether to use conformal prediction + is_conformal (bool, optional): Indicates whether to use conformal prediction as uncertainty criterion. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. @@ -160,7 +160,7 @@ def __init__( self.num_classes = num_classes self.eval_ood = eval_ood - self.isconformal = isconformal + self.is_conformal = is_conformal self.eval_shift = eval_shift self.eval_grouping_loss = eval_grouping_loss self.ood_criterion = get_ood_criterion(ood_criterion) @@ -252,7 +252,7 @@ def _init_metrics(self) -> None: ) self.test_ood_metrics = ood_metrics.clone(prefix="ood/") self.test_ood_entropy = Entropy() - if self.isconformal: + if self.is_conformal: cfm_metrics = MetricCollection( { "CovAcc": CoverageRate(), @@ -495,20 +495,19 @@ def test_step( logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - if self.isconformal: + if self.is_conformal: pred_conformal, confs_conformal = self.model.conformal(inputs) - if self.ood_criterion.input_type == OODCriterionInputType.LOGIT and not self.isconformal: + if self.ood_criterion.input_type == OODCriterionInputType.LOGIT and not self.is_conformal: ood_scores = self.ood_criterion(logits) - elif self.ood_criterion.input_type == OODCriterionInputType.PROB and not self.isconformal: + elif self.ood_criterion.input_type == OODCriterionInputType.PROB and not self.is_conformal: ood_scores = self.ood_criterion(probs) else: - if not self.isconformal: + if not self.is_conformal: ood_scores = self.ood_criterion(probs_per_est) else: ood_scores = confs_conformal - if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( @@ -533,7 +532,7 @@ def test_step( if self.eval_ood: self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) - if self.isconformal: + if self.is_conformal: self.test_cfm_metrics.update(pred_conformal, targets) if self.id_logit_storage is not None: @@ -613,7 +612,7 @@ def on_test_epoch_end(self) -> None: self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) - if self.isconformal: + if self.is_conformal: tmp_metrics = self.test_cfm_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) From 7008351d71e0e78d1fce558838aa72849cf61cdf Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 1 May 2025 10:51:59 +0200 Subject: [PATCH 53/69] :art: Modify conformal files location --- torch_uncertainty/post_processing/__init__.py | 8 +++++--- torch_uncertainty/post_processing/conformal/__init__.py | 4 ++++ .../post_processing/{ => conformal}/conformal_aps.py | 2 +- .../post_processing/{ => conformal}/conformal_raps.py | 2 +- .../post_processing/{ => conformal}/conformal_thr.py | 3 +-- 5 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 torch_uncertainty/post_processing/conformal/__init__.py rename torch_uncertainty/post_processing/{ => conformal}/conformal_aps.py (98%) rename torch_uncertainty/post_processing/{ => conformal}/conformal_raps.py (98%) rename torch_uncertainty/post_processing/{ => conformal}/conformal_thr.py (97%) diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index 6c99fc12..3d9922f2 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,8 +1,10 @@ # ruff: noqa: F401 from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler -from .conformal_aps import ConformalclassificationAPS -from .conformal_raps import ConformalClassificationRAPS -from .conformal_thr import ConformalclassificationTHR +from .conformal import ( + ConformalclassificationAPS, + ConformalClassificationRAPS, + ConformalclassificationTHR, +) from .laplace import LaplaceApprox from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/conformal/__init__.py b/torch_uncertainty/post_processing/conformal/__init__.py new file mode 100644 index 00000000..22dcdea3 --- /dev/null +++ b/torch_uncertainty/post_processing/conformal/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .conformal_aps import ConformalclassificationAPS +from .conformal_raps import ConformalClassificationRAPS +from .conformal_thr import ConformalclassificationTHR diff --git a/torch_uncertainty/post_processing/conformal_aps.py b/torch_uncertainty/post_processing/conformal/conformal_aps.py similarity index 98% rename from torch_uncertainty/post_processing/conformal_aps.py rename to torch_uncertainty/post_processing/conformal/conformal_aps.py index 1dacfb23..74a7169e 100644 --- a/torch_uncertainty/post_processing/conformal_aps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_aps.py @@ -4,7 +4,7 @@ from torch import Tensor, nn from torch.utils.data import DataLoader -from .abstract import PostProcessing +from torch_uncertainty.post_processing import PostProcessing class ConformalclassificationAPS(PostProcessing): diff --git a/torch_uncertainty/post_processing/conformal_raps.py b/torch_uncertainty/post_processing/conformal/conformal_raps.py similarity index 98% rename from torch_uncertainty/post_processing/conformal_raps.py rename to torch_uncertainty/post_processing/conformal/conformal_raps.py index 43f6b556..60f1e04f 100644 --- a/torch_uncertainty/post_processing/conformal_raps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_raps.py @@ -4,7 +4,7 @@ from torch import Tensor, nn from torch.utils.data import DataLoader -from .abstract import PostProcessing +from torch_uncertainty.post_processing import PostProcessing class ConformalClassificationRAPS(PostProcessing): diff --git a/torch_uncertainty/post_processing/conformal_thr.py b/torch_uncertainty/post_processing/conformal/conformal_thr.py similarity index 97% rename from torch_uncertainty/post_processing/conformal_thr.py rename to torch_uncertainty/post_processing/conformal/conformal_thr.py index bb333f8b..00da0938 100644 --- a/torch_uncertainty/post_processing/conformal_thr.py +++ b/torch_uncertainty/post_processing/conformal/conformal_thr.py @@ -4,8 +4,7 @@ from torch import Tensor, nn from torch.utils.data import DataLoader -from .abstract import PostProcessing -from .calibration import TemperatureScaler +from torch_uncertainty.post_processing import PostProcessing, TemperatureScaler class ConformalclassificationTHR(PostProcessing): From 62353a7676ae017e0878ebd9b503c5c04de69daf Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 3 May 2025 10:45:20 +0200 Subject: [PATCH 54/69] :hammer: Remove `calibrate()` methods in all conformal classes + polishing --- torch_uncertainty/post_processing/__init__.py | 7 +++--- torch_uncertainty/post_processing/abstract.py | 9 ++------ .../post_processing/conformal/__init__.py | 7 +++--- .../post_processing/conformal/abstract.py | 23 +++++++++++++++++++ .../conformal/conformal_aps.py | 13 +++++------ .../conformal/conformal_raps.py | 13 +++++------ .../conformal/conformal_thr.py | 18 +++++++++++---- 7 files changed, 59 insertions(+), 31 deletions(-) create mode 100644 torch_uncertainty/post_processing/conformal/abstract.py diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index 3d9922f2..6e285501 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -2,9 +2,10 @@ from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler from .conformal import ( - ConformalclassificationAPS, - ConformalClassificationRAPS, - ConformalclassificationTHR, + Conformal, + ConformalClsAPS, + ConformalClsRAPS, + ConformalClsTHR, ) from .laplace import LaplaceApprox from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index 7afe050c..8fc6c002 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -14,12 +14,7 @@ def set_model(self, model: nn.Module) -> None: self.model = model @abstractmethod - def fit(self, dataloader: DataLoader) -> None: - pass + def fit(self, dataloader: DataLoader) -> None: ... @abstractmethod - def forward( - self, - inputs: Tensor, - ) -> Tensor: - pass + def forward(self, inputs: Tensor) -> Tensor: ... diff --git a/torch_uncertainty/post_processing/conformal/__init__.py b/torch_uncertainty/post_processing/conformal/__init__.py index 22dcdea3..203cb6b6 100644 --- a/torch_uncertainty/post_processing/conformal/__init__.py +++ b/torch_uncertainty/post_processing/conformal/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 -from .conformal_aps import ConformalclassificationAPS -from .conformal_raps import ConformalClassificationRAPS -from .conformal_thr import ConformalclassificationTHR +from .abstract import Conformal +from .conformal_aps import ConformalClsAPS +from .conformal_raps import ConformalClsRAPS +from .conformal_thr import ConformalClsTHR diff --git a/torch_uncertainty/post_processing/conformal/abstract.py b/torch_uncertainty/post_processing/conformal/abstract.py new file mode 100644 index 00000000..b167851d --- /dev/null +++ b/torch_uncertainty/post_processing/conformal/abstract.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod + +from torch import Tensor, nn +from torch.utils.data import DataLoader + + +class Conformal(ABC, nn.Module): + def __init__(self, model: nn.Module | None = None): + super().__init__() + self.model = model + self.trained = False + + def set_model(self, model: nn.Module) -> None: + self.model = model + + @abstractmethod + def fit(self, dataloader: DataLoader) -> None: ... + + @abstractmethod + def forward(self, inputs: Tensor) -> Tensor: ... + + @abstractmethod + def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: ... diff --git a/torch_uncertainty/post_processing/conformal/conformal_aps.py b/torch_uncertainty/post_processing/conformal/conformal_aps.py index 74a7169e..9041181c 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_aps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_aps.py @@ -4,10 +4,10 @@ from torch import Tensor, nn from torch.utils.data import DataLoader -from torch_uncertainty.post_processing import PostProcessing +from .abstract import Conformal -class ConformalclassificationAPS(PostProcessing): +class ConformalClsAPS(Conformal): def __init__( self, model: nn.Module, @@ -26,6 +26,9 @@ def __init__( Defaults to ``None``. alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. Defaults to ``0.1``. + + Reference: + - TODO: """ super().__init__(model=model) self.model = model.to(device=device) @@ -79,7 +82,7 @@ def _calculate_all_labels(self, probs: Tensor): _, sorted_indices = torch.sort(indices, descending=False, dim=-1) return ordered_scores.gather(dim=-1, index=sorted_indices) - def calibrate(self, dataloader: DataLoader) -> None: + def fit(self, dataloader: DataLoader) -> None: """Calibrate the APS threshold q_hat on a calibration set.""" self.model.eval() aps_scores = [] @@ -94,10 +97,6 @@ def calibrate(self, dataloader: DataLoader) -> None: aps_scores = torch.cat(aps_scores) self.q_hat = torch.quantile(aps_scores, 1 - self.alpha) - def fit(self, dataloader: DataLoader) -> None: - """Alias for calibrate to match other API style.""" - self.calibrate(dataloader) - def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Compute the prediction set for each input.""" if self.q_hat is None: diff --git a/torch_uncertainty/post_processing/conformal/conformal_raps.py b/torch_uncertainty/post_processing/conformal/conformal_raps.py index 60f1e04f..f8459707 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_raps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_raps.py @@ -4,10 +4,10 @@ from torch import Tensor, nn from torch.utils.data import DataLoader -from torch_uncertainty.post_processing import PostProcessing +from .abstract import Conformal -class ConformalClassificationRAPS(PostProcessing): +class ConformalClsRAPS(Conformal): def __init__( self, model: nn.Module, @@ -30,6 +30,9 @@ def __init__( Defaults to ``None``. alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. Defaults to ``0.1``. + + Reference: + - TODO: """ super().__init__(model=model) self.model = model.to(device=device) @@ -97,7 +100,7 @@ def _calculate_all_labels(self, probs: Tensor) -> Tensor: reordered_scores.scatter_(dim=-1, index=indices, src=modified_scores) return reordered_scores - def calibrate(self, dataloader: DataLoader) -> None: + def fit(self, dataloader: DataLoader) -> None: """Calibrate the RAPS threshold q_hat on a calibration set.""" self.model.eval() raps_scores = [] @@ -112,10 +115,6 @@ def calibrate(self, dataloader: DataLoader) -> None: raps_scores = torch.cat(raps_scores) self.q_hat = torch.quantile(raps_scores, 1 - self.alpha) - def fit(self, dataloader: DataLoader) -> None: - """Alias for calibrate to match other API style.""" - self.calibrate(dataloader) - def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Compute the prediction set for each input.""" if self.q_hat is None: diff --git a/torch_uncertainty/post_processing/conformal/conformal_thr.py b/torch_uncertainty/post_processing/conformal/conformal_thr.py index 00da0938..83bd856a 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_thr.py +++ b/torch_uncertainty/post_processing/conformal/conformal_thr.py @@ -4,10 +4,12 @@ from torch import Tensor, nn from torch.utils.data import DataLoader -from torch_uncertainty.post_processing import PostProcessing, TemperatureScaler +from torch_uncertainty.post_processing import TemperatureScaler +from .abstract import Conformal -class ConformalclassificationTHR(PostProcessing): + +class ConformalClsTHR(Conformal): def __init__( self, model: nn.Module, @@ -30,6 +32,9 @@ def __init__( Defaults to ``None``. alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. Defaults to ``0.1``. + + Reference: + - `Least ambiguous set-valued classifiers with bounded error levels, Sadinle, M. et al., (2016) `_. """ super().__init__(model=model) self.model = model.to(device=device) @@ -46,15 +51,20 @@ def forward(self, inputs: Tensor) -> Tensor: logits = self.model(inputs) return logits / self.temp - def calibrate(self, dataloader: DataLoader) -> None: + def fit_temperature(self, dataloader: DataLoader) -> None: # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler( - model=self.model, lr=self.lr, max_iter=self.max_iter, device=self.device + model=self.model, + init_val=self.init_val, + lr=self.lr, + max_iter=self.max_iter, + device=self.device, ) scaled_model.fit(dataloader=dataloader) self.temp = scaled_model.temperature[0].item() def fit(self, dataloader: DataLoader) -> None: + self.fit_temperature(dataloader=dataloader) logits_list = [] labels_list = [] with torch.no_grad(): From 4f46131eafda36aa7c322f3fbe5f69b7242bfd1a Mon Sep 17 00:00:00 2001 From: Olivier Laurent Date: Mon, 5 May 2025 11:42:49 +0200 Subject: [PATCH 55/69] :book: Add some doc --- torch_uncertainty/layers/functional/packed.py | 66 +++++++++---------- torch_uncertainty/layers/packed.py | 6 +- torch_uncertainty/routines/classification.py | 3 +- torch_uncertainty/utils/data.py | 19 ++++++ 4 files changed, 57 insertions(+), 37 deletions(-) diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index a5ab40f9..c962531e 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -79,15 +79,15 @@ def packed_in_projection( emb_q // num_groups, emb_v // num_groups, ), f"expecting value weights shape of {(emb_q, emb_v)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == (emb_q,), ( - f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" - ) - assert b_k is None or b_k.shape == (emb_q,), ( - f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" - ) - assert b_v is None or b_v.shape == (emb_q,), ( - f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" - ) + assert b_q is None or b_q.shape == ( + emb_q, + ), f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + emb_q, + ), f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + emb_q, + ), f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" return ( packed_linear(q, w_q, num_groups, implementation, b_q), @@ -324,47 +324,47 @@ def packed_multi_head_attention_forward( # noqa: D417 # longer causal. is_causal = False - assert embed_dim == embed_dim_to_check, ( - f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" - ) + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" if isinstance(embed_dim, Tensor): # embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, ( - f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" - ) + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert key.shape[:2] == value.shape[:2], ( - f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" - ) + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" else: - assert key.shape == value.shape, ( - f"key shape {key.shape} does not match value shape {value.shape}" - ) + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" # # compute in-projection # if not use_separate_proj_weight: - assert in_proj_weight is not None, ( - "use_separate_proj_weight is False but in_proj_weight is None" - ) + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" q, k, v = packed_in_projection_packed( q=query, k=key, v=value, w=in_proj_weight, num_groups=num_groups, b=in_proj_bias ) else: - assert q_proj_weight is not None, ( - "use_separate_proj_weight is True but q_proj_weight is None" - ) - assert k_proj_weight is not None, ( - "use_separate_proj_weight is True but k_proj_weight is None" - ) - assert v_proj_weight is not None, ( - "use_separate_proj_weight is True but v_proj_weight is None" - ) + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" if in_proj_bias is None: b_q = b_k = b_v = None else: diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index c5ad45f1..27c5da2b 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -700,9 +700,9 @@ def __init__( self.dropout = dropout self.batch_first = batch_first self.head_dim = self.embed_dim // self.num_heads - assert self.head_dim * self.num_heads == self.embed_dim, ( - "embed_dim must be divisible by num_heads" - ) + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" self.num_estimators = num_estimators self.alpha = alpha diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 6f5f88e9..f20601b1 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -87,7 +87,8 @@ def __init__( loss (torch.nn.Module): Loss function to optimize the :attr:`model`. is_ensemble (bool, optional): Indicates whether the model is an ensemble at test time or not. Defaults to ``False``. - num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). + num_tta (int): Number of test-time augmentations (TTA). If ``1``: no TTA. + Defaults to ``1``. format_batch_fn (torch.nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and diff --git a/torch_uncertainty/utils/data.py b/torch_uncertainty/utils/data.py index 6339d5d7..01120952 100644 --- a/torch_uncertainty/utils/data.py +++ b/torch_uncertainty/utils/data.py @@ -10,6 +10,17 @@ def create_train_val_split( val_split_rate: float, val_transforms: Callable | None = None, ) -> tuple[Dataset, Dataset]: + """Split a dataset for training and validation. + + Args: + dataset (Dataset): The dataset to be split. + val_split_rate (float): The amount of the original dataset to use as validation split. + val_transforms (Callable | None, optional): The transformations to apply on the validation set. + Defaults to ``None``. + + Returns: + tuple[Dataset, Dataset]: The training and the validation splits. + """ train, val = random_split(dataset, [1 - val_split_rate, val_split_rate]) val = copy.deepcopy(val) val.dataset.transform = val_transforms @@ -18,6 +29,14 @@ def create_train_val_split( class TTADataset(Dataset): def __init__(self, dataset: Dataset, num_augmentations: int) -> None: + """Create a version of the dataset that returns the same sample multiple times. + + This is useful for test-time augmentation (TTA). + + Args: + dataset (Dataset): The dataset to be adapted for TTA. + num_augmentations (int): The number of augmentations to apply. + """ super().__init__() self.dataset = dataset self.num_augmentations = num_augmentations From 1c1ebfd7b0002b658bee550cdf28b6f5a9d16efc Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 5 May 2025 11:43:28 +0200 Subject: [PATCH 56/69] :white_check_mark: Add tests for Conformal Post-Processing --- tests/post_processing/test_conformal.py | 82 +++++++++++++++++++ .../conformal/conformal_aps.py | 3 +- .../conformal/conformal_raps.py | 3 +- .../conformal/conformal_thr.py | 54 ++++++------ 4 files changed, 109 insertions(+), 33 deletions(-) create mode 100644 tests/post_processing/test_conformal.py diff --git a/tests/post_processing/test_conformal.py b/tests/post_processing/test_conformal.py new file mode 100644 index 00000000..6ffa209c --- /dev/null +++ b/tests/post_processing/test_conformal.py @@ -0,0 +1,82 @@ +import pytest +import torch +from einops import repeat +from torch import nn +from torch.utils.data import DataLoader + +from torch_uncertainty.post_processing import ConformalClsAPS, ConformalClsRAPS, ConformalClsTHR + + +class TestConformalClsAPS: + """Testing the ConformalClsRAPS class.""" + + def test_fit(self): + inputs = repeat(torch.tensor([0.6, 0.3, 0.1]), "c -> b c", b=10) + labels = torch.tensor([0, 2] + [1] * 8) + + calibration_set = list(zip(inputs, labels, strict=True)) + dl = DataLoader(calibration_set, batch_size=10) + + conformal = ConformalClsAPS(model=nn.Identity(), randomized=False) + conformal.fit(dl) + out = conformal.conformal(inputs) + assert out[0].shape == (10, 3) + assert (out[0] == repeat(torch.tensor([True, True, False]), "c -> b c", b=10)).all() + assert out[1].shape == (10,) + assert (out[1] == torch.tensor([2.0] * 10)).all() + + def test_failures(self): + with pytest.raises(NotImplementedError): + ConformalClsAPS(score_type="test") + + +class TestConformalClsRAPS: + """Testing the ConformalClsRAPS class.""" + + def test_fit(self): + inputs = repeat(torch.tensor([6.0, 4.0, 1.0]), "c -> b c", b=10) + labels = torch.tensor([0, 2] + [1] * 8) + + calibration_set = list(zip(inputs, labels, strict=True)) + dl = DataLoader(calibration_set, batch_size=10) + + conformal = ConformalClsRAPS(model=nn.Identity(), randomized=False) + conformal.fit(dl) + out = conformal.conformal(inputs) + assert out[0].shape == (10, 3) + assert (out[0] == repeat(torch.tensor([True, True, False]), "c -> b c", b=10)).all() + assert out[1].shape == (10,) + assert (out[1] == torch.tensor([2.0] * 10)).all() + + def test_failures(self): + with pytest.raises(NotImplementedError): + ConformalClsRAPS(score_type="test") + + +class TestConformalClsTHR: + """Testing the ConformalClsTHR class.""" + + def test_main(self): + conformal = ConformalClsTHR(model=None, init_val=2) + + assert conformal.temperature == 2.0 + + conformal.set_model(nn.Identity()) + + assert isinstance(conformal.model, nn.Identity) + assert isinstance(conformal.temperature_scaler.model, nn.Identity) + + def test_fit(self): + inputs = repeat(torch.tensor([0.6, 0.3, 0.1]), "c -> b c", b=10) + labels = torch.tensor([0, 2] + [1] * 8) + + calibration_set = list(zip(inputs, labels, strict=True)) + dl = DataLoader(calibration_set, batch_size=10) + + conformal = ConformalClsTHR(model=nn.Identity(), init_val=2, lr=1, max_iter=10) + conformal.fit(dl) + out = conformal.conformal(inputs) + assert out[0].shape == (10, 3) + assert (out[0] == repeat(torch.tensor([True, True, False]), "c -> b c", b=10)).all() + assert out[1].shape == (10,) + assert (out[1] == torch.tensor([2.0] * 10)).all() diff --git a/torch_uncertainty/post_processing/conformal/conformal_aps.py b/torch_uncertainty/post_processing/conformal/conformal_aps.py index 9041181c..c4da65e8 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_aps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_aps.py @@ -10,7 +10,7 @@ class ConformalClsAPS(Conformal): def __init__( self, - model: nn.Module, + model: nn.Module | None = None, score_type: str = "softmax", randomized: bool = True, device: Literal["cpu", "cuda"] | torch.device | None = None, @@ -31,7 +31,6 @@ def __init__( - TODO: """ super().__init__(model=model) - self.model = model.to(device=device) self.randomized = randomized self.alpha = alpha self.device = device or "cpu" diff --git a/torch_uncertainty/post_processing/conformal/conformal_raps.py b/torch_uncertainty/post_processing/conformal/conformal_raps.py index f8459707..eb11e976 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_raps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_raps.py @@ -10,7 +10,7 @@ class ConformalClsRAPS(Conformal): def __init__( self, - model: nn.Module, + model: nn.Module | None = None, score_type: str = "softmax", randomized: bool = True, penalty: float = 0.1, @@ -35,7 +35,6 @@ def __init__( - TODO: """ super().__init__(model=model) - self.model = model.to(device=device) self.score_type = score_type self.randomized = randomized self.penalty = penalty diff --git a/torch_uncertainty/post_processing/conformal/conformal_thr.py b/torch_uncertainty/post_processing/conformal/conformal_thr.py index 83bd856a..251928f4 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_thr.py +++ b/torch_uncertainty/post_processing/conformal/conformal_thr.py @@ -12,7 +12,7 @@ class ConformalClsTHR(Conformal): def __init__( self, - model: nn.Module, + model: nn.Module | None = None, init_val: float = 1, lr: float = 0.1, max_iter: int = 100, @@ -22,7 +22,7 @@ def __init__( r"""Conformal prediction post-processing for calibrated models. Args: - model (nn.Module): Model to be calibrated. + model (nn.Module, optional): Model to be calibrated. init_val (float, optional): Initial value for the temperature. Defaults to ``1``. lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``. @@ -37,31 +37,28 @@ def __init__( - `Least ambiguous set-valued classifiers with bounded error levels, Sadinle, M. et al., (2016) `_. """ super().__init__(model=model) - self.model = model.to(device=device) - self.init_val = init_val - self.lr = lr - self.max_iter = max_iter self.device = device or "cpu" - self.temp = None # Will be set after calibration + self.temperature_scaler = TemperatureScaler( + model=model, + init_val=init_val, + lr=lr, + max_iter=max_iter, + device=self.device, + ) self.q_hat = None # Will be set after calibration self.alpha = alpha + def set_model(self, model: nn.Module) -> None: + self.model = model + self.temperature_scaler.set_model(model=model) + def forward(self, inputs: Tensor) -> Tensor: """Apply temperature scaling.""" - logits = self.model(inputs) - return logits / self.temp + return self.temperature_scaler(inputs) def fit_temperature(self, dataloader: DataLoader) -> None: # Fit the scaler on the calibration dataset - scaled_model = TemperatureScaler( - model=self.model, - init_val=self.init_val, - lr=self.lr, - max_iter=self.max_iter, - device=self.device, - ) - scaled_model.fit(dataloader=dataloader) - self.temp = scaled_model.temperature[0].item() + self.temperature_scaler.fit(dataloader=dataloader) def fit(self, dataloader: DataLoader) -> None: self.fit_temperature(dataloader=dataloader) @@ -70,12 +67,12 @@ def fit(self, dataloader: DataLoader) -> None: with torch.no_grad(): for images, labels in dataloader: images, labels = images.to(self.device), labels.to(self.device) - scaled_logits = self.model(images) / self.temp + scaled_logits = self.forward(images) logits_list.append(scaled_logits) labels_list.append(labels) scaled_logits = torch.cat(logits_list) - labels = torch.cat(labels_list) + labels = torch.cat(labels_list).long() probs = torch.softmax(scaled_logits, dim=1) true_class_probs = probs.gather(1, labels.unsqueeze(1)).squeeze(1) scores = 1.0 - true_class_probs # scores are (1 - true prob) @@ -84,26 +81,25 @@ def fit(self, dataloader: DataLoader) -> None: def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Perform conformal prediction on the test set.""" - if self.q_hat is None: - raise ValueError("You must calibrate and estimate the qhat first by calling `.fit()`.") - self.model.eval() with torch.no_grad(): - scaled_logits = self.model(inputs) / self.temp + scaled_logits = self.forward(inputs) probs = torch.softmax(scaled_logits, dim=1) - pred_set = probs >= (1.0 - self.q_hat) + pred_set = probs >= (1.0 - self.quantile) top1 = torch.argmax(probs, dim=1, keepdim=True) pred_set.scatter_(1, top1, True) # Always include top-1 class confidence_score = pred_set.sum(dim=1).float() - return ( - pred_set, - confidence_score, - ) + return (pred_set, confidence_score) @property def quantile(self) -> Tensor: if self.q_hat is None: raise ValueError("Quantile q_hat is not set. Run `.fit()` first.") return self.q_hat.detach() + + @property + def temperature(self) -> Tensor: + """Get the temperature parameter.""" + return self.temperature_scaler.temperature[0].detach() From ae6be5427330b931cfaf15a45af9597f4b6fd7a1 Mon Sep 17 00:00:00 2001 From: Olivier Laurent Date: Mon, 5 May 2025 11:53:09 +0200 Subject: [PATCH 57/69] :hammer: Move utils/data.py to datasets/utils.py --- torch_uncertainty/datamodules/abstract.py | 2 +- torch_uncertainty/datamodules/classification/cifar10.py | 2 +- torch_uncertainty/datamodules/classification/cifar100.py | 2 +- torch_uncertainty/datamodules/classification/mnist.py | 2 +- .../datamodules/classification/uci/uci_classification.py | 2 +- torch_uncertainty/datamodules/depth/base.py | 2 +- torch_uncertainty/datamodules/depth/muad.py | 2 +- torch_uncertainty/datamodules/segmentation/cityscapes.py | 2 +- torch_uncertainty/datamodules/segmentation/muad.py | 2 +- torch_uncertainty/{utils/data.py => datasets/utils.py} | 0 torch_uncertainty/utils/__init__.py | 1 - 11 files changed, 9 insertions(+), 10 deletions(-) rename torch_uncertainty/{utils/data.py => datasets/utils.py} (100%) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index a2a10cd8..7bb87ed3 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -6,7 +6,7 @@ from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike -from torch_uncertainty.utils import TTADataset +from torch_uncertainty.datasets.utils import TTADataset if util.find_spec("sklearn"): from sklearn.model_selection import StratifiedKFold diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 68d0b283..5852a709 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -13,8 +13,8 @@ from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H +from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import Cutout -from torch_uncertainty.utils import create_train_val_split class CIFAR10DataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 65496a5b..62d59d08 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -13,8 +13,8 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C +from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import Cutout -from torch_uncertainty.utils import create_train_val_split class CIFAR100DataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 5bd49e13..fa5b514d 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -9,8 +9,8 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST +from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import Cutout -from torch_uncertainty.utils import create_train_val_split class MNISTDataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/classification/uci/uci_classification.py b/torch_uncertainty/datamodules/classification/uci/uci_classification.py index 40ed06a0..9a837f4d 100644 --- a/torch_uncertainty/datamodules/classification/uci/uci_classification.py +++ b/torch_uncertainty/datamodules/classification/uci/uci_classification.py @@ -3,7 +3,7 @@ from torch.utils.data import Dataset from torch_uncertainty.datamodules.abstract import TUDataModule -from torch_uncertainty.utils import create_train_val_split +from torch_uncertainty.datasets.utils import create_train_val_split class UCIClassificationDataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 99f14cfc..610b18b1 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -8,8 +8,8 @@ from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule +from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import RandomRescale -from torch_uncertainty.utils import create_train_val_split class DepthDataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index f0272fc1..7b7f4164 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -3,7 +3,7 @@ from torch.nn.common_types import _size_2_t from torch_uncertainty.datasets import MUAD -from torch_uncertainty.utils import create_train_val_split +from torch_uncertainty.datasets.utils import create_train_val_split from .base import DepthDataModule diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 82aadcf5..fa8f774b 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -9,8 +9,8 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.segmentation import Cityscapes +from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import RandomRescale -from torch_uncertainty.utils import create_train_val_split class CityscapesDataModule(TUDataModule): diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 579a7679..94d8de9e 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -8,8 +8,8 @@ from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import MUAD +from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import RandomRescale -from torch_uncertainty.utils import create_train_val_split class MUADDataModule(TUDataModule): diff --git a/torch_uncertainty/utils/data.py b/torch_uncertainty/datasets/utils.py similarity index 100% rename from torch_uncertainty/utils/data.py rename to torch_uncertainty/datasets/utils.py diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index 5ba095bf..d7d32e94 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,7 +1,6 @@ # ruff: noqa: F401 from .checkpoints import get_version from .cli import TULightningCLI -from .data import TTADataset, create_train_val_split from .hub import load_hf from .misc import csv_writer, plot_hist from .trainer import TUTrainer From 95b0ebc1c424cabc6276cd3bbdfb94c61c22d531 Mon Sep 17 00:00:00 2001 From: Olivier Laurent Date: Mon, 5 May 2025 11:57:23 +0200 Subject: [PATCH 58/69] :book: Add Zero to the references --- docs/source/references.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/references.rst b/docs/source/references.rst index b60f2d38..0162e63c 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -210,6 +210,16 @@ For Warping Mixup, consider citing: * Authors: *Quentin Bouniot, Pavlo Mozharovskyi, and Florence d'Alché-Buc* * Paper: `ArXiv 2023 `__. +Test-Time-Adaptation with ZERO +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For ZERO, consider citing: + +**Frustratingly Easy Test-Time Adaptation of Vision-Language Models** + +* Authors: *Matteo Farina, Gianni Franchi, Giovanni Iacca, Massimiliano Mancini and Elisa Ricci* +* Paper: `NeurIPS 2024 `__. + Post-Processing Methods ----------------------- From 5611fb52b2a38b22f01ead17aa821d502f6a4b52 Mon Sep 17 00:00:00 2001 From: Olivier Laurent Date: Mon, 5 May 2025 12:00:58 +0200 Subject: [PATCH 59/69] :shirt: use get_train_set in abstract datamodule --- torch_uncertainty/datamodules/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 7bb87ed3..7ca303d3 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -116,7 +116,7 @@ def train_dataloader(self) -> DataLoader: Return: DataLoader: training dataloader. """ - return self._data_loader(self.train, shuffle=True) + return self._data_loader(self.get_train_set(), shuffle=True) def val_dataloader(self) -> DataLoader: r"""Get the validation dataloader. From b0c428a74cc5d4206c7a1374d471567cd9f376c0 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 5 May 2025 12:01:01 +0200 Subject: [PATCH 60/69] :white_check_mark: Improve coverage --- tests/post_processing/test_conformal.py | 31 ++++++++++++++++++- .../conformal/conformal_raps.py | 5 +-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/post_processing/test_conformal.py b/tests/post_processing/test_conformal.py index 6ffa209c..fb9765af 100644 --- a/tests/post_processing/test_conformal.py +++ b/tests/post_processing/test_conformal.py @@ -4,7 +4,26 @@ from torch import nn from torch.utils.data import DataLoader -from torch_uncertainty.post_processing import ConformalClsAPS, ConformalClsRAPS, ConformalClsTHR +from torch_uncertainty.post_processing import ( + Conformal, + ConformalClsAPS, + ConformalClsRAPS, + ConformalClsTHR, +) + + +class TestConformal: + """Testing the Conformal class.""" + + def test_errors(self): + Conformal.__abstractmethods__ = set() + conformal = Conformal(model=None) + assert conformal.model is None + conformal.set_model(nn.Identity()) + assert isinstance(conformal.model, nn.Identity) + conformal.fit(None) + conformal.forward(None) + conformal.conformal(None) class TestConformalClsAPS: @@ -29,6 +48,9 @@ def test_failures(self): with pytest.raises(NotImplementedError): ConformalClsAPS(score_type="test") + with pytest.raises(ValueError): + ConformalClsRAPS().quantile # noqa: B018 + class TestConformalClsRAPS: """Testing the ConformalClsRAPS class.""" @@ -52,6 +74,9 @@ def test_failures(self): with pytest.raises(NotImplementedError): ConformalClsRAPS(score_type="test") + with pytest.raises(ValueError): + ConformalClsRAPS().quantile # noqa: B018 + class TestConformalClsTHR: """Testing the ConformalClsTHR class.""" @@ -80,3 +105,7 @@ def test_fit(self): assert (out[0] == repeat(torch.tensor([True, True, False]), "c -> b c", b=10)).all() assert out[1].shape == (10,) assert (out[1] == torch.tensor([2.0] * 10)).all() + + def test_failures(self): + with pytest.raises(ValueError): + ConformalClsRAPS().quantile # noqa: B018 diff --git a/torch_uncertainty/post_processing/conformal/conformal_raps.py b/torch_uncertainty/post_processing/conformal/conformal_raps.py index eb11e976..9471558a 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_raps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_raps.py @@ -116,15 +116,12 @@ def fit(self, dataloader: DataLoader) -> None: def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Compute the prediction set for each input.""" - if self.q_hat is None: - raise ValueError("You must calibrate (fit) before calling conformal.") - self.model.eval() with torch.no_grad(): probs = self.forward(inputs) all_scores = self._calculate_all_labels(probs) - pred_set = all_scores <= self.q_hat + pred_set = all_scores <= self.quantile set_size = pred_set.sum(dim=1).float() return pred_set, set_size From 1c61a49909a6e61c6743375ac9ffcf8815dd0dc7 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 5 May 2025 12:20:04 +0200 Subject: [PATCH 61/69] :white_check_mark: Add tests for `CoverageRate` metric --- .../classification/test_coverage_rate.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/metrics/classification/test_coverage_rate.py diff --git a/tests/metrics/classification/test_coverage_rate.py b/tests/metrics/classification/test_coverage_rate.py new file mode 100644 index 00000000..9715f4e9 --- /dev/null +++ b/tests/metrics/classification/test_coverage_rate.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from torch_uncertainty.metrics import CoverageRate + + +class TestCoverageRate: + """Testing the CoverageRate metric class.""" + + def test_main(self) -> None: + metric = CoverageRate() + + preds = torch.tensor( + [ + [True, True, False], + [False, False, True], + [True, True, False], + [False, True, False], + [True, True, False], + [True, True, False], + [False, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + ] + ) + labels = torch.tensor([0] * 10) + metric.update(preds, labels) + assert metric.compute() == pytest.approx(0.7, rel=1e-2) + metric.reset() + labels = torch.tensor([1] * 10) + metric.update(preds, labels) + assert metric.compute() == pytest.approx(0.9, rel=1e-2) + metric.reset() + labels = torch.tensor([2] * 10) + metric.update(preds, labels) + assert metric.compute() == pytest.approx(0.1, rel=1e-2) + + metric = CoverageRate(num_classes=3, average="macro") + labels = torch.tensor([0] * 3 + [1] * 3 + [2] * 4) + metric.update(preds, labels) + assert metric.compute() == pytest.approx(0.5556, rel=1e-2) + + def test_invalid_args(self) -> None: + with pytest.raises(ValueError): + CoverageRate(num_classes=1) + with pytest.raises(ValueError): + CoverageRate(num_classes=3, average="invalid") + with pytest.raises(ValueError): + CoverageRate(num_classes=None, average="macro") From bd982c6b4cc9d74500a2523178478a667c3c32cc Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 5 May 2025 12:25:33 +0200 Subject: [PATCH 62/69] :shirt: Improve coverage slightly --- tests/post_processing/test_conformal.py | 2 +- torch_uncertainty/post_processing/conformal/conformal_aps.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/post_processing/test_conformal.py b/tests/post_processing/test_conformal.py index fb9765af..b8da378b 100644 --- a/tests/post_processing/test_conformal.py +++ b/tests/post_processing/test_conformal.py @@ -49,7 +49,7 @@ def test_failures(self): ConformalClsAPS(score_type="test") with pytest.raises(ValueError): - ConformalClsRAPS().quantile # noqa: B018 + ConformalClsAPS().quantile # noqa: B018 class TestConformalClsRAPS: diff --git a/torch_uncertainty/post_processing/conformal/conformal_aps.py b/torch_uncertainty/post_processing/conformal/conformal_aps.py index c4da65e8..7912bf94 100644 --- a/torch_uncertainty/post_processing/conformal/conformal_aps.py +++ b/torch_uncertainty/post_processing/conformal/conformal_aps.py @@ -98,15 +98,12 @@ def fit(self, dataloader: DataLoader) -> None: def conformal(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Compute the prediction set for each input.""" - if self.q_hat is None: - raise ValueError("You must calibrate (fit) before calling conformal.") - self.model.eval() with torch.no_grad(): probs = self.forward(inputs) all_scores = self._calculate_all_labels(probs) - pred_set = all_scores <= self.q_hat + pred_set = all_scores <= self.quantile set_size = pred_set.sum(dim=1).float() return pred_set, set_size From 94ce44f40949e6f4d5faac0fcc1f57ec24c6be08 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 5 May 2025 13:42:11 +0200 Subject: [PATCH 63/69] :bug: Fix incorrect logging in Classification routine --- tests/_dummies/datamodule.py | 1 + .../classification/test_coverage_rate.py | 3 +- tests/post_processing/test_conformal.py | 2 +- tests/routines/test_classification.py | 25 +++++++++++++ .../metrics/classification/coverage_rate.py | 36 +------------------ torch_uncertainty/routines/classification.py | 17 +++++---- 6 files changed, 40 insertions(+), 44 deletions(-) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index c86de1bc..7844463b 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -43,6 +43,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, + postprocess_set="test", ) self.eval_ood = eval_ood diff --git a/tests/metrics/classification/test_coverage_rate.py b/tests/metrics/classification/test_coverage_rate.py index 9715f4e9..c2cf7aa2 100644 --- a/tests/metrics/classification/test_coverage_rate.py +++ b/tests/metrics/classification/test_coverage_rate.py @@ -27,7 +27,8 @@ def test_main(self) -> None: labels = torch.tensor([0] * 10) metric.update(preds, labels) assert metric.compute() == pytest.approx(0.7, rel=1e-2) - metric.reset() + + metric = CoverageRate(validate_args=False) labels = torch.tensor([1] * 10) metric.update(preds, labels) assert metric.compute() == pytest.approx(0.9, rel=1e-2) diff --git a/tests/post_processing/test_conformal.py b/tests/post_processing/test_conformal.py index b8da378b..49af9d29 100644 --- a/tests/post_processing/test_conformal.py +++ b/tests/post_processing/test_conformal.py @@ -108,4 +108,4 @@ def test_fit(self): def test_failures(self): with pytest.raises(ValueError): - ConformalClsRAPS().quantile # noqa: B018 + ConformalClsTHR().quantile # noqa: B018 diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index e6c02495..af1e8465 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -13,6 +13,7 @@ from torch_uncertainty.ood_criteria import ( EntropyCriterion, ) +from torch_uncertainty.post_processing import ConformalClsTHR from torch_uncertainty.routines import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget @@ -333,6 +334,30 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + def test_one_estimator_conformal(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=3, + num_images=100, + eval_ood=True, + ) + + model = dummy_model( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + ) + routine = ClassificationRoutine( + model=model, + loss=None, + num_classes=3, + is_conformal=True, + post_processing=ConformalClsTHR(), + ) + trainer.test(routine, dm) + def test_classification_failures(self): # num_classes with pytest.raises(ValueError): diff --git a/torch_uncertainty/metrics/classification/coverage_rate.py b/torch_uncertainty/metrics/classification/coverage_rate.py index 8d2ad74c..6fa01ae3 100644 --- a/torch_uncertainty/metrics/classification/coverage_rate.py +++ b/torch_uncertainty/metrics/classification/coverage_rate.py @@ -2,7 +2,7 @@ from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.imports import _XLA_AVAILABLE +from torchmetrics.utilities.data import _bincount class CoverageRate(Metric): @@ -91,37 +91,3 @@ def compute(self) -> Tensor: if self.average == "micro": return _safe_divide(self.correct, self.total) return _safe_divide(self.correct, self.total).mean() - - -def _bincount(x: Tensor, minlength: int | None = None) -> Tensor: - """Implement custom bincount. - - PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running - MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of - `torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption - as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``. - - Args: - x: tensor to count - minlength: minimum length to count - - Returns: - Number of occurrences for each unique element in x - - Example: - >>> x = torch.tensor([0,0,0,1,1,2,2,2,2]) - >>> _bincount(x, minlength=3) - tensor([3, 2, 4]) - - Source: - https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/data.py#L178 - - """ - if minlength is None: - minlength = len(torch.unique(x)) - - if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps: - mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) - return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0) - - return torch.bincount(x, minlength=minlength) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 0679e1b8..f9442310 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -42,7 +42,7 @@ TUOODCriterion, get_ood_criterion, ) -from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing +from torch_uncertainty.post_processing import Conformal, LaplaceApprox, PostProcessing from torch_uncertainty.transforms import ( Mixup, MixupIO, @@ -496,7 +496,10 @@ def test_step( probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) if self.is_conformal: - pred_conformal, confs_conformal = self.model.conformal(inputs) + if isinstance(self.post_processing, Conformal): + pred_conformal, confs_conformal = self.post_processing.conformal(inputs) + else: + pred_conformal, confs_conformal = self.model.conformal(inputs) if self.ood_criterion.input_type == OODCriterionInputType.LOGIT and not self.is_conformal: ood_scores = self.ood_criterion(logits) @@ -612,11 +615,6 @@ def on_test_epoch_end(self) -> None: self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) - if self.is_conformal: - tmp_metrics = self.test_cfm_metrics.compute() - self.log_dict(tmp_metrics, sync_dist=True) - result_dict.update(tmp_metrics) - # already logged result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) @@ -625,6 +623,11 @@ def on_test_epoch_end(self) -> None: self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) + if self.is_conformal: + tmp_metrics = self.test_cfm_metrics.compute() + self.log_dict(tmp_metrics, sync_dist=True) + result_dict.update(tmp_metrics) + if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() shift_severity = self.trainer.datamodule.shift_severity From 2397091d97198785b448292077fc71838e3930bb Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 5 May 2025 20:28:36 +0200 Subject: [PATCH 64/69] :white_check_mark: Improve coverage --- tests/post_processing/test_conformal.py | 12 ++++++++++++ tests/routines/test_classification.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/post_processing/test_conformal.py b/tests/post_processing/test_conformal.py index 49af9d29..afe38241 100644 --- a/tests/post_processing/test_conformal.py +++ b/tests/post_processing/test_conformal.py @@ -44,6 +44,12 @@ def test_fit(self): assert out[1].shape == (10,) assert (out[1] == torch.tensor([2.0] * 10)).all() + conformal = ConformalClsAPS(model=nn.Identity(), randomized=True) + conformal.fit(dl) + out = conformal.conformal(inputs) + assert out[0].shape == (10, 3) + assert out[1].shape == (10,) + def test_failures(self): with pytest.raises(NotImplementedError): ConformalClsAPS(score_type="test") @@ -70,6 +76,12 @@ def test_fit(self): assert out[1].shape == (10,) assert (out[1] == torch.tensor([2.0] * 10)).all() + conformal = ConformalClsRAPS(model=nn.Identity(), randomized=True) + conformal.fit(dl) + out = conformal.conformal(inputs) + assert out[0].shape == (10, 3) + assert out[1].shape == (10,) + def test_failures(self): with pytest.raises(NotImplementedError): ConformalClsRAPS(score_type="test") diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index af1e8465..b7dff4a1 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -358,6 +358,23 @@ def test_one_estimator_conformal(self): ) trainer.test(routine, dm) + model = ConformalClsTHR( + model=dummy_model( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + ), + ) + model.fit(dm.postprocess_dataloader()) + + routine = ClassificationRoutine( + model=model, + loss=None, + num_classes=3, + is_conformal=True, + post_processing=None, + ) + trainer.test(routine, dm) + def test_classification_failures(self): # num_classes with pytest.raises(ValueError): From e4bc701e7683a7468969e5cffc89d77d8bb68e13 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 7 May 2025 17:31:00 +0200 Subject: [PATCH 65/69] :bug: Add missing callback higher level imports --- torch_uncertainty/callbacks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/callbacks/__init__.py b/torch_uncertainty/callbacks/__init__.py index 5561d42f..6857b64d 100644 --- a/torch_uncertainty/callbacks/__init__.py +++ b/torch_uncertainty/callbacks/__init__.py @@ -1,2 +1,2 @@ # ruff: noqa: F401 -from .checkpoint import TUClsCheckpoint +from .checkpoint import TUClsCheckpoint, TURegCheckpoint, TUSegCheckpoint From d8d7912e3cc72a2480355084200832d946c2b024 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 7 May 2025 17:32:14 +0200 Subject: [PATCH 66/69] :bug: Fix #171 --- .../datamodules/classification/cifar10.py | 8 +++++++- .../datamodules/classification/cifar100.py | 8 +++++++- .../datamodules/classification/imagenet.py | 18 ++++++++++++++++-- .../classification/tiny_imagenet.py | 8 +++++++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 5dec3fab..3ce8f128 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -133,7 +133,13 @@ def __init__( elif randaugment: main_transform = v2.RandAugment(num_ops=2, magnitude=20) elif auto_augment: - main_transform = rand_augment_transform(auto_augment, {}) + main_transform = v2.Compose( + [ + v2.ToPILImage(), + rand_augment_transform(auto_augment, {}), + v2.ToImage(), + ] + ) else: main_transform = nn.Identity() diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 1e2a0b6e..c0bab293 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -127,7 +127,13 @@ def __init__( elif randaugment: main_transform = v2.RandAugment(num_ops=2, magnitude=20) elif auto_augment: - main_transform = rand_augment_transform(auto_augment, {}) + main_transform = v2.Compose( + [ + v2.ToPILImage(), + rand_augment_transform(auto_augment, {}), + v2.ToImage(), + ] + ) else: main_transform = nn.Identity() diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d1b228f0..9f3592cb 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -170,20 +170,34 @@ def __init__( if self.procedure is None: if rand_augment_opt is not None: - main_transform = rand_augment_transform(rand_augment_opt, {}) + main_transform = v2.Compose( + [ + v2.ToPILImage(), + rand_augment_transform(rand_augment_opt, {}), + v2.ToImage(), + ] + ) else: main_transform = nn.Identity() elif self.procedure == "ViT": train_size = 224 main_transform = v2.Compose( [ + v2.ToPILImage(), Mixup(mixup_alpha=0.2, cutmix_alpha=1.0), rand_augment_transform("rand-m9-n2-mstd0.5", {}), + v2.ToImage(), ] ) elif self.procedure == "A3": train_size = 160 - main_transform = rand_augment_transform("rand-m6-mstd0.5-inc1", {}) + main_transform = v2.Compose( + [ + v2.ToPILImage(), + rand_augment_transform("rand-m6-mstd0.5-inc1", {}), + v2.ToImage(), + ] + ) else: raise ValueError("The procedure is unknown") diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 52598553..587bdffe 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -129,7 +129,13 @@ def __init__( basic_transform = nn.Identity() if rand_augment_opt is not None: - main_transform = rand_augment_transform(rand_augment_opt, {}) + main_transform = v2.Compose( + [ + v2.ToPILImage(), + rand_augment_transform(rand_augment_opt, {}), + v2.ToImage(), + ] + ) else: main_transform = nn.Identity() From 31312045934a7f890e29a4d36e3104ee7c9737c4 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 7 May 2025 17:37:12 +0200 Subject: [PATCH 67/69] :bug: TULightningCLI won't crash if `data.eval_ood` or `data.eval_shift` do not exist --- torch_uncertainty/utils/cli.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 105176d1..78af649b 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -1,3 +1,4 @@ +import contextlib from collections.abc import Callable from pathlib import Path from typing import Any @@ -155,5 +156,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: super().add_arguments_to_parser(parser) - parser.link_arguments("data.eval_ood", "model.eval_ood") - parser.link_arguments("data.eval_shift", "model.eval_shift") + with contextlib.suppress(ValueError): + parser.link_arguments("data.eval_ood", "model.eval_ood") + with contextlib.suppress(ValueError): + parser.link_arguments("data.eval_shift", "model.eval_shift") From 4fd72abab3872729de9dcd7cf159da33cf923d6b Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 9 May 2025 11:59:22 +0200 Subject: [PATCH 68/69] :art: Refine `torch_uncertainty.models` structure and fix LeNet config files on MNIST --- auto_tutorials_source/tutorial_bayesian.py | 2 +- .../tutorial_evidential_classification.py | 2 +- .../tutorial_mc_batch_norm.py | 2 +- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- auto_tutorials_source/tutorial_scaler.py | 2 +- docs/source/quickstart.rst | 2 +- .../batch_ensemble.yaml} | 0 .../bayesian.yaml} | 30 +++++++++---------- .../checkpoint_ensemble.yaml} | 5 +--- .../deep_ensemble.yaml} | 6 +--- .../{lenet_ema.yaml => lenet/ema.yaml} | 5 +--- .../{lenet.yaml => lenet/standard.yaml} | 5 +--- .../{lenet_swa.yaml => lenet/swa.yaml} | 5 +--- .../{lenet_swag.yaml => lenet/swag.yaml} | 5 +--- tests/models/test_lenet.py | 7 ++++- tests/models/test_resnets.py | 4 +-- tests/models/test_vggs.py | 3 +- tests/models/test_wideresnets.py | 10 +++---- tests/post_processing/test_mc_batch_norm.py | 2 +- .../baselines/classification/resnet.py | 2 +- .../baselines/classification/vgg.py | 2 +- .../baselines/classification/wideresnet.py | 2 +- .../models/classification/__init__.py | 5 ++++ .../models/{ => classification}/lenet.py | 2 +- .../{ => classification}/resnet/__init__.py | 0 .../{ => classification}/resnet/batched.py | 0 .../{ => classification}/resnet/lpbnn.py | 0 .../{ => classification}/resnet/masked.py | 0 .../{ => classification}/resnet/mimo.py | 0 .../{ => classification}/resnet/packed.py | 0 .../models/{ => classification}/resnet/std.py | 0 .../{ => classification}/resnet/utils.py | 0 .../models/classification/vgg/__init__.py | 3 ++ .../models/{ => classification}/vgg/base.py | 0 .../{ => classification}/vgg/configs.py | 0 .../models/{ => classification}/vgg/packed.py | 0 .../models/{ => classification}/vgg/std.py | 0 .../wideresnet/__init__.py | 0 .../wideresnet/batched.py | 0 .../{ => classification}/wideresnet/masked.py | 0 .../{ => classification}/wideresnet/mimo.py | 0 .../{ => classification}/wideresnet/packed.py | 0 .../{ => classification}/wideresnet/std.py | 0 torch_uncertainty/models/vgg/__init__.py | 3 -- 44 files changed, 54 insertions(+), 64 deletions(-) rename experiments/classification/mnist/configs/{lenet_batch_ensemble.yaml => lenet/batch_ensemble.yaml} (100%) rename experiments/classification/mnist/configs/{bayesian_lenet.yaml => lenet/bayesian.yaml} (71%) rename experiments/classification/mnist/configs/{lenet_checkpoint_ensemble.yaml => lenet/checkpoint_ensemble.yaml} (89%) rename experiments/classification/mnist/configs/{lenet_deep_ensemble.yaml => lenet/deep_ensemble.yaml} (89%) rename experiments/classification/mnist/configs/{lenet_ema.yaml => lenet/ema.yaml} (88%) rename experiments/classification/mnist/configs/{lenet.yaml => lenet/standard.yaml} (88%) rename experiments/classification/mnist/configs/{lenet_swa.yaml => lenet/swa.yaml} (88%) rename experiments/classification/mnist/configs/{lenet_swag.yaml => lenet/swag.yaml} (88%) create mode 100644 torch_uncertainty/models/classification/__init__.py rename torch_uncertainty/models/{ => classification}/lenet.py (98%) rename torch_uncertainty/models/{ => classification}/resnet/__init__.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/batched.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/lpbnn.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/masked.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/mimo.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/packed.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/std.py (100%) rename torch_uncertainty/models/{ => classification}/resnet/utils.py (100%) create mode 100644 torch_uncertainty/models/classification/vgg/__init__.py rename torch_uncertainty/models/{ => classification}/vgg/base.py (100%) rename torch_uncertainty/models/{ => classification}/vgg/configs.py (100%) rename torch_uncertainty/models/{ => classification}/vgg/packed.py (100%) rename torch_uncertainty/models/{ => classification}/vgg/std.py (100%) rename torch_uncertainty/models/{ => classification}/wideresnet/__init__.py (100%) rename torch_uncertainty/models/{ => classification}/wideresnet/batched.py (100%) rename torch_uncertainty/models/{ => classification}/wideresnet/masked.py (100%) rename torch_uncertainty/models/{ => classification}/wideresnet/mimo.py (100%) rename torch_uncertainty/models/{ => classification}/wideresnet/packed.py (100%) rename torch_uncertainty/models/{ => classification}/wideresnet/std.py (100%) delete mode 100644 torch_uncertainty/models/vgg/__init__.py diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index c32c991d..0ecee1e8 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -44,7 +44,7 @@ from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.losses import ELBOLoss -from torch_uncertainty.models.lenet import bayesian_lenet +from torch_uncertainty.models.classification import bayesian_lenet from torch_uncertainty.routines import ClassificationRoutine # %% diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index babf2a73..368ed694 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -33,7 +33,7 @@ from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.losses import DECLoss -from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.models.classification import lenet from torch_uncertainty.routines import ClassificationRoutine diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 18df5cff..90a5b3ac 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -29,7 +29,7 @@ from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.models.classification import lenet from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm from torch_uncertainty.routines import ClassificationRoutine diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index d17df7c2..caae6edb 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -35,7 +35,7 @@ from torch import nn from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.models.classification import lenet from torch_uncertainty.models import mc_dropout from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index a5c3d9d4..d0a51099 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -28,7 +28,7 @@ # %% from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.metrics import CalibrationError -from torch_uncertainty.models.resnet import resnet +from torch_uncertainty.models.classification import resnet from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.utils import load_hf diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 6812cc89..18081799 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -157,7 +157,7 @@ backbone with the following code: .. code:: python - from torch_uncertainty.models.resnet import packed_resnet + from torch_uncertainty.models.classification import packed_resnet model = packed_resnet( in_channels = 3, diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet/batch_ensemble.yaml similarity index 100% rename from experiments/classification/mnist/configs/lenet_batch_ensemble.yaml rename to experiments/classification/mnist/configs/lenet/batch_ensemble.yaml diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/lenet/bayesian.yaml similarity index 71% rename from experiments/classification/mnist/configs/bayesian_lenet.yaml rename to experiments/classification/mnist/configs/lenet/bayesian.yaml index dab14676..ce8f6f7f 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/lenet/bayesian.yaml @@ -10,7 +10,7 @@ trainer: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: save_dir: logs/lenet - name: standard + name: bayesian default_hp_metric: false callbacks: - class_path: torch_uncertainty.callbacks.TUClsCheckpoint @@ -24,28 +24,28 @@ trainer: check_finite: true model: model: - class_path: torch_uncertainty.models.StochasticModel + class_path: torch_uncertainty.models.classification.bayesian_lenet init_args: - model: - class_path: torch_uncertainty.models.lenet._LeNet - init_args: - in_channels: 1 - num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d - activation: torch.nn.ReLU - norm: torch.nn.Identity - groups: 1 - dropout_rate: 0 - layer_args: {} + in_channels: 1 + num_classes: 10 num_samples: 16 - num_classes: 10 + prior_sigma_1: null + prior_sigma_2: null + prior_pi: null + mu_init: null + sigma_init: null + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0.0 loss: class_path: torch_uncertainty.losses.ELBOLoss init_args: kl_weight: 0.00002 inner_loss: torch.nn.CrossEntropyLoss num_samples: 3 + num_classes: 10 + is_ensemble: true data: root: ./data batch_size: 128 diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet/checkpoint_ensemble.yaml similarity index 89% rename from experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml rename to experiments/classification/mnist/configs/lenet/checkpoint_ensemble.yaml index b33797d9..766251a5 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet/checkpoint_ensemble.yaml @@ -27,17 +27,14 @@ model: class_path: torch_uncertainty.models.CheckpointEnsemble init_args: model: - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.classification.lenet init_args: in_channels: 1 num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d activation: torch.nn.ReLU norm: torch.nn.Identity groups: 1 dropout_rate: 0 - layer_args: {} save_schedule: - 20 - 25 diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet/deep_ensemble.yaml similarity index 89% rename from experiments/classification/mnist/configs/lenet_deep_ensemble.yaml rename to experiments/classification/mnist/configs/lenet/deep_ensemble.yaml index c5803e2d..1bd72897 100644 --- a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet/deep_ensemble.yaml @@ -31,18 +31,14 @@ model: init_args: models: # LeNet - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.classification.lenet init_args: in_channels: 1 num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d activation: torch.nn.ReLU norm: torch.nn.Identity groups: 1 dropout_rate: 0 - # last_layer_dropout: false - layer_args: {} num_estimators: 5 task: classification probabilistic: false diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet/ema.yaml similarity index 88% rename from experiments/classification/mnist/configs/lenet_ema.yaml rename to experiments/classification/mnist/configs/lenet/ema.yaml index 09873d61..e74d9e5d 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet/ema.yaml @@ -27,17 +27,14 @@ model: class_path: torch_uncertainty.models.wrappers.EMA init_args: model: - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.classification.lenet init_args: in_channels: 1 num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d activation: torch.nn.ReLU norm: torch.nn.Identity groups: 1 dropout_rate: 0 - layer_args: {} momentum: 0.99 num_classes: 10 loss: CrossEntropyLoss diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet/standard.yaml similarity index 88% rename from experiments/classification/mnist/configs/lenet.yaml rename to experiments/classification/mnist/configs/lenet/standard.yaml index f2078eac..0dc96168 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet/standard.yaml @@ -24,17 +24,14 @@ trainer: check_finite: true model: model: - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.classification.lenet init_args: in_channels: 1 num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d activation: torch.nn.ReLU norm: torch.nn.Identity groups: 1 dropout_rate: 0 - layer_args: {} num_classes: 10 loss: CrossEntropyLoss data: diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet/swa.yaml similarity index 88% rename from experiments/classification/mnist/configs/lenet_swa.yaml rename to experiments/classification/mnist/configs/lenet/swa.yaml index 8c374ff3..5627c4d3 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet/swa.yaml @@ -27,17 +27,14 @@ model: class_path: torch_uncertainty.models.wrappers.SWA init_args: model: - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.classification.lenet init_args: in_channels: 1 num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d activation: torch.nn.ReLU norm: torch.nn.Identity groups: 1 dropout_rate: 0 - layer_args: {} cycle_start: 19 cycle_length: 5 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet/swag.yaml similarity index 88% rename from experiments/classification/mnist/configs/lenet_swag.yaml rename to experiments/classification/mnist/configs/lenet/swag.yaml index 9acc27d1..59ee3e89 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet/swag.yaml @@ -27,17 +27,14 @@ model: class_path: torch_uncertainty.models.wrappers.SWAG init_args: model: - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.classification.lenet init_args: in_channels: 1 num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d activation: torch.nn.ReLU norm: torch.nn.Identity groups: 1 dropout_rate: 0 - layer_args: {} cycle_start: 10 cycle_length: 5 num_classes: 10 diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index c6a08180..ef1121a3 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -2,7 +2,12 @@ import torch from torch import nn -from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet +from torch_uncertainty.models.classification import ( + batchensemble_lenet, + bayesian_lenet, + lenet, + packed_lenet, +) class TestLeNet: diff --git a/tests/models/test_resnets.py b/tests/models/test_resnets.py index fe757591..6ba1538c 100644 --- a/tests/models/test_resnets.py +++ b/tests/models/test_resnets.py @@ -1,7 +1,7 @@ import pytest import torch -from torch_uncertainty.models.resnet import ( +from torch_uncertainty.models.classification import ( batched_resnet, lpbnn_resnet, masked_resnet, @@ -9,7 +9,7 @@ packed_resnet, resnet, ) -from torch_uncertainty.models.resnet.utils import get_resnet_num_blocks +from torch_uncertainty.models.classification.resnet.utils import get_resnet_num_blocks class TestResnet: diff --git a/tests/models/test_vggs.py b/tests/models/test_vggs.py index e281d2d9..d96cb1c9 100644 --- a/tests/models/test_vggs.py +++ b/tests/models/test_vggs.py @@ -1,7 +1,6 @@ import pytest -from torch_uncertainty.models.vgg.packed import packed_vgg -from torch_uncertainty.models.vgg.std import vgg +from torch_uncertainty.models.classification import packed_vgg, vgg class TestVGGs: diff --git a/tests/models/test_wideresnets.py b/tests/models/test_wideresnets.py index 1d50bb91..38b3def4 100644 --- a/tests/models/test_wideresnets.py +++ b/tests/models/test_wideresnets.py @@ -1,20 +1,20 @@ import pytest import torch -from torch_uncertainty.models.wideresnet import wideresnet28x10 -from torch_uncertainty.models.wideresnet.batched import ( +from torch_uncertainty.models.classification import wideresnet28x10 +from torch_uncertainty.models.classification.wideresnet.batched import ( _BatchWideResNet, batched_wideresnet28x10, ) -from torch_uncertainty.models.wideresnet.masked import ( +from torch_uncertainty.models.classification.wideresnet.masked import ( _MaskedWideResNet, masked_wideresnet28x10, ) -from torch_uncertainty.models.wideresnet.mimo import ( +from torch_uncertainty.models.classification.wideresnet.mimo import ( _MIMOWideResNet, mimo_wideresnet28x10, ) -from torch_uncertainty.models.wideresnet.packed import ( +from torch_uncertainty.models.classification.wideresnet.packed import ( _PackedWideResNet, packed_wideresnet28x10, ) diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index 1d8d552b..209e68b9 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -8,7 +8,7 @@ from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d -from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.models.classification import lenet from torch_uncertainty.post_processing import MCBatchNorm diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 1e2b18ab..3d6faa86 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -4,7 +4,7 @@ from torch.optim import Optimizer from torch_uncertainty.models import mc_dropout -from torch_uncertainty.models.resnet import ( +from torch_uncertainty.models.classification import ( batched_resnet, lpbnn_resnet, masked_resnet, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 0d0b887e..53ac0087 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -4,7 +4,7 @@ from torch.optim import Optimizer from torch_uncertainty.models import mc_dropout -from torch_uncertainty.models.vgg import ( +from torch_uncertainty.models.classification import ( packed_vgg, vgg, ) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 1be8939c..39b13739 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -4,7 +4,7 @@ from torch.optim import Optimizer from torch_uncertainty.models import mc_dropout -from torch_uncertainty.models.wideresnet import ( +from torch_uncertainty.models.classification import ( batched_wideresnet28x10, masked_wideresnet28x10, mimo_wideresnet28x10, diff --git a/torch_uncertainty/models/classification/__init__.py b/torch_uncertainty/models/classification/__init__.py new file mode 100644 index 00000000..8780ab72 --- /dev/null +++ b/torch_uncertainty/models/classification/__init__.py @@ -0,0 +1,5 @@ +# ruff: noqa: F401, F403 +from .lenet import * +from .resnet import * +from .vgg import * +from .wideresnet import * diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/classification/lenet.py similarity index 98% rename from torch_uncertainty/models/lenet.py rename to torch_uncertainty/models/classification/lenet.py index b31ba133..d6edbe28 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/classification/lenet.py @@ -11,7 +11,7 @@ from torch_uncertainty.models import StochasticModel from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble -__all__ = ["bayesian_lenet", "lenet", "packed_lenet"] +__all__ = ["batchensemble_lenet", "bayesian_lenet", "lenet", "packed_lenet"] class _LeNet(nn.Module): diff --git a/torch_uncertainty/models/resnet/__init__.py b/torch_uncertainty/models/classification/resnet/__init__.py similarity index 100% rename from torch_uncertainty/models/resnet/__init__.py rename to torch_uncertainty/models/classification/resnet/__init__.py diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/classification/resnet/batched.py similarity index 100% rename from torch_uncertainty/models/resnet/batched.py rename to torch_uncertainty/models/classification/resnet/batched.py diff --git a/torch_uncertainty/models/resnet/lpbnn.py b/torch_uncertainty/models/classification/resnet/lpbnn.py similarity index 100% rename from torch_uncertainty/models/resnet/lpbnn.py rename to torch_uncertainty/models/classification/resnet/lpbnn.py diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/classification/resnet/masked.py similarity index 100% rename from torch_uncertainty/models/resnet/masked.py rename to torch_uncertainty/models/classification/resnet/masked.py diff --git a/torch_uncertainty/models/resnet/mimo.py b/torch_uncertainty/models/classification/resnet/mimo.py similarity index 100% rename from torch_uncertainty/models/resnet/mimo.py rename to torch_uncertainty/models/classification/resnet/mimo.py diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/classification/resnet/packed.py similarity index 100% rename from torch_uncertainty/models/resnet/packed.py rename to torch_uncertainty/models/classification/resnet/packed.py diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/classification/resnet/std.py similarity index 100% rename from torch_uncertainty/models/resnet/std.py rename to torch_uncertainty/models/classification/resnet/std.py diff --git a/torch_uncertainty/models/resnet/utils.py b/torch_uncertainty/models/classification/resnet/utils.py similarity index 100% rename from torch_uncertainty/models/resnet/utils.py rename to torch_uncertainty/models/classification/resnet/utils.py diff --git a/torch_uncertainty/models/classification/vgg/__init__.py b/torch_uncertainty/models/classification/vgg/__init__.py new file mode 100644 index 00000000..05beb228 --- /dev/null +++ b/torch_uncertainty/models/classification/vgg/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa: F401, F403 +from .packed import packed_vgg +from .std import vgg diff --git a/torch_uncertainty/models/vgg/base.py b/torch_uncertainty/models/classification/vgg/base.py similarity index 100% rename from torch_uncertainty/models/vgg/base.py rename to torch_uncertainty/models/classification/vgg/base.py diff --git a/torch_uncertainty/models/vgg/configs.py b/torch_uncertainty/models/classification/vgg/configs.py similarity index 100% rename from torch_uncertainty/models/vgg/configs.py rename to torch_uncertainty/models/classification/vgg/configs.py diff --git a/torch_uncertainty/models/vgg/packed.py b/torch_uncertainty/models/classification/vgg/packed.py similarity index 100% rename from torch_uncertainty/models/vgg/packed.py rename to torch_uncertainty/models/classification/vgg/packed.py diff --git a/torch_uncertainty/models/vgg/std.py b/torch_uncertainty/models/classification/vgg/std.py similarity index 100% rename from torch_uncertainty/models/vgg/std.py rename to torch_uncertainty/models/classification/vgg/std.py diff --git a/torch_uncertainty/models/wideresnet/__init__.py b/torch_uncertainty/models/classification/wideresnet/__init__.py similarity index 100% rename from torch_uncertainty/models/wideresnet/__init__.py rename to torch_uncertainty/models/classification/wideresnet/__init__.py diff --git a/torch_uncertainty/models/wideresnet/batched.py b/torch_uncertainty/models/classification/wideresnet/batched.py similarity index 100% rename from torch_uncertainty/models/wideresnet/batched.py rename to torch_uncertainty/models/classification/wideresnet/batched.py diff --git a/torch_uncertainty/models/wideresnet/masked.py b/torch_uncertainty/models/classification/wideresnet/masked.py similarity index 100% rename from torch_uncertainty/models/wideresnet/masked.py rename to torch_uncertainty/models/classification/wideresnet/masked.py diff --git a/torch_uncertainty/models/wideresnet/mimo.py b/torch_uncertainty/models/classification/wideresnet/mimo.py similarity index 100% rename from torch_uncertainty/models/wideresnet/mimo.py rename to torch_uncertainty/models/classification/wideresnet/mimo.py diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/classification/wideresnet/packed.py similarity index 100% rename from torch_uncertainty/models/wideresnet/packed.py rename to torch_uncertainty/models/classification/wideresnet/packed.py diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/classification/wideresnet/std.py similarity index 100% rename from torch_uncertainty/models/wideresnet/std.py rename to torch_uncertainty/models/classification/wideresnet/std.py diff --git a/torch_uncertainty/models/vgg/__init__.py b/torch_uncertainty/models/vgg/__init__.py deleted file mode 100644 index 837e5d97..00000000 --- a/torch_uncertainty/models/vgg/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# ruff: noqa: F401, F403 -from .packed import * -from .std import * From a48a51f3d3bad692b6b442df757e4d8d35d99b54 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 10 May 2025 15:19:28 +0200 Subject: [PATCH 69/69] :hammer: Replace `legacy` implementation from the `PackedLinear` with `conv1d` implementation - Packed layers with `last=True` will pass the `num_estimators` from the feature dimension to the batch dimension --- .../tutorial_from_de_to_pe.py | 3 - auto_tutorials_source/tutorial_pe_cifar10.py | 3 +- tests/layers/test_packed.py | 87 +++++---- torch_uncertainty/layers/functional/packed.py | 7 + torch_uncertainty/layers/packed.py | 184 +++++++++++------- .../models/classification/resnet/packed.py | 12 +- .../models/classification/resnet/std.py | 4 +- .../models/classification/vgg/base.py | 8 - .../classification/wideresnet/packed.py | 2 - 9 files changed, 180 insertions(+), 130 deletions(-) diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py index ea2180a2..64f35867 100644 --- a/auto_tutorials_source/tutorial_from_de_to_pe.py +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -353,9 +353,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = F.max_pool2d(out, 2) out = F.relu(self.conv2(out)) out = F.max_pool2d(out, 2) - out = rearrange( - out, "e (m c) h w -> (m e) c h w", m=self.num_estimators - ) out = torch.flatten(out, 1) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) diff --git a/auto_tutorials_source/tutorial_pe_cifar10.py b/auto_tutorials_source/tutorial_pe_cifar10.py index d3a233cd..ae7631b2 100644 --- a/auto_tutorials_source/tutorial_pe_cifar10.py +++ b/auto_tutorials_source/tutorial_pe_cifar10.py @@ -188,7 +188,6 @@ def __init__(self) -> None: def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) - x = rearrange(x, "e (m c) h w -> (m e) c h w", m=self.num_estimators) x = x.flatten(1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) @@ -267,7 +266,7 @@ def forward(self, x): # Let us see what the Packed-Ensemble thinks these examples above are: logits = packed_net(images) -logits = rearrange(logits, "(n b) c -> b n c", n=packed_net.num_estimators) +logits = rearrange(logits, "(m b) c -> b m c", m=packed_net.num_estimators) probs_per_est = F.softmax(logits, dim=-1) outputs = probs_per_est.mean(dim=1) diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index 0efdfab0..7ad51764 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -131,35 +131,29 @@ def extended_batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: class TestPackedLinear: """Testing the PackedLinear layer class.""" - # Legacy tests - def test_linear_one_estimator_no_rearrange(self, feat_input: torch.Tensor): - layer = PackedLinear(6, 2, alpha=1, num_estimators=1, rearrange=False, bias=False) - out = layer(feat_input) - assert out.shape == torch.Size([2, 1]) - - def test_linear_two_estimators_no_rearrange(self, feat_input: torch.Tensor): - layer = PackedLinear(6, 2, alpha=1, num_estimators=2, rearrange=False) - out = layer(feat_input) - assert out.shape == torch.Size([2, 1]) - - def test_linear_one_estimator_rearrange(self, feat_input_one_rearrange: torch.Tensor): - layer = PackedLinear(5, 2, alpha=1, num_estimators=1, rearrange=True) - out = layer(feat_input_one_rearrange) - assert out.shape == torch.Size([3, 2]) - - def test_linear_two_estimator_rearrange_not_divisible(self): - feat = torch.rand((2 * 3, 3)) - layer = PackedLinear(5, 1, alpha=1, num_estimators=2, rearrange=True) - out = layer(feat) - assert out.shape == torch.Size([6, 1]) + # Conv1d implementation tests + def test_linear_conv1d_implementation( + self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor + ): + layer = PackedLinear(16, 4, alpha=2, num_estimators=1, implementation="conv1d", first=True) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([3, 8]) + layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="conv1d") + out = layer(feat_input_16_features) + assert out.shape == torch.Size([3, 4]) + layer = PackedLinear(6, 2, alpha=1, num_estimators=1, implementation="conv1d") + out = layer(feat_multi_dim) + assert out.shape == torch.Size([1, 2, 3, 4, 2]) # Full implementation tests def test_linear_full_implementation( self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor ): - layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full", bias=False) + layer = PackedLinear( + 16, 4, alpha=2, num_estimators=1, implementation="full", bias=False, first=True + ) out = layer(feat_input_16_features) - assert out.shape == torch.Size([3, 4]) + assert out.shape == torch.Size([3, 8]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="full") out = layer(feat_input_16_features) assert out.shape == torch.Size([3, 4]) @@ -171,9 +165,9 @@ def test_linear_full_implementation( def test_linear_sparse_implementation( self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor ): - layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="sparse") + layer = PackedLinear(16, 4, alpha=2, num_estimators=1, implementation="sparse", first=True) out = layer(feat_input_16_features) - assert out.shape == torch.Size([3, 4]) + assert out.shape == torch.Size([3, 8]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="sparse") out = layer(feat_input_16_features) assert out.shape == torch.Size([3, 4]) @@ -185,9 +179,9 @@ def test_linear_sparse_implementation( def test_linear_einsum_implementation( self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor ): - layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="einsum") + layer = PackedLinear(16, 4, alpha=2, num_estimators=1, implementation="einsum", first=True) out = layer(feat_input_16_features) - assert out.shape == torch.Size([3, 4]) + assert out.shape == torch.Size([3, 8]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="einsum") out = layer(feat_input_16_features) assert out.shape == torch.Size([3, 4]) @@ -195,19 +189,28 @@ def test_linear_einsum_implementation( out = layer(feat_multi_dim) assert out.shape == torch.Size([1, 2, 3, 4, 2]) + # Conv1d implementation tests + def test_linear_last_parameter( + self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor + ): + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="conv1d", last=True) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([3, 4]) + layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="full", last=True) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([6, 4]) + layer = PackedLinear(6, 2, alpha=1, num_estimators=1, implementation="sparse", last=True) + out = layer(feat_multi_dim) + assert out.shape == torch.Size([1, 2, 3, 4, 2]) + layer = PackedLinear(6, 2, alpha=1, num_estimators=2, implementation="einsum", last=True) + out = layer(feat_multi_dim) + assert out.shape == torch.Size([2, 2, 3, 4, 2]) + def test_linear_extend(self): - layer = PackedLinear(5, 3, alpha=1, num_estimators=2, gamma=1, implementation="legacy") - assert layer.weight.shape == torch.Size([4, 3, 1]) - assert layer.bias.shape == torch.Size([4]) layer = PackedLinear(5, 3, alpha=1, num_estimators=2, gamma=1, implementation="full") assert layer.weight.shape == torch.Size([2, 2, 3]) assert layer.bias.shape == torch.Size([4]) # with first=True - layer = PackedLinear( - 5, 3, alpha=1, num_estimators=2, gamma=1, implementation="legacy", first=True - ) - assert layer.weight.shape == torch.Size([4, 5, 1]) - assert layer.bias.shape == torch.Size([4]) layer = PackedLinear( 5, 3, alpha=1, num_estimators=2, gamma=1, implementation="full", first=True ) @@ -216,25 +219,25 @@ def test_linear_extend(self): def test_linear_failures(self): with pytest.raises(ValueError): - _ = PackedLinear(5, 2, alpha=None, num_estimators=1, rearrange=True) + _ = PackedLinear(5, 2, alpha=None, num_estimators=1) with pytest.raises(ValueError): - _ = PackedLinear(5, 2, alpha=-1, num_estimators=1, rearrange=True) + _ = PackedLinear(5, 2, alpha=-1, num_estimators=1) with pytest.raises(ValueError): - _ = PackedLinear(5, 2, alpha=1, num_estimators=None, rearrange=True) + _ = PackedLinear(5, 2, alpha=1, num_estimators=None) with pytest.raises(TypeError): - _ = PackedLinear(5, 2, alpha=1, num_estimators=1.5, rearrange=True) + _ = PackedLinear(5, 2, alpha=1, num_estimators=1.5) with pytest.raises(ValueError): - _ = PackedLinear(5, 2, alpha=1, num_estimators=-1, rearrange=True) + _ = PackedLinear(5, 2, alpha=1, num_estimators=-1) with pytest.raises(TypeError): - _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=0.5, rearrange=True) + _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=0.5) with pytest.raises(ValueError): - _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=-1, rearrange=True) + _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=-1) with pytest.raises(ValueError): _ = PackedLinear( diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index a5ab40f9..3f4d15f6 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -24,6 +24,7 @@ def packed_linear( transformation using `torch.nn.functional.linear`. - "sparse": uses a sparse weight tensor directly to apply the linear transformation. - "einsum": uses `torch.einsum` to apply the packed linear transformation. + - "conv1d": uses `torch.nn.functional.conv1d` to apply the packed linear transformation. rearrange (bool, optional): _description_. Defaults to True. bias (Tensor | None, optional): _description_. Defaults to None. @@ -47,6 +48,12 @@ def packed_linear( if bias is not None: out += bias return out + if implementation == "conv1d": + input_size = inputs.size() + inputs = rearrange(inputs, "... d -> (...) d 1") + weight = rearrange(weight, "m i j -> (m i) j 1") + out = F.conv1d(inputs, weight, bias, stride=1, padding=0, dilation=1, groups=num_groups) + return out.reshape(input_size[:-1] + (-1,)) raise ValueError(f"Unknown implementation: {implementation}") diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 66150367..a74a1a4d 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -49,8 +49,7 @@ def __init__( bias: bool = True, first: bool = False, last: bool = False, - implementation: str = "legacy", - rearrange: bool = True, + implementation: str = "conv1d", device=None, dtype=None, ) -> None: @@ -73,17 +72,26 @@ def __init__( Defaults to ``False``. implementation (str, optional): The implementation to use. Available implementations: - - ``"legacy"`` (default): The legacy implementation of the linear layer. + - ``"conv1d"`` (default): The conv1d implementation of the linear layer. - ``"sparse"``: The sparse implementation of the linear layer. - ``"full"``: The full implementation of the linear layer. - ``"einsum"``: The einsum implementation of the linear layer. - rearrange (bool, optional): Rearrange the input and outputs for - compatibility with previous and later layers. Defaults to ``True``. device (torch.device, optional): The device to use for the layer's parameters. Defaults to ``None``. dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to ``None``. + Shape: + - Input: + - If :attr:`first` is ``True``: :math:`(B, \ast, H_{\text{in}})` where + :math:`B` is the batch size, :math:`\ast` means any number of + additional dimensions and :math:`H_{\text{in}}=\text{in\_features}`. + - Otherwise: :math:`(B, \ast, H_{\text{in}} \times \alpha)` + - Output: + - If :attr:`last` is ``True``: :math:`(B, \ast, H_{\text{out}}\times M)` where + :math:`H_{\text{out}}=\text{out\_features}` and :math:`M=\text{num\_estimators}`. + - Otherwise: :math:`(B, \ast, H_{\text{out}} \times \alpha)` + Explanation Note: Increasing :attr:`alpha` will increase the number of channels of the ensemble, increasing its representation capacity. Increasing @@ -98,24 +106,20 @@ def __init__( :attr:`n_estimators` :math:`\times`:attr:`gamma`. However, the number of input and output features will be changed to comply with this constraint. - - Note: - The input should be of shape (`batch_size`, :attr:`in_features`, 1, - 1). The (often) necessary rearrange operation is executed by - default. """ check_packed_parameters_consistency(alpha, gamma, num_estimators) - if implementation not in ["legacy", "sparse", "full", "einsum"]: + if implementation not in ["sparse", "full", "einsum", "conv1d"]: raise ValueError( f"Unknown implementation: {implementation} for PackedLinear" - "Available implementations are: 'legacy', 'sparse', 'full', 'einsum'" + "Available implementations are: 'legacy', 'sparse', 'full', 'einsum', 'conv1d'" ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.first = first + self.last = last self.num_estimators = num_estimators self.rearrange = rearrange self.implementation = implementation @@ -135,28 +139,16 @@ def __init__( num_estimators * gamma ) - if self.implementation == "legacy": - self.weight = nn.Parameter( - torch.empty( - ( - extended_out_features, - extended_in_features // actual_groups, - 1, - ), - **factory_kwargs, - ) - ) - else: - self.weight = nn.Parameter( - torch.empty( - ( - actual_groups, - extended_out_features // actual_groups, - extended_in_features // actual_groups, - ), - **factory_kwargs, - ) + self.weight = nn.Parameter( + torch.empty( + ( + actual_groups, + extended_out_features // actual_groups, + extended_in_features // actual_groups, + ), + **factory_kwargs, ) + ) self.in_features = extended_in_features // actual_groups self.out_features = extended_out_features // actual_groups @@ -183,26 +175,19 @@ def reset_parameters(self) -> None: if self.implementation == "sparse": self.weight = nn.Parameter(torch.block_diag(*self.weight).to_sparse()) - def _rearrange_forward(self, x: Tensor) -> Tensor: - x = x.unsqueeze(-1) - if not self.first: - x = rearrange(x, "(m e) c h -> e (m c) h", m=self.num_estimators) - x = F.conv1d(x, self.weight, self.bias, 1, 0, 1, self.groups) - x = rearrange(x, "e (m c) h -> (m e) c h", m=self.num_estimators) - return x.squeeze(-1) - def forward(self, inputs: Tensor) -> Tensor: - if self.implementation == "legacy": - if self.rearrange: - return self._rearrange_forward(inputs) - return F.conv1d(inputs, self.weight, self.bias, 1, 0, 1, self.groups) - return packed_linear( + out = packed_linear( inputs=inputs, weight=self.weight, num_groups=self.groups, implementation=self.implementation, bias=self.bias, ) + return ( + out + if not self.last + else rearrange(out, "b ... (m h) -> (m b) ... h", m=self.num_estimators) + ) class PackedConv1d(nn.Module): @@ -258,6 +243,17 @@ def __init__( dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to ``None``. + Shape: + - Input: + - If :attr:`first` is ``True``: :math:`(B, C_{\text{in}}, L_{\text{in}})` where + :math:`B` is the batch size, :math:`C_{\text{in}}=\text{in\_channels}`, and + :math:`L_{\text{in}}` is the length of the signal sequence. + - Otherwise: :math:`(B, C_{\text{in}} \times \alpha, L_{\text{in}})` + - Output: + - If :attr:`last` is ``True``: :math:`(B, C_{\text{out}}\times M, L_{\text{out}})` + where :math:`C_{\text{out}}=\text{out\_channels}` and :math:`M=\text{num\_estimators}`. + - Otherwise: :math:`(B, C_{\text{out}} \times \alpha, L_{\text{out}})` + Explanation Note: Increasing :attr:`alpha` will increase the number of channels of the ensemble, increasing its representation capacity. Increasing @@ -276,7 +272,8 @@ def __init__( check_packed_parameters_consistency(alpha, gamma, num_estimators) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - + self.first = first + self.last = last self.num_estimators = num_estimators # Define the number of channels of the underlying convolution @@ -313,7 +310,12 @@ def __init__( ) def forward(self, inputs: Tensor) -> Tensor: - return self.conv(inputs) + out = self.conv(inputs) + return ( + out + if not self.last + else rearrange(out, "b (m c) ... -> (m b) c ...", m=self.num_estimators) + ) @property def weight(self) -> Tensor: @@ -379,6 +381,17 @@ def __init__( dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to ``None``. + Shape: + - Input: + - If :attr:`first` is ``True``: :math:`(B, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})` where + :math:`B` is the batch size, :math:`C_{\text{in}}=\text{in\_channels}`, + :math:`H_{\text{in}}` and :math:`W_{\text{in}}` are the height and width of the input image. + - Otherwise: :math:`(B, C_{\text{in}} \times \alpha, H_{\text{in}}, W_{\text{in}})` + - Output: + - If :attr:`last` is ``True``: :math:`(B, C_{\text{out}}\times M, H_{\text{out}}, W_{\text{out}})` + where :math:`C_{\text{out}}=\text{out\_channels}` and :math:`M=\text{num\_estimators}`. + - Otherwise: :math:`(B, C_{\text{out}} \times \alpha, H_{\text{out}}, W_{\text{out}})` + Explanation Note: Increasing :attr:`alpha` will increase the number of channels of the ensemble, increasing its representation capacity. Increasing @@ -398,6 +411,8 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + self.first = first + self.last = last self.num_estimators = num_estimators # Define the number of channels of the underlying convolution @@ -434,7 +449,12 @@ def __init__( ) def forward(self, inputs: Tensor) -> Tensor: - return self.conv(inputs) + out = self.conv(inputs) + return ( + out + if not self.last + else rearrange(out, "b (m c) ... -> (m b) c ...", m=self.num_estimators) + ) @property def weight(self) -> Tensor: @@ -500,6 +520,18 @@ def __init__( dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to ``None``. + Shape: + - Input: + - If :attr:`first` is ``True``: :math:`(B, C_{\text{in}}, D_{\text{in}}, H__{\text{in}}, W__{\text{in}})` + where :math:`B` is the batch size, :math:`C_{\text{in}}=\text{in\_channels}`, + :math:`D_{\text{in}}` is the depth of the input, :math:`H_{\text{in}}` + and :math:`W_{\text{in}}` are height and width of the input planes. + - Otherwise: :math:`(B, C_{\text{in}} \times \alpha, D__{\text{in}}, H_{\text{in}}, W_{\text{in}})` + - Output: + - If :attr:`last` is ``True``: :math:`(B, C_{\text{out}}\times M, D_{\text{out}}, H__{\text{out}}, W__{\text{out}})` where + :math:`C_{\text{out}}=\text{out\_channels}` and :math:`M=\text{num\_estimators}`. + - Otherwise: :math:`(B, C_{\text{out}} \times \alpha, D_{\text{out}}, H__{\text{out}}, W__{\text{out}})` + Explanation Note: Increasing :attr:`alpha` will increase the number of channels of the ensemble, increasing its representation capacity. Increasing @@ -519,6 +551,8 @@ def __init__( super().__init__() check_packed_parameters_consistency(alpha, gamma, num_estimators) + self.first = first + self.last = last self.num_estimators = num_estimators # Define the number of channels of the underlying convolution @@ -555,7 +589,12 @@ def __init__( ) def forward(self, inputs: Tensor) -> Tensor: - return self.conv(inputs) + out = self.conv(inputs) + return ( + out + if not self.last + else rearrange(out, "b (m c) ... -> (m b) c ...", m=self.num_estimators) + ) @property def weight(self) -> Tensor: @@ -707,6 +746,8 @@ def __init__( self.num_estimators = num_estimators self.alpha = alpha self.gamma = gamma + self.first = first + self.last = last if not self._qkv_same_embed_dim: self.q_proj_weight = nn.Parameter( @@ -766,18 +807,22 @@ def __init__( else: self.bias_k = self.bias_v = None - self.out_proj = PackedLinear( - in_features=embed_dim, - out_features=embed_dim, - alpha=alpha, - num_estimators=num_estimators, - gamma=gamma, - implementation="einsum", - bias=bias, - first=False, - last=last, - **factory_kwargs, + out_embed_dim = int(embed_dim * (num_estimators if last else alpha)) + + self.out_proj_weight = nn.Parameter( + torch.empty( + ( + self.num_groups, + out_embed_dim // self.num_groups, + self.embed_dim // self.num_groups, + ), + **factory_kwargs, + ) ) + if bias: + self.out_proj_bias = nn.Parameter(torch.empty(out_embed_dim, **factory_kwargs)) + else: + self.register_parameter("out_proj_bias", None) self.add_zero_attn = add_zero_attn @@ -793,9 +838,12 @@ def _reset_parameters(self): nn.init.xavier_uniform_(self.k_proj_weight[i]) nn.init.xavier_uniform_(self.v_proj_weight[i]) + for i in range(self.out_proj_weight.size(0)): + nn.init.xavier_uniform_(self.out_proj_weight[i]) + if self.in_proj_bias is not None: nn.init.constant_(self.in_proj_bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) + nn.init.constant_(self.out_proj_bias, 0.0) def forward( self, @@ -909,8 +957,8 @@ def forward( self.bias_v, self.add_zero_attn, self.dropout, - self.out_proj.weight, - self.out_proj.bias, + self.out_proj_weight, + self.out_proj_bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -939,8 +987,8 @@ def forward( self.bias_v, self.add_zero_attn, self.dropout, - self.out_proj.weight, - self.out_proj.bias, + self.out_proj_weight, + self.out_proj_bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -948,6 +996,10 @@ def forward( average_attn_weights=average_attn_weights, is_causal=is_causal, ) + + if self.last: + attn_output = rearrange(attn_output, "l b (m e) -> l (m b) e", m=self.num_estimators) + if self.batch_first and is_batched: return attn_output.transpose(1, 0), None return attn_output, None diff --git a/torch_uncertainty/models/classification/resnet/packed.py b/torch_uncertainty/models/classification/resnet/packed.py index cf16343f..353011f5 100644 --- a/torch_uncertainty/models/classification/resnet/packed.py +++ b/torch_uncertainty/models/classification/resnet/packed.py @@ -1,7 +1,6 @@ from typing import Any, Literal import torch.nn.functional as F -from einops import rearrange from torch import Tensor, nn from torch_uncertainty.layers import PackedConv2d, PackedLinear @@ -205,6 +204,7 @@ def __init__( style: Literal["imagenet", "cifar"] = "imagenet", in_planes: int = 64, normalization_layer: type[nn.Module] = nn.BatchNorm2d, + linear_implementation: str = "conv1d", ) -> None: super().__init__() @@ -323,6 +323,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, last=True, + implementation=linear_implementation, ) def _make_layer( @@ -366,9 +367,6 @@ def forward(self, x: Tensor) -> Tensor: out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) - - out = rearrange(out, "e (m c) h w -> (m e) c h w", m=self.num_estimators) - out = self.pool(out) out = self.final_dropout(self.flatten(out)) return self.linear(out) @@ -390,13 +388,14 @@ def packed_resnet( num_estimators: int, alpha: int, gamma: int, - conv_bias: bool = True, + conv_bias: bool = False, width_multiplier: float = 1.0, groups: int = 1, dropout_rate: float = 0, style: Literal["imagenet", "cifar"] = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, pretrained: bool = False, + linear_implementation: str = "conv1d", ) -> _PackedResNet: """Packed-Ensembles of ResNet. @@ -417,6 +416,8 @@ def packed_resnet( normalization_layer (nn.Module, optional): Normalization layer. pretrained (bool, optional): Whether to load pretrained weights. Defaults to ``False``. + linear_implementation (str, optional): Implementation of the + packed linear layer. Defaults to ``conv1d``. Returns: _PackedResNet: A Packed-Ensembles ResNet. @@ -437,6 +438,7 @@ def packed_resnet( style=style, in_planes=int(in_planes * width_multiplier), normalization_layer=normalization_layer, + linear_implementation=linear_implementation, ) if pretrained: # coverage: ignore weights = weight_ids[str(num_classes)][str(arch)] diff --git a/torch_uncertainty/models/classification/resnet/std.py b/torch_uncertainty/models/classification/resnet/std.py index 1cd096bc..baa2baa1 100644 --- a/torch_uncertainty/models/classification/resnet/std.py +++ b/torch_uncertainty/models/classification/resnet/std.py @@ -348,7 +348,7 @@ def resnet( in_channels: int, num_classes: int, arch: int, - conv_bias: bool = True, + conv_bias: bool = False, dropout_rate: float = 0.0, width_multiplier: float = 1.0, groups: int = 1, @@ -363,7 +363,7 @@ def resnet( num_classes (int): Number of classes to predict. arch (int): The architecture of the ResNet. conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. + ``False``. dropout_rate (float): Dropout rate. Defaults to 0.0. width_multiplier (float): Width multiplier. Defaults to 1.0. groups (int): Number of groups in convolutions. Defaults to 1. diff --git a/torch_uncertainty/models/classification/vgg/base.py b/torch_uncertainty/models/classification/vgg/base.py index 96b48f06..065bcbe5 100644 --- a/torch_uncertainty/models/classification/vgg/base.py +++ b/torch_uncertainty/models/classification/vgg/base.py @@ -1,6 +1,5 @@ from typing import Any -from einops import rearrange from torch import Tensor, nn from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear @@ -111,13 +110,6 @@ def _make_layers(self, cfg: list) -> nn.Sequential: def feats_forward(self, x: Tensor) -> Tensor: x = self.features(x) - - if self.linear_layer == PackedLinear: - x = rearrange( - x, - "e (m c) h w -> (m e) c h w", - m=self.model_kwargs["num_estimators"], - ) x = self.avgpool(x) return self.flatten(x) diff --git a/torch_uncertainty/models/classification/wideresnet/packed.py b/torch_uncertainty/models/classification/wideresnet/packed.py index d1ffbd21..01e9b1b9 100644 --- a/torch_uncertainty/models/classification/wideresnet/packed.py +++ b/torch_uncertainty/models/classification/wideresnet/packed.py @@ -1,7 +1,6 @@ from collections.abc import Callable from typing import Literal -from einops import rearrange from torch import Tensor, nn from torch.nn.functional import relu @@ -246,7 +245,6 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) - out = rearrange(out, "e (m c) h w -> (m e) c h w", m=self.num_estimators) out = self.pool(out) return self.final_dropout(self.flatten(out))