|
1 | 1 | from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator |
2 | 2 | import os |
3 | | - |
| 3 | +try: |
| 4 | + import tensorflow as tf # type: ignore |
| 5 | +except ModuleNotFoundError: |
| 6 | + raise ModuleNotFoundError( |
| 7 | + "'tensorflow' module should be installed to convert the Dataset into tensorflow format" |
| 8 | + ) |
4 | 9 |
|
5 | 10 | class DiffgramTensorflowDataset(DiffgramDatasetIterator): |
6 | 11 |
|
7 | | - def __init__(self, project, diffgram_file_id_list = None): |
| 12 | + def __init__(self, project, diffgram_file_id_list): |
8 | 13 | """ |
9 | 14 |
|
10 | 15 | :param project (sdk.core.core.Project): A Project object from the Diffgram SDK |
11 | 16 | :param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram. |
12 | 17 | :param transform (callable, optional): Optional transforms to be applied on a sample |
13 | 18 | """ |
14 | | - super(DiffgramDatasetIterator, self).__init__(project, diffgram_file_id_list) |
15 | | - global tf |
16 | | - try: |
17 | | - import tensorflow as tf # type: ignore |
18 | | - except ModuleNotFoundError: |
19 | | - raise ModuleNotFoundError( |
20 | | - "'tensorflow' module should be installed to convert the Dataset into tensorflow format" |
21 | | - ) |
| 19 | + super(DiffgramTensorflowDataset, self).__init__(project, diffgram_file_id_list) |
| 20 | + |
22 | 21 | self.diffgram_file_id_list = diffgram_file_id_list |
23 | 22 |
|
24 | 23 | self.project = project |
@@ -52,29 +51,34 @@ def __iter__(self): |
52 | 51 | self.current_file_index = 0 |
53 | 52 | return self |
54 | 53 |
|
| 54 | + def get_next_elm(self): |
| 55 | + yield self.__next__() |
| 56 | + |
55 | 57 | def __next__(self): |
56 | 58 | file_id = self.diffgram_file_id_list[self.current_file_index] |
57 | 59 | diffgram_file = self.project.file.get_by_id(file_id, with_instances = True) |
| 60 | + print('AAA', diffgram_file.id) |
| 61 | + image = self.get_image_data(diffgram_file) |
58 | 62 | instance_data = self.get_file_instances(diffgram_file) |
59 | | - filename, file_extension = os.path.splitext(instance_data['diffgram_file']['image']['original_filename']) |
60 | | - print('instance_data', instance_data) |
| 63 | + filename, file_extension = os.path.splitext(instance_data['diffgram_file'].image['original_filename']) |
| 64 | + label_names_bytes = [x.encode() for x in instance_data['label_name_list']] |
61 | 65 | tf_example_dict = { |
62 | | - 'image/height': self.int64_feature(instance_data['diffgram_file']['height']), |
63 | | - 'image/width': self.int64_feature(instance_data['diffgram_file']['width']), |
64 | | - 'image/filename': self.bytes_feature(filename), |
65 | | - 'image/source_id': self.bytes_feature(filename), |
66 | | - 'image/encoded': self.bytes_feature(instance_data['image']), |
67 | | - 'image/format': self.bytes_feature(file_extension), |
| 66 | + 'image/height': self.int64_feature(instance_data['diffgram_file'].image['height']), |
| 67 | + 'image/width': self.int64_feature(instance_data['diffgram_file'].image['width']), |
| 68 | + 'image/filename': self.bytes_feature(filename.encode()), |
| 69 | + 'image/source_id': self.bytes_feature(filename.encode()), |
| 70 | + 'image/encoded': self.bytes_feature(image.tobytes()), |
| 71 | + 'image/format': self.bytes_feature(file_extension.encode()), |
68 | 72 | 'image/object/bbox/xmin': self.float_list_feature(instance_data['x_min_list']), |
69 | 73 | 'image/object/bbox/xmax': self.float_list_feature(instance_data['x_max_list']), |
70 | 74 | 'image/object/bbox/ymin': self.float_list_feature(instance_data['y_min_list']), |
71 | 75 | 'image/object/bbox/ymax': self.float_list_feature(instance_data['y_max_list']), |
72 | | - 'image/object/class/text': self.bytes_list_feature(instance_data['label_name_list']), |
| 76 | + 'image/object/class/text': self.bytes_list_feature(label_names_bytes), |
73 | 77 | 'image/object/class/label': self.int64_list_feature(instance_data['label_id_list']), |
74 | 78 | } |
75 | 79 | tf_example = tf.train.Example(features = tf.train.Features(feature = tf_example_dict)) |
76 | 80 | self.current_file_index += 1 |
77 | 81 | return tf_example |
78 | 82 |
|
79 | 83 | def get_dataset_obj(self): |
80 | | - return tf.data.Dataset.from_generator(self.__iter__) |
| 84 | + return tf.data.Dataset.from_generator(self.get_next_elm, output_signature = tf.TensorSpec(shape=(1,))) |
0 commit comments