Skip to content

Commit 8c9dfe1

Browse files
committed
Updated chebi.py for train_version restructure
1 parent 667b079 commit 8c9dfe1

File tree

1 file changed

+122
-60
lines changed

1 file changed

+122
-60
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 122 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC):
127127
Attributes:
128128
single_class (int): The ID of the single class to predict.
129129
chebi_version_train (int): The version of ChEBI to use for training and validation.
130+
dynamic_data_split_seed (int): The seed for random data splitting, default is 42.
131+
dynamic_df_train (pd.DataFrame): DataFrame to store the training data split.
132+
dynamic_df_test (pd.DataFrame): DataFrame to store the test data split.
133+
dynamic_df_val (pd.DataFrame): DataFrame to store the validation data split.
130134
"""
131135

132136
def __init__(
@@ -144,6 +148,16 @@ def __init__(
144148
self.dynamic_df_test = None
145149
self.dynamic_df_val = None
146150

151+
if self.chebi_version_train is not None:
152+
# Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given
153+
# This is to get the data from respective directory related to "chebi_version_train"
154+
_init_kwargs = kwargs
155+
_init_kwargs["chebi_version"] = self.chebi_version_train
156+
self._chebi_version_train_obj = self.__class__(
157+
single_class=self.single_class,
158+
**_init_kwargs,
159+
)
160+
147161
def extract_class_hierarchy(self, chebi_path):
148162
"""
149163
Extracts the class hierarchy from the ChEBI ontology.
@@ -238,15 +252,20 @@ def _setup_pruned_test_set(
238252
"""Create a test set with the same leaf nodes, but use only classes that appear in the training set"""
239253
# TODO: find a more efficient way to do this
240254
filename_old = "classes.txt"
241-
filename_new = f"classes_v{self.chebi_version_train}.txt"
255+
# filename_new = f"classes_v{self.chebi_version_train}.txt"
242256
# dataset = torch.load(os.path.join(self.processed_dir, "test.pt"))
243257

244258
# Load original classes (from the current ChEBI version - chebi_version)
245259
with open(os.path.join(self.processed_dir_main, filename_old), "r") as file:
246260
orig_classes = file.readlines()
247261

248262
# Load new classes (from the training ChEBI version - chebi_version_train)
249-
with open(os.path.join(self.processed_dir_main, filename_new), "r") as file:
263+
with open(
264+
os.path.join(
265+
self._chebi_version_train_obj.processed_dir_main, filename_old
266+
),
267+
"r",
268+
) as file:
250269
new_classes = file.readlines()
251270

252271
# Create a mapping which give index of a class from chebi_version, if the corresponding
@@ -277,42 +296,74 @@ def _setup_pruned_test_set(
277296
def setup_processed(self):
278297
print("Transform data")
279298
os.makedirs(self.processed_dir, exist_ok=True)
280-
for k in self.processed_file_names_dict.keys():
281-
# processed_name = (
282-
# "test.pt" if k == "test" else self.processed_file_names_dict[k]
283-
# )
284-
processed_name = self.processed_file_names_dict[k]
285-
if k == "data_chebi_train" and self.chebi_version_train is None:
286-
# To skip the encoding of data for "chebi_version_train", if it's not given
287-
continue
288-
289-
if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
290-
print(
291-
"Missing encoded data, transform processed data into encoded data",
292-
k,
293-
)
294-
torch.save(
295-
self._load_data_from_file(
296-
os.path.join(
297-
self.processed_dir_main, self.raw_file_names_dict[k]
298-
)
299-
),
300-
os.path.join(self.processed_dir, processed_name),
301-
)
302-
303299
# -------- Commented the code for Data Handling Restructure for Issue No.10
304300
# -------- https://github.com/ChEB-AI/python-chebai/issues/10
301+
# for k in self.processed_file_names_dict.keys():
302+
# processed_name = (
303+
# "test.pt" if k == "test" else self.processed_file_names_dict[k]
304+
# )
305+
# if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
306+
# print("transform", k)
307+
# torch.save(
308+
# self._load_data_from_file(
309+
# os.path.join(self.raw_dir, self.raw_file_names_dict[k])
310+
# ),
311+
# os.path.join(self.processed_dir, processed_name),
312+
# )
305313
# # create second test set with classes used in train
306314
# if self.chebi_version_train is not None and not os.path.isfile(
307-
# os.path.join(
308-
# self.processed_dir, self.processed_file_names_dict["data_chebi_train"]
309-
# )
315+
# os.path.join(self.processed_dir, self.processed_file_names_dict["test"])
310316
# ):
311317
# print("transform test (select classes)")
312318
# self._setup_pruned_test_set()
319+
#
320+
# processed_name = self.processed_file_names_dict[k]
321+
# if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
322+
# print(
323+
# "Missing encoded data, transform processed data into encoded data",
324+
# k,
325+
# )
326+
# torch.save(
327+
# self._load_data_from_file(
328+
# os.path.join(
329+
# self.processed_dir_main, self.raw_file_names_dict[k]
330+
# )
331+
# ),
332+
# os.path.join(self.processed_dir, processed_name),
333+
# )
334+
335+
# Transform the processed data into encoded data
336+
processed_name = self.processed_file_names_dict["data"]
337+
if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
338+
print(
339+
f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:",
340+
processed_name,
341+
)
342+
torch.save(
343+
self._load_data_from_file(
344+
os.path.join(
345+
self.processed_dir_main,
346+
self.raw_file_names_dict["data"],
347+
)
348+
),
349+
os.path.join(self.processed_dir, processed_name),
350+
)
351+
352+
# Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist
353+
if self.chebi_version_train is not None and not os.path.isfile(
354+
os.path.join(
355+
self._chebi_version_train_obj.processed_dir,
356+
self._chebi_version_train_obj.raw_file_names_dict["data"],
357+
)
358+
):
359+
print(
360+
f"Missing encoded data related to train version: {self.chebi_version_train}"
361+
)
362+
print("Call the setup method related to it")
363+
self._chebi_version_train_obj.setup()
313364

314365
def get_test_split(self, df: pd.DataFrame, seed: int = None):
315-
print("Get test data split")
366+
print("\nGet test data split")
316367

317368
# df_list = df.values.tolist()
318369
# df_list = [row[1] for row in df_list]
@@ -441,7 +492,6 @@ def processed_file_names_dict(self) -> dict:
441492
# else:
442493
# res[set] = f"{set}{train_v_str}.pt"
443494
res["data"] = "data.pt"
444-
res["data_chebi_train"] = f"data{train_v_str}.pt"
445495
return res
446496

447497
@property
@@ -464,7 +514,6 @@ def raw_file_names_dict(self) -> dict:
464514
# else:
465515
# res[set] = f"{set}{train_v_str}.pkl"
466516
res["data"] = "data.pkl"
467-
res["data_chebi_train"] = f"data{train_v_str}.pkl"
468517
return res
469518

470519
@property
@@ -560,42 +609,55 @@ def prepare_data(self, *args, **kwargs):
560609
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"])
561610
self.save_processed(df, filename=self.raw_file_names_dict["data"])
562611

563-
# Data from chebi_version_train
564-
if self.chebi_version_train is not None and not os.path.isfile(
565-
os.path.join(
566-
self.processed_dir_main,
567-
self.raw_file_names_dict["data_chebi_train"],
568-
)
569-
):
570-
chebi_path = self._load_chebi(self.chebi_version_train)
571-
g = self.extract_class_hierarchy(chebi_path)
572-
df = self.graph_to_raw_dataset(
573-
g, self.raw_file_names_dict["data_chebi_train"]
574-
)
575-
self.save_processed(
576-
df, filename=self.raw_file_names_dict["data_chebi_train"]
577-
)
612+
if self.chebi_version_train is not None:
613+
if not os.path.isfile(
614+
os.path.join(
615+
self._chebi_version_train_obj.processed_dir_main,
616+
self._chebi_version_train_obj.raw_file_names_dict["data"],
617+
)
618+
):
619+
print(
620+
f"Missing processed data related to train version: {self.chebi_version_train}"
621+
)
622+
print("Call the prepare_data method related to it")
623+
# Generate the "chebi_version_train" data if it doesn't exist
624+
self._chebi_version_train_obj.prepare_data(*args, **kwargs)
578625

579626
def _get_dynamic_splits(self):
580627
"""Generate data splits during run-time and saves in class variables"""
581628

582629
# Load encoded data derived from "chebi_version"
583-
data_chebi_version = torch.load(
584-
os.path.join(self.processed_dir, self.processed_file_names_dict["data"])
585-
)
630+
try:
631+
filename = self.processed_file_names_dict["data"]
632+
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
633+
except FileNotFoundError:
634+
raise FileNotFoundError(
635+
f"File {filename} doesn't exists. "
636+
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
637+
)
638+
586639
df_chebi_version = pd.DataFrame(data_chebi_version)
587640
train_df_chebi_ver, df_test_chebi_ver = self.get_test_split(
588641
df_chebi_version, seed=self.dynamic_data_split_seed
589642
)
590643

591644
if self.chebi_version_train is not None:
592645
# Load encoded data derived from "chebi_version_train"
593-
data_chebi_train_version = torch.load(
594-
os.path.join(
595-
self.processed_dir,
596-
self.processed_file_names_dict["data_chebi_train"],
646+
try:
647+
filename_train = (
648+
self._chebi_version_train_obj.processed_file_names_dict["data"]
597649
)
598-
)
650+
data_chebi_train_version = torch.load(
651+
os.path.join(
652+
self._chebi_version_train_obj.processed_dir, filename_train
653+
)
654+
)
655+
except FileNotFoundError:
656+
raise FileNotFoundError(
657+
f"File {filename_train} doesn't exists related to chebi_version_train {self.chebi_version_train}."
658+
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
659+
)
660+
599661
df_chebi_train_version = pd.DataFrame(data_chebi_train_version)
600662
# Get train/val split of data based on "chebi_version_train", but
601663
# using test set from "chebi_version"
@@ -744,12 +806,12 @@ def select_classes(self, g, split_name, *args, **kwargs):
744806
)
745807
)
746808
filename = "classes.txt"
747-
if (
748-
self.chebi_version_train
749-
is not None
750-
# and self.raw_file_names_dict["test"] != split_name
751-
):
752-
filename = f"classes_v{self.chebi_version_train}.txt"
809+
# if (
810+
# self.chebi_version_train
811+
# is not None
812+
# # and self.raw_file_names_dict["test"] != split_name
813+
# ):
814+
# filename = f"classes_v{self.chebi_version_train}.txt"
753815
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
754816
fout.writelines(str(node) + "\n" for node in nodes)
755817
return nodes

0 commit comments

Comments
 (0)