-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtest.py
More file actions
92 lines (79 loc) · 3.74 KB
/
test.py
File metadata and controls
92 lines (79 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from trainer import ClassifierTrainer
from utils import get_config, check_dir, get_local_time
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='configs/GTSRB.yaml', help="net configuration")
parser.add_argument('-i', '--input_dir', type=str,
default='GTSRB-new\RB\val\0-clear',
help="input image path")
parser.add_argument('-o', '--output_dir', type=str, default='result-2/GTSRB/original/',
help="output image path")
parser.add_argument('-p', '--checkpoint', type=str, default='checkpoints-new-2/0-0.2/outputs/GTSRB/checkpoints/classifier.pt',
help="checkpoint")
parser.add_argument('-l', '--log_name', type=str, default='0-0.2.log', help="log name")
parser.add_argument('-g', '--gpu_id', type=int, default=0, help="gpu id")
opts = parser.parse_args()
# Load experiment setting
config = get_config(opts.config)
# Setup model and data loader
trainer = ClassifierTrainer(config)
state_dict = torch.load(opts.checkpoint, map_location='cuda:{}'.format(opts.gpu_id))
trainer.net.load_state_dict(state_dict['net'])
epochs = state_dict['epochs']
min_loss = state_dict['min_loss']
acc = state_dict['acc'] if 'acc' in state_dict.keys() else 0.0
print("=" * 40)
print('Resume from epoch: {}, min-loss: {} acc: {}'.format(epochs, min_loss, acc))
print("=" * 40)
trainer.cuda()
trainer.eval()
pred_acc_list = []
test_list = os.listdir(opts.input_dir)
test_list = [os.path.join(opts.input_dir, x) for x in test_list]
test_list = [x for x in test_list if 'input' in os.path.basename(x)]
# # original version for cat and dog
# transform = transforms.Compose([transforms.Resize([config['new_size'], config['new_size']]),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.Resize([config['crop_image_height'], config['crop_image_width']]),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
to_tensor = transforms.ToTensor()
opts.log_name = os.path.basename(opts.input_dir) + '-' + opts.log_name
log_pwd = os.path.join(opts.output_dir, opts.log_name)
check_dir(opts.output_dir)
accuracy_list = []
with torch.no_grad():
t_bar = tqdm(test_list)
t_bar.set_description('Processing')
with open(log_pwd, 'w') as fid_w:
for image_info in t_bar:
img_pwd = image_info
image = Image.open(img_pwd).convert('RGB')
# cv2.imshow('{}'.format(CLASS_ID), np.asarray(image)[:, :, ::-1])
# cv2.waitKey()
label = int(os.path.dirname(img_pwd).split(os.sep)[-1].split('-')[0])
image = transform(image)
image = image.unsqueeze(0).cuda()
pred = trainer.net(image)
ps = torch.exp(pred)
top_p, top_class = ps.topk(1, dim=1)
accuracy = int(top_class.item() == label)
accuracy_list.append(float(accuracy))
if accuracy < 1:
line_info = '{} | pred: {}, label: {}'.format(img_pwd, top_class.item(), label)
print(line_info)
fid_w.write(line_info + '\n')
# cv2.imshow('error result', cv2.imread(img_pwd))
# cv2.waitKey(10)
mean_acc = np.mean(accuracy_list)
print('\n<{}> Test result: accuracy: {}'.format(get_local_time(), mean_acc))
fid_w.write('\n<{}> Test result: accuracy: {}\n'.format(get_local_time(), mean_acc))