Repository for "Few-Shot Learning for Image Classification on Spacecraft" benchmarking scripts
Python version: 3.12
The following packages are required:
- PyTorch
- numpy
- tqdm
- configargparse
- scikit-learn
- pandas
- matplotlib
- seaborn
Download datasets into data folder.
Should be in this directory format:
data
│ ├── dataset name
│ │ ├── class 1
│ │ ├── class 2
│ │ ├── class 3
│ │ ├── .json files
Create .json files for classes you want included in dataset. Files for the following datasets are already included:
- UC Merced Land Use (http://weegee.vision.ucmerced.edu/datasets/landuse.html)
- NWPU-RESISC45 (https://www.tensorflow.org/datasets/catalog/resisc45)
- PatternNet (https://sites.google.com/view/zhouwx/dataset)
To add a new dataset
- Download the dataset into the data folder as shown above
- Create .json file(s) for classes you want to use
- Create .config file(s) in the configs folder to specify dataset, path to .json file, and file-types for images
Feel free to edit any scripts to meet your needs
Run from root:
./pretrain.sh
To sweep over different shots, run from root:
./run_sweep.sh
To sweep over different shots on CNN baseline, run from root:
./run_baseline_sweep.sh
To evaluate class accuracies, run from root:
./run_confusion.sh
The full list of command-line arguments can be found in src/utils/parser_configs.py
--n-tasks: number of few-shot testing episodes
--n-way: number of classes
--n-shot: number of samples per class
--n-query: number of queries per few-shot task
--train-split: path to training dataset .json file
--test-split: path to testing dataset .json file
--algorithm: which classification head to use (PrototypicalNetworks, LaplacianShot, SVM)
--pretrained_dataset: name of dataset backbone was pretrained on