-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathstandalone_tests.cpp
More file actions
109 lines (97 loc) · 2.74 KB
/
standalone_tests.cpp
File metadata and controls
109 lines (97 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <cstdio>
#include <memory>
#include <algorithm>
#include <cstdint>
#include <cassert>
#include <vector>
#include <functional>
#include <array>
#include <type_traits>
#include <string>
#include <cstring>
#include <chrono>
#include <cmath>
#include <map>
#include "tensor_processor.h"
#include "tensor_compiler.h"
#include "neural_network.h"
#include "siren.h"
#include "nn_tests.h"
#include "direct/nnd_tests.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"
void perform_tests_args(char **argv, int argc, int start)
{
std::vector<std::string> test_groups = {"base", "gpu", "nn", "benchmark"};
std::map<std::string, std::vector<int>> test_ids;
for (std::string &g : test_groups)
test_ids[g] = {-1};
int idx = start;
std::string active_group = "";
bool in_group = false;
while (idx < argc)
{
std::string arg = argv[idx];
for (std::string &g : test_groups)
{
if (arg == g)
{
active_group = g;
break;
}
}
if (arg == "all")
{
if (active_group == "") //perform all tests in all groups
{
for (std::string &g : test_groups)
test_ids[g] = {};
}
else //perform all tests in this groups
test_ids[active_group] = {};
}
else if (active_group != "")
{
char* p;
int test_num = strtol(arg.c_str(), &p, 10);
if (p != arg.c_str())
{
if (test_ids[active_group][0] == -1)
test_ids[active_group][0] = test_num;
else
test_ids[active_group].push_back(test_num);
}
else
printf("invalid argument %s. It should be \"all\" or number\n", arg.c_str());
}
idx++;
}
if (test_ids["base"].empty() || test_ids["base"][0] != -1)
nn::perform_tests_tensor_processor(test_ids["base"]);
if (test_ids["gpu"].empty() || test_ids["gpu"][0] != -1)
nn::perform_tests_tensor_processor_GPU(test_ids["gpu"]);
if (test_ids["nn"].empty() || test_ids["nn"][0] != -1)
nn::perform_tests_neural_networks(test_ids["nn"]);
if (test_ids["benchmark"].empty() || test_ids["benchmark"][0] != -1)
nn::perform_tests_performance(test_ids["benchmark"]);
}
int main(int argc, char **argv)
{
//nnd::perform_tests();
if (argc == 1)
{
nn::perform_all_tests();
}
else if (argv[1] == "-h" || argv[1] == "-help" || argv[1] == "--help")
{
printf("./nn_test <test_group_1> <test_1> <test_2> ... <test_group_k> <test_1> ...\n");
printf("test groups are [base, gpu, nn, benchmark]\n");
printf("test_i is either test number or \"all\" for all tests\n");
printf("e.g. ./nn_test base all gpu all ./nn_test benchmark 1 2\n");
}
else
perform_tests_args(argv, argc, 1);
return 0;
}