-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathModel.py
More file actions
22 lines (17 loc) · 702 Bytes
/
Model.py
File metadata and controls
22 lines (17 loc) · 702 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
import tensorflow_hub as hub
from Dataset import Dataset
class Model:
def __init__(self, params):
self._top_k = params['TOP_K']
self._prepare_graph(params)
def _prepare_graph(self, params):
tf.reset_default_graph()
self.dataset = Dataset(params)
logits = self._prepare_model(self.dataset.img_data)
softmax = tf.nn.softmax(logits)
self.top_prediction = tf.nn.top_k(softmax, self._top_k, name='top_prediction')
def _prepare_model(self, images):
module = hub.Module('https://tfhub.dev/google/imagenet/inception_v3/classification/1')
logits = module(dict(images=images))
return logits