Skip to content

Commit 5ffbe49

Browse files
authored
Merge pull request #20 from daniel-code/feat/pytorch_models
Feat/pytorch models
2 parents bc8d8b7 + 8cc4685 commit 5ffbe49

File tree

9 files changed

+61
-77
lines changed

9 files changed

+61
-77
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,19 @@ python train.py -r "datasets/final/train"
8787
python train.py -r "datasets/final/train" --user-pretrained-weight --finetune-last-layer --use-lr-scheduler --use-auto-augment
8888
```
8989

90-
- Training with different model types. See more details in `scripts/different_models.sh`
90+
- Training with different model types. See more details in `scripts/different_models.sh`.
91+
Support [pytorch built-in model types](https://pytorch.org/vision/main/models.html#classification).
9192

9293
```commandline
9394
python train.py -r "datasets/final/train" --model-type resnext50_32x4d
9495
```
9596

96-
Support model types:
97+
- Training with different image size. Some model has image resolution constraint, e.g. vit, only accept image size by (
98+
244, 244).
9799

98-
- ResNet: resnet18, resnet34, resnet_50, resnet_101
99-
- ResNext: resnext50_32x4d, resnext101_32x8d
100-
- Swin: swin_t, swin_s, swin_b
100+
```commandline
101+
python train.py -r "datasets/final/train" --model-type vit_b_16 --image-size 224 224
102+
```
101103

102104
After training, the model weight will export to `model_weights/<model-type>_<exp_time>`.
103105
Use `tensorboard --logdir model_weights` to browse training log.
@@ -159,7 +161,8 @@ Options:
159161
python analysis.py -r "datasets/final/train" --model-path "model_weights/<model-type>_<exp_time>/model.pt"
160162
```
161163

162-
By default, the `reports/test.png` is AUC of ROC curve and confusion matrix, and the `reports/test_images.jpg` shows the fail cases.
164+
By default, the `reports/test.png` is AUC of ROC curve and confusion matrix, and the `reports/test_images.jpg` shows the
165+
fail cases.
163166

164167
## Inference
165168

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .resnet import ResNet
2-
from .resnext import ResNext
3-
from .swin import Swin
1+
from .base import ModelBase
2+
from .torch_model_wrapper import is_torch_builtin_models, TorchModelWrapper
43

5-
__all__ = ['ResNet', 'Swin', 'ResNext']
4+
__all__ = ['ModelBase', 'TorchModelWrapper', 'is_torch_builtin_models']

dogs_cats_classifier/models/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def __init__(self,
2626
self.user_pretrained_weight = user_pretrained_weight
2727
self.finetune_last_layer = finetune_last_layer
2828

29-
self.models_mapping = self._setup_models_mapping()
30-
assert model_type in self.models_mapping, f'{model_type} is not available. There is available model types: {list(self.models_mapping.keys())}'
3129
self.model = self._setup_model(model_type=model_type)
3230

3331
if finetune_last_layer:
@@ -39,9 +37,6 @@ def __init__(self,
3937

4038
self.example_input_array = torch.zeros((1, 3, input_shape[0], input_shape[1]), dtype=torch.float32)
4139

42-
def _setup_models_mapping(self) -> dict:
43-
raise NotImplementedError
44-
4540
def _setup_model(self, model_type) -> torch.nn.Module:
4641
raise NotImplementedError
4742

dogs_cats_classifier/models/resnet.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

dogs_cats_classifier/models/resnext.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

dogs_cats_classifier/models/swin.py

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import types
2+
3+
from .base import ModelBase
4+
from torch.nn import Module, Sequential, Linear
5+
6+
7+
def is_torch_builtin_models(model_type: str):
8+
model_type = model_type.lower()
9+
model_module = __import__('torchvision.models', fromlist=(model_type, ), level=0)
10+
if model_type in dir(model_module):
11+
model_func = getattr(model_module, model_type)
12+
return True if isinstance(model_func, types.FunctionType) else False
13+
14+
return False
15+
16+
17+
class TorchModelWrapper(ModelBase):
18+
def _setup_model(self, model_type) -> Module:
19+
# dynamic import module
20+
model_module = __import__('torchvision.models', fromlist=(model_type, ), level=0)
21+
model_func = getattr(model_module, model_type)
22+
model = model_func(weights='DEFAULT' if self.user_pretrained_weight else None)
23+
24+
# get last layer's name
25+
layer_name = list(model.named_children())[-1][0]
26+
27+
# check last layer
28+
last_layer = model.get_submodule(layer_name)
29+
if isinstance(last_layer, Sequential):
30+
last_layer = last_layer[-1]
31+
in_features = last_layer.in_features
32+
33+
# replace the last layer
34+
last_layer = getattr(model, layer_name)
35+
if isinstance(last_layer, Sequential):
36+
last_layer[-1] = Linear(in_features=in_features, out_features=self.num_classes)
37+
else:
38+
setattr(model, layer_name, Linear(in_features=in_features, out_features=self.num_classes))
39+
40+
return model

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name='dogs_cats_classifier',
55
packages=find_packages(),
6-
version='0.1.0',
6+
version='0.1.1',
77
description='Create an algorithm to distinguish dogs from cats',
88
author='YanRu',
99
license="MIT",

train.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytorch_lightning.callbacks import LearningRateMonitor
88

99
from dogs_cats_classifier.data import DogsCatsImagesDataModule
10-
from dogs_cats_classifier.models import ResNet, Swin, ResNext
10+
from dogs_cats_classifier.models import TorchModelWrapper, is_torch_builtin_models
1111
from dogs_cats_classifier.utils import Evaluator
1212
from datetime import datetime
1313

@@ -78,16 +78,12 @@ def main(batch_size, max_epochs, num_workers, image_size, dataset_root, fast_dev
7878
print(dogs_cats_datamodule)
7979

8080
# prepare model
81-
if 'swin' in model_type:
82-
model = Swin
83-
elif 'resnext' in model_type:
84-
model = ResNext
85-
elif 'resnet' in model_type:
86-
model = ResNet
81+
if is_torch_builtin_models(model_type):
82+
model_class = TorchModelWrapper
8783
else:
8884
raise ValueError(f'{model_type} is not available.')
8985

90-
model = model(
86+
model = model_class(
9187
num_classes=1,
9288
model_type=model_type,
9389
input_shape=image_size,
@@ -117,9 +113,10 @@ def main(batch_size, max_epochs, num_workers, image_size, dataset_root, fast_dev
117113
torch.jit.save(script_model, os.path.join(output_path, 'model.pt'))
118114

119115
# evaluation
120-
dogs_cats_datamodule.setup()
121-
evaluator = Evaluator(model=model, output_path=output_path)
122-
evaluator.evaluate(dataloader=dogs_cats_datamodule.test_dataloader(), title=f'{model_type}_test', verbose=False)
116+
if not fast_dev_run:
117+
dogs_cats_datamodule.setup()
118+
evaluator = Evaluator(model=model, output_path=output_path)
119+
evaluator.evaluate(dataloader=dogs_cats_datamodule.test_dataloader(), title=f'{model_type}_test', verbose=False)
123120

124121

125122
if __name__ == '__main__':

0 commit comments

Comments
 (0)