diff --git a/src/benchmark/__init__.py b/src/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/annotation_benchmark.py b/src/benchmark/annotation_benchmark.py new file mode 100644 index 0000000..0f54681 --- /dev/null +++ b/src/benchmark/annotation_benchmark.py @@ -0,0 +1,66 @@ +from typing import List +from src.utils import get_pmcid_annotation + + +class AnnotationBenchmark: + def __init__(self): + pass + + def get_var_drug_ann_score(self, var_drug_ann: List[dict]): + return 1.0 + + def get_var_pheno_ann_score(self, var_pheno_ann: List[dict]): + return 1.0 + + def get_var_fa_ann_score(self, var_fa_ann: List[dict]): + return 1.0 + + def get_study_parameters_score(self, study_parameters: List[dict]): + return 1.0 + + def calculate_total_score( + self, + var_drug_ann: List[dict], + var_pheno_ann: List[dict], + var_fa_ann: List[dict], + study_parameters: List[dict], + ): + # Return average of all scores + scores = [ + self.get_var_drug_ann_score(var_drug_ann), + self.get_var_pheno_ann_score(var_pheno_ann), + self.get_var_fa_ann_score(var_fa_ann), + self.get_study_parameters_score(study_parameters), + ] + return sum(scores) / len(scores) + + def run(self, pmcid: str): + pmcid_annotation = get_pmcid_annotation(pmcid) + + var_drug_ann = pmcid_annotation.get("varDrugAnn", []) + var_pheno_ann = pmcid_annotation.get("varPhenoAnn", []) + var_fa_ann = pmcid_annotation.get("varFaAnn", []) + study_parameters = pmcid_annotation.get("studyParameters", []) + + total_score = self.calculate_total_score( + var_drug_ann, var_pheno_ann, var_fa_ann, study_parameters + ) + print(f"Score for pmcid {pmcid}: {total_score}") + return total_score + + def run_all(self): + benchmark_pmcids = [] + with open("persistent_data/benchmark_pmcids.txt", "r") as f: + benchmark_pmcids = f.read().splitlines() + scores = [] + for pmcid in benchmark_pmcids: + scores.append(self.run(pmcid)) + + overall_score = sum(scores) / len(scores) + print(f"Average score: {overall_score}") + return overall_score + + +if __name__ == "__main__": + benchmark = AnnotationBenchmark() + benchmark.run_all() diff --git a/src/utils.py b/src/utils.py index 5076921..5f0ea3d 100644 --- a/src/utils.py +++ b/src/utils.py @@ -5,10 +5,19 @@ from termcolor import colored from src.article_parser import MarkdownParser from pydantic import BaseModel, ValidationError +from pathlib import Path _true_variant_cache: Optional[dict] = None +def get_pmcid_annotation( + pmcid: str, annotations_by_pmcid: Path = Path("data/annotations_by_pmcid.json") +) -> dict: + with open(annotations_by_pmcid, "r") as f: + annotations_by_pmcid = json.load(f) + return annotations_by_pmcid.get(pmcid, {}) + + def extractVariantsRegex(text): # Note, seems to extract a ton of variants, not just the ones that are being studied # Think it might only be applicable to rsIDs @@ -79,7 +88,7 @@ def compare_lists( return true_positives, true_negatives, false_positives, false_negatives -def get_true_variants(pmcid: str) -> List[str]: +def get_true_variants(pmcid: str, annotations_by_pmcid: Path) -> List[str]: """ Get the actual annotated variants for a given PMCID. Uses module-level caching to load the JSON file only once. @@ -88,7 +97,7 @@ def get_true_variants(pmcid: str) -> List[str]: if _true_variant_cache is None: try: - with open("data/benchmark/true_variant_list.json", "r") as f: + with open(annotations_by_pmcid, "r") as f: _true_variant_cache = json.load(f) except FileNotFoundError: logger.error(