diff --git a/peerqa/data_loader.py b/peerqa/data_loader.py index 8ac3457..03c688b 100644 --- a/peerqa/data_loader.py +++ b/peerqa/data_loader.py @@ -1,17 +1,22 @@ from typing import Iterator, Literal, Union +from pathlib import Path import pandas as pd from tqdm.auto import tqdm class PaperLoader: - def __init__(self, papers_file: str): - self.df = pd.read_json(papers_file, lines=True) + def __init__(self, papers_file: Union[str, Path]): + papers_file = Path(papers_file) + if papers_file.exists() and papers_file.stat().st_size > 0: + self.df = pd.read_json(papers_file, lines=True) + else: + self.df = pd.DataFrame({"paper_id": []}) def __call__( self, granularity: Literal["sentences", "paragraphs"], - template: str = None, + template: Union[str, None] = None, show_progress: bool = True, ) -> Iterator[tuple[str, list[str], list[str]]]: """Yields a the paper_id, sentence/paragraph index and sentence/paragraph.""" @@ -85,6 +90,8 @@ def get_title(self, paper_id: str) -> str: return title def has_paper_id(self, paper_id: str) -> bool: + if self.df.empty: + return False return paper_id in self.df.paper_id.unique()