This repository contains the code necessary to reproduce the results described in the associated paper:
- Associated paper: https://arxiv.org/abs/2411.09296
- Setup the environment.
- This project was tested with python 3.11.5
- You can use the requirement.txt file to setup the appropriate python packages.
- Alternatively, the docker build file can be used to create an image which can run the package.
- Download the RS3L dataset.
- The datafiles are too large and thus are not stored in this repository.
- You can find them on Zenodo:
- doi: 10.5281/zenodo.10633814
-
This package uses Snakemake, Hydra and OmegaConf.
-
The different Snakemake modules are found in the
snakemake/workflowfolder. They are all invoked by the main fileall.smkwhere individual results can be selected. -
General training configuration is given by the parent config files
configs/train.yaml, which composes all others, and specific training configuration are founds in theconfigs/experiment/folders. -
In this file you can specify the
network_name,seed, etc. -
More specific settings are found in one of the other config folders:
callbacks:- Provides a collection of
lightning.pytorch.callbacksto run during training.
- Provides a collection of
datamodule:- Defines the data files used for training and testing.
- Also which coordinates are used for the object kinematics.
hydra:- Configures the hydra package for running, does not need to be changed.
loggers:- By default we use
Weights and Biasesfor logging.- You will need to make a free account here:
https://wandb.ai/and put your username in theentityentry of this yaml file.
- You will need to make a free account here:
- By default we use
model:- Configures the model architecture and hyperparameters for training.
- By default the transformer + normalising flow is used.
paths:- Define the paths to the data download here as well as the desired save directory for the models.
trainer:- Configuration for the PyTorch Lightning
Trainerclass.
- Configuration for the PyTorch Lightning
-
The main workflow can be launched using the
invoke experiment-run name --workflow all, wherenameis used to separate different runs and can be chosen arbitrarily. Execution order of scripts can be retraced by followinginputsandoutputssection of the individual snakemake files. -
The most important scripts are given below:
-
scripts/train.py- Compiles the run config as described above and trains the model.
- Will save checkpoints based on the
paths.output_dirkey.
-
scripts/export.py- Creates an output
.h5file containing results for each event in the models test set.
- Creates an output
-
franckstools/sam.py- Contains a wrapper including all the logic to train a pytorch lightning model using the different sharpness aware methods. (Weight-space)
-
franckstools/adversarial_attack.py- Contains a wrapper including all the logic to train a pytorch lightning model using the different adversarial methods. (Feature-space)
-
scripts/export_hessian.py- Calculates the largest eigenvalue of the hessian for the models.
