-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathmain.py
More file actions
181 lines (144 loc) · 4.96 KB
/
main.py
File metadata and controls
181 lines (144 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Read PMML files and make predictions
Example usage:
python main.py predict \
--model=examples/deepnetwork/ResNet50.pmml \
--input=tests/assets/cat.jpg
python main.py runserver \
--model=examples/deepnetwork/VGG16/model.pmml
python main.py build_torch_examples
python main.py build_keras_examples
python main.py validate
"""
import json
import glob
import argparse
from lxml import etree
from imageio import imread
from models.gpr.parser import GaussianProcessParser
from models.deepnetwork.core.intermediate import DeepNetwork
from models.deepnetwork.core.utils import strip_namespace
def load_pmml(filename):
"""
Load a PMML file
The model type is determined read from the PMML file
"""
tree = etree.iterparse(filename)
root = strip_namespace(tree).root
config = {}
config["filename"] = filename
header = root.find("Header")
if "description" in header.attrib:
config["description"] = header.attrib["description"]
if "copyright" in header.attrib:
config["copyright"] = header.attrib["copyright"]
dnn = root.find("DeepNetwork")
gpr = root.find("GaussianProcessModel")
if dnn is not None:
model = DeepNetwork(**config)
model.load_pmml(root)
elif gpr is not None:
parser = GaussianProcessParser()
model = parser.parse(filename)
else:
raise ValueError("Could not find a valid model in %s"%filename)
return model
def build_keras_examples():
"""
Automatically build examples from publically available models
"""
from models.deepnetwork import generate_keras_models
generate_keras_models.build_models([
"VGG_16",
"VGG_19",
"RESNET_50",
"MOBILENET",
"INCEPTION_V3",
"INCEPTION_RESNET",
"DENSENET_121",
"DENSENET_169",
])
def build_pytorch_examples():
"""
Automatically build examples from publically available models
"""
from models.deepnetwork import generate_torch_models
generate_torch_models.build_models([
"VGG_16",
"VGG_19",
"RESNET_50",
#"MOBILENET",
#"INCEPTION_V3",
#"INCEPTION_RESNET",
#"DENSENET_121",
#"DENSENET_169",
#"DENSENET_201"
])
def validate_models_using_schema(filename):
"""
Validate a file against the schema
Validates all models if a filename is not provided
"""
model = DeepNetwork()
if filename:
filenames = [filename]
else:
keras_filenames = glob.glob("examples/deepnetwork/*.pmml")
torch_filenames = glob.glob("examples/deepnetwork/*.pmml")
filenames = keras_filenames + torch_filenames
for filepath in filenames:
print("Validating {0}".format(filepath))
if model.validate_pmml(filepath):
print("PMML File is VALID\n")
else:
print("PMML File is INVALID\n")
model.read_pmml(filepath) # Force error
def predict(model, input_file):
"""
Return a prediction from a model
The input_file is either an image or a json file describing the input
"""
if input_file.endswith(".json"):
with open(input_file,"w") as fd:
data = json.loads(input_file)
else:
data = imread(input_file)
result = model.predict(data)
print("Model predicted class: %s"%result)
return result
parser = argparse.ArgumentParser(description='Main entry point for PMML package.')
parser.add_argument('--runserver', default=False, help='Run a server')
subparsers = parser.add_subparsers(dest='operation', help='One of [predict, runserver, validate, ...]')
# Prediction parser
predict_parser = subparsers.add_parser('predict',
help='Usage: main.py predict --input input')
predict_parser.add_argument('--model', default='',
help='The PMML file to load')
predict_parser.add_argument('--input', default='',
help='The path to the input file for testing')
# Build_pytorch parser
subparsers.add_parser('build_pytorch_examples',
help='Usage: main.py build_pytorch_examples')
# Build_keras parser
subparsers.add_parser('build_keras_examples',
help='Usage: main.py build_keras_examples')
# Validate parser
parser_validate = subparsers.add_parser('validate',
help='Usage: main.py validate [--filename filename]')
parser_validate.add_argument('--filename', type=str,
help='PMML file to validate')
if __name__=="__main__":
args = parser.parse_args()
if args.operation.lower()=="build_keras_examples":
build_keras_examples()
elif args.operation.lower()=="build_pytorch_examples":
build_keras_examples()
elif args.operation=="validate":
validate_models_using_schema(args.filename)
elif args.operation.lower()=="predict":
model = load_pmml(args.model)
prediction = predict(model, args.input)
elif args.operation.lower()=="runserver":
model = load_pmml(args.model)
else:
raise ValueError("Unknown operation %s"%args.operation)