Skip to content

Commit 066f10d

Browse files
committed
fix: tf dataset get item
1 parent 4d61f77 commit 066f10d

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

sdk/diffgram/core/directory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, client, file_id_list_sliced = None):
7979
self.client = client
8080
self.id = None
8181
self.file_list_metadata = {}
82-
82+
8383
if file_id_list_sliced is None:
8484
self.file_id_list = self.all_file_ids()
8585
else:
@@ -145,8 +145,7 @@ def to_tensorflow(self):
145145
project = self.client,
146146
diffgram_file_id_list = file_id_list
147147
)
148-
tf_dataset = diffgram_tensorflow_dataset.get_dataset_obj()
149-
return tf_dataset
148+
return diffgram_tensorflow_dataset
150149

151150
def new(self, name: str):
152151
"""

sdk/diffgram/core/sliced_directory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,4 @@ def to_tensorflow(self):
4646
project = self.client,
4747
diffgram_file_id_list = file_id_list
4848
)
49-
tf_dataset = diffgram_tensorflow_dataset.get_dataset_obj()
50-
return tf_dataset
49+
return diffgram_tensorflow_dataset

sdk/diffgram/tensorflow_diffgram/diffgram_tensorflow_dataset.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
22
import os
3+
34
try:
45
import tensorflow as tf # type: ignore
56
except ModuleNotFoundError:
67
raise ModuleNotFoundError(
78
"'tensorflow' module should be installed to convert the Dataset into tensorflow format"
89
)
910

11+
1012
class DiffgramTensorflowDataset(DiffgramDatasetIterator):
1113

1214
def __init__(self, project, diffgram_file_id_list):
@@ -47,17 +49,13 @@ def __validate_file_ids(self):
4749
raise Exception(
4850
'Some file IDs do not belong to the project. Please provide only files from the same project.')
4951

50-
def __iter__(self):
51-
self.current_file_index = 0
52-
return self
53-
54-
def get_next_elm(self):
55-
yield self.__next__()
52+
def __getitem__(self, idx):
53+
tf_example = self.get_tf_train_example(idx)
54+
return tf_example
5655

57-
def __next__(self):
58-
file_id = self.diffgram_file_id_list[self.current_file_index]
56+
def get_tf_train_example(self, idx):
57+
file_id = self.diffgram_file_id_list[idx]
5958
diffgram_file = self.project.file.get_by_id(file_id, with_instances = True)
60-
print('AAA', diffgram_file.id)
6159
image = self.get_image_data(diffgram_file)
6260
instance_data = self.get_file_instances(diffgram_file)
6361
filename, file_extension = os.path.splitext(instance_data['diffgram_file'].image['original_filename'])
@@ -77,8 +75,12 @@ def __next__(self):
7775
'image/object/class/label': self.int64_list_feature(instance_data['label_id_list']),
7876
}
7977
tf_example = tf.train.Example(features = tf.train.Features(feature = tf_example_dict))
78+
return tf_example
79+
80+
def __next__(self):
81+
tf_example = self.get_tf_train_example(self.current_file_index)
8082
self.current_file_index += 1
8183
return tf_example
8284

83-
def get_dataset_obj(self):
84-
return tf.data.Dataset.from_generator(self.get_next_elm, output_signature = tf.TensorSpec(shape=(1,)))
85+
# def get_dataset_obj(self):
86+
# return tf.data.Dataset.from_generator(self.get_next_elm, output_signature = tf.TensorSpec(shape = (1,)))

0 commit comments

Comments
 (0)