diff --git a/README.md b/README.md index 79f5614..a770ad8 100644 --- a/README.md +++ b/README.md @@ -3,4 +3,5 @@ `torchcvnn/examples` is a repository showcasing examples of using [torchcvnn](https://www.github.com/torchcvnn/torchcvnn). - [Classification of MNIST in the Fourier space with complex valued CNNs](./mnist_conv/README.md) +- [Classification of MSTAR using a patched RVNN](./mstar_resnet/README.md) - [Complex valued Neural Implicit Representation for cardiac reconstruction](./nir_cinejense/README.md) diff --git a/mstar_resnet/README.md b/mstar_resnet/README.md new file mode 100644 index 0000000..37252cc --- /dev/null +++ b/mstar_resnet/README.md @@ -0,0 +1,9 @@ +# MSTAR Classification with patched RVNN + +This code ... + +```bash +python -m pip install -r requirements.txt +python mstar.py +``` + diff --git a/mstar_resnet/mstar.py b/mstar_resnet/mstar.py new file mode 100644 index 0000000..4e9df98 --- /dev/null +++ b/mstar_resnet/mstar.py @@ -0,0 +1,264 @@ +# coding: utf-8 + +# MIT License + +# Copyright (c) 2024 Jérémy Fix + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +# Example using complex valued neural networks to classify the SAMPLE data + +In this example, we will use the complex valued neural networks to classify the SAMPLE data. This sample script also shows how to patch a pre-constructed neural network, as provided by TIMM to make it complex valued. + +We benefit from timm to build the architecture but then replace the real valued modules by complex valued counterparts. + +Requires dependencies : + python3 -m pip install torchcvnn timm torchvision +""" + +# Standard imports +import random +import argparse +import logging +import sys + +# External imports +import torch +import torch.nn as nn +from torchvision.transforms import v2 +import timm +import numpy as np + +# Local imports +import torchcvnn +import torchcvnn.nn as c_nn +import torchcvnn.datasets +import utils + +MIN_VALUE = 0.02 +MAX_VALUE = 40 + + +class LogAmplitudeTransform: + def __init__(self): + pass + + def __call__(self, tensor) -> torch.Tensor: + new_tensor = self._transform_amplitude(tensor) + return new_tensor + + def _transform_amplitude(self, tensor: torch.Tensor) -> torch.Tensor: + new_tensor = [] + for idx, ch in enumerate(tensor): + amplitude = torch.abs(ch) + phase = torch.angle(ch) + min_val = MIN_VALUE + max_val = MAX_VALUE + amplitude = torch.clip(amplitude, min_val, max_val) + transformed_amplitude = ( + torch.log10(amplitude) - torch.log10(torch.tensor([min_val])) + ) / ( + torch.log10(torch.tensor([max_val])) + - torch.log10(torch.tensor([min_val])) + ) + new_tensor.append(transformed_amplitude * torch.exp(1j * phase)) + return torch.as_tensor(np.stack(new_tensor), dtype=torch.complex64) + + +def get_dataloaders(datadir, batch_size=64, valid_ratio=0.1): + transform = v2.Compose( + transforms=[ + v2.ToImage(), + v2.Resize(128), + v2.CenterCrop(128), + LogAmplitudeTransform(), + ] + ) + + train_valid_dataset = torchcvnn.datasets.SAMPLE( + datadir, transform=transform, download=True + ) + + all_indices = list(range(len(train_valid_dataset))) + random.shuffle(all_indices) + split_idx = int(valid_ratio * len(train_valid_dataset)) + valid_indices, train_indices = all_indices[:split_idx], all_indices[split_idx:] + + # Train dataloader + train_dataset = torch.utils.data.Subset(train_valid_dataset, train_indices) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True + ) + + # Valid dataloader + valid_dataset = torch.utils.data.Subset(train_valid_dataset, valid_indices) + valid_loader = torch.utils.data.DataLoader( + valid_dataset, batch_size=batch_size, shuffle=False + ) + + num_classes = len(train_valid_dataset.class_names) + + return train_loader, valid_loader, num_classes + + +def convert_to_complex(module: nn.Module) -> nn.Module: + cdtype = torch.complex64 + for name, child in module.named_children(): + if isinstance(child, nn.Conv2d): + setattr( + module, + name, + nn.Conv2d( + child.in_channels, + child.out_channels, + child.kernel_size, + stride=child.stride, + padding=child.padding, + bias=child.bias is not None, + dtype=cdtype, + ), + ) + + elif isinstance(child, nn.ReLU): + setattr(module, name, c_nn.modReLU()) + + elif isinstance(child, nn.BatchNorm2d): + setattr(module, name, c_nn.BatchNorm2d(child.num_features)) + + elif isinstance(child, nn.MaxPool2d): + setattr( + module, + name, + c_nn.MaxPool2d( + child.kernel_size, + stride=child.stride, + padding=child.padding, + ), + ) + elif isinstance(child, nn.Linear): + setattr( + module, + name, + nn.Linear( + child.in_features, + child.out_features, + bias=child.bias is not None, + dtype=cdtype, + ), + ) + else: + convert_to_complex(child) + + return module + + +def init_weights(m: nn.Module) -> None: + """ + Initialize weights for the given module. + """ + if isinstance(m, (nn.Linear, nn.Conv2d)): + c_nn.init.complex_kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + m.bias.data.fill_(0.01) + + +def train(datadir): + """ + Train function + + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + valid_ratio = 0.3 + batch_size = 128 + epochs = 100 + lr = 0.001 + + # Dataloading + train_loader, valid_loader, num_classes = get_dataloaders( + datadir, batch_size=batch_size, valid_ratio=valid_ratio + ) + + X, _, _ = next(iter(train_loader)) + in_chans = X.shape[1] + + # Build the model as a patched TIMM + # and send it to the right device + real_valued_model = timm.create_model( + "resnet50", pretrained=False, num_classes=num_classes, in_chans=in_chans + ) + model = convert_to_complex(real_valued_model) + # Add a final layer to the model to transform the complex valued logits into + # real valued logits to go into the CrossEntropyLoss + model = nn.Sequential( + model, + c_nn.Mod(), + ) # not sure if this is the right way to do it + + # Initialize the weights + model.apply(init_weights) + + model.to(device) + + # Loss, optimizer, callbacks + f_loss = nn.CrossEntropyLoss() + optim = torch.optim.AdamW(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs) + logpath = utils.generate_unique_logpath("./logs", "SAMPLE") + logging.info(f"Logging to {logpath}") + checkpoint = utils.ModelCheckpoint(model, logpath, 4, min_is_best=True) + + # Training loop + for e in range(epochs): + logging.info(">> Training") + train_loss, train_acc = utils.train_epoch( + model, train_loader, f_loss, optim, device + ) + + logging.info(">> Testing") + valid_loss, valid_acc = utils.test_epoch(model, valid_loader, f_loss, device) + updated = checkpoint.update(valid_loss) + scheduler.step() + better_str = "[>> BETTER <<]" if updated else "" + + logging.info( + f"[Step {e}] Train : CE {train_loss:5.2f} Acc {train_acc:5.2f} | Valid : CE {valid_loss:5.2f} Acc {valid_acc:5.2f}" + + better_str + + f" | LR {scheduler.get_last_lr()[0]:.5f}" + ) + + +if __name__ == "__main__": + + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + parser = argparse.ArgumentParser(description="SAMPLE classification with torchcvnn") + parser.add_argument( + "--datadir", type=str, default="data", help="Path to the data directory" + ) + + args = parser.parse_args() + + datadir = args.datadir + + train(datadir) diff --git a/mstar_resnet/requirements.txt b/mstar_resnet/requirements.txt new file mode 100644 index 0000000..ad9c42d --- /dev/null +++ b/mstar_resnet/requirements.txt @@ -0,0 +1,4 @@ +torchcvnn +timm +numpy +torchvision diff --git a/mstar_resnet/utils.py b/mstar_resnet/utils.py new file mode 100644 index 0000000..42e487a --- /dev/null +++ b/mstar_resnet/utils.py @@ -0,0 +1,244 @@ +# coding: utf-8 +# MIT License + +# Copyright (c) 2023 Jeremy Fix + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Standard imports +import os +from typing import Tuple + +# External imports +import torch +import torch.nn as nn +import tqdm + +# import torch.onnx + + +def train_epoch( + model: nn.Module, + loader: torch.utils.data.DataLoader, + f_loss: nn.Module, + optim: torch.optim.Optimizer, + device: torch.device, +) -> Tuple[float, float]: + """ + Run the training loop for nsteps minibatches of the dataloader + + Arguments: + model: the model to train + loader: an iterable dataloader + f_loss (nn.Module): the loss + optim : an optimizing algorithm + device: the device on which to run the code + + Returns: + The averaged training loss + The averaged training accuracy + """ + model.train() + + loss_avg = 0 + acc_avg = 0 + num_samples = 0 + for minibatch in tqdm.tqdm(loader): + inputs = minibatch[0] + outputs = minibatch[1] + + inputs = inputs.cfloat().to(device) + outputs = outputs.to(device) + + # Forward propagate through the model + pred_outputs = model(inputs) + + # Forward propagate through the loss + loss = f_loss(pred_outputs, outputs) + + # Backward pass and update + optim.zero_grad() + loss.backward() + optim.step() + + num_samples += inputs.shape[0] + + # Denormalize the loss that is supposed to be averaged over the + # minibatch + loss_avg += inputs.shape[0] * loss.item() + pred_cls = pred_outputs.argmax(dim=-1) + acc_avg += (pred_cls == outputs).sum().item() + + return loss_avg / num_samples, acc_avg / num_samples + + +def test_epoch( + model: nn.Module, + loader: torch.utils.data.DataLoader, + f_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """ + Run the test loop for n_test_batches minibatches of the dataloader + + Arguments: + model: the model to evaluate + loader: an iterable dataloader + f_loss: the loss + device: the device on which to run the code + + Returns: + The averaged test loss + The averaged test accuracy + + """ + model.eval() + + loss_avg = 0 + acc_avg = 0 + num_samples = 0 + with torch.no_grad(): + for minibatch in loader: + inputs = minibatch[0] + outputs = minibatch[1] + + inputs = inputs.cfloat().to(device) + outputs = outputs.to(device) + + # Forward propagate through the model + pred_outputs = model(inputs) + + # Forward propagate through the loss + loss = f_loss(pred_outputs, outputs) + + loss_avg += inputs.shape[0] * loss.item() + pred_cls = pred_outputs.argmax(dim=-1) + acc_avg += (pred_cls == outputs).sum().item() + num_samples += inputs.shape[0] + + return loss_avg / num_samples, acc_avg / num_samples + + +class ModelCheckpoint(object): + def __init__( + self, + model: torch.nn.Module, + savepath: str, + num_input_dims: int, + min_is_best: bool = True, + ) -> None: + """ + Early stopping callback + + Arguments: + model: the model to save + savepath: the location where to save the model's parameters + num_input_dims: the number of dimensions for the input tensor (required for onnx export) + min_is_best: whether the min metric or the max metric as the best + """ + self.model = model + self.savepath = savepath + self.num_input_dims = num_input_dims + self.best_score = None + if min_is_best: + self.is_better = self.lower_is_better + else: + self.is_better = self.higher_is_better + + def lower_is_better(self, score: float) -> bool: + """ + Test if the provided score is lower than the best score found so far + + Arguments: + score: the score to test + + Returns: + res : is the provided score lower than the best score so far ? + """ + return self.best_score is None or score < self.best_score + + def higher_is_better(self, score: float) -> bool: + """ + Test if the provided score is higher than the best score found so far + + Arguments: + score: the score to test + + Returns: + res : is the provided score higher than the best score so far ? + """ + return self.best_score is None or score > self.best_score + + def update(self, score: float) -> bool: + """ + If the provided score is better than the best score registered so far, + saves the model's parameters on disk as a pytorch tensor + + Arguments: + score: the new score to consider + + Returns: + res: whether or not the provided score is better than the best score + registered so far + """ + if self.is_better(score): + self.model.eval() + + torch.save( + self.model.state_dict(), os.path.join(self.savepath, "best_model.pt") + ) + + # torch.onnx.export( + # self.model, + # dummy_input, + # os.path.join(self.savepath, "best_model.onnx"), + # verbose=False, + # input_names=["input"], + # output_names=["output"], + # dynamic_axes={ + # "input": {0: "batch"}, + # "output": {0: "batch"}, + # }, + # ) + + self.best_score = score + return True + return False + + +def generate_unique_logpath(logdir: str, raw_run_name: str) -> str: + """ + Generate a unique directory name and create it if necessary + + Arguments: + logdir: the prefix directory + raw_run_name: the base name + + Returns: + log_path: a non-existent path like logdir/raw_run_name_xxxx + where xxxx is an int + """ + i = 0 + while True: + run_name = raw_run_name + "_" + str(i) + log_path = os.path.join(logdir, run_name) + if not os.path.isdir(log_path): + os.makedirs(log_path) + return log_path + i = i + 1