diff --git a/src/benchmark/annotation_benchmark.py b/src/benchmark/annotation_benchmark.py index e251c7e..b4dd6be 100644 --- a/src/benchmark/annotation_benchmark.py +++ b/src/benchmark/annotation_benchmark.py @@ -1,5 +1,7 @@ from typing import List +import json from src.utils import get_pmcid_annotation +from src.benchmark.pheno_benchmark import evaluate_phenotype_annotations from src.benchmark.fa_benchmark import evaluate_functional_analysis from src.benchmark.drug_benchmark import evaluate_drug_annotations @@ -15,8 +17,31 @@ def get_var_drug_ann_score(self, var_drug_ann: List[dict]): except Exception: return 1.0 - def get_var_pheno_ann_score(self, var_pheno_ann: List[dict]): - return 1.0 + def get_var_pheno_ann_score(self, var_pheno_ann: List[dict], pmcid: str): + # Load ground truth annotations + with open("persistent_data/benchmark_annotations.json", "r") as f: + ground_truth_data = json.load(f) + + # Get ground truth for this PMCID + if pmcid not in ground_truth_data: + return 0.0 + + ground_truth_pheno_ann = ground_truth_data[pmcid].get("var_pheno_ann", []) + + # If both are empty, perfect score + if not var_pheno_ann and not ground_truth_pheno_ann: + return 1.0 + + # If one is empty but not the other, score is 0 + if not var_pheno_ann or not ground_truth_pheno_ann: + return 0.0 + + # Compare: [ground_truth, prediction] + try: + score = evaluate_phenotype_annotations([ground_truth_pheno_ann, var_pheno_ann]) + return score / 100.0 + except Exception: + return 0.0 def get_var_fa_ann_score(self, var_fa_ann: List[dict]): try: @@ -34,11 +59,12 @@ def calculate_total_score( var_pheno_ann: List[dict], var_fa_ann: List[dict], study_parameters: List[dict], + pmcid: str, ): # Return average of all scores scores = [ self.get_var_drug_ann_score(var_drug_ann), - self.get_var_pheno_ann_score(var_pheno_ann), + self.get_var_pheno_ann_score(var_pheno_ann, pmcid), self.get_var_fa_ann_score(var_fa_ann), self.get_study_parameters_score(study_parameters), ] @@ -53,7 +79,7 @@ def run(self, pmcid: str): study_parameters = pmcid_annotation.get("studyParameters", []) total_score = self.calculate_total_score( - var_drug_ann, var_pheno_ann, var_fa_ann, study_parameters + var_drug_ann, var_pheno_ann, var_fa_ann, study_parameters, pmcid ) print(f"Score for pmcid {pmcid}: {total_score}") return total_score diff --git a/src/benchmark/pheno_benchmark.py b/src/benchmark/pheno_benchmark.py new file mode 100644 index 0000000..3821209 --- /dev/null +++ b/src/benchmark/pheno_benchmark.py @@ -0,0 +1,256 @@ +from typing import List, Dict, Any, Tuple, Set +from dataclasses import dataclass +import re + + +class PhenotypeAnnotationBenchmark: + """Benchmark for evaluating phenotype annotation predictions.""" + + # Fields to compare (excluding metadata fields) + CORE_FIELDS = [ + "Variant/Haplotypes", + "Gene", + "Drug(s)", + "Phenotype Category", + "Alleles", + "Is/Is Not associated", + "Direction of effect", + "Phenotype", + "When treated with/exposed to/when assayed with", + "Comparison Allele(s) or Genotype(s)", + ] + + # Fields with weighted importance + FIELD_WEIGHTS = { + "Phenotype": 2.0, + "Drug(s)": 1.5, + "Direction of effect": 2.0, + "Alleles": 1.5, + "Is/Is Not associated": 1.0, + "Variant/Haplotypes": 1.0, + "Gene": 1.0, + "Phenotype Category": 0.5, + "When treated with/exposed to/when assayed with": 0.5, + "Comparison Allele(s) or Genotype(s)": 1.0, + } + + def __init__(self, matching_threshold: float = 0.7): + """ + Initialize benchmark. + + Args: + matching_threshold: Minimum score to consider a match (0-1) + """ + self.matching_threshold = matching_threshold + + def _normalize_value(self, value: Any) -> str: + """Normalize a field value for comparison.""" + if value is None: + return "" + + # Convert to string and normalize + s = str(value).lower().strip() + + # Remove extra whitespace + s = re.sub(r'\s+', ' ', s) + + # Remove punctuation variations + s = re.sub(r'[,;]+', '', s) + + return s + + def _compare_field(self, pred_value: Any, gt_value: Any) -> float: + """ + Compare two field values and return similarity score (0-1). + + Args: + pred_value: Predicted value + gt_value: Ground truth value + + Returns: + Similarity score between 0 and 1 + """ + pred_norm = self._normalize_value(pred_value) + ground_truth_norm = self._normalize_value(gt_value) + + # Both empty or None + if not pred_norm and not ground_truth_norm: + return 1.0 + + # One empty, one not + if not pred_norm or not ground_truth_norm: + return 0.0 + + # Exact match + if pred_norm == ground_truth_norm: + return 1.0 + + # Check if one contains the other (useful for partial matches) + if pred_norm in ground_truth_norm or ground_truth_norm in pred_norm: + return 0.8 + + #The Jaccard index is particularly useful when the presence or absence of elements + # in the sets is more important than their frequency or order. + # could be used to help check for multiple entries put in one annotation? + pred_tokens = set(pred_norm.split()) + gt_tokens = set(ground_truth_norm.split()) + + if pred_tokens and gt_tokens: + intersection = len(pred_tokens & gt_tokens) + union = len(pred_tokens | gt_tokens) + jaccard = intersection / union if union > 0 else 0.0 + return jaccard + + return 0.0 + + def _compare_annotations(self, pred: Dict[str, Any], gt: Dict[str, Any]) -> float: + """ + Compare a predicted annotation with a ground truth annotation. + + Args: + pred: Predicted annotation + gt: Ground truth annotation + + Returns: + Float ranging from 0 - 1 denoting similarity + """ + field_scores = {} + weighted_sum = 0.0 + total_weight = 0.0 + + for field in self.CORE_FIELDS: + weight = self.FIELD_WEIGHTS.get(field, 1.0) + similarity = self._compare_field(pred.get(field), gt.get(field)) + + field_scores[field] = similarity + weighted_sum += similarity * weight + total_weight += weight + + # Calculate weighted average + matching_score = weighted_sum / total_weight + + return matching_score + + def _find_best_matches( + self, + predictions: List[Dict[str, Any]], + ground_truths: List[Dict[str, Any]] + ) -> List[Tuple[int, int, float]]: + """ + Find best matches between predictions and ground truths. + + Returns: + List of (pred_idx, gt_idx, score) tuples sorted by score descending + """ + matches = [] + + for pred_idx, pred in enumerate(predictions): + for gt_idx, gt in enumerate(ground_truths): + match_score = self._compare_annotations(pred, gt) + if match_score >= self.matching_threshold: + matches.append((pred_idx, gt_idx, match_score)) + + # Sort by score descending + matches.sort(key=lambda x: x[2], reverse=True) + + return matches + + def evaluate( + self, + samples: List[Any] + ) -> float: + """ + Evaluate predictions against ground truths and return similarity score. + + Handles both single annotation pairs and lists of annotations. + + Args: + samples: List with exactly 2 items: + - [ground_truth_dict, prediction_dict] for single comparison + - [ground_truth_list, prediction_list] for multiple comparisons + + Returns: + Similarity score between 0 and 1 + """ + if not isinstance(samples, list) or len(samples) != 2: + raise ValueError("Expected a list with exactly two items: [ground_truth, prediction].") + + gt, pred = samples[0], samples[1] + + # Normalize to lists + if isinstance(gt, dict) and isinstance(pred, dict): + # Single annotation pair + gt_list = [gt] + pred_list = [pred] + elif isinstance(gt, list) and isinstance(pred, list): + # Multiple annotations + gt_list = gt + pred_list = pred + else: + raise ValueError("Both items must be either dicts or lists: [ground_truth, prediction].") + + if not gt_list or not pred_list: + return 0.0 + + # Find all potential matches + all_matches = self._find_best_matches(pred_list, gt_list) + + # Greedily assign matches (allowing many-to-one mapping) + matched_preds: Set[int] = set() + matched_gts: Set[int] = set() + match_scores = [] + + for pred_idx, gt_idx, score in all_matches: + # Allow multiple predictions to match same ground truth (many-to-one) + # but each prediction can only match once (one-to-one from pred side) + if pred_idx not in matched_preds: + matched_preds.add(pred_idx) + matched_gts.add(gt_idx) + match_scores.append(score) + + # Calculate average similarity across all ground truths + # Matched GTs contribute their match score + # Unmatched GTs contribute 0 + total_score = sum(match_scores) + total_possible = len(gt_list) + + return total_score / total_possible + + +def evaluate_phenotype_annotations( + samples: List[Any], + matching_threshold: float = 0.7 +) -> float: + """ + Benchmark phenotype annotations and return an aggregate similarity score. + + Handles both single annotation pairs and lists of annotations. + + Args: + samples: List with exactly 2 items: + - [ground_truth_dict, prediction_dict] for single comparison + - [ground_truth_list, prediction_list] for multiple comparisons + matching_threshold: Minimum similarity score to consider a match (0-1) + + Returns: + Similarity score between 0-100 representing how well prediction(s) + match ground truth(s) across all fields. + + Examples: + # Single annotation pair + >>> ground_truth = {"Phenotype": "sensitivity", "Drug(s)": "etoposide", ...} + >>> prediction = {"Phenotype": "sensitivity", "Drug(s)": "etoposide", ...} + >>> score = benchmark_phenotype_annotations([ground_truth, prediction]) + >>> print(f"Model Score: {score:.1f}/100") + + # Multiple annotations + >>> ground_truths = [gt1, gt2, gt3] + >>> predictions = [pred1, pred2] + >>> score = benchmark_phenotype_annotations([ground_truths, predictions]) + >>> print(f"Model Score: {score:.1f}/100") + """ + benchmark = PhenotypeAnnotationBenchmark(matching_threshold=matching_threshold) + similarity = benchmark.evaluate(samples) + + # Return as 0-100 scale + return similarity * 100