-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassify.py
More file actions
105 lines (73 loc) · 2.8 KB
/
classify.py
File metadata and controls
105 lines (73 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# ===================================================================
# Example : apply a specific pre-trained classifier to the test images
# path to the directory containing images is specified on the command line
# e.g. python classify.py --data=path_to_data
# path to the pre-trained network weights can be specified on the command line
# e.g. python classify.py --model=path_to_model
# python classify.py --data=test_images --model=classifier.model
# Author : Amir Atapour Abarghouei, amir.atapour-abarghouei@durham.ac.uk
# Copyright (c) 2022 Amir Atapour Abarghouei
# License : LGPL - http://www.gnu.org/licenses/lgpl.html
# ===================================================================
import os
import argparse
import numpy as np
import cv2
# ===================================================================
# parse command line arguments for paths to the data and model
parser = argparse.ArgumentParser(
description='Perform image classification on test images')
parser.add_argument(
"--data",
type=str,
help="specify path to test images",
default='test_images')
parser.add_argument(
"--model",
type=str,
help="specify path to model weights",
default='classifier.model')
args = parser.parse_args()
# ===================================================================
# load model weights:
model = cv2.dnn.readNetFromONNX(args.model)
# lists to keep filenames, images and identifiers for healthy and sick labels:
names = []
images = []
healthys = []
sicks = []
# the first 20 images are healthy and the next 20 are not:
for i in range(1, 21):
healthys.append(f'im{str(i).zfill(2)}')
for i in range(21, 41):
sicks.append(f'im{str(i).zfill(2)}')
# read all the images from the directory
for file in os.listdir(args.data):
names.append(file)
names.sort()
# remove any extra files Mac might have put in there:
if ".DS_Store" in names:
names.remove(".DS_Store")
# keeping track of the number of correct predictions for accuracy:
correct = 0
# main loop:
for filename in names:
# read image:
img = cv2.imread(os.path.join(args.data, filename))
if img is not None:
# pass the image through the neural network:
blob = cv2.dnn.blobFromImage(img, 1.0 / 255, (256, 256),(0, 0, 0), swapRB=True, crop=False)
model.setInput(blob)
output = model.forward()
# identify what the predicted label is:
if(output > 0.5):
print(f'{filename}: sick')
if(filename.startswith(tuple(sicks))):
correct += 1
else:
print(f'{filename}: healthy')
if(filename.startswith(tuple(healthys))):
correct += 1
# print final accuracy:
print(f'Accuracy is {correct/len(names)}')
# ===================================================================