Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added src/benchmark/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions src/benchmark/annotation_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List
from src.utils import get_pmcid_annotation


class AnnotationBenchmark:
def __init__(self):
pass

def get_var_drug_ann_score(self, var_drug_ann: List[dict]):
return 1.0

def get_var_pheno_ann_score(self, var_pheno_ann: List[dict]):
return 1.0

def get_var_fa_ann_score(self, var_fa_ann: List[dict]):
return 1.0

def get_study_parameters_score(self, study_parameters: List[dict]):
return 1.0

def calculate_total_score(
self,
var_drug_ann: List[dict],
var_pheno_ann: List[dict],
var_fa_ann: List[dict],
study_parameters: List[dict],
):
# Return average of all scores
scores = [
self.get_var_drug_ann_score(var_drug_ann),
self.get_var_pheno_ann_score(var_pheno_ann),
self.get_var_fa_ann_score(var_fa_ann),
self.get_study_parameters_score(study_parameters),
]
return sum(scores) / len(scores)

def run(self, pmcid: str):
pmcid_annotation = get_pmcid_annotation(pmcid)

var_drug_ann = pmcid_annotation.get("varDrugAnn", [])
var_pheno_ann = pmcid_annotation.get("varPhenoAnn", [])
var_fa_ann = pmcid_annotation.get("varFaAnn", [])
study_parameters = pmcid_annotation.get("studyParameters", [])

total_score = self.calculate_total_score(
var_drug_ann, var_pheno_ann, var_fa_ann, study_parameters
)
print(f"Score for pmcid {pmcid}: {total_score}")
return total_score

def run_all(self):
benchmark_pmcids = []
with open("persistent_data/benchmark_pmcids.txt", "r") as f:
benchmark_pmcids = f.read().splitlines()
scores = []
for pmcid in benchmark_pmcids:
scores.append(self.run(pmcid))

overall_score = sum(scores) / len(scores)
print(f"Average score: {overall_score}")
return overall_score


if __name__ == "__main__":
benchmark = AnnotationBenchmark()
benchmark.run_all()
13 changes: 11 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@
from termcolor import colored
from src.article_parser import MarkdownParser
from pydantic import BaseModel, ValidationError
from pathlib import Path

_true_variant_cache: Optional[dict] = None


def get_pmcid_annotation(
pmcid: str, annotations_by_pmcid: Path = Path("data/annotations_by_pmcid.json")
) -> dict:
with open(annotations_by_pmcid, "r") as f:
annotations_by_pmcid = json.load(f)
return annotations_by_pmcid.get(pmcid, {})


def extractVariantsRegex(text):
# Note, seems to extract a ton of variants, not just the ones that are being studied
# Think it might only be applicable to rsIDs
Expand Down Expand Up @@ -79,7 +88,7 @@ def compare_lists(
return true_positives, true_negatives, false_positives, false_negatives


def get_true_variants(pmcid: str) -> List[str]:
def get_true_variants(pmcid: str, annotations_by_pmcid: Path) -> List[str]:
"""
Get the actual annotated variants for a given PMCID.
Uses module-level caching to load the JSON file only once.
Expand All @@ -88,7 +97,7 @@ def get_true_variants(pmcid: str) -> List[str]:

if _true_variant_cache is None:
try:
with open("data/benchmark/true_variant_list.json", "r") as f:
with open(annotations_by_pmcid, "r") as f:
_true_variant_cache = json.load(f)
except FileNotFoundError:
logger.error(
Expand Down
Loading