@@ -576,9 +576,6 @@ def prepare_data(self, *args, **kwargs):
576576 df , filename = self .raw_file_names_dict ["data_chebi_train" ]
577577 )
578578
579- def setup (self , ** kwargs ):
580- super ().setup (** kwargs )
581-
582579 def _get_dynamic_splits (self ):
583580 """Generate data splits during run-time and saves in class variables"""
584581
@@ -640,34 +637,44 @@ def dynamic_split_dfs(self):
640637 "test" : self .dynamic_df_test ,
641638 }
642639
643- def load_processed_data (self , kind : str = None ) -> List :
640+ def load_processed_data (self , kind : str = None , filename : str = None ) -> List :
644641 """
645642 Load processed data from a file.
646643
647644 Args:
648645 kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None.
646+ filename (str, optional): The name of the file to load the dataset from. Defaults to None.
649647
650648 Returns:
651649 List: The loaded processed data.
652650
653651 Raises:
654- ValueError: If kind is None.
652+ ValueError: If both kind and filename are None.
653+ FileNotFoundError: If the specified file does not exist.
655654 """
656- if kind is None :
657- raise ValueError ("kind is required to load the correct dataset" )
658- # if both kind and filename are given, use filename
659- if kind is not None :
655+ if kind is None and filename is None :
656+ raise ValueError (
657+ "Either kind or filename is required to load the correct dataset, both are None"
658+ )
659+
660+ # If both kind and filename are given, use filename
661+ if kind is not None and filename is None :
660662 try :
661- # processed_file_names_dict is only implemented for _ChEBIDataExtractor
662663 if self .use_inner_cross_validation and kind != "test" :
663664 filename = self .processed_file_names_dict [
664665 f"fold_{ self .fold_index } _{ kind } "
665666 ]
666667 else :
667668 data_df = self .dynamic_split_dfs [kind ]
668- except NotImplementedError :
669- filename = f"{ kind } "
670- return data_df .to_dict (orient = "records" )
669+ return data_df .to_dict (orient = "records" )
670+ except KeyError :
671+ kind = f"{ kind } "
672+
673+ # If filename is provided
674+ try :
675+ return torch .load (os .path .join (self .processed_dir , filename ))
676+ except FileNotFoundError :
677+ raise FileNotFoundError (f"File { filename } doesn't exist" )
671678
672679
673680class JCIExtendedBase (_ChEBIDataExtractor ):
0 commit comments