Skip to content

Commit 04edb30

Browse files
committed
wip: pytorch tensorflow
1 parent 15241d4 commit 04edb30

File tree

4 files changed

+46
-21
lines changed

4 files changed

+46
-21
lines changed

sdk/diffgram/core/directory.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ..regular.regular import refresh_from_dict
33
import logging
44
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
5+
from diffgram.tensorflow_diffgram.diffgram_tensorflow_dataset import DiffgramTensorflowDataset
56

67
def get_directory_list(self):
78
"""
@@ -131,6 +132,15 @@ def to_pytorch(self, transform = None):
131132
)
132133
return pytorch_dataset
133134

135+
def to_tensorflow(self):
136+
file_id_list = self.all_file_ids()
137+
diffgram_tensorflow_dataset = DiffgramTensorflowDataset(
138+
project = self.client,
139+
diffgram_file_id_list = file_id_list
140+
)
141+
tf_dataset = diffgram_tensorflow_dataset.get_dataset_obj()
142+
return tf_dataset
143+
134144
def new(self, name: str):
135145
"""
136146
Create a new directory and update directory list.

sdk/diffgram/core/sliced_directory.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from diffgram.core.directory import Directory
22
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
3+
from diffgram.tensorflow_diffgram.diffgram_tensorflow_dataset import DiffgramTensorflowDataset
34

45

56
class SlicedDirectory(Directory):
@@ -15,7 +16,6 @@ def all_file_ids(self):
1516
page_num = 1
1617
result = []
1718
while page_num is not None:
18-
print('slcied query', self.query)
1919
diffgram_files = self.list_files(limit = 1000,
2020
page_num = page_num,
2121
file_view_mode = 'ids_only',
@@ -37,3 +37,12 @@ def to_pytorch(self, transform = None):
3737

3838
)
3939
return pytorch_dataset
40+
41+
def to_tensorflow(self):
42+
file_id_list = self.all_file_ids()
43+
diffgram_tensorflow_dataset = DiffgramTensorflowDataset(
44+
project = self.client,
45+
diffgram_file_id_list = file_id_list
46+
)
47+
tf_dataset = diffgram_tensorflow_dataset.get_dataset_obj()
48+
return tf_dataset

sdk/diffgram/tensorflow_diffgram/diffgram_tensorflow_dataset.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
22
import os
3-
3+
try:
4+
import tensorflow as tf # type: ignore
5+
except ModuleNotFoundError:
6+
raise ModuleNotFoundError(
7+
"'tensorflow' module should be installed to convert the Dataset into tensorflow format"
8+
)
49

510
class DiffgramTensorflowDataset(DiffgramDatasetIterator):
611

7-
def __init__(self, project, diffgram_file_id_list = None):
12+
def __init__(self, project, diffgram_file_id_list):
813
"""
914
1015
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
1116
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1217
:param transform (callable, optional): Optional transforms to be applied on a sample
1318
"""
14-
super(DiffgramDatasetIterator, self).__init__(project, diffgram_file_id_list)
15-
global tf
16-
try:
17-
import tensorflow as tf # type: ignore
18-
except ModuleNotFoundError:
19-
raise ModuleNotFoundError(
20-
"'tensorflow' module should be installed to convert the Dataset into tensorflow format"
21-
)
19+
super(DiffgramTensorflowDataset, self).__init__(project, diffgram_file_id_list)
20+
2221
self.diffgram_file_id_list = diffgram_file_id_list
2322

2423
self.project = project
@@ -52,29 +51,34 @@ def __iter__(self):
5251
self.current_file_index = 0
5352
return self
5453

54+
def get_next_elm(self):
55+
yield self.__next__()
56+
5557
def __next__(self):
5658
file_id = self.diffgram_file_id_list[self.current_file_index]
5759
diffgram_file = self.project.file.get_by_id(file_id, with_instances = True)
60+
print('AAA', diffgram_file.id)
61+
image = self.get_image_data(diffgram_file)
5862
instance_data = self.get_file_instances(diffgram_file)
59-
filename, file_extension = os.path.splitext(instance_data['diffgram_file']['image']['original_filename'])
60-
print('instance_data', instance_data)
63+
filename, file_extension = os.path.splitext(instance_data['diffgram_file'].image['original_filename'])
64+
label_names_bytes = [x.encode() for x in instance_data['label_name_list']]
6165
tf_example_dict = {
62-
'image/height': self.int64_feature(instance_data['diffgram_file']['height']),
63-
'image/width': self.int64_feature(instance_data['diffgram_file']['width']),
64-
'image/filename': self.bytes_feature(filename),
65-
'image/source_id': self.bytes_feature(filename),
66-
'image/encoded': self.bytes_feature(instance_data['image']),
67-
'image/format': self.bytes_feature(file_extension),
66+
'image/height': self.int64_feature(instance_data['diffgram_file'].image['height']),
67+
'image/width': self.int64_feature(instance_data['diffgram_file'].image['width']),
68+
'image/filename': self.bytes_feature(filename.encode()),
69+
'image/source_id': self.bytes_feature(filename.encode()),
70+
'image/encoded': self.bytes_feature(image.tobytes()),
71+
'image/format': self.bytes_feature(file_extension.encode()),
6872
'image/object/bbox/xmin': self.float_list_feature(instance_data['x_min_list']),
6973
'image/object/bbox/xmax': self.float_list_feature(instance_data['x_max_list']),
7074
'image/object/bbox/ymin': self.float_list_feature(instance_data['y_min_list']),
7175
'image/object/bbox/ymax': self.float_list_feature(instance_data['y_max_list']),
72-
'image/object/class/text': self.bytes_list_feature(instance_data['label_name_list']),
76+
'image/object/class/text': self.bytes_list_feature(label_names_bytes),
7377
'image/object/class/label': self.int64_list_feature(instance_data['label_id_list']),
7478
}
7579
tf_example = tf.train.Example(features = tf.train.Features(feature = tf_example_dict))
7680
self.current_file_index += 1
7781
return tf_example
7882

7983
def get_dataset_obj(self):
80-
return tf.data.Dataset.from_generator(self.__iter__)
84+
return tf.data.Dataset.from_generator(self.get_next_elm, output_signature = tf.TensorSpec(shape=(1,)))

sdk/diffgram/tensorflow_diffgram/pytorch_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def display_masks():
3838
dataset = project.directory.get('Default')
3939

4040
pytorch_dataset = dataset.to_pytorch()
41+
tf_dataset = dataset.to_tensorflow()
42+
4143

4244
sliced_dataset = dataset.slice(query = 'labels.sheep > 0 or labels.sofa > 0')
4345

0 commit comments

Comments
 (0)