11from diffgram .core .diffgram_dataset_iterator import DiffgramDatasetIterator
22import os
3+
34try :
45 import tensorflow as tf # type: ignore
56except ModuleNotFoundError :
67 raise ModuleNotFoundError (
78 "'tensorflow' module should be installed to convert the Dataset into tensorflow format"
89 )
910
11+
1012class DiffgramTensorflowDataset (DiffgramDatasetIterator ):
1113
1214 def __init__ (self , project , diffgram_file_id_list ):
@@ -47,17 +49,13 @@ def __validate_file_ids(self):
4749 raise Exception (
4850 'Some file IDs do not belong to the project. Please provide only files from the same project.' )
4951
50- def __iter__ (self ):
51- self .current_file_index = 0
52- return self
53-
54- def get_next_elm (self ):
55- yield self .__next__ ()
52+ def __getitem__ (self , idx ):
53+ tf_example = self .get_tf_train_example (idx )
54+ return tf_example
5655
57- def __next__ (self ):
58- file_id = self .diffgram_file_id_list [self . current_file_index ]
56+ def get_tf_train_example (self , idx ):
57+ file_id = self .diffgram_file_id_list [idx ]
5958 diffgram_file = self .project .file .get_by_id (file_id , with_instances = True )
60- print ('AAA' , diffgram_file .id )
6159 image = self .get_image_data (diffgram_file )
6260 instance_data = self .get_file_instances (diffgram_file )
6361 filename , file_extension = os .path .splitext (instance_data ['diffgram_file' ].image ['original_filename' ])
@@ -77,8 +75,12 @@ def __next__(self):
7775 'image/object/class/label' : self .int64_list_feature (instance_data ['label_id_list' ]),
7876 }
7977 tf_example = tf .train .Example (features = tf .train .Features (feature = tf_example_dict ))
78+ return tf_example
79+
80+ def __next__ (self ):
81+ tf_example = self .get_tf_train_example (self .current_file_index )
8082 self .current_file_index += 1
8183 return tf_example
8284
83- def get_dataset_obj (self ):
84- return tf .data .Dataset .from_generator (self .get_next_elm , output_signature = tf .TensorSpec (shape = (1 ,)))
85+ # def get_dataset_obj(self):
86+ # return tf.data.Dataset.from_generator(self.get_next_elm, output_signature = tf.TensorSpec(shape = (1,)))
0 commit comments