@@ -127,6 +127,10 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC):
127127 Attributes:
128128 single_class (int): The ID of the single class to predict.
129129 chebi_version_train (int): The version of ChEBI to use for training and validation.
130+ dynamic_data_split_seed (int): The seed for random data splitting, default is 42.
131+ dynamic_df_train (pd.DataFrame): DataFrame to store the training data split.
132+ dynamic_df_test (pd.DataFrame): DataFrame to store the test data split.
133+ dynamic_df_val (pd.DataFrame): DataFrame to store the validation data split.
130134 """
131135
132136 def __init__ (
@@ -144,6 +148,16 @@ def __init__(
144148 self .dynamic_df_test = None
145149 self .dynamic_df_val = None
146150
151+ if self .chebi_version_train is not None :
152+ # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given
153+ # This is to get the data from respective directory related to "chebi_version_train"
154+ _init_kwargs = kwargs
155+ _init_kwargs ["chebi_version" ] = self .chebi_version_train
156+ self ._chebi_version_train_obj = self .__class__ (
157+ single_class = self .single_class ,
158+ ** _init_kwargs ,
159+ )
160+
147161 def extract_class_hierarchy (self , chebi_path ):
148162 """
149163 Extracts the class hierarchy from the ChEBI ontology.
@@ -238,15 +252,20 @@ def _setup_pruned_test_set(
238252 """Create a test set with the same leaf nodes, but use only classes that appear in the training set"""
239253 # TODO: find a more efficient way to do this
240254 filename_old = "classes.txt"
241- filename_new = f"classes_v{ self .chebi_version_train } .txt"
255+ # filename_new = f"classes_v{self.chebi_version_train}.txt"
242256 # dataset = torch.load(os.path.join(self.processed_dir, "test.pt"))
243257
244258 # Load original classes (from the current ChEBI version - chebi_version)
245259 with open (os .path .join (self .processed_dir_main , filename_old ), "r" ) as file :
246260 orig_classes = file .readlines ()
247261
248262 # Load new classes (from the training ChEBI version - chebi_version_train)
249- with open (os .path .join (self .processed_dir_main , filename_new ), "r" ) as file :
263+ with open (
264+ os .path .join (
265+ self ._chebi_version_train_obj .processed_dir_main , filename_old
266+ ),
267+ "r" ,
268+ ) as file :
250269 new_classes = file .readlines ()
251270
252271 # Create a mapping which give index of a class from chebi_version, if the corresponding
@@ -277,42 +296,74 @@ def _setup_pruned_test_set(
277296 def setup_processed (self ):
278297 print ("Transform data" )
279298 os .makedirs (self .processed_dir , exist_ok = True )
280- for k in self .processed_file_names_dict .keys ():
281- # processed_name = (
282- # "test.pt" if k == "test" else self.processed_file_names_dict[k]
283- # )
284- processed_name = self .processed_file_names_dict [k ]
285- if k == "data_chebi_train" and self .chebi_version_train is None :
286- # To skip the encoding of data for "chebi_version_train", if it's not given
287- continue
288-
289- if not os .path .isfile (os .path .join (self .processed_dir , processed_name )):
290- print (
291- "Missing encoded data, transform processed data into encoded data" ,
292- k ,
293- )
294- torch .save (
295- self ._load_data_from_file (
296- os .path .join (
297- self .processed_dir_main , self .raw_file_names_dict [k ]
298- )
299- ),
300- os .path .join (self .processed_dir , processed_name ),
301- )
302-
303299 # -------- Commented the code for Data Handling Restructure for Issue No.10
304300 # -------- https://github.com/ChEB-AI/python-chebai/issues/10
301+ # for k in self.processed_file_names_dict.keys():
302+ # processed_name = (
303+ # "test.pt" if k == "test" else self.processed_file_names_dict[k]
304+ # )
305+ # if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
306+ # print("transform", k)
307+ # torch.save(
308+ # self._load_data_from_file(
309+ # os.path.join(self.raw_dir, self.raw_file_names_dict[k])
310+ # ),
311+ # os.path.join(self.processed_dir, processed_name),
312+ # )
305313 # # create second test set with classes used in train
306314 # if self.chebi_version_train is not None and not os.path.isfile(
307- # os.path.join(
308- # self.processed_dir, self.processed_file_names_dict["data_chebi_train"]
309- # )
315+ # os.path.join(self.processed_dir, self.processed_file_names_dict["test"])
310316 # ):
311317 # print("transform test (select classes)")
312318 # self._setup_pruned_test_set()
319+ #
320+ # processed_name = self.processed_file_names_dict[k]
321+ # if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
322+ # print(
323+ # "Missing encoded data, transform processed data into encoded data",
324+ # k,
325+ # )
326+ # torch.save(
327+ # self._load_data_from_file(
328+ # os.path.join(
329+ # self.processed_dir_main, self.raw_file_names_dict[k]
330+ # )
331+ # ),
332+ # os.path.join(self.processed_dir, processed_name),
333+ # )
334+
335+ # Transform the processed data into encoded data
336+ processed_name = self .processed_file_names_dict ["data" ]
337+ if not os .path .isfile (os .path .join (self .processed_dir , processed_name )):
338+ print (
339+ f"Missing encoded data related to version { self .chebi_version } , transform processed data into encoded data:" ,
340+ processed_name ,
341+ )
342+ torch .save (
343+ self ._load_data_from_file (
344+ os .path .join (
345+ self .processed_dir_main ,
346+ self .raw_file_names_dict ["data" ],
347+ )
348+ ),
349+ os .path .join (self .processed_dir , processed_name ),
350+ )
351+
352+ # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist
353+ if self .chebi_version_train is not None and not os .path .isfile (
354+ os .path .join (
355+ self ._chebi_version_train_obj .processed_dir ,
356+ self ._chebi_version_train_obj .raw_file_names_dict ["data" ],
357+ )
358+ ):
359+ print (
360+ f"Missing encoded data related to train version: { self .chebi_version_train } "
361+ )
362+ print ("Call the setup method related to it" )
363+ self ._chebi_version_train_obj .setup ()
313364
314365 def get_test_split (self , df : pd .DataFrame , seed : int = None ):
315- print ("Get test data split" )
366+ print ("\n Get test data split" )
316367
317368 # df_list = df.values.tolist()
318369 # df_list = [row[1] for row in df_list]
@@ -441,7 +492,6 @@ def processed_file_names_dict(self) -> dict:
441492 # else:
442493 # res[set] = f"{set}{train_v_str}.pt"
443494 res ["data" ] = "data.pt"
444- res ["data_chebi_train" ] = f"data{ train_v_str } .pt"
445495 return res
446496
447497 @property
@@ -464,7 +514,6 @@ def raw_file_names_dict(self) -> dict:
464514 # else:
465515 # res[set] = f"{set}{train_v_str}.pkl"
466516 res ["data" ] = "data.pkl"
467- res ["data_chebi_train" ] = f"data{ train_v_str } .pkl"
468517 return res
469518
470519 @property
@@ -560,42 +609,55 @@ def prepare_data(self, *args, **kwargs):
560609 df = self .graph_to_raw_dataset (g , self .raw_file_names_dict ["data" ])
561610 self .save_processed (df , filename = self .raw_file_names_dict ["data" ])
562611
563- # Data from chebi_version_train
564- if self .chebi_version_train is not None and not os .path .isfile (
565- os .path .join (
566- self .processed_dir_main ,
567- self .raw_file_names_dict ["data_chebi_train" ],
568- )
569- ):
570- chebi_path = self ._load_chebi (self .chebi_version_train )
571- g = self .extract_class_hierarchy (chebi_path )
572- df = self .graph_to_raw_dataset (
573- g , self .raw_file_names_dict ["data_chebi_train" ]
574- )
575- self .save_processed (
576- df , filename = self .raw_file_names_dict ["data_chebi_train" ]
577- )
612+ if self .chebi_version_train is not None :
613+ if not os .path .isfile (
614+ os .path .join (
615+ self ._chebi_version_train_obj .processed_dir_main ,
616+ self ._chebi_version_train_obj .raw_file_names_dict ["data" ],
617+ )
618+ ):
619+ print (
620+ f"Missing processed data related to train version: { self .chebi_version_train } "
621+ )
622+ print ("Call the prepare_data method related to it" )
623+ # Generate the "chebi_version_train" data if it doesn't exist
624+ self ._chebi_version_train_obj .prepare_data (* args , ** kwargs )
578625
579626 def _get_dynamic_splits (self ):
580627 """Generate data splits during run-time and saves in class variables"""
581628
582629 # Load encoded data derived from "chebi_version"
583- data_chebi_version = torch .load (
584- os .path .join (self .processed_dir , self .processed_file_names_dict ["data" ])
585- )
630+ try :
631+ filename = self .processed_file_names_dict ["data" ]
632+ data_chebi_version = torch .load (os .path .join (self .processed_dir , filename ))
633+ except FileNotFoundError :
634+ raise FileNotFoundError (
635+ f"File { filename } doesn't exists. "
636+ f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
637+ )
638+
586639 df_chebi_version = pd .DataFrame (data_chebi_version )
587640 train_df_chebi_ver , df_test_chebi_ver = self .get_test_split (
588641 df_chebi_version , seed = self .dynamic_data_split_seed
589642 )
590643
591644 if self .chebi_version_train is not None :
592645 # Load encoded data derived from "chebi_version_train"
593- data_chebi_train_version = torch .load (
594- os .path .join (
595- self .processed_dir ,
596- self .processed_file_names_dict ["data_chebi_train" ],
646+ try :
647+ filename_train = (
648+ self ._chebi_version_train_obj .processed_file_names_dict ["data" ]
597649 )
598- )
650+ data_chebi_train_version = torch .load (
651+ os .path .join (
652+ self ._chebi_version_train_obj .processed_dir , filename_train
653+ )
654+ )
655+ except FileNotFoundError :
656+ raise FileNotFoundError (
657+ f"File { filename_train } doesn't exists related to chebi_version_train { self .chebi_version_train } ."
658+ f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
659+ )
660+
599661 df_chebi_train_version = pd .DataFrame (data_chebi_train_version )
600662 # Get train/val split of data based on "chebi_version_train", but
601663 # using test set from "chebi_version"
@@ -744,12 +806,12 @@ def select_classes(self, g, split_name, *args, **kwargs):
744806 )
745807 )
746808 filename = "classes.txt"
747- if (
748- self .chebi_version_train
749- is not None
750- # and self.raw_file_names_dict["test"] != split_name
751- ):
752- filename = f"classes_v{ self .chebi_version_train } .txt"
809+ # if (
810+ # self.chebi_version_train
811+ # is not None
812+ # # and self.raw_file_names_dict["test"] != split_name
813+ # ):
814+ # filename = f"classes_v{self.chebi_version_train}.txt"
753815 with open (os .path .join (self .processed_dir_main , filename ), "wt" ) as fout :
754816 fout .writelines (str (node ) + "\n " for node in nodes )
755817 return nodes
0 commit comments