Skip to content

Commit 667b079

Browse files
committed
add filename parameter to load_processed_data
1 parent c1b6b0d commit 667b079

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,6 @@ def prepare_data(self, *args, **kwargs):
576576
df, filename=self.raw_file_names_dict["data_chebi_train"]
577577
)
578578

579-
def setup(self, **kwargs):
580-
super().setup(**kwargs)
581-
582579
def _get_dynamic_splits(self):
583580
"""Generate data splits during run-time and saves in class variables"""
584581

@@ -640,34 +637,44 @@ def dynamic_split_dfs(self):
640637
"test": self.dynamic_df_test,
641638
}
642639

643-
def load_processed_data(self, kind: str = None) -> List:
640+
def load_processed_data(self, kind: str = None, filename: str = None) -> List:
644641
"""
645642
Load processed data from a file.
646643
647644
Args:
648645
kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None.
646+
filename (str, optional): The name of the file to load the dataset from. Defaults to None.
649647
650648
Returns:
651649
List: The loaded processed data.
652650
653651
Raises:
654-
ValueError: If kind is None.
652+
ValueError: If both kind and filename are None.
653+
FileNotFoundError: If the specified file does not exist.
655654
"""
656-
if kind is None:
657-
raise ValueError("kind is required to load the correct dataset")
658-
# if both kind and filename are given, use filename
659-
if kind is not None:
655+
if kind is None and filename is None:
656+
raise ValueError(
657+
"Either kind or filename is required to load the correct dataset, both are None"
658+
)
659+
660+
# If both kind and filename are given, use filename
661+
if kind is not None and filename is None:
660662
try:
661-
# processed_file_names_dict is only implemented for _ChEBIDataExtractor
662663
if self.use_inner_cross_validation and kind != "test":
663664
filename = self.processed_file_names_dict[
664665
f"fold_{self.fold_index}_{kind}"
665666
]
666667
else:
667668
data_df = self.dynamic_split_dfs[kind]
668-
except NotImplementedError:
669-
filename = f"{kind}"
670-
return data_df.to_dict(orient="records")
669+
return data_df.to_dict(orient="records")
670+
except KeyError:
671+
kind = f"{kind}"
672+
673+
# If filename is provided
674+
try:
675+
return torch.load(os.path.join(self.processed_dir, filename))
676+
except FileNotFoundError:
677+
raise FileNotFoundError(f"File {filename} doesn't exist")
671678

672679

673680
class JCIExtendedBase(_ChEBIDataExtractor):

0 commit comments

Comments
 (0)