diff --git a/requirements.txt b/requirements.txt index da8cd9b..3b8c18e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) # SPDX-License-Identifier: Apache-2.0 +sentence-transformers>=3.0.0 diff --git a/src/benchmark/annotation_benchmark.py b/src/benchmark/annotation_benchmark.py index 0f54681..e251c7e 100644 --- a/src/benchmark/annotation_benchmark.py +++ b/src/benchmark/annotation_benchmark.py @@ -1,5 +1,7 @@ from typing import List from src.utils import get_pmcid_annotation +from src.benchmark.fa_benchmark import evaluate_functional_analysis +from src.benchmark.drug_benchmark import evaluate_drug_annotations class AnnotationBenchmark: @@ -7,13 +9,21 @@ def __init__(self): pass def get_var_drug_ann_score(self, var_drug_ann: List[dict]): - return 1.0 + try: + result = evaluate_drug_annotations(var_drug_ann) + return float(result.get("overall_score", 0.0)) + except Exception: + return 1.0 def get_var_pheno_ann_score(self, var_pheno_ann: List[dict]): return 1.0 def get_var_fa_ann_score(self, var_fa_ann: List[dict]): - return 1.0 + try: + result = evaluate_functional_analysis(var_fa_ann) + return float(result.get("overall_score", 0.0)) + except Exception: + return 1.0 def get_study_parameters_score(self, study_parameters: List[dict]): return 1.0 diff --git a/src/benchmark/drug_benchmark.py b/src/benchmark/drug_benchmark.py new file mode 100644 index 0000000..49b274a --- /dev/null +++ b/src/benchmark/drug_benchmark.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List, Any, Optional, Tuple +from difflib import SequenceMatcher +import numpy as np +import re +from sentence_transformers import SentenceTransformer + + +_model: Optional[SentenceTransformer] = None + + +def _get_model() -> SentenceTransformer: + global _model + if _model is None: + _model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO") + return _model + + +def evaluate_drug_annotations(samples: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Parallel benchmark for drug entries. + Input is a list with exactly two dicts: + - samples[0] = ground truth annotation dict + - samples[1] = prediction annotation dict + """ + + if not isinstance(samples, list) or len(samples) != 2: + raise ValueError("Expected a list with exactly two dicts: [ground_truth, prediction].") + gt, pred = samples[0], samples[1] + if not isinstance(gt, dict) or not isinstance(pred, dict): + raise ValueError("Both items must be dicts: [ground_truth_dict, prediction_dict].") + + # Variant expansion and alignment (mirroring FA) + def parse_variant_list(variants_text: Optional[str]) -> List[str]: + if not variants_text: + return [] + tokens = re.split(r'[,;|\s]+(?:\+\s*)?', variants_text) + return [t.strip() for t in tokens if t and t.strip()] + + def normalize_variant(variant: str) -> str: + v = variant.strip() + if v.lower().startswith('rs'): + return v.lower() + return re.sub(r'\s+', '', v) + + def expand_annotations_by_variant(annotations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + expanded: List[Dict[str, Any]] = [] + for ann in annotations: + variants_field = ann.get('Variant/Haplotypes') + tokens = parse_variant_list(variants_field) + if len(tokens) <= 1: + expanded.append(ann) + continue + for tok in tokens: + new_ann = dict(ann) + new_ann['Variant/Haplotypes'] = normalize_variant(tok) + new_ann['_expanded_from_multi_variant'] = True + expanded.append(new_ann) + return expanded + + def align_by_variant( + ground_truth_list: List[Dict[str, Any]], + predictions_list: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[str]]: + rs_re = re.compile(r"rs\d+", re.IGNORECASE) + gt_expanded = expand_annotations_by_variant(ground_truth_list or []) + pred_expanded = expand_annotations_by_variant(predictions_list or []) + + pred_index: List[Tuple[set, str, Dict[str, Any]]] = [] + for rec in pred_expanded: + raw = (rec.get('Variant/Haplotypes') or '').strip() + raw_norm = normalize_variant(raw).lower() + rsids = set(m.group(0).lower() for m in rs_re.finditer(raw)) + pred_index.append((rsids, raw_norm, rec)) + + aligned_gt: List[Dict[str, Any]] = [] + aligned_pred: List[Dict[str, Any]] = [] + display_keys: List[str] = [] + + for gt_rec in gt_expanded: + gt_raw = (gt_rec.get('Variant/Haplotypes') or '').strip() + gt_norm = normalize_variant(gt_raw).lower() + gt_rs = set(m.group(0).lower() for m in rs_re.finditer(gt_raw)) + + match = None + if gt_rs: + for rsids, raw_norm, pred_rec in pred_index: + if rsids & gt_rs: + match = pred_rec + break + if match is None and gt_norm: + for rsids, raw_norm, pred_rec in pred_index: + if gt_norm in raw_norm: + match = pred_rec + break + + if match is not None: + aligned_gt.append(gt_rec) + aligned_pred.append(match) + disp = next(iter(gt_rs)) if gt_rs else gt_norm + display_keys.append(disp) + + return aligned_gt, aligned_pred, display_keys + + # Prepare lists and align + gt_list_raw: List[Dict[str, Any]] = [gt] + pred_list_raw: List[Dict[str, Any]] = [pred] + gt_list, pred_list, _ = align_by_variant(gt_list_raw, pred_list_raw) + if not gt_list: + # nothing aligned; return empty result structure + return {'total_samples': 0, 'field_scores': {}, 'overall_score': 0.0, 'detailed_results': []} + + model = _get_model() + + def normalize_drug_name(name: str) -> str: + # lowercase, strip, collapse whitespace, standardize separators + n = name.lower().strip() + # normalize hyphens to space and standardize slashes + n = n.replace('-', ' ') + # replace slashes with space-slash-space to normalize + n = n.replace('/', ' / ') + n = re.sub(r"\s+", " ", n) + return n + + def parse_drug_list(value: Optional[str]) -> List[str]: + if not value: + return [] + # split on commas or the word 'or' (for individual tokens), keep slashes inside tokens + parts = re.split(r",|\bor\b", value, flags=re.IGNORECASE) + tokens = [] + for part in parts: + token = normalize_drug_name(part) + if token: + tokens.append(token) + # remove empties and duplicates while preserving order + seen = set() + unique = [] + for t in tokens: + if t and t not in seen: + seen.add(t) + unique.append(t) + return unique + + def exact_match(gt_val: Any, pred_val: Any) -> float: + if gt_val is None and pred_val is None: + return 1.0 + if gt_val is None or pred_val is None: + return 0.0 + return 1.0 if str(gt_val).strip().lower() == str(pred_val).strip().lower() else 0.0 + + def semantic_similarity(gt_val: Any, pred_val: Any) -> float: + if gt_val is None and pred_val is None: + return 1.0 + if gt_val is None or pred_val is None: + return 0.0 + gt_str = str(gt_val).strip() + pred_str = str(pred_val).strip() + if gt_str == pred_str: + return 1.0 + try: + embeddings = model.encode([gt_str, pred_str]) + gt_embedding = embeddings[0] + pred_embedding = embeddings[1] + similarity = float( + np.dot(gt_embedding, pred_embedding) + / (np.linalg.norm(gt_embedding) * np.linalg.norm(pred_embedding)) + ) + return similarity + except Exception: + return SequenceMatcher(None, gt_str.lower(), pred_str.lower()).ratio() + + def variant_substring_match(gt_val: Any, pred_val: Any) -> float: + """Return 1.0 if GT substring appears in prediction (case-insensitive). + Also accept canonical gene-phenotype group labels via category equality. + """ + if gt_val is None and pred_val is None: + return 1.0 + if gt_val is None or pred_val is None: + return 0.0 + gt_str = re.sub(r"\s+", " ", str(gt_val).strip().lower()) + pred_str = re.sub(r"\s+", " ", str(pred_val).strip().lower()) + # canonical phenotype-group labels that may appear in this field + phenotype_group_labels = { + 'poor metabolizers', + 'intermediate metabolizers', + 'extensive metabolizers', + 'ultra-rapid metabolizers', + 'intermediate activity', + 'reduced function', + 'normal function', + } + if gt_str in phenotype_group_labels or pred_str in phenotype_group_labels: + return 1.0 if gt_str == pred_str else 0.0 + if not gt_str: + return 1.0 if not pred_str else 0.0 + return 1.0 if gt_str in pred_str else 0.0 + + def parse_allele_tokens(text: Optional[str]) -> List[str]: + if not text: + return [] + # split on '+' and whitespace and commas/semicolons + parts = re.split(r"[+/,;\s]+", str(text)) + tokens = [p.strip().lower() for p in parts if p and p.strip()] + return tokens + + def alleles_set_coverage(gt_val: Any, pred_val: Any) -> float: + """Order-insensitive coverage for allele/group fields. + Scores fraction of GT tokens present in Pred tokens (1.0 if both empty). + """ + gt_tokens = parse_allele_tokens(gt_val) + pred_tokens = parse_allele_tokens(pred_val) + if not gt_tokens and not pred_tokens: + return 1.0 + if not gt_tokens or not pred_tokens: + return 0.0 + pred_set = set(pred_tokens) + covered = sum(1 for t in gt_tokens if t in pred_set) + return covered / len(gt_tokens) + + def drugs_coverage(gt: Dict[str, Any], pred: Dict[str, Any]) -> float: + """Operator-aware coverage for Drug(s). + Uses `Multiple drugs And/or` to decide coverage rule. Defaults to 'or' if missing. + """ + gt_drugs_raw = gt.get('Drug(s)') + pred_drugs_raw = pred.get('Drug(s)') + gt_tokens = parse_drug_list(gt_drugs_raw) + pred_tokens = parse_drug_list(pred_drugs_raw) + if not gt_tokens and not pred_tokens: + return 1.0 + if not gt_tokens or not pred_tokens: + return 0.0 + + operator = (gt.get('Multiple drugs And/or') or pred.get('Multiple drugs And/or') or 'or').strip().lower() + + def token_match(g: str, p: str) -> bool: + if g == p: + return True + # allow strong fuzzy match for short drug tokens + return SequenceMatcher(None, g, p).ratio() >= 0.85 + + # build coverage matrix + covered = [False] * len(gt_tokens) + for i, gtok in enumerate(gt_tokens): + for ptok in pred_tokens: + if token_match(gtok, ptok): + covered[i] = True + break + + if operator == 'and': + return 1.0 if all(covered) else sum(1 for c in covered if c) / len(covered) + # default 'or': partial credit by fraction covered, with minimum of whether any matched + frac = sum(1 for c in covered if c) / len(covered) + return max(frac, 1.0 if any(covered) else 0.0) + + # Map evaluators to drug schema fields; Drug(s) handled separately below. + def category_equal(a: Any, b: Any) -> float: + a_norm = (re.sub(r"\s+", " ", str(a).strip().lower()) if a is not None else None) + b_norm = (re.sub(r"\s+", " ", str(b).strip().lower()) if b is not None else None) + if a_norm is None and b_norm is None: + return 1.0 + if a_norm is None or b_norm is None: + return 0.0 + return 1.0 if a_norm == b_norm else 0.0 + + field_evaluators = { + 'Variant/Haplotypes': variant_substring_match, + 'Gene': semantic_similarity, + 'PMID': exact_match, + 'Phenotype Category': category_equal, + 'Significance': category_equal, + 'Alleles': alleles_set_coverage, + 'Specialty Population': semantic_similarity, + 'Metabolizer types': semantic_similarity, + 'isPlural': category_equal, + 'Is/Is Not associated': category_equal, + 'Direction of effect': category_equal, + 'PD/PK terms': semantic_similarity, + 'Multiple drugs And/or': category_equal, + 'Population types': semantic_similarity, + 'Population Phenotypes or diseases': semantic_similarity, + 'Multiple phenotypes or diseases And/or': category_equal, + 'Comparison Allele(s) or Genotype(s)': alleles_set_coverage, + 'Comparison Metabolizer types': semantic_similarity, + } + + results: Dict[str, Any] = {'total_samples': len(gt_list), 'field_scores': {}, 'overall_score': 0.0} + + for field, evaluator in field_evaluators.items(): + scores: List[float] = [] + for g, p in zip(gt_list, pred_list): + scores.append(evaluator(g.get(field), p.get(field))) + results['field_scores'][field] = {'mean_score': sum(scores) / len(scores), 'scores': scores} + + # Compute Drug(s) scores with operator-aware logic + drug_scores: List[float] = [] + for g, p in zip(gt_list, pred_list): + drug_scores.append(drugs_coverage(g, p)) + results['field_scores']['Drug(s)'] = {'mean_score': sum(drug_scores) / len(drug_scores), 'scores': drug_scores} + + results['detailed_results'] = [] + for i, (g, p) in enumerate(zip(gt_list, pred_list)): + sample_result: Dict[str, Any] = {'sample_id': i, 'field_scores': {}} + for field, evaluator in field_evaluators.items(): + sample_result['field_scores'][field] = evaluator(g.get(field), p.get(field)) + sample_result['field_scores']['Drug(s)'] = drugs_coverage(g, p) + # No dependency penalties wired yet for drug entries; can be added later if needed + sample_result['dependency_issues'] = [] + results['detailed_results'].append(sample_result) + + for field in list(field_evaluators.keys()) + ['Drug(s)']: + field_scores = [s['field_scores'][field] for s in results['detailed_results']] + results['field_scores'][field] = {'mean_score': sum(field_scores) / len(field_scores), 'scores': field_scores} + + field_means = [v['mean_score'] for v in results['field_scores'].values()] + results['overall_score'] = sum(field_means) / len(field_means) if field_means else 0.0 + return results + + diff --git a/src/benchmark/fa_benchmark.py b/src/benchmark/fa_benchmark.py new file mode 100644 index 0000000..cb81080 --- /dev/null +++ b/src/benchmark/fa_benchmark.py @@ -0,0 +1,383 @@ +from typing import Dict, List, Any, Optional, Tuple +from difflib import SequenceMatcher +import numpy as np +import re +from sentence_transformers import SentenceTransformer + + +_model: Optional[SentenceTransformer] = None + + +def _get_model() -> SentenceTransformer: + global _model + if _model is None: + _model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO") + return _model + + +def parse_variant_list(variants_text: Optional[str]) -> List[str]: + if not variants_text: + return [] + tokens = re.split(r'[,;|\s]+(?:\+\s*)?', variants_text) + return [t.strip() for t in tokens if t and t.strip()] + + +def normalize_variant(variant: str) -> str: + v = variant.strip() + if v.lower().startswith('rs'): + return v.lower() + return re.sub(r'\s+', '', v) + + +def expand_annotations_by_variant(annotations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + expanded: List[Dict[str, Any]] = [] + for ann in annotations: + variants_field = ann.get('Variant/Haplotypes') + tokens = parse_variant_list(variants_field) + if len(tokens) <= 1: + expanded.append(ann) + continue + for tok in tokens: + new_ann = dict(ann) + new_ann['Variant/Haplotypes'] = normalize_variant(tok) + new_ann['_expanded_from_multi_variant'] = True + expanded.append(new_ann) + return expanded + + +def align_fa_annotations_by_variant( + ground_truth_fa: List[Dict[str, Any]], + predictions_fa: List[Dict[str, Any]], +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[str]]: + """ + Align FA annotations by variant string with robust matching: + 1) Expand multi-variant records to one per variant + 2) Prefer rsID intersection; fallback to normalized substring containment + Returns aligned (gt_list, pred_list, display_keys) + """ + rs_re = re.compile(r"rs\d+", re.IGNORECASE) + + gt_expanded = expand_annotations_by_variant(ground_truth_fa or []) + pred_expanded = expand_annotations_by_variant(predictions_fa or []) + + pred_index: List[Tuple[set, str, Dict[str, Any]]] = [] + for rec in pred_expanded: + raw = (rec.get('Variant/Haplotypes') or '').strip() + raw_norm = normalize_variant(raw).lower() + rsids = set(m.group(0).lower() for m in rs_re.finditer(raw)) + pred_index.append((rsids, raw_norm, rec)) + + aligned_gt: List[Dict[str, Any]] = [] + aligned_pred: List[Dict[str, Any]] = [] + display_keys: List[str] = [] + + for gt_rec in gt_expanded: + gt_raw = (gt_rec.get('Variant/Haplotypes') or '').strip() + gt_norm = normalize_variant(gt_raw).lower() + gt_rs = set(m.group(0).lower() for m in rs_re.finditer(gt_raw)) + + match = None + if gt_rs: + for rsids, raw_norm, pred_rec in pred_index: + if rsids & gt_rs: + match = pred_rec + break + if match is None and gt_norm: + for rsids, raw_norm, pred_rec in pred_index: + if gt_norm in raw_norm: + match = pred_rec + break + + if match is not None: + aligned_gt.append(gt_rec) + aligned_pred.append(match) + disp = next(iter(gt_rs)) if gt_rs else gt_norm + display_keys.append(disp) + + return aligned_gt, aligned_pred, display_keys + +def evaluate_fa_from_articles( + ground_truth_article: Dict[str, Any], + predictions_article: Dict[str, Any], +) -> Dict[str, Any]: + """ + Given two article dicts with var_fa_ann lists, align by variant and evaluate. + Returns standard results plus results['aligned_variants'] and a 'status'. + """ + gt_fa = (ground_truth_article or {}).get('var_fa_ann', []) or [] + pred_fa = (predictions_article or {}).get('var_fa_ann', []) or [] + + if not gt_fa or not pred_fa: + return { + 'total_samples': 0, + 'field_scores': {}, + 'overall_score': 0.0, + 'detailed_results': [], + 'aligned_variants': [], + 'status': 'missing_var_fa_ann', + } + + aligned_gt, aligned_pred, display = align_fa_annotations_by_variant(gt_fa, pred_fa) + if not aligned_gt: + return { + 'total_samples': 0, + 'field_scores': {}, + 'overall_score': 0.0, + 'detailed_results': [], + 'aligned_variants': [], + 'status': 'no_overlap_after_alignment', + } + + results = _evaluate_functional_analysis_pairs(aligned_gt, aligned_pred, None) + results['aligned_variants'] = display + results['status'] = 'ok' + return results + + +def validate_external_data(annotation: Dict[str, Any]) -> List[str]: + issues: List[str] = [] + rsid_pattern = re.compile(r'^rs\d+$', re.IGNORECASE) + star_allele_pattern = re.compile(r'^[A-Z0-9]+\*\d+$') + gene_pattern = re.compile(r'^[A-Z0-9]+$') + + variants = annotation.get('Variant/Haplotypes', '') + if variants: + variant_list = re.split(r'[,;|\s]+(?:\+\s*)?', variants) + for variant in (v.strip() for v in variant_list if v.strip()): + if variant.lower().startswith('rs'): + if not rsid_pattern.match(variant): + issues.append(f"Invalid rsID format: {variant}") + elif '*' in variant: + if not star_allele_pattern.match(variant): + issues.append(f"Invalid star allele format: {variant}") + + gene = annotation.get('Gene', '') + if gene and not gene_pattern.match(gene): + issues.append(f"Invalid gene name format: {gene}") + return issues + + +def validate_field_dependencies(annotation: Dict[str, Any]) -> List[str]: + issues: List[str] = [] + gene_product = annotation.get('Gene/gene product') + gene = annotation.get('Gene') + if gene_product and not gene: + issues.append("Gene field required when Gene/gene product is specified") + + comparison_alleles = annotation.get('Comparison Allele(s) or Genotype(s)') + variants = annotation.get('Variant/Haplotypes') + if comparison_alleles and not variants: + issues.append("Variant/Haplotypes required when Comparison Allele(s) is specified") + + direction = annotation.get('Direction of effect') + association = annotation.get('Is/Is Not associated') + if direction and association != 'Associated with': + issues.append("Direction of effect requires 'Associated with' status") + + functional_terms = annotation.get('Functional terms') + gene_product = annotation.get('Gene/gene product') + if functional_terms and not gene_product: + issues.append("Gene/gene product recommended when Functional terms is specified") + return issues + + +def validate_annotation_consistency(annotation: Dict[str, Any], study_params: List[Dict[str, Any]]) -> List[str]: + issues: List[str] = [] + variant_id = annotation.get('Variant Annotation ID') + if variant_id: + found_in_study = any(sp.get('Variant Annotation ID') == variant_id for sp in study_params) + if not found_in_study: + issues.append(f"Variant Annotation ID {variant_id} not found in study_parameters") + + annotation_pmid = annotation.get('PMID') + if annotation_pmid and study_params: + study_pmids = {sp.get('PMID') for sp in study_params if sp.get('PMID')} + if study_pmids and annotation_pmid not in study_pmids: + issues.append(f"PMID {annotation_pmid} inconsistent with study parameters") + return issues + + +def validate_all_dependencies(annotation: Dict[str, Any], study_params: Optional[List[Dict[str, Any]]] = None) -> List[str]: + issues: List[str] = [] + issues.extend(validate_external_data(annotation)) + issues.extend(validate_field_dependencies(annotation)) + if study_params: + issues.extend(validate_annotation_consistency(annotation, study_params)) + return issues + + +def evaluate_functional_analysis(samples: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Evaluate FA when provided a list with exactly two dicts: + - samples[0] = ground truth annotation dict + - samples[1] = prediction annotation dict + + Args: + samples: [ground_truth_dict, prediction_dict] + + Returns: + Dict with overall and per-field scores. + """ + + if not isinstance(samples, list) or len(samples) != 2: + raise ValueError("Expected a list with exactly two dicts: [ground_truth, prediction].") + gt, pred = samples[0], samples[1] + if not isinstance(gt, dict) or not isinstance(pred, dict): + raise ValueError("Both items must be dicts: [ground_truth_dict, prediction_dict].") + + gt_list: List[Dict[str, Any]] = [gt] + pred_list: List[Dict[str, Any]] = [pred] + return _evaluate_functional_analysis_pairs(gt_list, pred_list, None) + + +def _evaluate_functional_analysis_pairs( + gt_list: List[Dict[str, Any]], + pred_list: List[Dict[str, Any]], + study_parameters: Optional[List[Dict[str, Any]]], +) -> Dict[str, Any]: + model = _get_model() + + def exact_match(gt_val: Any, pred_val: Any) -> float: + if gt_val is None and pred_val is None: + return 1.0 + if gt_val is None or pred_val is None: + return 0.0 + return 1.0 if str(gt_val).strip().lower() == str(pred_val).strip().lower() else 0.0 + + def semantic_similarity(gt_val: Any, pred_val: Any) -> float: + if gt_val is None and pred_val is None: + return 1.0 + if gt_val is None or pred_val is None: + return 0.0 + gt_str = str(gt_val).strip() + pred_str = str(pred_val).strip() + if gt_str == pred_str: + return 1.0 + try: + embeddings = model.encode([gt_str, pred_str]) + gt_embedding = embeddings[0] + pred_embedding = embeddings[1] + similarity = float(np.dot(gt_embedding, pred_embedding) / (np.linalg.norm(gt_embedding) * np.linalg.norm(pred_embedding))) + return similarity + except Exception: + return SequenceMatcher(None, gt_str.lower(), pred_str.lower()).ratio() + + def variant_coverage(gt_variants: str, pred_variants: str) -> float: + if not gt_variants or not pred_variants: + return 1.0 if not gt_variants and not pred_variants else 0.0 + + def is_wildtype(variant: str) -> bool: + variant_lower = variant.lower().strip() + wildtype_indicators = ['wild type', 'wildtype', 'wt', '*1'] + return any(wt in variant_lower for wt in wildtype_indicators) + + def is_rsid(variant: str) -> bool: + return variant.strip().lower().startswith('rs') + + def is_star_allele(variant: str) -> bool: + return '*' in variant.strip() + + gt_list = [v.strip() for v in re.split(r'[,;|\s]+(?:\+\s*)?', gt_variants) if v.strip()] + pred_list = [v.strip() for v in re.split(r'[,;|\s]+(?:\+\s*)?', pred_variants) if v.strip()] + gt_list_filtered = [v for v in gt_list if not is_wildtype(v)] + if not gt_list_filtered: + return 1.0 if pred_list else 0.0 + + covered_count = 0 + for gt_var in gt_list_filtered: + gt_var_lower = gt_var.lower() + if is_rsid(gt_var) or is_star_allele(gt_var): + if any(gt_var_lower == pred_var.lower() for pred_var in pred_list): + covered_count += 1 + else: + best_match_score = 0.0 + for pred_var in pred_list: + if gt_var_lower in pred_var.lower() or pred_var.lower() in gt_var_lower: + best_match_score = 1.0 + break + similarity = SequenceMatcher(None, gt_var_lower, pred_var.lower()).ratio() + best_match_score = max(best_match_score, similarity) + if best_match_score > 0.8: + covered_count += 1 + return covered_count / len(gt_list_filtered) + + def variant_substring_match(gt_val: Any, pred_val: Any) -> float: + if gt_val is None and pred_val is None: + return 1.0 + if gt_val is None or pred_val is None: + return 0.0 + gt_str = str(gt_val).strip().lower() + pred_str = str(pred_val).strip().lower() + if not gt_str: + return 1.0 if not pred_str else 0.0 + return 1.0 if gt_str in pred_str else 0.0 + + field_evaluators = { + 'Variant/Haplotypes': variant_substring_match, + 'Gene': semantic_similarity, + 'Drug(s)': semantic_similarity, + 'PMID': exact_match, + 'Phenotype Category': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Significance': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Alleles': semantic_similarity, + 'Specialty Population': semantic_similarity, + 'Assay type': semantic_similarity, + 'Metabolizer types': semantic_similarity, + 'isPlural': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Is/Is Not associated': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Direction of effect': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Functional terms': semantic_similarity, + 'Gene/gene product': semantic_similarity, + 'When treated with/exposed to/when assayed with': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Multiple drugs And/or': lambda gt, pred: 1.0 if (gt and pred and gt.lower().strip() == pred.lower().strip()) else (1.0 if not gt and not pred else 0.0), + 'Cell type': semantic_similarity, + 'Comparison Allele(s) or Genotype(s)': semantic_similarity, + 'Comparison Metabolizer types': semantic_similarity + } + + results: Dict[str, Any] = {'total_samples': len(gt_list), 'field_scores': {}, 'overall_score': 0.0} + + for field, evaluator in field_evaluators.items(): + scores: List[float] = [] + for gt, pred in zip(gt_list, pred_list): + scores.append(evaluator(gt.get(field), pred.get(field))) + results['field_scores'][field] = {'mean_score': sum(scores) / len(scores), 'scores': scores} + + results['detailed_results'] = [] + for i, (gt, pred) in enumerate(zip(gt_list, pred_list)): + sample_result: Dict[str, Any] = {'sample_id': i, 'field_scores': {}} + for field, evaluator in field_evaluators.items(): + sample_result['field_scores'][field] = evaluator(gt.get(field), pred.get(field)) + dependency_issues = validate_all_dependencies(pred, study_parameters) + sample_result['dependency_issues'] = dependency_issues + if dependency_issues: + penalty_per_issue = 0.05 + total_penalty = min(len(dependency_issues) * penalty_per_issue, 0.3) + fields_to_penalize = set() + for issue in dependency_issues: + if "Gene" in issue or "gene" in issue: + fields_to_penalize.update(['Gene', 'Gene/gene product']) + elif "Variant" in issue or "variant" in issue: + fields_to_penalize.update(['Variant/Haplotypes', 'Comparison Allele(s) or Genotype(s)']) + elif "Direction" in issue or "Associated" in issue: + fields_to_penalize.update(['Direction of effect', 'Is/Is Not associated']) + elif "Functional" in issue: + fields_to_penalize.update(['Functional terms', 'Gene/gene product']) + elif "rsID" in issue or "star allele" in issue: + fields_to_penalize.add('Variant/Haplotypes') + else: + fields_to_penalize.update(sample_result['field_scores'].keys()) + for field in fields_to_penalize: + original_score = sample_result['field_scores'][field] + sample_result['field_scores'][field] = original_score * (1 - total_penalty) + results['detailed_results'].append(sample_result) + + for field in list(field_evaluators.keys()): + field_scores = [s['field_scores'][field] for s in results['detailed_results']] + results['field_scores'][field] = {'mean_score': sum(field_scores) / len(field_scores), 'scores': field_scores} + + field_means = [v['mean_score'] for v in results['field_scores'].values()] + results['overall_score'] = sum(field_means) / len(field_means) if field_means else 0.0 + return results + +