From 82db35db59d27bebb8e9155fb9e01c4997a1ce66 Mon Sep 17 00:00:00 2001 From: Arellano Tavara Date: Sun, 23 Nov 2025 13:45:06 -0600 Subject: [PATCH 1/2] Add cancer genomics datasets and tasks (ClinVar, COSMIC, TCGA-PRAD) This PR adds support for cancer genomics data to PyHealth: Datasets: - ClinVarDataset: Variant clinical significance annotations - COSMICDataset: Cancer somatic mutations catalogue - TCGAPRADDataset: TCGA Prostate Adenocarcinoma multi-omics data Tasks: - VariantClassificationClinVar: Predict pathogenic/benign variants - MutationPathogenicityPrediction: FATHMM-based mutation prediction - CancerSurvivalPrediction: Patient survival outcome prediction - CancerMutationBurden: High vs low TMB classification Features: - YAML configs for all three datasets - Helper methods for data cleaning (_safe_float, _extract_genes, etc.) - Class constants for category mappings (ACMG/AMP guidelines) - Comprehensive docstrings with examples - 43 unit tests (all passing) --- pyhealth/datasets/__init__.py | 3 + pyhealth/datasets/clinvar.py | 169 +++++++++++ pyhealth/datasets/configs/clinvar.yaml | 16 + pyhealth/datasets/configs/cosmic.yaml | 15 + pyhealth/datasets/configs/tcga_prad.yaml | 23 ++ pyhealth/datasets/cosmic.py | 197 +++++++++++++ pyhealth/datasets/tcga_prad.py | 275 ++++++++++++++++++ pyhealth/tasks/__init__.py | 5 + pyhealth/tasks/cancer_survival.py | 254 ++++++++++++++++ pyhealth/tasks/variant_classification.py | 234 +++++++++++++++ test-resources/clinvar/clinvar-pyhealth.csv | 13 + test-resources/cosmic/cosmic-pyhealth.csv | 13 + .../tcga_prad/tcga_prad_clinical-pyhealth.csv | 6 + .../tcga_prad_mutations-pyhealth.csv | 20 ++ tests/core/test_clinvar.py | 121 ++++++++ tests/core/test_cosmic.py | 134 +++++++++ tests/core/test_tcga_prad.py | 195 +++++++++++++ 17 files changed, 1693 insertions(+) create mode 100644 pyhealth/datasets/clinvar.py create mode 100644 pyhealth/datasets/configs/clinvar.yaml create mode 100644 pyhealth/datasets/configs/cosmic.yaml create mode 100644 pyhealth/datasets/configs/tcga_prad.yaml create mode 100644 pyhealth/datasets/cosmic.py create mode 100644 pyhealth/datasets/tcga_prad.py create mode 100644 pyhealth/tasks/cancer_survival.py create mode 100644 pyhealth/tasks/variant_classification.py create mode 100644 test-resources/clinvar/clinvar-pyhealth.csv create mode 100644 test-resources/cosmic/cosmic-pyhealth.csv create mode 100644 test-resources/tcga_prad/tcga_prad_clinical-pyhealth.csv create mode 100644 test-resources/tcga_prad/tcga_prad_mutations-pyhealth.csv create mode 100644 tests/core/test_clinvar.py create mode 100644 tests/core/test_cosmic.py create mode 100644 tests/core/test_tcga_prad.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index ced02afd7..7d6a65f16 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -49,6 +49,8 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset from .chestxray14 import ChestXray14Dataset +from .clinvar import ClinVarDataset +from .cosmic import COSMICDataset from .covid19_cxr import COVID19CXRDataset from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset @@ -64,6 +66,7 @@ def __init__(self, *args, **kwargs): from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset +from .tcga_prad import TCGAPRADDataset from .splitter import ( split_by_patient, split_by_patient_conformal, diff --git a/pyhealth/datasets/clinvar.py b/pyhealth/datasets/clinvar.py new file mode 100644 index 000000000..fee9c21c5 --- /dev/null +++ b/pyhealth/datasets/clinvar.py @@ -0,0 +1,169 @@ +"""ClinVar dataset for PyHealth. + +This module provides the ClinVarDataset class for loading and processing +ClinVar variant data for machine learning tasks. +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class ClinVarDataset(BaseDataset): + """ClinVar dataset for variant classification. + + ClinVar is a freely accessible, public archive of reports of the relationships + among human variations and phenotypes, with supporting evidence. This dataset + enables variant pathogenicity prediction tasks. + + Dataset is available at: + https://ftp.ncbi.nlm.nih.gov/pub/clinvar/ + + Args: + root: Root directory of the raw data containing the ClinVar files. + tables: Optional list of additional tables to load beyond defaults. + dataset_name: Optional name of the dataset. Defaults to "clinvar". + config_path: Optional path to the configuration file. If not provided, + uses the default config in the configs directory. + + Attributes: + root: Root directory of the raw data. + dataset_name: Name of the dataset. + config_path: Path to the configuration file. + + Examples: + >>> from pyhealth.datasets import ClinVarDataset + >>> dataset = ClinVarDataset(root="/path/to/clinvar") + >>> dataset.stats() + >>> samples = dataset.set_task() + >>> print(samples[0]) + """ + + def __init__( + self, + root: str, + tables: List[str] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "clinvar.yaml" + + # Prepare standardized CSV if not exists + pyhealth_csv = os.path.join(root, "clinvar-pyhealth.csv") + if not os.path.exists(pyhealth_csv): + logger.info("Preparing ClinVar metadata...") + self.prepare_metadata(root) + + default_tables = ["variants"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "clinvar", + config_path=config_path, + **kwargs, + ) + + @staticmethod + def prepare_metadata(root: str) -> None: + """Prepare metadata for the ClinVar dataset. + + Converts raw ClinVar variant_summary.txt to standardized CSV format. + + Args: + root: Root directory containing the ClinVar files. + """ + # Try to find the raw ClinVar file + possible_files = [ + "variant_summary.txt", + "variant_summary.txt.gz", + "clinvar_variant_summary.txt", + "clinvar.vcf", + ] + + raw_file = None + for fname in possible_files: + fpath = os.path.join(root, fname) + if os.path.exists(fpath): + raw_file = fpath + break + + if raw_file is None: + logger.warning( + f"No raw ClinVar file found in {root}. " + "Please download from https://ftp.ncbi.nlm.nih.gov/pub/clinvar/ " + "and place variant_summary.txt in the root directory." + ) + # Create empty placeholder + pd.DataFrame( + columns=[ + "gene_symbol", + "clinical_significance", + "review_status", + "chromosome", + "position", + "reference_allele", + "alternate_allele", + "variant_type", + "assembly", + ] + ).to_csv(os.path.join(root, "clinvar-pyhealth.csv"), index=False) + return + + logger.info(f"Processing ClinVar file: {raw_file}") + + # Read the raw file + if raw_file.endswith(".gz"): + df = pd.read_csv(raw_file, sep="\t", compression="gzip", low_memory=False) + else: + df = pd.read_csv(raw_file, sep="\t", low_memory=False) + + # Standardize column names + column_mapping = { + "GeneSymbol": "gene_symbol", + "ClinicalSignificance": "clinical_significance", + "ReviewStatus": "review_status", + "Chromosome": "chromosome", + "PositionVCF": "position", + "ReferenceAlleleVCF": "reference_allele", + "AlternateAlleleVCF": "alternate_allele", + "Type": "variant_type", + "Assembly": "assembly", + } + + # Select and rename columns that exist + available_cols = [c for c in column_mapping.keys() if c in df.columns] + df_out = df[available_cols].rename( + columns={k: v for k, v in column_mapping.items() if k in available_cols} + ) + + # Filter for GRCh38 assembly if assembly column exists + if "assembly" in df_out.columns: + df_out = df_out[df_out["assembly"] == "GRCh38"] + + # Save to standardized CSV + output_path = os.path.join(root, "clinvar-pyhealth.csv") + df_out.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_out)} variants to {output_path}") + + @property + def default_task(self): + """Returns the default task for this dataset. + + Returns: + VariantClassificationClinVar: The default classification task. + """ + from pyhealth.tasks import VariantClassificationClinVar + + return VariantClassificationClinVar() diff --git a/pyhealth/datasets/configs/clinvar.yaml b/pyhealth/datasets/configs/clinvar.yaml new file mode 100644 index 000000000..961d2ab72 --- /dev/null +++ b/pyhealth/datasets/configs/clinvar.yaml @@ -0,0 +1,16 @@ +version: "1.0" +tables: + variants: + file_path: "clinvar-pyhealth.csv" + patient_id: null + timestamp: null + attributes: + - "gene_symbol" + - "clinical_significance" + - "review_status" + - "chromosome" + - "position" + - "reference_allele" + - "alternate_allele" + - "variant_type" + - "assembly" diff --git a/pyhealth/datasets/configs/cosmic.yaml b/pyhealth/datasets/configs/cosmic.yaml new file mode 100644 index 000000000..c873b2d4a --- /dev/null +++ b/pyhealth/datasets/configs/cosmic.yaml @@ -0,0 +1,15 @@ +version: "1.0" +tables: + mutations: + file_path: "cosmic-pyhealth.csv" + patient_id: "sample_id" + timestamp: null + attributes: + - "gene_name" + - "hgvsc" + - "hgvsp" + - "mutation_description" + - "fathmm_prediction" + - "primary_site" + - "primary_histology" + - "mutation_somatic_status" diff --git a/pyhealth/datasets/configs/tcga_prad.yaml b/pyhealth/datasets/configs/tcga_prad.yaml new file mode 100644 index 000000000..ca507c7f8 --- /dev/null +++ b/pyhealth/datasets/configs/tcga_prad.yaml @@ -0,0 +1,23 @@ +version: "1.0" +tables: + mutations: + file_path: "tcga_prad_mutations-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "hugo_symbol" + - "variant_classification" + - "variant_type" + - "hgvsc" + - "hgvsp" + - "tumor_sample_barcode" + clinical: + file_path: "tcga_prad_clinical-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "age_at_diagnosis" + - "gleason_score" + - "vital_status" + - "days_to_death" + - "tumor_stage" diff --git a/pyhealth/datasets/cosmic.py b/pyhealth/datasets/cosmic.py new file mode 100644 index 000000000..570681e2f --- /dev/null +++ b/pyhealth/datasets/cosmic.py @@ -0,0 +1,197 @@ +"""COSMIC dataset for PyHealth. + +This module provides the COSMICDataset class for loading and processing +COSMIC (Catalogue Of Somatic Mutations In Cancer) data for machine learning tasks. +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class COSMICDataset(BaseDataset): + """COSMIC dataset for cancer somatic mutation analysis. + + COSMIC (Catalogue Of Somatic Mutations In Cancer) is the world's largest + and most comprehensive resource for exploring the impact of somatic + mutations in human cancer. This dataset enables mutation pathogenicity + prediction and cancer gene analysis tasks. + + Dataset is available at: + https://cancer.sanger.ac.uk/cosmic/download + + Note: + COSMIC requires registration and license agreement for data access. + + Args: + root: Root directory of the raw data containing the COSMIC files. + tables: Optional list of additional tables to load beyond defaults. + dataset_name: Optional name of the dataset. Defaults to "cosmic". + config_path: Optional path to the configuration file. If not provided, + uses the default config in the configs directory. + + Attributes: + root: Root directory of the raw data. + dataset_name: Name of the dataset. + config_path: Path to the configuration file. + + Examples: + >>> from pyhealth.datasets import COSMICDataset + >>> dataset = COSMICDataset(root="/path/to/cosmic") + >>> dataset.stats() + >>> samples = dataset.set_task() + >>> print(samples[0]) + """ + + def __init__( + self, + root: str, + tables: List[str] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "cosmic.yaml" + + # Prepare standardized CSV if not exists + pyhealth_csv = os.path.join(root, "cosmic-pyhealth.csv") + if not os.path.exists(pyhealth_csv): + logger.info("Preparing COSMIC metadata...") + self.prepare_metadata(root) + + default_tables = ["mutations"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "cosmic", + config_path=config_path, + **kwargs, + ) + + @staticmethod + def prepare_metadata(root: str) -> None: + """Prepare metadata for the COSMIC dataset. + + Converts raw COSMIC TSV/CSV files to standardized CSV format. + + Args: + root: Root directory containing the COSMIC files. + """ + # Try to find the raw COSMIC file + possible_files = [ + "CosmicMutantExportCensus.tsv", + "CosmicMutantExportCensus.tsv.gz", + "CosmicMutantExport.tsv", + "CosmicMutantExport.tsv.gz", + "cosmic_mutations.tsv", + "cosmic_mutations.csv", + ] + + raw_file = None + for fname in possible_files: + fpath = os.path.join(root, fname) + if os.path.exists(fpath): + raw_file = fpath + break + + if raw_file is None: + logger.warning( + f"No raw COSMIC file found in {root}. " + "Please download from https://cancer.sanger.ac.uk/cosmic/download " + "and place CosmicMutantExportCensus.tsv in the root directory." + ) + # Create empty placeholder + pd.DataFrame( + columns=[ + "sample_id", + "gene_name", + "hgvsc", + "hgvsp", + "mutation_description", + "fathmm_prediction", + "primary_site", + "primary_histology", + "mutation_somatic_status", + ] + ).to_csv(os.path.join(root, "cosmic-pyhealth.csv"), index=False) + return + + logger.info(f"Processing COSMIC file: {raw_file}") + + # Read the raw file + sep = "\t" if ".tsv" in raw_file else "," + if raw_file.endswith(".gz"): + df = pd.read_csv(raw_file, sep=sep, compression="gzip", low_memory=False) + else: + df = pd.read_csv(raw_file, sep=sep, low_memory=False) + + # Standardize column names (COSMIC uses various naming conventions) + column_mapping = { + "ID_SAMPLE": "sample_id", + "GENE_NAME": "gene_name", + "Gene name": "gene_name", + "HGVSC": "hgvsc", + "HGVSP": "hgvsp", + "MUTATION_DESCRIPTION": "mutation_description", + "Mutation Description": "mutation_description", + "FATHMM_PREDICTION": "fathmm_prediction", + "FATHMM prediction": "fathmm_prediction", + "PRIMARY_SITE": "primary_site", + "Primary site": "primary_site", + "PRIMARY_HISTOLOGY": "primary_histology", + "Primary histology": "primary_histology", + "MUTATION_SOMATIC_STATUS": "mutation_somatic_status", + "Mutation somatic status": "mutation_somatic_status", + } + + # Rename columns that exist + rename_dict = {k: v for k, v in column_mapping.items() if k in df.columns} + df = df.rename(columns=rename_dict) + + # Select columns that exist in our schema + output_cols = [ + "sample_id", + "gene_name", + "hgvsc", + "hgvsp", + "mutation_description", + "fathmm_prediction", + "primary_site", + "primary_histology", + "mutation_somatic_status", + ] + available_cols = [c for c in output_cols if c in df.columns] + + # If sample_id doesn't exist, create from index + if "sample_id" not in df.columns: + df["sample_id"] = df.index.astype(str) + available_cols = ["sample_id"] + [c for c in available_cols if c != "sample_id"] + + df_out = df[available_cols] + + # Save to standardized CSV + output_path = os.path.join(root, "cosmic-pyhealth.csv") + df_out.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_out)} mutations to {output_path}") + + @property + def default_task(self): + """Returns the default task for this dataset. + + Returns: + MutationPathogenicityPrediction: The default prediction task. + """ + from pyhealth.tasks import MutationPathogenicityPrediction + + return MutationPathogenicityPrediction() diff --git a/pyhealth/datasets/tcga_prad.py b/pyhealth/datasets/tcga_prad.py new file mode 100644 index 000000000..94c6f8e0e --- /dev/null +++ b/pyhealth/datasets/tcga_prad.py @@ -0,0 +1,275 @@ +"""TCGA-PRAD dataset for PyHealth. + +This module provides the TCGAPRADDataset class for loading and processing +TCGA Prostate Adenocarcinoma (PRAD) data for machine learning tasks. +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class TCGAPRADDataset(BaseDataset): + """TCGA Prostate Adenocarcinoma (PRAD) dataset. + + The Cancer Genome Atlas (TCGA) PRAD dataset contains multi-omics data + for prostate adenocarcinoma patients, including somatic mutations, + clinical data, and survival outcomes. This dataset enables cancer + survival prediction and mutation analysis tasks. + + Dataset is available at: + https://portal.gdc.cancer.gov/projects/TCGA-PRAD + + Args: + root: Root directory of the raw data containing the TCGA-PRAD files. + tables: Optional list of additional tables to load beyond defaults. + dataset_name: Optional name of the dataset. Defaults to "tcga_prad". + config_path: Optional path to the configuration file. If not provided, + uses the default config in the configs directory. + + Attributes: + root: Root directory of the raw data. + dataset_name: Name of the dataset. + config_path: Path to the configuration file. + + Examples: + >>> from pyhealth.datasets import TCGAPRADDataset + >>> dataset = TCGAPRADDataset(root="/path/to/tcga_prad") + >>> dataset.stats() + >>> samples = dataset.set_task() + >>> print(samples[0]) + """ + + def __init__( + self, + root: str, + tables: List[str] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "tcga_prad.yaml" + + # Prepare standardized CSVs if not exists + mutations_csv = os.path.join(root, "tcga_prad_mutations-pyhealth.csv") + clinical_csv = os.path.join(root, "tcga_prad_clinical-pyhealth.csv") + + if not os.path.exists(mutations_csv) or not os.path.exists(clinical_csv): + logger.info("Preparing TCGA-PRAD metadata...") + self.prepare_metadata(root) + + default_tables = ["mutations", "clinical"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "tcga_prad", + config_path=config_path, + **kwargs, + ) + + @staticmethod + def prepare_metadata(root: str) -> None: + """Prepare metadata for the TCGA-PRAD dataset. + + Converts raw TCGA MAF and clinical files to standardized CSV format. + + Args: + root: Root directory containing the TCGA-PRAD files. + """ + # Process mutations file + TCGAPRADDataset._prepare_mutations(root) + # Process clinical file + TCGAPRADDataset._prepare_clinical(root) + + @staticmethod + def _prepare_mutations(root: str) -> None: + """Prepare mutations data from MAF file.""" + # Try to find the raw mutations file + possible_files = [ + "PRAD_mutations.csv", + "TCGA.PRAD.mutect.maf", + "TCGA.PRAD.mutect.maf.gz", + "PRAD.maf", + "PRAD.maf.gz", + "mutations.maf", + ] + + raw_file = None + for fname in possible_files: + fpath = os.path.join(root, fname) + if os.path.exists(fpath): + raw_file = fpath + break + + output_path = os.path.join(root, "tcga_prad_mutations-pyhealth.csv") + + if raw_file is None: + logger.warning( + f"No raw TCGA-PRAD mutations file found in {root}. " + "Please download from GDC portal or use TCGAmutations R package." + ) + # Create empty placeholder + pd.DataFrame( + columns=[ + "patient_id", + "hugo_symbol", + "variant_classification", + "variant_type", + "hgvsc", + "hgvsp", + "tumor_sample_barcode", + ] + ).to_csv(output_path, index=False) + return + + logger.info(f"Processing TCGA-PRAD mutations file: {raw_file}") + + # Read the raw file + if raw_file.endswith(".gz"): + df = pd.read_csv( + raw_file, sep="\t", compression="gzip", comment="#", low_memory=False + ) + elif raw_file.endswith(".maf"): + df = pd.read_csv(raw_file, sep="\t", comment="#", low_memory=False) + else: + df = pd.read_csv(raw_file, low_memory=False) + + # Standardize column names + column_mapping = { + "Hugo_Symbol": "hugo_symbol", + "Variant_Classification": "variant_classification", + "Variant_Type": "variant_type", + "HGVSc": "hgvsc", + "HGVSp_Short": "hgvsp", + "HGVSp": "hgvsp", + "Tumor_Sample_Barcode": "tumor_sample_barcode", + } + + rename_dict = {k: v for k, v in column_mapping.items() if k in df.columns} + df = df.rename(columns=rename_dict) + + # Extract patient_id from tumor_sample_barcode (first 12 characters) + if "tumor_sample_barcode" in df.columns: + df["patient_id"] = df["tumor_sample_barcode"].str[:12] + else: + df["patient_id"] = df.index.astype(str) + + # Select output columns + output_cols = [ + "patient_id", + "hugo_symbol", + "variant_classification", + "variant_type", + "hgvsc", + "hgvsp", + "tumor_sample_barcode", + ] + available_cols = [c for c in output_cols if c in df.columns] + df_out = df[available_cols] + + df_out.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_out)} mutations to {output_path}") + + @staticmethod + def _prepare_clinical(root: str) -> None: + """Prepare clinical data file.""" + # Try to find the raw clinical file + possible_files = [ + "PRAD_clinical.csv", + "clinical.tsv", + "clinical.csv", + "nationwidechildrens.org_clinical_patient_prad.txt", + ] + + raw_file = None + for fname in possible_files: + fpath = os.path.join(root, fname) + if os.path.exists(fpath): + raw_file = fpath + break + + output_path = os.path.join(root, "tcga_prad_clinical-pyhealth.csv") + + if raw_file is None: + logger.warning( + f"No raw TCGA-PRAD clinical file found in {root}. " + "Please download from GDC portal." + ) + # Create empty placeholder + pd.DataFrame( + columns=[ + "patient_id", + "age_at_diagnosis", + "gleason_score", + "vital_status", + "days_to_death", + "tumor_stage", + ] + ).to_csv(output_path, index=False) + return + + logger.info(f"Processing TCGA-PRAD clinical file: {raw_file}") + + # Read the raw file + sep = "\t" if raw_file.endswith(".tsv") or raw_file.endswith(".txt") else "," + df = pd.read_csv(raw_file, sep=sep, low_memory=False) + + # Standardize column names (TCGA uses various naming conventions) + column_mapping = { + "submitter_id": "patient_id", + "bcr_patient_barcode": "patient_id", + "case_id": "patient_id", + "age_at_diagnosis": "age_at_diagnosis", + "age_at_initial_pathologic_diagnosis": "age_at_diagnosis", + "gleason_score": "gleason_score", + "primary_gleason_grade": "gleason_score", + "vital_status": "vital_status", + "days_to_death": "days_to_death", + "tumor_stage": "tumor_stage", + "ajcc_pathologic_stage": "tumor_stage", + "pathologic_stage": "tumor_stage", + } + + rename_dict = {k: v for k, v in column_mapping.items() if k in df.columns} + df = df.rename(columns=rename_dict) + + # If patient_id doesn't exist, create from index + if "patient_id" not in df.columns: + df["patient_id"] = df.index.astype(str) + + # Select output columns + output_cols = [ + "patient_id", + "age_at_diagnosis", + "gleason_score", + "vital_status", + "days_to_death", + "tumor_stage", + ] + available_cols = [c for c in output_cols if c in df.columns] + df_out = df[available_cols].drop_duplicates(subset=["patient_id"]) + + df_out.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_out)} clinical records to {output_path}") + + @property + def default_task(self): + """Returns the default task for this dataset. + + Returns: + CancerSurvivalPrediction: The default prediction task. + """ + from pyhealth.tasks import CancerSurvivalPrediction + + return CancerSurvivalPrediction() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 82c890e5e..fb3c6966a 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,5 +1,6 @@ from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot +from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .bmd_hs_disease_classification import BMDHSDiseaseClassification from .cardiology_detect import ( cardiology_isAD_fn, @@ -57,3 +58,7 @@ ) from .sleep_staging_v2 import SleepStagingSleepEDF from .temple_university_EEG_tasks import EEG_events_fn, EEG_isAbnormal_fn +from .variant_classification import ( + MutationPathogenicityPrediction, + VariantClassificationClinVar, +) diff --git a/pyhealth/tasks/cancer_survival.py b/pyhealth/tasks/cancer_survival.py new file mode 100644 index 000000000..b3e854a2f --- /dev/null +++ b/pyhealth/tasks/cancer_survival.py @@ -0,0 +1,254 @@ +"""Cancer survival prediction tasks for PyHealth. + +This module provides tasks for predicting cancer patient survival outcomes +using multi-omics data from TCGA datasets. +""" + +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + + +class CancerSurvivalPrediction(BaseTask): + """Task for predicting cancer patient survival outcomes. + + This task predicts whether a cancer patient is alive or deceased based on + their mutation profile and clinical features from TCGA datasets. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The input schema specifying required inputs. + output_schema (Dict[str, str]): The output schema specifying outputs. + VITAL_STATUS_DEAD (tuple): Values indicating deceased status. + VITAL_STATUS_ALIVE (tuple): Values indicating alive status. + + Note: + Patients without clinical data or with unknown vital status are + excluded from the output samples. + + Examples: + >>> from pyhealth.datasets import TCGAPRADDataset + >>> from pyhealth.tasks import CancerSurvivalPrediction + >>> dataset = TCGAPRADDataset(root="/path/to/tcga_prad") + >>> task = CancerSurvivalPrediction() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "CancerSurvivalPrediction" + input_schema: Dict[str, str] = { + "mutations": "sequence", + "age_at_diagnosis": "tensor", + "gleason_score": "tensor", + } + output_schema: Dict[str, str] = {"vital_status": "binary"} + + # Vital status category mappings + VITAL_STATUS_DEAD: tuple = ("dead", "deceased", "1") + VITAL_STATUS_ALIVE: tuple = ("alive", "living", "0") + + def _safe_float(self, value: Any, default: float = 0.0) -> float: + """Safely convert value to float, handling None and NaN. + + Args: + value: Value to convert. + default: Default value if conversion fails. + + Returns: + Float representation of value or default. + """ + if value is None or str(value) == "nan": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + def _extract_genes(self, mutations: List[Any]) -> List[str]: + """Extract valid gene symbols from mutation events. + + Args: + mutations: List of mutation event objects. + + Returns: + List of gene symbol strings, excluding None and NaN values. + """ + genes: List[str] = [] + for mut in mutations: + gene = getattr(mut, "hugo_symbol", None) + if gene is not None and str(gene) != "nan": + genes.append(str(gene)) + return genes + + def _parse_vital_status(self, raw_value: Any) -> Optional[int]: + """Parse vital status to binary label. + + Args: + raw_value: Raw vital status value from clinical data. + + Returns: + 1 for deceased, 0 for alive, None if value is invalid or unknown. + """ + if raw_value is None or str(raw_value) == "nan": + return None + + value_lower = str(raw_value).lower() + + if value_lower in self.VITAL_STATUS_DEAD: + return 1 + elif value_lower in self.VITAL_STATUS_ALIVE: + return 0 + + return None + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process patient mutation and clinical data for survival prediction. + + Args: + patient: A patient object containing mutation and clinical data. + + Returns: + List[Dict[str, Any]]: A list containing a single dictionary with + patient features and survival label. Returns an empty list if + clinical data is missing or vital status is unknown. + + Note: + Returns empty list for patients with: + - No clinical events + - Missing or null vital status + - Unrecognized vital status values + """ + mutations = patient.get_events(event_type="mutations") + clinical = patient.get_events(event_type="clinical") + + if len(clinical) == 0: + return [] + + clin = clinical[0] + + # Parse vital status + vital_status = self._parse_vital_status( + getattr(clin, "vital_status", None) + ) + if vital_status is None: + return [] + + # Extract features + mutated_genes = self._extract_genes(mutations) + age = self._safe_float(getattr(clin, "age_at_diagnosis", None)) + gleason = self._safe_float(getattr(clin, "gleason_score", None)) + + return [ + { + "patient_id": patient.patient_id, + "mutations": mutated_genes, + "age_at_diagnosis": age, + "gleason_score": gleason, + "vital_status": vital_status, + } + ] + + +class CancerMutationBurden(BaseTask): + """Task for predicting high vs low tumor mutation burden. + + This task classifies patients based on their tumor mutation burden (TMB), + which is associated with immunotherapy response. TMB is approximated by + counting the number of mutated genes. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The input schema specifying required inputs. + output_schema (Dict[str, str]): The output schema specifying outputs. + TMB_THRESHOLD (int): Mutation count threshold for high TMB classification. + + Note: + This is a simplified TMB calculation based on gene count. Clinical TMB + is typically measured as mutations per megabase of sequenced DNA. + + Examples: + >>> from pyhealth.datasets import TCGAPRADDataset + >>> from pyhealth.tasks import CancerMutationBurden + >>> dataset = TCGAPRADDataset(root="/path/to/tcga_prad") + >>> task = CancerMutationBurden() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "CancerMutationBurden" + input_schema: Dict[str, str] = { + "mutations": "sequence", + "age_at_diagnosis": "tensor", + } + output_schema: Dict[str, str] = {"high_tmb": "binary"} + + # TMB threshold (number of mutated genes for high TMB classification) + TMB_THRESHOLD: int = 10 + + def _safe_float(self, value: Any, default: float = 0.0) -> float: + """Safely convert value to float, handling None and NaN. + + Args: + value: Value to convert. + default: Default value if conversion fails. + + Returns: + Float representation of value or default. + """ + if value is None or str(value) == "nan": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + def _extract_genes(self, mutations: List[Any]) -> List[str]: + """Extract valid gene symbols from mutation events. + + Args: + mutations: List of mutation event objects. + + Returns: + List of gene symbol strings, excluding None and NaN values. + """ + genes: List[str] = [] + for mut in mutations: + gene = getattr(mut, "hugo_symbol", None) + if gene is not None and str(gene) != "nan": + genes.append(str(gene)) + return genes + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process patient data to predict tumor mutation burden. + + Args: + patient: A patient object containing mutation data. + + Returns: + List[Dict[str, Any]]: A list containing a single dictionary with + patient features and TMB classification label. + + Note: + High TMB is defined as having >= TMB_THRESHOLD mutated genes. + All patients with mutation data are included in the output. + """ + mutations = patient.get_events(event_type="mutations") + clinical = patient.get_events(event_type="clinical") + + # Extract mutated genes + mutated_genes = self._extract_genes(mutations) + + # Classify TMB based on mutation count + high_tmb = 1 if len(mutated_genes) >= self.TMB_THRESHOLD else 0 + + # Get age if clinical data available + age = 0.0 + if len(clinical) > 0: + age = self._safe_float(getattr(clinical[0], "age_at_diagnosis", None)) + + return [ + { + "patient_id": patient.patient_id, + "mutations": mutated_genes, + "age_at_diagnosis": age, + "high_tmb": high_tmb, + } + ] diff --git a/pyhealth/tasks/variant_classification.py b/pyhealth/tasks/variant_classification.py new file mode 100644 index 000000000..5234d0f05 --- /dev/null +++ b/pyhealth/tasks/variant_classification.py @@ -0,0 +1,234 @@ +"""Variant classification tasks for PyHealth. + +This module provides tasks for classifying genetic variants based on +their clinical significance using ClinVar and COSMIC datasets. +""" + +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + + +class VariantClassificationClinVar(BaseTask): + """Task for classifying variant clinical significance using ClinVar data. + + This task predicts the clinical significance of genetic variants + (e.g., Pathogenic, Benign, Uncertain significance) based on variant + features from the ClinVar database. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The input schema specifying required inputs. + output_schema (Dict[str, str]): The output schema specifying outputs. + CLINICAL_SIGNIFICANCE_CATEGORIES (Dict[str, str]): Mapping of raw values + to standardized clinical significance labels. + + Note: + Variants with conflicting interpretations or non-standard clinical + significance values are excluded from the output samples. + + Examples: + >>> from pyhealth.datasets import ClinVarDataset + >>> from pyhealth.tasks import VariantClassificationClinVar + >>> dataset = ClinVarDataset(root="/path/to/clinvar") + >>> task = VariantClassificationClinVar() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "VariantClassificationClinVar" + input_schema: Dict[str, str] = { + "gene_symbol": "text", + "variant_type": "text", + "chromosome": "text", + } + output_schema: Dict[str, str] = {"clinical_significance": "multiclass"} + + # Standard clinical significance categories (ACMG/AMP guidelines) + CLINICAL_SIGNIFICANCE_CATEGORIES: Dict[str, str] = { + "pathogenic": "Pathogenic", + "likely pathogenic": "Likely pathogenic", + "benign": "Benign", + "likely benign": "Likely benign", + "uncertain significance": "Uncertain significance", + "vus": "Uncertain significance", + } + + def _normalize_clinical_significance( + self, raw_value: Optional[str] + ) -> Optional[str]: + """Normalize clinical significance to standard ACMG/AMP categories. + + Args: + raw_value: Raw clinical significance string from ClinVar. + + Returns: + Normalized category string, or None if value is invalid or + does not map to a standard category. + """ + if raw_value is None or raw_value == "" or str(raw_value) == "nan": + return None + + value_lower = str(raw_value).lower() + + # Check for exact or partial matches + if "likely pathogenic" in value_lower: + return self.CLINICAL_SIGNIFICANCE_CATEGORIES["likely pathogenic"] + elif "pathogenic" in value_lower and "likely" not in value_lower: + return self.CLINICAL_SIGNIFICANCE_CATEGORIES["pathogenic"] + elif "likely benign" in value_lower: + return self.CLINICAL_SIGNIFICANCE_CATEGORIES["likely benign"] + elif "benign" in value_lower and "likely" not in value_lower: + return self.CLINICAL_SIGNIFICANCE_CATEGORIES["benign"] + elif "uncertain" in value_lower or "vus" in value_lower: + return self.CLINICAL_SIGNIFICANCE_CATEGORIES["uncertain significance"] + + # Conflicting or unrecognized categories + return None + + def _safe_str(self, value: Any, default: str = "") -> str: + """Safely convert value to string, handling None and NaN. + + Args: + value: Value to convert. + default: Default value if conversion fails. + + Returns: + String representation of value or default. + """ + if value is None or str(value) == "nan": + return default + return str(value) + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a variant record to extract features and label. + + Args: + patient: A patient object containing variant data. + In ClinVar, each "patient" represents a single variant record. + + Returns: + List[Dict[str, Any]]: A list containing a single dictionary with + variant features and clinical significance label. Returns an empty + list if the variant has no events, invalid clinical significance, + or conflicting interpretations. + + Note: + Returns empty list for variants with: + - No variant events + - Missing or empty clinical significance + - Conflicting interpretations + - Non-standard clinical significance values + """ + events = patient.get_events(event_type="variants") + + if len(events) == 0: + return [] + + event = events[0] + + # Normalize clinical significance + raw_clinical_sig = getattr(event, "clinical_significance", None) + label = self._normalize_clinical_significance(raw_clinical_sig) + + if label is None: + return [] + + return [ + { + "patient_id": patient.patient_id, + "gene_symbol": self._safe_str(getattr(event, "gene_symbol", "")), + "variant_type": self._safe_str(getattr(event, "variant_type", "")), + "chromosome": self._safe_str(getattr(event, "chromosome", "")), + "clinical_significance": label, + } + ] + + +class MutationPathogenicityPrediction(BaseTask): + """Task for predicting mutation pathogenicity using COSMIC data. + + This task predicts whether a somatic mutation is pathogenic or neutral + based on FATHMM predictions and mutation features from the COSMIC database. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The input schema specifying required inputs. + output_schema (Dict[str, str]): The output schema specifying outputs. + VALID_FATHMM_PREDICTIONS (tuple): Valid FATHMM prediction values. + + Note: + Only mutations with valid FATHMM predictions (PATHOGENIC or NEUTRAL) + are included in the output samples. + + Examples: + >>> from pyhealth.datasets import COSMICDataset + >>> from pyhealth.tasks import MutationPathogenicityPrediction + >>> dataset = COSMICDataset(root="/path/to/cosmic") + >>> task = MutationPathogenicityPrediction() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "MutationPathogenicityPrediction" + input_schema: Dict[str, str] = { + "gene_name": "text", + "mutation_description": "text", + "primary_site": "text", + } + output_schema: Dict[str, str] = {"fathmm_prediction": "binary"} + + # Valid FATHMM prediction categories + VALID_FATHMM_PREDICTIONS: tuple = ("PATHOGENIC", "NEUTRAL") + + def _safe_str(self, value: Any, default: str = "") -> str: + """Safely convert value to string, handling None and NaN. + + Args: + value: Value to convert. + default: Default value if conversion fails. + + Returns: + String representation of value or default. + """ + if value is None or str(value) == "nan": + return default + return str(value) + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process mutation records to extract features and pathogenicity label. + + Args: + patient: A patient object containing mutation data. + In COSMIC, each "patient" represents a sample with mutations. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, one per valid mutation, + each containing mutation features and binary pathogenicity label. + Returns an empty list if no mutations have valid FATHMM predictions. + + Note: + Only mutations with FATHMM predictions of "PATHOGENIC" or "NEUTRAL" + are included. Mutations with missing or other prediction values + are excluded. + """ + events = patient.get_events(event_type="mutations") + samples: List[Dict[str, Any]] = [] + + for event in events: + fathmm = getattr(event, "fathmm_prediction", None) + + if fathmm not in self.VALID_FATHMM_PREDICTIONS: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "gene_name": self._safe_str(getattr(event, "gene_name", "")), + "mutation_description": self._safe_str( + getattr(event, "mutation_description", "") + ), + "primary_site": self._safe_str(getattr(event, "primary_site", "")), + "fathmm_prediction": 1 if fathmm == "PATHOGENIC" else 0, + } + ) + + return samples diff --git a/test-resources/clinvar/clinvar-pyhealth.csv b/test-resources/clinvar/clinvar-pyhealth.csv new file mode 100644 index 000000000..cc402cb62 --- /dev/null +++ b/test-resources/clinvar/clinvar-pyhealth.csv @@ -0,0 +1,13 @@ +gene_symbol,clinical_significance,review_status,chromosome,position,reference_allele,alternate_allele,variant_type,assembly +BRCA1,Pathogenic,reviewed by expert panel,17,43044295,G,A,single nucleotide variant,GRCh38 +BRCA2,Likely pathogenic,criteria provided multiple submitters,13,32316461,C,T,single nucleotide variant,GRCh38 +TP53,Pathogenic,reviewed by expert panel,17,7675088,G,C,single nucleotide variant,GRCh38 +EGFR,Benign,criteria provided single submitter,7,55181378,A,G,single nucleotide variant,GRCh38 +KRAS,Likely benign,criteria provided multiple submitters,12,25227342,G,A,single nucleotide variant,GRCh38 +PIK3CA,Uncertain significance,criteria provided single submitter,3,179203765,C,T,single nucleotide variant,GRCh38 +PTEN,Pathogenic,reviewed by expert panel,10,87933147,G,T,single nucleotide variant,GRCh38 +APC,Likely pathogenic,criteria provided multiple submitters,5,112839514,C,A,single nucleotide variant,GRCh38 +MLH1,Benign,criteria provided single submitter,3,37050340,A,C,single nucleotide variant,GRCh38 +MSH2,Uncertain significance,criteria provided conflicting,2,47656940,T,C,single nucleotide variant,GRCh38 +ATM,Pathogenic,criteria provided multiple submitters,11,108244035,G,A,single nucleotide variant,GRCh38 +CHEK2,Likely pathogenic,reviewed by expert panel,22,28695868,C,T,single nucleotide variant,GRCh38 diff --git a/test-resources/cosmic/cosmic-pyhealth.csv b/test-resources/cosmic/cosmic-pyhealth.csv new file mode 100644 index 000000000..2d140163d --- /dev/null +++ b/test-resources/cosmic/cosmic-pyhealth.csv @@ -0,0 +1,13 @@ +sample_id,gene_name,hgvsc,hgvsp,mutation_description,fathmm_prediction,primary_site,primary_histology,mutation_somatic_status +TCGA-001,TP53,c.818G>A,p.R273H,Substitution - Missense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-001,PTEN,c.388C>T,p.R130*,Substitution - Nonsense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-002,SPOP,c.422G>A,p.W131*,Substitution - Missense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-002,FOXA1,c.629C>T,p.P210L,Substitution - Missense,NEUTRAL,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-003,ERG,c.1012A>G,p.K338E,Substitution - Missense,NEUTRAL,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-003,TMPRSS2,c.457G>A,p.E153K,Substitution - Missense,NEUTRAL,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-004,BRCA2,c.5946delT,p.S1982fs,Deletion - Frameshift,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-004,ATM,c.2376G>A,p.W792*,Substitution - Nonsense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-005,CDK12,c.4723C>T,p.R1575*,Substitution - Nonsense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-005,APC,c.4348C>T,p.R1450*,Substitution - Nonsense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-006,PIK3CA,c.3140A>G,p.H1047R,Substitution - Missense,PATHOGENIC,prostate,adenocarcinoma,Confirmed somatic variant +TCGA-006,KRAS,c.35G>T,p.G12V,Substitution - Missense,NEUTRAL,prostate,adenocarcinoma,Confirmed somatic variant diff --git a/test-resources/tcga_prad/tcga_prad_clinical-pyhealth.csv b/test-resources/tcga_prad/tcga_prad_clinical-pyhealth.csv new file mode 100644 index 000000000..fbd1b49d5 --- /dev/null +++ b/test-resources/tcga_prad/tcga_prad_clinical-pyhealth.csv @@ -0,0 +1,6 @@ +patient_id,age_at_diagnosis,gleason_score,vital_status,days_to_death,tumor_stage +TCGA-2A-A8VL,65,7,Dead,1250,Stage III +TCGA-CH-5737,58,6,Alive,,Stage II +TCGA-EJ-7783,72,9,Dead,890,Stage IV +TCGA-G9-6348,61,7,Alive,,Stage II +TCGA-HC-7740,55,6,Alive,,Stage I diff --git a/test-resources/tcga_prad/tcga_prad_mutations-pyhealth.csv b/test-resources/tcga_prad/tcga_prad_mutations-pyhealth.csv new file mode 100644 index 000000000..6a52f7c3b --- /dev/null +++ b/test-resources/tcga_prad/tcga_prad_mutations-pyhealth.csv @@ -0,0 +1,20 @@ +patient_id,hugo_symbol,variant_classification,variant_type,hgvsc,hgvsp,tumor_sample_barcode +TCGA-2A-A8VL,TP53,Missense_Mutation,SNP,c.818G>A,p.R273H,TCGA-2A-A8VL-01A +TCGA-2A-A8VL,PTEN,Nonsense_Mutation,SNP,c.388C>T,p.R130*,TCGA-2A-A8VL-01A +TCGA-2A-A8VL,SPOP,Missense_Mutation,SNP,c.422G>A,p.W131C,TCGA-2A-A8VL-01A +TCGA-CH-5737,FOXA1,Missense_Mutation,SNP,c.629C>T,p.P210L,TCGA-CH-5737-01A +TCGA-CH-5737,ERG,Missense_Mutation,SNP,c.1012A>G,p.K338E,TCGA-CH-5737-01A +TCGA-EJ-7783,BRCA2,Frame_Shift_Del,DEL,c.5946delT,p.S1982fs,TCGA-EJ-7783-01A +TCGA-EJ-7783,ATM,Nonsense_Mutation,SNP,c.2376G>A,p.W792*,TCGA-EJ-7783-01A +TCGA-G9-6348,CDK12,Nonsense_Mutation,SNP,c.4723C>T,p.R1575*,TCGA-G9-6348-01A +TCGA-G9-6348,APC,Nonsense_Mutation,SNP,c.4348C>T,p.R1450*,TCGA-G9-6348-01A +TCGA-G9-6348,PIK3CA,Missense_Mutation,SNP,c.3140A>G,p.H1047R,TCGA-G9-6348-01A +TCGA-G9-6348,RB1,Nonsense_Mutation,SNP,c.1735C>T,p.R579*,TCGA-G9-6348-01A +TCGA-G9-6348,BRCA1,Missense_Mutation,SNP,c.5123C>A,p.A1708E,TCGA-G9-6348-01A +TCGA-G9-6348,CHEK2,Missense_Mutation,SNP,c.470T>C,p.I157T,TCGA-G9-6348-01A +TCGA-G9-6348,MSH2,Missense_Mutation,SNP,c.1865G>A,p.R622Q,TCGA-G9-6348-01A +TCGA-G9-6348,MLH1,Missense_Mutation,SNP,c.350C>T,p.T117M,TCGA-G9-6348-01A +TCGA-G9-6348,PALB2,Frame_Shift_Del,DEL,c.172delA,p.N58fs,TCGA-G9-6348-01A +TCGA-G9-6348,RAD51D,Missense_Mutation,SNP,c.620C>T,p.T207M,TCGA-G9-6348-01A +TCGA-HC-7740,KRAS,Missense_Mutation,SNP,c.35G>T,p.G12V,TCGA-HC-7740-01A +TCGA-HC-7740,BRAF,Missense_Mutation,SNP,c.1799T>A,p.V600E,TCGA-HC-7740-01A diff --git a/tests/core/test_clinvar.py b/tests/core/test_clinvar.py new file mode 100644 index 000000000..1a3eb4b78 --- /dev/null +++ b/tests/core/test_clinvar.py @@ -0,0 +1,121 @@ +""" +Unit tests for the ClinVarDataset and VariantClassificationClinVar classes. + +Author: + Abraham Arellano +""" +import os +import shutil +import unittest +from pathlib import Path + +from pyhealth.datasets import ClinVarDataset +from pyhealth.tasks import VariantClassificationClinVar + + +class TestClinVarDataset(unittest.TestCase): + """Test cases for ClinVarDataset.""" + + @classmethod + def setUpClass(cls): + """Set up test resources path.""" + cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "clinvar" + + def test_dataset_initialization(self): + """Test that the dataset initializes correctly.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "clinvar") + + def test_stats(self): + """Test that stats() runs without error.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + dataset.stats() + + def test_num_patients(self): + """Test the number of unique patient IDs (variants).""" + dataset = ClinVarDataset(root=str(self.test_resources)) + # Each row is a separate "patient" since patient_id is null + self.assertEqual(len(dataset.unique_patient_ids), 12) + + def test_get_patient(self): + """Test retrieving a patient/variant record.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + patient_id = dataset.unique_patient_ids[0] + patient = dataset.get_patient(patient_id) + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, patient_id) + + def test_get_events(self): + """Test getting events from a patient.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + patient_id = dataset.unique_patient_ids[0] + patient = dataset.get_patient(patient_id) + events = patient.get_events(event_type="variants") + self.assertEqual(len(events), 1) + + def test_event_attributes(self): + """Test that event attributes are correctly loaded.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + patient_id = dataset.unique_patient_ids[0] + patient = dataset.get_patient(patient_id) + events = patient.get_events(event_type="variants") + event = events[0] + + # Check that attributes exist + self.assertIn("gene_symbol", event) + self.assertIn("clinical_significance", event) + self.assertIn("chromosome", event) + + def test_default_task(self): + """Test that the default task is VariantClassificationClinVar.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + self.assertIsInstance(dataset.default_task, VariantClassificationClinVar) + + def test_set_task(self): + """Test setting and running the variant classification task.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + task = VariantClassificationClinVar() + samples = dataset.set_task(task) + + # Should have samples for variants with valid clinical significance + self.assertGreater(len(samples), 0) + + def test_task_output_format(self): + """Test that task output has the correct format.""" + dataset = ClinVarDataset(root=str(self.test_resources)) + task = VariantClassificationClinVar() + samples = dataset.set_task(task) + + if len(samples) > 0: + sample = samples[0] + self.assertIn("patient_id", sample) + self.assertIn("gene_symbol", sample) + self.assertIn("clinical_significance", sample) + + +class TestVariantClassificationClinVar(unittest.TestCase): + """Test cases for VariantClassificationClinVar task.""" + + def test_task_attributes(self): + """Test task class attributes.""" + task = VariantClassificationClinVar() + self.assertEqual(task.task_name, "VariantClassificationClinVar") + self.assertIn("gene_symbol", task.input_schema) + self.assertIn("clinical_significance", task.output_schema) + + def test_input_schema(self): + """Test input schema definition.""" + task = VariantClassificationClinVar() + self.assertEqual(task.input_schema["gene_symbol"], "text") + self.assertEqual(task.input_schema["variant_type"], "text") + self.assertEqual(task.input_schema["chromosome"], "text") + + def test_output_schema(self): + """Test output schema definition.""" + task = VariantClassificationClinVar() + self.assertEqual(task.output_schema["clinical_significance"], "multiclass") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_cosmic.py b/tests/core/test_cosmic.py new file mode 100644 index 000000000..72c8456d7 --- /dev/null +++ b/tests/core/test_cosmic.py @@ -0,0 +1,134 @@ +""" +Unit tests for the COSMICDataset and MutationPathogenicityPrediction classes. + +Author: + Abraham Arellano +""" +import os +import shutil +import unittest +from pathlib import Path + +from pyhealth.datasets import COSMICDataset +from pyhealth.tasks import MutationPathogenicityPrediction + + +class TestCOSMICDataset(unittest.TestCase): + """Test cases for COSMICDataset.""" + + @classmethod + def setUpClass(cls): + """Set up test resources path.""" + cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "cosmic" + + def test_dataset_initialization(self): + """Test that the dataset initializes correctly.""" + dataset = COSMICDataset(root=str(self.test_resources)) + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "cosmic") + + def test_stats(self): + """Test that stats() runs without error.""" + dataset = COSMICDataset(root=str(self.test_resources)) + dataset.stats() + + def test_num_patients(self): + """Test the number of unique patient IDs (samples).""" + dataset = COSMICDataset(root=str(self.test_resources)) + # Should have 6 unique sample IDs (TCGA-001 through TCGA-006) + self.assertEqual(len(dataset.unique_patient_ids), 6) + + def test_get_patient(self): + """Test retrieving a patient/sample record.""" + dataset = COSMICDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-001") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "TCGA-001") + + def test_get_events(self): + """Test getting mutation events from a sample.""" + dataset = COSMICDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-001") + events = patient.get_events(event_type="mutations") + # TCGA-001 has 2 mutations in test data + self.assertEqual(len(events), 2) + + def test_event_attributes(self): + """Test that event attributes are correctly loaded.""" + dataset = COSMICDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-001") + events = patient.get_events(event_type="mutations") + event = events[0] + + # Check that attributes exist + self.assertIn("gene_name", event) + self.assertIn("fathmm_prediction", event) + self.assertIn("primary_site", event) + + def test_default_task(self): + """Test that the default task is MutationPathogenicityPrediction.""" + dataset = COSMICDataset(root=str(self.test_resources)) + self.assertIsInstance(dataset.default_task, MutationPathogenicityPrediction) + + def test_set_task(self): + """Test setting and running the pathogenicity prediction task.""" + dataset = COSMICDataset(root=str(self.test_resources)) + task = MutationPathogenicityPrediction() + samples = dataset.set_task(task) + + # Should have samples for mutations with FATHMM predictions + self.assertGreater(len(samples), 0) + + def test_task_output_format(self): + """Test that task output has the correct format.""" + dataset = COSMICDataset(root=str(self.test_resources)) + task = MutationPathogenicityPrediction() + samples = dataset.set_task(task) + + if len(samples) > 0: + sample = samples[0] + self.assertIn("patient_id", sample) + self.assertIn("gene_name", sample) + self.assertIn("fathmm_prediction", sample) + self.assertIn(sample["fathmm_prediction"], [0, 1]) + + def test_pathogenic_vs_neutral_labels(self): + """Test that FATHMM predictions are correctly converted to binary.""" + dataset = COSMICDataset(root=str(self.test_resources)) + task = MutationPathogenicityPrediction() + samples = dataset.set_task(task) + + # Count pathogenic and neutral + pathogenic_count = sum(1 for s in samples if s["fathmm_prediction"] == 1) + neutral_count = sum(1 for s in samples if s["fathmm_prediction"] == 0) + + # Both should be present in test data + self.assertGreater(pathogenic_count, 0) + self.assertGreater(neutral_count, 0) + + +class TestMutationPathogenicityPrediction(unittest.TestCase): + """Test cases for MutationPathogenicityPrediction task.""" + + def test_task_attributes(self): + """Test task class attributes.""" + task = MutationPathogenicityPrediction() + self.assertEqual(task.task_name, "MutationPathogenicityPrediction") + self.assertIn("gene_name", task.input_schema) + self.assertIn("fathmm_prediction", task.output_schema) + + def test_input_schema(self): + """Test input schema definition.""" + task = MutationPathogenicityPrediction() + self.assertEqual(task.input_schema["gene_name"], "text") + self.assertEqual(task.input_schema["mutation_description"], "text") + self.assertEqual(task.input_schema["primary_site"], "text") + + def test_output_schema(self): + """Test output schema definition.""" + task = MutationPathogenicityPrediction() + self.assertEqual(task.output_schema["fathmm_prediction"], "binary") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_tcga_prad.py b/tests/core/test_tcga_prad.py new file mode 100644 index 000000000..08468a481 --- /dev/null +++ b/tests/core/test_tcga_prad.py @@ -0,0 +1,195 @@ +""" +Unit tests for the TCGAPRADDataset and CancerSurvivalPrediction classes. + +Author: + Abraham Arellano +""" +import os +import shutil +import unittest +from pathlib import Path + +from pyhealth.datasets import TCGAPRADDataset +from pyhealth.tasks import CancerSurvivalPrediction, CancerMutationBurden + + +class TestTCGAPRADDataset(unittest.TestCase): + """Test cases for TCGAPRADDataset.""" + + @classmethod + def setUpClass(cls): + """Set up test resources path.""" + cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "tcga_prad" + + def test_dataset_initialization(self): + """Test that the dataset initializes correctly.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "tcga_prad") + + def test_stats(self): + """Test that stats() runs without error.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + dataset.stats() + + def test_num_patients(self): + """Test the number of unique patient IDs.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + # Should have 5 unique patients + self.assertEqual(len(dataset.unique_patient_ids), 5) + + def test_get_patient(self): + """Test retrieving a patient record.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-2A-A8VL") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "TCGA-2A-A8VL") + + def test_get_mutation_events(self): + """Test getting mutation events from a patient.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-2A-A8VL") + events = patient.get_events(event_type="mutations") + # TCGA-2A-A8VL has 3 mutations in test data + self.assertEqual(len(events), 3) + + def test_get_clinical_events(self): + """Test getting clinical events from a patient.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-2A-A8VL") + events = patient.get_events(event_type="clinical") + # Each patient has 1 clinical record + self.assertEqual(len(events), 1) + + def test_mutation_attributes(self): + """Test that mutation attributes are correctly loaded.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-2A-A8VL") + events = patient.get_events(event_type="mutations") + event = events[0] + + # Check that attributes exist + self.assertIn("hugo_symbol", event) + self.assertIn("variant_classification", event) + self.assertIn("variant_type", event) + + def test_clinical_attributes(self): + """Test that clinical attributes are correctly loaded.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + patient = dataset.get_patient("TCGA-2A-A8VL") + events = patient.get_events(event_type="clinical") + event = events[0] + + # Check that attributes exist + self.assertIn("age_at_diagnosis", event) + self.assertIn("gleason_score", event) + self.assertIn("vital_status", event) + + def test_default_task(self): + """Test that the default task is CancerSurvivalPrediction.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + self.assertIsInstance(dataset.default_task, CancerSurvivalPrediction) + + def test_set_task_survival(self): + """Test setting and running the survival prediction task.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + task = CancerSurvivalPrediction() + samples = dataset.set_task(task) + + # Should have samples for patients with clinical data + self.assertGreater(len(samples), 0) + + def test_task_output_format(self): + """Test that task output has the correct format.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + task = CancerSurvivalPrediction() + samples = dataset.set_task(task) + + if len(samples) > 0: + sample = samples[0] + self.assertIn("patient_id", sample) + self.assertIn("mutations", sample) + self.assertIn("vital_status", sample) + self.assertIn(sample["vital_status"], [0, 1]) + # After processing, mutations is converted to tensor by SequenceProcessor + self.assertTrue(hasattr(sample["mutations"], '__len__')) + + def test_vital_status_labels(self): + """Test that vital status is correctly converted to binary.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + task = CancerSurvivalPrediction() + samples = dataset.set_task(task) + + # Count alive and dead + dead_count = sum(1 for s in samples if s["vital_status"] == 1) + alive_count = sum(1 for s in samples if s["vital_status"] == 0) + + # Both should be present in test data + self.assertGreater(dead_count, 0) + self.assertGreater(alive_count, 0) + + +class TestCancerSurvivalPrediction(unittest.TestCase): + """Test cases for CancerSurvivalPrediction task.""" + + def test_task_attributes(self): + """Test task class attributes.""" + task = CancerSurvivalPrediction() + self.assertEqual(task.task_name, "CancerSurvivalPrediction") + self.assertIn("mutations", task.input_schema) + self.assertIn("vital_status", task.output_schema) + + def test_input_schema(self): + """Test input schema definition.""" + task = CancerSurvivalPrediction() + self.assertEqual(task.input_schema["mutations"], "sequence") + self.assertEqual(task.input_schema["age_at_diagnosis"], "tensor") + self.assertEqual(task.input_schema["gleason_score"], "tensor") + + def test_output_schema(self): + """Test output schema definition.""" + task = CancerSurvivalPrediction() + self.assertEqual(task.output_schema["vital_status"], "binary") + + +class TestCancerMutationBurden(unittest.TestCase): + """Test cases for CancerMutationBurden task.""" + + @classmethod + def setUpClass(cls): + """Set up test resources path.""" + cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "tcga_prad" + + def test_task_attributes(self): + """Test task class attributes.""" + task = CancerMutationBurden() + self.assertEqual(task.task_name, "CancerMutationBurden") + self.assertIn("mutations", task.input_schema) + self.assertIn("high_tmb", task.output_schema) + + def test_set_task(self): + """Test setting and running the mutation burden task.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + task = CancerMutationBurden() + samples = dataset.set_task(task) + + # Should have samples + self.assertGreater(len(samples), 0) + + def test_output_format(self): + """Test that task output has the correct format.""" + dataset = TCGAPRADDataset(root=str(self.test_resources)) + task = CancerMutationBurden() + samples = dataset.set_task(task) + + if len(samples) > 0: + sample = samples[0] + self.assertIn("patient_id", sample) + self.assertIn("mutations", sample) + self.assertIn("high_tmb", sample) + # high_tmb is converted to int by BinaryLabelProcessor + self.assertIn(int(sample["high_tmb"]), [0, 1]) + + +if __name__ == "__main__": + unittest.main() From e281ef094bc8cd771b56aebaffa7de464fb74358 Mon Sep 17 00:00:00 2001 From: Arellano Tavara Date: Mon, 24 Nov 2025 10:30:54 -0600 Subject: [PATCH 2/2] Add documentation for cancer genomics datasets and tasks - Add RST docs for ClinVarDataset, COSMICDataset, TCGAPRADDataset - Add RST docs for variant classification and cancer survival tasks - Update datasets.rst and tasks.rst index files --- docs/api/datasets.rst | 3 +++ docs/api/datasets/pyhealth.datasets.COSMICDataset.rst | 9 +++++++++ docs/api/datasets/pyhealth.datasets.ClinVarDataset.rst | 9 +++++++++ docs/api/datasets/pyhealth.datasets.TCGAPRADDataset.rst | 9 +++++++++ docs/api/tasks.rst | 4 ++++ docs/api/tasks/pyhealth.tasks.CancerMutationBurden.rst | 7 +++++++ .../tasks/pyhealth.tasks.CancerSurvivalPrediction.rst | 7 +++++++ .../pyhealth.tasks.MutationPathogenicityPrediction.rst | 7 +++++++ .../pyhealth.tasks.VariantClassificationClinVar.rst | 7 +++++++ 9 files changed, 62 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.COSMICDataset.rst create mode 100644 docs/api/datasets/pyhealth.datasets.ClinVarDataset.rst create mode 100644 docs/api/datasets/pyhealth.datasets.TCGAPRADDataset.rst create mode 100644 docs/api/tasks/pyhealth.tasks.CancerMutationBurden.rst create mode 100644 docs/api/tasks/pyhealth.tasks.CancerSurvivalPrediction.rst create mode 100644 docs/api/tasks/pyhealth.tasks.MutationPathogenicityPrediction.rst create mode 100644 docs/api/tasks/pyhealth.tasks.VariantClassificationClinVar.rst diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 913439ee8..3412e5ac5 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -48,5 +48,8 @@ Available Datasets datasets/pyhealth.datasets.ChestXray14Dataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset + datasets/pyhealth.datasets.ClinVarDataset + datasets/pyhealth.datasets.COSMICDataset + datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils diff --git a/docs/api/datasets/pyhealth.datasets.COSMICDataset.rst b/docs/api/datasets/pyhealth.datasets.COSMICDataset.rst new file mode 100644 index 000000000..e38cdc338 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.COSMICDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.COSMICDataset +=============================== + +The COSMIC (Catalogue of Somatic Mutations in Cancer) dataset provides comprehensive information about somatic mutations in human cancers. For more information see `COSMIC `_. This dataset was contributed as part of the Prostate-VarBench benchmarking work (`arXiv:2511.09576 `_). + +.. autoclass:: pyhealth.datasets.COSMICDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.ClinVarDataset.rst b/docs/api/datasets/pyhealth.datasets.ClinVarDataset.rst new file mode 100644 index 000000000..dacfbadbd --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ClinVarDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.ClinVarDataset +================================ + +The ClinVar dataset provides information about genomic variants and their clinical significance based on ACMG/AMP guidelines. For more information see `ClinVar `_. This dataset was contributed as part of the Prostate-VarBench benchmarking work (`arXiv:2511.09576 `_). + +.. autoclass:: pyhealth.datasets.ClinVarDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.TCGAPRADDataset.rst b/docs/api/datasets/pyhealth.datasets.TCGAPRADDataset.rst new file mode 100644 index 000000000..3ac881ea7 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.TCGAPRADDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.TCGAPRADDataset +================================= + +The Cancer Genome Atlas Prostate Adenocarcinoma (TCGA-PRAD) dataset provides multi-omics data including somatic mutations and clinical information for prostate cancer patients. For more information see `TCGA-PRAD `_. This dataset was contributed as part of the Prostate-VarBench benchmarking work (`arXiv:2511.09576 `_). + +.. autoclass:: pyhealth.datasets.TCGAPRADDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index fcf783921..934886e2d 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -95,3 +95,7 @@ Available Tasks Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification + Variant Classification (ClinVar) + Mutation Pathogenicity (COSMIC) + Cancer Survival Prediction (TCGA) + Cancer Mutation Burden (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.CancerMutationBurden.rst b/docs/api/tasks/pyhealth.tasks.CancerMutationBurden.rst new file mode 100644 index 000000000..4e965d939 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.CancerMutationBurden.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.CancerMutationBurden +=================================== + +.. autoclass:: pyhealth.tasks.CancerMutationBurden + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.CancerSurvivalPrediction.rst b/docs/api/tasks/pyhealth.tasks.CancerSurvivalPrediction.rst new file mode 100644 index 000000000..3684ed37e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.CancerSurvivalPrediction.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.CancerSurvivalPrediction +======================================= + +.. autoclass:: pyhealth.tasks.CancerSurvivalPrediction + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.MutationPathogenicityPrediction.rst b/docs/api/tasks/pyhealth.tasks.MutationPathogenicityPrediction.rst new file mode 100644 index 000000000..089918dc6 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.MutationPathogenicityPrediction.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.MutationPathogenicityPrediction +============================================== + +.. autoclass:: pyhealth.tasks.MutationPathogenicityPrediction + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.VariantClassificationClinVar.rst b/docs/api/tasks/pyhealth.tasks.VariantClassificationClinVar.rst new file mode 100644 index 000000000..0e3decfcd --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.VariantClassificationClinVar.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.VariantClassificationClinVar +=========================================== + +.. autoclass:: pyhealth.tasks.VariantClassificationClinVar + :members: + :undoc-members: + :show-inheritance: