Skip to content

✨ Version 0.5.0 #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 79 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
c2da93a
:bug: Fix #152
alafage Mar 20, 2025
4cd6743
:bug: Fix `nn.Dropout.forward()` argument in `_MLP.forward()`
alafage Mar 20, 2025
a3f53b3
:hammer: Rework OOD criteria
o-laurent Mar 20, 2025
86ffe4b
:hammer: Allow for string OOD criteria
o-laurent Mar 20, 2025
53509d7
:bug: Add unpushed modifications
o-laurent Mar 20, 2025
404917d
:sparkles: Add first impl. of TTA
o-laurent Mar 20, 2025
bebe01f
:shirt: remove mention of calibration methods
o-laurent Mar 20, 2025
91166ef
Merge branch 'dev' into tta
o-laurent Mar 20, 2025
88b4d00
:bug: Fix the TIN dataset & add doc
o-laurent Mar 21, 2025
9b7fb7d
:wrench: Don't show an error when deleting inexistent ssh-agent
o-laurent Mar 21, 2025
9117bb2
:wrench: We probably don't need to run the docs every week anymore fo…
o-laurent Mar 21, 2025
cd94ae2
:wrench: Improve MNIST configs
o-laurent Mar 23, 2025
81b0015
:shirt: Fix arg tab
o-laurent Mar 23, 2025
dd8e304
:shirt: Normalize import of nn.functional
o-laurent Mar 23, 2025
629aca9
:shirt: Improve some docstrings
o-laurent Mar 23, 2025
13d01c7
:sparkles: Add Zero and fix TTA in cls
o-laurent Mar 23, 2025
f8ff1b9
:sparkles: Add eval_batch_size
o-laurent Mar 23, 2025
d79ef8e
:bug: Add forgotten eval_batch_size in init
o-laurent Mar 24, 2025
4192f19
Merge pull request #156 from ENSTA-U2IS-AI/eval_batch_size
o-laurent Mar 24, 2025
1cec6ae
:hammer: Set default value of `reset_model_parameters` to `True`
alafage Mar 24, 2025
2f1d77e
:zap: Update License time span
o-laurent Mar 25, 2025
ee9bb6a
Merge pull request #153 from ENSTA-U2IS-AI/ood_scores
alafage Mar 25, 2025
fd21db2
:books: Added doc for ood scores and fixed covergae
fira7s Mar 25, 2025
d91658c
:art: Move failure test + rename MaxSoftmaxProbabilityCriterion -> Ma…
alafage Mar 25, 2025
2233294
:shirt: Add new rules, Lint & format
o-laurent Mar 26, 2025
020ab89
:books: Update Dockerfile documentation
tonyzamyatin Mar 31, 2025
afa984c
:hammer: Slight update of the `TULightningCLI` to support lightning<2…
alafage Apr 1, 2025
1dcaa96
:book: Fix TULightningCLI docstring
alafage Apr 1, 2025
b4c3b06
Merge pull request #161 from tonyzamyatin/docs/docker
alafage Apr 2, 2025
aacfcf8
:hammer: Enable setting train and test transforms for Depth datamodules
alafage Apr 10, 2025
2586d9a
:hammer: Enable setting train and test transforms in Segmentation dat…
alafage Apr 10, 2025
db09be6
:hammer: Enable setting train and test transforms in Classification d…
alafage Apr 10, 2025
0d5ba10
:white_check_mark: Improve coverage
alafage Apr 10, 2025
68376cb
:construction: Implement specific ModelCheckpoint for Classification
alafage Apr 16, 2025
7de882e
:hammer: Enable storing models on cpu in `deep_ensembles`
alafage Apr 16, 2025
5bb320c
:bug: Fix ``_DeepEnsembles.forward()``
alafage Apr 16, 2025
fedcb81
:hammer: Change metric logging names
alafage Apr 16, 2025
8597110
:sparkles: Add ``TUClsCheckpoint``
alafage Apr 16, 2025
028161d
:wrench: Use GPU pytorch version
alafage Apr 16, 2025
8db167a
:bug: Check whether cuda is available in ``test_deep_ensembles.py``
alafage Apr 16, 2025
00ff02d
:hammer: Simplify ``_DeepEnsembles.to()`` method
alafage Apr 16, 2025
e4af17b
:shirt: Improve coverage
alafage Apr 16, 2025
c52f0c1
Merge branch 'dev' of github.com:ENSTA-U2IS-AI/torch-uncertainty into…
alafage Apr 16, 2025
22f8083
:bug: Fix metric logging names
alafage Apr 16, 2025
9d63666
:sparkles: Add ``TUSegCheckpoint`` callback
alafage Apr 16, 2025
3945aec
:shirt: Improve coverage
alafage Apr 16, 2025
8e0d431
:sparkles: Add `TURegCheckpoint`callback
alafage Apr 17, 2025
c4e8232
:wrench: Update all ModelCheckpoint callbacks in configs
alafage Apr 17, 2025
798e612
:hammer: Roll back metric name changes and use `auto_insert_metric_na…
alafage Apr 23, 2025
c0fb2f5
:bug: Fix typo in Grouping Loss log name
alafage Apr 23, 2025
7045c5a
:shirt: Improve coverage of `ClassificationRoutine`
alafage Apr 23, 2025
424c186
Merge branch 'dev' of github.com:ENSTA-U2IS-AI/torch-uncertainty into…
alafage Apr 23, 2025
6885915
:hammer: Update metric log names
alafage Apr 23, 2025
a1c7b01
:hammer: Update checkpoint names
alafage Apr 23, 2025
aa07bcb
Merge pull request #166 from ENSTA-U2IS-AI/checkpointing
alafage Apr 23, 2025
0979fdb
:sparkles: Add `CoverageRate` metric
alafage Apr 25, 2025
36d56cf
:sparkles: `deep_ensembles()` wrapper can take ckpt paths to init the…
alafage Apr 29, 2025
bfd86bd
Added ConformalClassificationRAPS, ConformalClassificationAPS, and Co…
giannifranchi Apr 29, 2025
803579b
:art: Comply with ruff rules
alafage Apr 30, 2025
7008351
:art: Modify conformal files location
alafage May 1, 2025
62353a7
:hammer: Remove `calibrate()` methods in all conformal classes + poli…
alafage May 3, 2025
4f46131
:book: Add some doc
o-laurent May 5, 2025
1c1ebfd
:white_check_mark: Add tests for Conformal Post-Processing
alafage May 5, 2025
ae6be54
:hammer: Move utils/data.py to datasets/utils.py
o-laurent May 5, 2025
95b0ebc
:book: Add Zero to the references
o-laurent May 5, 2025
5611fb5
:shirt: use get_train_set in abstract datamodule
o-laurent May 5, 2025
b0c428a
:white_check_mark: Improve coverage
alafage May 5, 2025
1c61a49
:white_check_mark: Add tests for `CoverageRate` metric
alafage May 5, 2025
bd982c6
:shirt: Improve coverage slightly
alafage May 5, 2025
94ce44f
:bug: Fix incorrect logging in Classification routine
alafage May 5, 2025
2397091
:white_check_mark: Improve coverage
alafage May 5, 2025
2318a7c
Merge pull request #168 from ENSTA-U2IS-AI/conformal
alafage May 5, 2025
e4bc701
:bug: Add missing callback higher level imports
alafage May 7, 2025
d8d7912
:bug: Fix #171
alafage May 7, 2025
3131204
:bug: TULightningCLI won't crash if `data.eval_ood` or `data.eval_shi…
alafage May 7, 2025
4fd72ab
:art: Refine `torch_uncertainty.models` structure and fix LeNet confi…
alafage May 9, 2025
a48a51f
:hammer: Replace `legacy` implementation from the `PackedLinear` with…
alafage May 10, 2025
728ed63
Merge branch 'dev' of github.com:ENSTA-U2IS/torch-uncertainty into tta
o-laurent May 12, 2025
2d9c740
Merge pull request #154 from ENSTA-U2IS-AI/tta
o-laurent May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .github/workflows/build-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ on:
types: [opened, reopened, ready_for_review, synchronize]
branches:
- main
schedule:
- cron: "00 12 * * 0" # Every Sunday noon (preserve the cache folders)
workflow_dispatch:

env:
Expand Down Expand Up @@ -57,8 +55,8 @@ jobs:
external_repository: torch-uncertainty/torch-uncertainty.github.io
publish_branch: main
publish_dir: docs/build/html


# ||: not to error if there is no running ssh-agent
- name: Kill SSH Agent
run: |
killall ssh-agent
continue-on-error: true
killall ssh-agent ||:
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
- name: Install dependencies
if: steps.changed-files-specific.outputs.only_changed != 'true'
run: |
python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
python3 -m pip install torch torchvision
python3 -m pip install .[all]

- name: Check style & format
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2023-2024 Adrien Lafage and Olivier Laurent
Copyright 2023-2025 Adrien Lafage and Olivier Laurent

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion auto_tutorials_source/tutorial_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.models.lenet import bayesian_lenet
from torch_uncertainty.models.classification import bayesian_lenet
from torch_uncertainty.routines import ClassificationRoutine

# %%
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import DECLoss
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.models.classification import lenet
from torch_uncertainty.routines import ClassificationRoutine


Expand Down
3 changes: 0 additions & 3 deletions auto_tutorials_source/tutorial_from_de_to_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = rearrange(
out, "e (m c) h w -> (m e) c h w", m=self.num_estimators
)
out = torch.flatten(out, 1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
Expand Down
2 changes: 1 addition & 1 deletion auto_tutorials_source/tutorial_mc_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.models.classification import lenet
from torch_uncertainty.optim_recipes import optim_cifar10_resnet18
from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm
from torch_uncertainty.routines import ClassificationRoutine
Expand Down
2 changes: 1 addition & 1 deletion auto_tutorials_source/tutorial_mc_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torch import nn

from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.models.classification import lenet
from torch_uncertainty.models import mc_dropout
from torch_uncertainty.optim_recipes import optim_cifar10_resnet18
from torch_uncertainty.routines import ClassificationRoutine
Expand Down
3 changes: 1 addition & 2 deletions auto_tutorials_source/tutorial_pe_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __init__(self) -> None:
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = rearrange(x, "e (m c) h w -> (m e) c h w", m=self.num_estimators)
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
Expand Down Expand Up @@ -267,7 +266,7 @@ def forward(self, x):
# Let us see what the Packed-Ensemble thinks these examples above are:

logits = packed_net(images)
logits = rearrange(logits, "(n b) c -> b n c", n=packed_net.num_estimators)
logits = rearrange(logits, "(m b) c -> b m c", m=packed_net.num_estimators)
probs_per_est = F.softmax(logits, dim=-1)
outputs = probs_per_est.mean(dim=1)

Expand Down
2 changes: 1 addition & 1 deletion auto_tutorials_source/tutorial_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# %%
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.metrics import CalibrationError
from torch_uncertainty.models.resnet import resnet
from torch_uncertainty.models.classification import resnet
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.utils import load_hf

Expand Down
159 changes: 88 additions & 71 deletions docker/DOCKER.md
Original file line number Diff line number Diff line change
@@ -1,71 +1,88 @@
# :whale: Docker image for contributors

### Pre-built Docker image
1. To pull the pre-built image from Docker Hub, simply run:
```bash
docker pull docker.io/tonyzamyatin/torch-uncertainty:latest
```

This image includes:
- PyTorch with CUDA support
- OpenGL (for visualization tasks)
- Git, OpenSSH, and all Python dependencies

Checkout the [registry on Docker Hub](https://hub.docker.com/repository/docker/tonyzamyatin/torch-uncertainty/general) for all available images.

2. To start a container using this image, set up the necessary environment variables and run:
```bash
docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \
-e VM_SSH_PUBLIC_KEY="your-public-key" \
-e GITHUB_SSH_PRIVATE_KEY="your-github-key" \
-e GITHUB_USER="your-github-username" \
-e GIT_USER_EMAIL="your-git-email" \
-e GIT_USER_NAME="your-git-name" \
docker.io/tonyzamyatin/torch-uncertainty
```

Optionally, you can also set `-e USER_COMPACT_SHELL_PROMPT="true"`
to make the VM's shell prompts compact and colorized.

**Note:** Some cloud providers offer templates, in which you can preconfigure
in advance which Docker image to pull and which environment variables to set.
In this case, the provider will pull the image, set all environment variables,
and start the container for you.

3. Once your cloud provider has deployed the VM, it will display the host address and SSH port.
You can connect to the container via SSH using:
```bash
ssh -i /path/to/private_key root@<VM_HOST> -p <VM_PORT>
```

Replace `<VM_HOST>` and `<VM_PORT>` with the values provided by your cloud provider,
and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`.

4. The container exposes port `8888` in case you want to run Jupyter Notebooks or TensorBoard.

**Note:** The `/workspace` directory is mounted from your local machine or cloud storage,
so changes persist across container restarts.
If using a cloud provider, ensure your network volume is correctly attached to avoid losing data.

### Modifying and publishing custom Docker image

If you want to make changes to the Dockerfile, follow these steps:
1. Edit the Dockerfile to fit your needs.

2. Build the modified image:
```
docker build -t my-custom-image .
```

3. Push to a Docker registry (if you want to use it on another VM):
```
docker tag my-custom-image mydockerhubuser/my-custom-image:tag
docker push mydockerhubuser/my-custom-image:tag
```

4. Pull the custom image onto your VM:
```
docker pull mydockerhubuser/my-custom-image
```

5. Run the container using the same docker run command with the new image name.
# 🐋 Docker image for contributors

This Docker image is designed for users and contributors who want to run experiments with `torch-uncertainty` on remote virtual machines with GPU support. It is particularly useful for those who do not have access to a local GPU and need a pre-configured environment for development and experimentation.

---
## How to Use The Docker Image
### Step 1: Fork the Repository

Before proceeding, ensure you have forked the `torch-uncertainty` repository to your own GitHub account. You can do this by visiting the [torch-uncertainty GitHub repository](https://github.com/ENSTA-U2IS-AI/torch-uncertainty) and clicking the **Fork** button in the top-right corner.

Once forked, clone your forked repository to your local machine:
```bash
git clone git@github.com:<your-username>/torch-uncertainty.git
cd torch-uncertainty
```

> ### ⚠️ IMPORTANT NOTE: Keep Your Fork Synced
>
> **To ensure that you are working with the latest stable version and bug fixes, you must manually sync your fork with the upstream repository before building the Docker image. Failure to sync your fork may result in outdated dependencies or missing bug fixes in the Docker image.**

### Step 2: Build the Docker image locally
Build the modified image locally and push it to a Docker registry:
```
docker build -t my-torch-uncertainty-docker:version .
docker push my-dockerhub-user/my-torch-uncertainty-image:version
```
### Step 3: Set environment variables on your VM
Connect to you VM and set the following environment variables:
```bash
export VM_SSH_PUBLIC_KEY="$(cat ~/.ssh/id_rsa.pub)"
export GITHUB_SSH_PRIVATE_KEY="$(cat ~/.ssh/id_rsa)"
export GITHUB_USER="your-github-username"
export GIT_USER_EMAIL="your-email@example.com"
export GIT_USER_NAME="Your Name"
export USE_COMPACT_SHELL_PROMPT=true
```

Here is a brief explanation of the environment variables used in the Docker setup:
- **`VM_SSH_PUBLIC_KEY`**: The public SSH key used to authenticate with the container via SSH.
- **`GITHUB_SSH_PRIVATE_KEY`**: The private SSH key used to authenticate with GitHub for cloning and pushing repositories.
- **`GITHUB_USER`**: The GitHub username used to clone the repository during the first-time setup.
- **`GIT_USER_EMAIL`**: The email address associated with the Git configuration for commits.
- **`GIT_USER_NAME`**: The name associated with the Git configuration for commits.
- **`USE_COMPACT_SHELL_PROMPT`** (optional): Enables a compact and colorized shell prompt inside the container if set to `"true"`.

### Step 4: Run the Docker container
First, authenticate with your Docker registry if you use a private registry.
Then run the following command to run the Docker image from your docker registriy
```bash
docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \
-e VM_SSH_PUBLIC_KEY \
-e GITHUB_SSH_PRIVATE_KEY \
-e GITHUB_USER \
-e GIT_USER_EMAIL \
-e GIT_USER_NAME \
-e USE_COMPACT_SHELL_PROMPT \
docker.io/my-dockerhub-user/my-torch-uncertainty-image:version
```

### Step 5: Connect to your container
Once the container is up and running, you can connect to it via SSH:
```bash
ssh -i /path/to/private_key root@<VM_HOST> -p <VM_PORT>
```
Replace `<VM_HOST>` and `<VM_PORT>` with the host and port of your VM,
and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`.

The container exposes port `8888` in case you want to run Jupyter Notebooks or TensorBoard.

**Note:** The `/workspace` directory is mounted from your local machine or cloud storage,
so changes persist across container restarts.
If using a cloud provider, ensure your network volume is correctly attached to avoid losing data.

## Remote Development

This Docker setup also allows for remote development on the VM, since GitHub SSH access is set up and the whole repo is cloned to the VM from your GitHub fork.
For example, you can seamlessly connect your VS Code editor to your remote VM and run experiments, as if on your local machine but with the GPU acceleration of your VM.
See [VS Code Remote Development](https://code.visualstudio.com/docs/remote/remote-overview) for further details.

## Streamline setup with your Cloud provider of choice

Many cloud providers offer "templates" where you can specify a Docker image to use as a base. This means you can:

1. Specify the Docker image from your Docker registry as the base image.
2. Preconfigure the necessary environment variables in the template.
3. Reuse the template any time you need to spin up a virtual machine for experiments.

The cloud provider will handle setting the environment variables, pulling the Docker image, and spinning up the container for you. This approach simplifies the process and ensures consistency across experiments.
21 changes: 21 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,27 @@ Scaling Methods
VectorScaler
MatrixScaler



OOD Scores
-----------------------

.. currentmodule:: torch_uncertainty.ood_criteria

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class_inherited.rst

TUOODCriterion
MaxLogitCriterion
EnergyCriterion
MaxSoftmaxCriterion
EntropyCriterion
MutualInformationCriterion
VariationRatioCriterion


Datamodules
-----------

Expand Down
9 changes: 3 additions & 6 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ These routines make it very easy to:

- train ensembles-like methods (Deep Ensembles, Packed-Ensembles, MIMO, Masksembles, etc)
- compute and monitor uncertainty metrics: calibration, out-of-distribution detection, proper scores, grouping loss, etc.
- leverage calibration methods automatically during evaluation
- leverage post-processing methods automatically during evaluation

Yet, we take account that their will be as many different uses of TorchUncertainty as there are of users.
This page provides ideas on how to benefit from TorchUncertainty at all levels: from ready-to-train lightning-based models to using only specific
Expand Down Expand Up @@ -46,12 +46,9 @@ and its parameters.
# ...
eval_ood: bool = False,
eval_grouping_loss: bool = False,
ood_criterion: Literal[
"msp", "logit", "energy", "entropy", "mi", "vr"
] = "msp",
ood_criterion: type[TUOODCriterion] | str = "msp",
log_plots: bool = False,
save_in_csv: bool = False,
calibration_set: Literal["val", "test"] | None = None,
) -> None:
...

Expand Down Expand Up @@ -160,7 +157,7 @@ backbone with the following code:

.. code:: python

from torch_uncertainty.models.resnet import packed_resnet
from torch_uncertainty.models.classification import packed_resnet

model = packed_resnet(
in_channels = 3,
Expand Down
10 changes: 10 additions & 0 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@ For Warping Mixup, consider citing:
* Authors: *Quentin Bouniot, Pavlo Mozharovskyi, and Florence d'Alché-Buc*
* Paper: `ArXiv 2023 <https://arxiv.org/abs/2311.01434>`__.

Test-Time-Adaptation with ZERO
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For ZERO, consider citing:

**Frustratingly Easy Test-Time Adaptation of Vision-Language Models**

* Authors: *Matteo Farina, Gianni Franchi, Giovanni Iacca, Massimiliano Mancini and Elisa Ricci*
* Paper: `NeurIPS 2024 <https://arxiv.org/abs/2405.18330>`__.

Post-Processing Methods
-----------------------

Expand Down
6 changes: 1 addition & 5 deletions experiments/classification/cifar10/configs/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ trainer:
save_dir: logs/resnet
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: torch_uncertainty.callbacks.TUClsCheckpoint
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ trainer:
name: batched
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: torch_uncertainty.callbacks.TUClsCheckpoint
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ trainer:
name: masked
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: torch_uncertainty.callbacks.TUClsCheckpoint
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ trainer:
name: mimo
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: torch_uncertainty.callbacks.TUClsCheckpoint
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
Expand Down
Loading