-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtesting.py
More file actions
22 lines (20 loc) · 749 Bytes
/
testing.py
File metadata and controls
22 lines (20 loc) · 749 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import data_handlers
import cnn_model
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 50
model = torch.load("model.pt").to(device)
training_set, validation_set, test_set = data_handlers.set_loaders(batch_size = batch_size)
accuracy = 0
num = 0
with torch.no_grad():
running_loss = 0
for i, data in enumerate(test_set, 0):
inputs, labels = data[0].float().to(device), data[1].long().to(device)
outputs = model(inputs)
accuracy += (np.argmax(outputs.cpu().numpy(), axis = 1) == labels.cpu().numpy()).sum()
num += len(labels)
print("Accuracy: {:.2f}%".format(100 * accuracy / num))