diff --git a/getting-started.md b/getting-started.md index a5efbbe6..e12429da 100644 --- a/getting-started.md +++ b/getting-started.md @@ -122,6 +122,12 @@ as well as `.jpg`s showing from which parts of the slide features are extracted. Most of the background should be marked in red, meaning ignored that it was ignored during feature extraction. +> In case you want to use a gated model (e.g. Virchow2), you need to login in your console using: +> ``` +>huggingface-cli login +> ``` +> More info about this [here](https://huggingface.co/docs/huggingface_hub/en/guides/cli). + > **If you are using the UNI or CONCH models** > and working in an environment where your home directory storage is limited, > you may want to also specify your huggingface storage directory @@ -151,6 +157,7 @@ meaning ignored that it was ignored during feature extraction. [COBRA2]: https://huggingface.co/KatherLab/COBRA [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine +[PRISM]: https://huggingface.co/paige-ai/Prism @@ -266,6 +273,7 @@ STAMP currently supports the following encoders: - [COBRA2] - [EAGLE] - [MADELEINE] +- [PRISM] Slide encoders take as input the already extracted tile-level features in the preprocessing step. Each encoder accepts only certain extractors and most @@ -273,12 +281,13 @@ work only on CUDA devices: | Encoder | Required Extractor | Compatible Devices | |--|--|--| -| CHIEF | CTRANSPATH, CHIEF-CTRANSPATH | CUDA only | +| CHIEF | CHIEF-CTRANSPATH | CUDA only | | TITAN | CONCH1.5 | CUDA, cpu, mps | GIGAPATH | GIGAPATH | CUDA only | COBRA2 | CONCH, UNI, VIRCHOW2 or H-OPTIMUS-0 | CUDA only | EAGLE | CTRANSPATH, CHIEF-CTRANSPATH | CUDA only | MADELEINE | CONCH | CUDA only +| PRISM | VIRCHOW_FULL | CUDA only As with feature extractors, most of these models require you to request @@ -363,4 +372,27 @@ patient_encoding: stamp --config stamp-test-experiment/config.yaml encode_patients ``` -The output `.h5` features will have the patient's id as name. \ No newline at end of file +The output `.h5` features will have the patient's id as name. + +## Training with Patient-Level Features + +Once you have patient-level features, +you can train models directly on these features. This is useful because: +- **Efficient with Limited Data**: Patient-level modeling often performs better when data is scarce, since pretrained encoders can extract robust features from each slide as a whole. +- **Faster Training & Reduced Overfitting**: With fewer parameters to train compared to tile-level models, patient-level models train more quickly and are less prone to overfitting. +- **Enables Interpretable Cohort Analysis**: Patient-level features can be used for unsupervised analyses, such as clustering, making it easier to interpret and explore patient subgroups within your cohort. + +> **Note:** Slide-level features are not supported for modeling because the ground truth +> labels in the clinical table are at the patient level. + +To train a model using patient-level features, you can use the same command as before: +```sh +stamp --config stamp-test-experiment/config.yaml crossval +``` + +The key differences for patient-level modeling are: +- The `feature_dir` should contain patient-level `.h5` files (one per patient). +- The `slide_table` is not needed since there's a direct mapping from patient ID to feature file. +- STAMP will automatically detect that these are patient-level features and use a MultiLayer Perceptron (MLP) classifier instead of the Vision Transformer. + +You can then run statistics as done with tile-level features. \ No newline at end of file diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 89792122..3806cddb 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -7,6 +7,12 @@ import yaml from stamp.config import StampConfig +from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + VitModelParams, +) STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") @@ -126,32 +132,20 @@ def _run_cli(args: argparse.Namespace) -> None: if config.training is None: raise ValueError("no training configuration supplied") + # use default advanced config in case none is provided + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()) + ) + _add_file_handle_(_logger, output_dir=config.training.output_dir) _logger.info( "using the following configuration:\n" f"{yaml.dump(config.training.model_dump(mode='json'))}" ) - # We pass every parameter explicitly so our type checker can do its work. + train_categorical_model_( - output_dir=config.training.output_dir, - clini_table=config.training.clini_table, - slide_table=config.training.slide_table, - feature_dir=config.training.feature_dir, - patient_label=config.training.patient_label, - ground_truth_label=config.training.ground_truth_label, - filename_label=config.training.filename_label, - categories=config.training.categories, - # Dataset and -loader parameters - bag_size=config.training.bag_size, - num_workers=config.training.num_workers, - # Training paramenters - batch_size=config.training.batch_size, - max_epochs=config.training.max_epochs, - patience=config.training.patience, - accelerator=config.training.accelerator, - # Experimental features - use_vary_precision_transform=config.training.use_vary_precision_transform, - use_alibi=config.training.use_alibi, + config=config.training, advanced=config.advanced_config ) case "deploy": @@ -189,27 +183,16 @@ def _run_cli(args: argparse.Namespace) -> None: "using the following configuration:\n" f"{yaml.dump(config.crossval.model_dump(mode='json'))}" ) + + # use default advanced config in case none is provided + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()) + ) + categorical_crossval_( - output_dir=config.crossval.output_dir, - clini_table=config.crossval.clini_table, - slide_table=config.crossval.slide_table, - feature_dir=config.crossval.feature_dir, - patient_label=config.crossval.patient_label, - ground_truth_label=config.crossval.ground_truth_label, - filename_label=config.crossval.filename_label, - categories=config.crossval.categories, - n_splits=config.crossval.n_splits, - # Dataset and -loader parameters - bag_size=config.crossval.bag_size, - num_workers=config.crossval.num_workers, - # Crossval paramenters - batch_size=config.crossval.batch_size, - max_epochs=config.crossval.max_epochs, - patience=config.crossval.patience, - accelerator=config.crossval.accelerator, - # Experimental Features - use_vary_precision_transform=config.crossval.use_vary_precision_transform, - use_alibi=config.crossval.use_alibi, + config=config.crossval, + advanced=config.advanced_config, ) case "statistics": diff --git a/src/stamp/config.py b/src/stamp/config.py index ca283dba..3d847324 100644 --- a/src/stamp/config.py +++ b/src/stamp/config.py @@ -2,7 +2,12 @@ from stamp.encoding.config import PatientEncodingConfig, SlideEncodingConfig from stamp.heatmaps.config import HeatmapConfig -from stamp.modeling.config import CrossvalConfig, DeploymentConfig, TrainConfig +from stamp.modeling.config import ( + AdvancedConfig, + CrossvalConfig, + DeploymentConfig, + TrainConfig, +) from stamp.preprocessing.config import PreprocessingConfig from stamp.statistics import StatsConfig @@ -23,3 +28,5 @@ class StampConfig(BaseModel): slide_encoding: SlideEncodingConfig | None = None patient_encoding: PatientEncodingConfig | None = None + + advanced_config: AdvancedConfig | None = None diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 3ba0f0d2..e2e8eda2 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -6,7 +6,7 @@ preprocessing: # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", # "virchow-full", "musk", "mstar", "plip" # Some of them require requesting access to the respective authors beforehand. - extractor: "ctranspath" + extractor: "chief-ctranspath" # Device to run feature extraction on ("cpu", "cuda", "cuda:0", etc.) device: "cuda" @@ -79,16 +79,14 @@ crossval: # Number of folds to split the data into for cross-validation #n_splits: 5 - # Experimental features: + # Path to a YAML file with advanced training parameters. + #params_path: "path/to/train_params.yaml" - # Please try uncommenting the settings below - # and report if they improve / reduce model performance! + # Experimental features: # Change the precision of features during training #use_vary_precision_transform: true - # Use ALiBi positional embedding - # use_alibi: true training: @@ -126,17 +124,14 @@ training: # If unspecified, they will be inferred from the table itself. #categories: ["mutated", "wild type"] - # Experimental features: + # Path to a YAML file with advanced training parameters. + #params_path: "path/to/model_params.yaml" - # Please try uncommenting the settings below - # and report if they improve / reduce model performance! + # Experimental features: # Change the precision of features during training #use_vary_precision_transform: true - # Use ALiBi positional embedding - # use_alibi: true - deployment: output_dir: "/path/to/save/files/to" @@ -272,3 +267,36 @@ patient_encoding: # Add a hash of the entire preprocessing codebase in the feature folder name. #generate_hash: True + + +advanced_config: + max_epochs: 64 + patience: 16 + batch_size: 64 + # Only for tile-level training. Reducing its amount could affect + # model performance. Reduces memory consumption. Default value works + # fine for most cases. + bag_size: 512 + # Optional parameters + #num_workers: 16 # Default chosen by cpu cores + + # Select a model. Not working yet, added for future support. + # Now it uses a ViT for tile features and a MLP for patient features. + #model_name: "vit" + + model_params: + # Tile-level training models: + vit: # Vision Transformer + dim_model: 512 + dim_feedforward: 512 + n_heads: 8 + n_layers: 2 + dropout: 0.25 + # Experimental feature: Use ALiBi positional embedding + use_alibi: false + + # Patient-level training models: + mlp: # Multilayer Perceptron + dim_hidden: 512 + num_layers: 2 + dropout: 0.25 diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 95492230..3148f635 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -64,6 +64,11 @@ def init_slide_encoder_( selected_encoder: Encoder = Madeleine() + case EncoderName.PRISM: + from stamp.encoding.encoder.prism import Prism + + selected_encoder: Encoder = Prism() + case Encoder(): selected_encoder = encoder @@ -145,6 +150,11 @@ def init_patient_encoder_( selected_encoder: Encoder = Madeleine() + case EncoderName.PRISM: + from stamp.encoding.encoder.prism import Prism + + selected_encoder: Encoder = Prism() + case Encoder(): selected_encoder = encoder diff --git a/src/stamp/encoding/config.py b/src/stamp/encoding/config.py index 5158b3b9..e743fcfd 100644 --- a/src/stamp/encoding/config.py +++ b/src/stamp/encoding/config.py @@ -13,8 +13,7 @@ class EncoderName(StrEnum): TITAN = "titan" GIGAPATH = "gigapath" MADELEINE = "madeleine" - # PRISM = "paigeai-prism" - # waiting for paige-ai authors to fix it + PRISM = "prism" class SlideEncodingConfig(BaseModel, arbitrary_types_allowed=True): diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 94192812..d9035f7a 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -6,7 +6,6 @@ import h5py import numpy as np -import pandas as pd import torch from torch import Tensor from tqdm import tqdm @@ -14,7 +13,7 @@ import stamp from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName -from stamp.modeling.data import CoordsInfo, get_coords +from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel @@ -83,7 +82,9 @@ def encode_slides_( slide_embedding = self._generate_slide_embedding( feats, device, coords=coords ) - self._save_features_(output_path=output_path, feats=slide_embedding) + self._save_features_( + output_path=output_path, feats=slide_embedding, feat_type="slide" + ) def encode_patients_( self, @@ -113,7 +114,7 @@ def encode_patients_( if self.precision == torch.float16: self.model.half() - slide_table = self._read_slide_table(slide_table_path) + slide_table = read_table(slide_table_path) patient_groups = slide_table.groupby(patient_label) for patient_id, group in (progress := tqdm(patient_groups)): @@ -142,7 +143,9 @@ def encode_patients_( patient_embedding = self._generate_patient_embedding( feats_list, device, **kwargs ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) @abstractmethod def _generate_slide_embedding( @@ -161,10 +164,6 @@ def _generate_patient_embedding( """Generate patient embedding. Must be implemented by subclasses.""" pass - @staticmethod - def _read_slide_table(slide_table_path: Path) -> pd.DataFrame: - return pd.read_csv(slide_table_path) - def _validate_and_read_features(self, h5_path: str) -> tuple[Tensor, CoordsInfo]: feats, coords, extractor = self._read_h5(h5_path) if extractor not in self.required_extractors: @@ -192,7 +191,9 @@ def _read_h5( ) return feats, coords, extractor - def _save_features_(self, output_path: Path, feats: np.ndarray) -> None: + def _save_features_( + self, output_path: Path, feats: np.ndarray, feat_type: str + ) -> None: with ( NamedTemporaryFile(dir=output_path.parent, delete=False) as tmp_h5_file, h5py.File(tmp_h5_file, "w") as f, @@ -204,6 +205,7 @@ def _save_features_(self, output_path: Path, feats: np.ndarray) -> None: f.attrs["precision"] = str(self.precision) f.attrs["stamp_version"] = stamp.__version__ f.attrs["code_hash"] = get_processing_code_hash(Path(__file__))[:8] + f.attrs["feat_type"] = feat_type # TODO: Add more metadata like tile-level extractor name # and maybe tile size in pixels and microns except Exception: diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2c33a5ca..eaab9750 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -115,7 +115,9 @@ def __init__(self) -> None: model=model, identifier=EncoderName.CHIEF, precision=torch.float32, - required_extractors=[ExtractorName.CHIEF_CTRANSPATH], + required_extractors=[ + ExtractorName.CHIEF_CTRANSPATH, + ], ) def _generate_slide_embedding( @@ -192,7 +194,9 @@ def encode_patients_( .cpu() .numpy() ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) def initialize_weights(module): diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index 85c0ae2d..b2fb293d 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -164,7 +164,9 @@ def encode_slides_( continue slide_embedding = self._generate_slide_embedding(feats, device, agg_feats) - self._save_features_(output_path=output_path, feats=slide_embedding) + self._save_features_( + output_path=output_path, feats=slide_embedding, feat_type="slide" + ) # TODO: Add @override decorator on each encoder once it is added to python def encode_patients_( @@ -233,4 +235,6 @@ def encode_patients_( patient_embedding = self._generate_patient_embedding( feats_list, device, agg_feats_list ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index c7931400..e2fb0ebb 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -183,7 +183,9 @@ def encode_patients_( all_feats_list, device, coords_list=all_coords_list ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) def _generate_patient_embedding( self, feats_list, device, coords_list: list | None = None, **kwargs diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 0307d68f..41dd19f1 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -166,4 +166,6 @@ def encode_patients_( patient_embedding = self._generate_patient_embedding( all_feats_list, device, all_coords_list ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 96edd3db..0e1c76fd 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -1,10 +1,12 @@ import os +from collections.abc import Sequence from pathlib import Path import torch from pydantic import BaseModel, ConfigDict, Field -from stamp.types import PandasLabel +from stamp.modeling.registry import ModelName +from stamp.types import Category, PandasLabel class TrainConfig(BaseModel): @@ -13,7 +15,7 @@ class TrainConfig(BaseModel): output_dir: Path = Field(description="The directory to save the results to") clini_table: Path = Field(description="Excel or CSV to read clinical data from") - slide_table: Path = Field( + slide_table: Path | None = Field( description="Excel or CSV to read patient-slide associations from" ) feature_dir: Path = Field(description="Directory containing feature files") @@ -21,24 +23,18 @@ class TrainConfig(BaseModel): ground_truth_label: PandasLabel = Field( description="Name of categorical column in clinical table to train on" ) - categories: list[str] | None = None + categories: Sequence[Category] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" - # Dataset and -loader parameters - bag_size: int = 512 - num_workers: int = min(os.cpu_count() or 1, 16) - - # Training paramenters - batch_size: int = 64 - max_epochs: int = 64 - patience: int = 16 - accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + params_path: Path | None = Field( + default=None, + description="Optional: Path to a YAML file with advanced training parameters.", + ) # Experimental features use_vary_precision_transform: bool = False - use_alibi: bool = False class CrossvalConfig(TrainConfig): @@ -61,3 +57,42 @@ class DeploymentConfig(BaseModel): num_workers: int = min(os.cpu_count() or 1, 16) accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + + +class VitModelParams(BaseModel): + model_config = ConfigDict(extra="forbid") + dim_model: int = 512 + dim_feedforward: int = 512 + n_heads: int = 8 + n_layers: int = 2 + dropout: float = 0.25 + # Experimental feature: Use ALiBi positional embedding + use_alibi: bool = False + + +class MlpModelParams(BaseModel): + model_config = ConfigDict(extra="forbid") + dim_hidden: int = 512 + num_layers: int = 2 + dropout: float = 0.25 + + +class ModelParams(BaseModel): + model_config = ConfigDict(extra="forbid") + vit: VitModelParams + mlp: MlpModelParams + + +class AdvancedConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + bag_size: int = 512 + num_workers: int = min(os.cpu_count() or 1, 16) + batch_size: int = 64 + max_epochs: int = 64 + patience: int = 16 + accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + model_name: ModelName | None = Field( + default=None, + description='Optional: "vit" or "mlp". Defaults based on feature type.', + ) + model_params: ModelParams diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index aaa21b49..37bdf381 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,28 +1,30 @@ import logging from collections.abc import Mapping, Sequence -from pathlib import Path from typing import Any, Final import numpy as np -from lightning.pytorch.accelerators.accelerator import Accelerator from pydantic import BaseModel from sklearn.model_selection import StratifiedKFold +from stamp.modeling.config import AdvancedConfig, CrossvalConfig from stamp.modeling.data import ( PatientData, + detect_feature_type, filter_complete_patient_data_, + load_patient_level_data, + patient_feature_dataloader, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, + tile_bag_dataloader, ) from stamp.modeling.deploy import _predict, _to_prediction_df from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.modeling.train import setup_model_for_training, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - Category, FeaturePath, GroundTruth, - PandasLabel, PatientId, ) @@ -43,58 +45,56 @@ class _Splits(BaseModel): def categorical_crossval_( - clini_table: Path, - slide_table: Path, - feature_dir: Path, - output_dir: Path, - patient_label: PandasLabel, - ground_truth_label: PandasLabel, - filename_label: PandasLabel, - categories: Sequence[Category] | None, - n_splits: int, - # Dataset and -loader parameters - bag_size: int, - num_workers: int, - # Training paramenters - batch_size: int, - max_epochs: int, - patience: int, - accelerator: str | Accelerator, - # Experimental features - use_vary_precision_transform: bool, - use_alibi: bool, + config: CrossvalConfig, + advanced: AdvancedConfig, ) -> None: - patient_to_ground_truth: Final[dict[PatientId, GroundTruth]] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, - patient_label=patient_label, + feature_type = detect_feature_type(config.feature_dir) + _logger.info(f"Detected feature type: {feature_type}") + + if feature_type == "tile": + if config.slide_table is None: + raise ValueError("A slide table is required for tile-level modeling") + patient_to_ground_truth: dict[PatientId, GroundTruth] = ( + patient_to_ground_truth_from_clini_table_( + clini_table_path=config.clini_table, + ground_truth_label=config.ground_truth_label, + patient_label=config.patient_label, + ) ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=config.slide_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + filename_label=config.filename_label, + ) ) - ) - - # Clean data (remove slides without ground truth, missing features, etc.) - patient_to_data: Final[Mapping[Category, PatientData]] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, + patient_to_data: Mapping[PatientId, PatientData] = ( + filter_complete_patient_data_( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=True, + ) ) - ) + elif feature_type == "patient": + patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( + clini_table=config.clini_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ) + patient_to_ground_truth: dict[PatientId, GroundTruth] = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + else: + raise RuntimeError(f"Unsupported feature type: {feature_type}") - output_dir.mkdir(parents=True, exist_ok=True) - splits_file = output_dir / "splits.json" + config.output_dir.mkdir(parents=True, exist_ok=True) + splits_file = config.output_dir / "splits.json" # Generate the splits, or load them from the splits file if they already exist if not splits_file.exists(): - splits = _get_splits(patient_to_data=patient_to_data, n_splits=n_splits) + splits = _get_splits(patient_to_data=patient_to_data, n_splits=config.n_splits) with open(splits_file, "w") as fp: fp.write(splits.model_dump_json(indent=4)) else: @@ -120,7 +120,7 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) - categories = categories or sorted( + categories = config.categories or sorted( { patient_data.ground_truth for patient_data in patient_to_data.values() @@ -129,7 +129,7 @@ def categorical_crossval_( ) for split_i, split in enumerate(splits.splits): - split_dir = output_dir / f"split-{split_i}" + split_dir = config.output_dir / f"split-{split_i}" if (split_dir / "patient-preds.csv").exists(): _logger.info( @@ -141,13 +141,11 @@ def categorical_crossval_( # Train the model if not (split_dir / "model.ckpt").exists(): model, train_dl, valid_dl = setup_model_for_training( - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, - ground_truth_label=ground_truth_label, - bag_size=bag_size, - num_workers=num_workers, - batch_size=batch_size, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, + ground_truth_label=config.ground_truth_label, + advanced=advanced, patient_to_data={ patient_id: patient_data for patient_id, patient_data in patient_to_data.items() @@ -165,42 +163,70 @@ def categorical_crossval_( ), train_transform=( VaryPrecisionTransform(min_fraction_bits=1) - if use_vary_precision_transform + if config.use_vary_precision_transform else None ), - use_alibi=use_alibi, + feature_type=feature_type, ) model = train_model_( output_dir=split_dir, model=model, train_dl=train_dl, valid_dl=valid_dl, - max_epochs=max_epochs, - patience=patience, - accelerator=accelerator, + max_epochs=advanced.max_epochs, + patience=advanced.patience, + accelerator=advanced.accelerator, ) else: - model = LitVisionTransformer.load_from_checkpoint(split_dir / "model.ckpt") + if feature_type == "tile": + model = LitVisionTransformer.load_from_checkpoint( + split_dir / "model.ckpt" + ) + else: + model = LitMLPClassifier.load_from_checkpoint(split_dir / "model.ckpt") # Deploy on test set if not (split_dir / "patient-preds.csv").exists(): + # Prepare test dataloader + test_patients = [ + pid for pid in split.test_patients if pid in patient_to_data + ] + test_patient_data = [patient_to_data[pid] for pid in test_patients] + if feature_type == "tile": + test_dl, _ = tile_bag_dataloader( + patient_data=test_patient_data, + bag_size=None, + categories=categories, + batch_size=1, + shuffle=False, + num_workers=advanced.num_workers, + transform=None, + ) + elif feature_type == "patient": + test_dl, _ = patient_feature_dataloader( + patient_data=test_patient_data, + categories=categories, + batch_size=1, + shuffle=False, + num_workers=advanced.num_workers, + transform=None, + ) + else: + raise RuntimeError(f"Unsupported feature type: {feature_type}") + predictions = _predict( model=model, - patient_to_data={ - patient_id: patient_data - for patient_id, patient_data in patient_to_data.items() - if patient_id in split.test_patients - }, - num_workers=num_workers, - accelerator=accelerator, + test_dl=test_dl, + patient_ids=test_patients, + accelerator=advanced.accelerator, ) _to_prediction_df( categories=categories, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, - patient_label=patient_label, - ground_truth_label=ground_truth_label, + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index d9935c35..23e3ca08 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -35,6 +35,7 @@ ) _logger = logging.getLogger("stamp") +_logged_stamp_v1_warning = False __author__ = "Marko van Treeck" @@ -56,7 +57,7 @@ class PatientData(Generic[GroundTruthType]): feature_files: Iterable[FeaturePath | BinaryIO] -def dataloader_from_patient_data( +def tile_bag_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], bag_size: int | None, @@ -69,7 +70,7 @@ def dataloader_from_patient_data( DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], Sequence[Category], ]: - """Creates a dataloader from patient data, encoding the ground truths. + """Creates a dataloader from patient data for tile-level (bagged) features. Args: categories: @@ -115,6 +116,109 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def patient_feature_dataloader( + *, + patient_data: Sequence[PatientData[GroundTruth | None]], + categories: Sequence[Category] | None = None, + batch_size: int, + shuffle: bool, + num_workers: int, + transform: Callable[[Tensor], Tensor] | None, +) -> tuple[DataLoader, Sequence[Category]]: + """ + Creates a dataloader for patient-level features (one feature vector per patient). + """ + feature_files = [next(iter(p.feature_files)) for p in patient_data] + raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) + categories = ( + categories if categories is not None else list(np.unique(raw_ground_truths)) + ) + one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) + ds = PatientFeatureDataset(feature_files, one_hot, transform=transform) + dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + return dl, categories + + +def detect_feature_type(feature_dir: Path) -> str: + """ + Detects feature type by inspecting all .h5 files in feature_dir. + + Returns: + "tile" if all files are tile-level, "patient" if all are patient-level. + If files have mixed types, raises an error. + If no .h5 files are found, raises an error. + """ + feature_types = set() + files_checked = 0 + + for file in feature_dir.glob("*.h5"): + files_checked += 1 + with h5py.File(file, "r") as h5: + feat_type = h5.attrs.get("feat_type") + encoder = h5.attrs.get("encoder") + + if feat_type is not None or encoder is not None: + feature_types.add(str(feat_type)) + else: + # If feat_type is missing, always treat as tile-level feature + feature_types.add("tile") + + if files_checked == 0: + raise RuntimeError("No .h5 feature files found in feature_dir.") + + if len(feature_types) > 1: + raise RuntimeError( + f"Multiple feature types detected in {feature_dir}: {feature_types}. " + "All feature files must have the same type." + ) + + return feature_types.pop() + + +def load_patient_level_data( + *, + clini_table: Path, + feature_dir: Path, + patient_label: PandasLabel, + ground_truth_label: PandasLabel, + feature_ext: str = ".h5", +) -> dict[PatientId, PatientData]: + """ + Loads PatientData for patient-level features, matching patients in the clinical table + to feature files in feature_dir named {patient_id}.h5. + """ + # TODO: I'm not proud at all of this. Any other alternative for mapping + # clinical data to the patient-level feature paths that avoids + # creating another slide table for encoded featuress is welcome :P. + + clini_df = read_table( + clini_table, + usecols=[patient_label, ground_truth_label], + dtype=str, + ).dropna() + + patient_to_data: dict[PatientId, PatientData] = {} + missing_features = [] + for _, row in clini_df.iterrows(): + patient_id = PatientId(str(row[patient_label])) + ground_truth = row[ground_truth_label] + feature_file = feature_dir / f"{patient_id}{feature_ext}" + if feature_file.exists(): + patient_to_data[patient_id] = PatientData( + ground_truth=ground_truth, + feature_files=[FeaturePath(feature_file)], + ) + else: + missing_features.append(patient_id) + + if missing_features: + _logger.warning( + f"Some patients have no feature file in {feature_dir}: {missing_features}" + ) + + return patient_to_data + + @dataclass class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): """A dataset of bags of instances.""" @@ -185,6 +289,47 @@ def __getitem__( ) +class PatientFeatureDataset(Dataset): + """ + Dataset for single feature vector per sample (e.g. slide-level or patient-level). + Each item is a (feature_vector, label_onehot) tuple. + """ + + def __init__( + self, + feature_files: Sequence[FeaturePath | BinaryIO], + ground_truths: Tensor, # shape: [num_samples, num_classes] + transform: Callable[[Tensor], Tensor] | None, + ): + if len(feature_files) != len(ground_truths): + raise ValueError("Number of feature files and ground truths must match.") + self.feature_files = feature_files + self.ground_truths = ground_truths + self.transform = transform + + def __len__(self): + return len(self.feature_files) + + def __getitem__(self, idx: int): + feature_file = self.feature_files[idx] + with h5py.File(feature_file, "r") as h5: + feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] + # Accept [V] or [1, V] + if feats.ndim == 2 and feats.shape[0] == 1: + feats = feats[0] + elif feats.ndim == 1: + pass + else: + raise RuntimeError( + f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." + "Check that the features are patient-level." + ) + if self.transform is not None: + feats = self.transform(feats) + label = self.ground_truths[idx] + return feats, label + + @dataclass class CoordsInfo: coords_um: np.ndarray @@ -224,9 +369,13 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: == 224 ): # Historic STAMP format - _logger.info( - f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" - ) + # TODO: find a better way to get this warning just once + global _logged_stamp_v1_warning + if not _logged_stamp_v1_warning: + _logger.info( + f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" + ) + _logged_stamp_v1_warning = True tile_size_um = Microns(256.0) tile_size_px = TilePixels(224) coords_um = coords / 224 * 256 @@ -286,7 +435,7 @@ def patient_to_ground_truth_from_clini_table_( ground_truth_label: PandasLabel, ) -> dict[PatientId, GroundTruth]: """Loads the patients and their ground truths from a clini table.""" - clini_df = _read_table( + clini_df = read_table( clini_table_path, usecols=[patient_label, ground_truth_label], dtype=str, @@ -320,7 +469,7 @@ def slide_to_patient_from_slide_table_( filename_label: PandasLabel, ) -> dict[FeaturePath, PatientId]: """Creates a slide-to-patient mapping from a slide table.""" - slide_df = _read_table( + slide_df = read_table( slide_table_path, usecols=[patient_label, filename_label], dtype=str, @@ -336,7 +485,7 @@ def slide_to_patient_from_slide_table_( return slide_to_patient -def _read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: +def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: if not isinstance(path, Path): return pd.read_csv(path, **kwargs) elif path.suffix == ".xlsx": diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 0441cd53..144dd2ce 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -11,13 +11,16 @@ from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( - PatientData, - dataloader_from_patient_data, + detect_feature_type, filter_complete_patient_data_, + load_patient_level_data, + patient_feature_dataloader, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, + tile_bag_dataloader, ) from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.types import GroundTruth, PandasLabel, PatientId __all__ = ["deploy_categorical_model_"] @@ -36,7 +39,7 @@ def deploy_categorical_model_( output_dir: Path, checkpoint_paths: Sequence[Path], clini_table: Path | None, - slide_table: Path, + slide_table: Path | None, feature_dir: Path, ground_truth_label: PandasLabel | None, patient_label: PandasLabel, @@ -44,10 +47,21 @@ def deploy_categorical_model_( num_workers: int, accelerator: str | Accelerator, ) -> None: + # --- Detect feature type and load correct model --- + feature_type = detect_feature_type(feature_dir) + _logger.info(f"Detected feature type: {feature_type}") + + if feature_type == "tile": + ModelClass = LitVisionTransformer + elif feature_type == "patient": + ModelClass = LitMLPClassifier + else: + raise RuntimeError( + f"Unsupported feature type for deployment: {feature_type}. Only 'tile' and 'patient' are supported." + ) + models = [ - LitVisionTransformer.load_from_checkpoint( - checkpoint_path=checkpoint_path - ).eval() + ModelClass.load_from_checkpoint(checkpoint_path=checkpoint_path).eval() for checkpoint_path in checkpoint_paths ] @@ -78,37 +92,77 @@ def deploy_categorical_model_( output_dir.mkdir(exist_ok=True, parents=True) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, - ) - - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] - if clini_table is not None: - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, + # --- Data loading logic --- + if feature_type == "tile": + if slide_table is None: + raise ValueError("A slide table is required for tile-level modeling") + slide_to_patient = slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, patient_label=patient_label, + filename_label=filename_label, ) - else: + if clini_table is not None: + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + else: + patient_to_ground_truth = { + patient_id: None for patient_id in set(slide_to_patient.values()) + } + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=False, + ) + test_dl, _ = tile_bag_dataloader( + patient_data=list(patient_to_data.values()), + bag_size=None, # We want all tiles to be seen by the model + categories=list(models[0].categories), + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + patient_ids = list(patient_to_data.keys()) + elif feature_type == "patient": + if slide_table is not None: + _logger.warning( + "slide_table is ignored for patient-level features during deployment." + ) + if clini_table is None: + raise ValueError( + "clini_table is required for patient-level feature deployment." + ) + patient_to_data = load_patient_level_data( + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + ) + test_dl, _ = patient_feature_dataloader( + patient_data=list(patient_to_data.values()), + categories=list(models[0].categories), + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + patient_ids = list(patient_to_data.keys()) patient_to_ground_truth = { - patient_id: None for patient_id in set(slide_to_patient.values()) + pid: pd.ground_truth for pid, pd in patient_to_data.items() } - - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=False, - ) + else: + raise RuntimeError(f"Unsupported feature type: {feature_type}") all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 for model_i, model in enumerate(models): predictions = _predict( model=model, - patient_to_data=patient_to_data, - num_workers=num_workers, + test_dl=test_dl, + patient_ids=patient_ids, accelerator=accelerator, ) all_predictions.append(predictions) @@ -130,7 +184,7 @@ def deploy_categorical_model_( patient_id: torch.stack( [predictions[patient_id] for predictions in all_predictions] ).mean(dim=0) - for patient_id in patient_to_data.keys() + for patient_id in patient_ids }, patient_label=patient_label, ground_truth_label=ground_truth_label, @@ -139,33 +193,23 @@ def deploy_categorical_model_( def _predict( *, - model: LitVisionTransformer, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None]], - num_workers: int, + model: lightning.LightningModule, + test_dl: torch.utils.data.DataLoader, + patient_ids: Sequence[PatientId], accelerator: str | Accelerator, ) -> Mapping[PatientId, Float[torch.Tensor, "category"]]: # noqa: F821 model = model.eval() torch.set_float32_matmul_precision("medium") - patients_used_for_training: set[PatientId] = set(model.train_patients) | set( - model.valid_patients - ) - if overlap := patients_used_for_training & set(patient_to_data.keys()): + # Check for data leakage + patients_used_for_training: set[PatientId] = set( + getattr(model, "train_patients", []) + ) | set(getattr(model, "valid_patients", [])) + if overlap := patients_used_for_training & set(patient_ids): raise ValueError( f"some of the patients in the validation set were used during training: {overlap}" ) - test_dl, _ = dataloader_from_patient_data( - patient_data=list(patient_to_data.values()), - bag_size=None, # Use all the tiles for deployment - # Use same encoding scheme as during training - categories=list(model.categories), - batch_size=1, - shuffle=False, - num_workers=num_workers, - transform=None, - ) - trainer = lightning.Trainer( accelerator=accelerator, devices=1, # Needs to be 1, otherwise half the predictions are missing for some reason @@ -181,7 +225,7 @@ def _predict( dim=1, ) - return dict(zip(patient_to_data, predictions, strict=True)) + return dict(zip(patient_ids, predictions, strict=True)) def _to_prediction_df( diff --git a/src/stamp/modeling/lightning_model.py b/src/stamp/modeling/lightning_model.py index c30c87ef..4849c4e1 100644 --- a/src/stamp/modeling/lightning_model.py +++ b/src/stamp/modeling/lightning_model.py @@ -57,6 +57,8 @@ class LitVisionTransformer(lightning.LightningModule): **metadata: Additional metadata to store with the model. """ + supported_features = ["tile"] + def __init__( self, *, diff --git a/src/stamp/modeling/mlp_classifier.py b/src/stamp/modeling/mlp_classifier.py new file mode 100644 index 00000000..98650f08 --- /dev/null +++ b/src/stamp/modeling/mlp_classifier.py @@ -0,0 +1,126 @@ +from collections.abc import Iterable, Sequence + +import lightning +import numpy as np +import torch +from packaging.version import Version +from torch import Tensor, nn +from torchmetrics.classification import MulticlassAUROC + +import stamp +from stamp.types import Category, PandasLabel, PatientId + + +class MLPClassifier(nn.Module): + """ + Simple MLP for classification from a single feature vector. + """ + + def __init__( + self, + dim_input: int, + dim_hidden: int, + dim_output: int, + num_layers: int, + dropout: float, + ): + super().__init__() + layers = [] + in_dim = dim_input + for i in range(num_layers - 1): + layers.append(nn.Linear(in_dim, dim_hidden)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(dropout)) + in_dim = dim_hidden + layers.append(nn.Linear(in_dim, dim_output)) + self.mlp = nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.mlp(x) + + +class LitMLPClassifier(lightning.LightningModule): + """ + PyTorch Lightning wrapper for MLPClassifier. + """ + + supported_features = ["patient"] + + def __init__( + self, + *, + categories: Sequence[Category], + category_weights: torch.Tensor, + dim_input: int, + dim_hidden: int, + num_layers: int, + dropout: float, + ground_truth_label: PandasLabel, + train_patients: Iterable[PatientId], + valid_patients: Iterable[PatientId], + stamp_version: Version = Version(stamp.__version__), + **metadata, + ): + super().__init__() + self.save_hyperparameters() + self.model = MLPClassifier( + dim_input=dim_input, + dim_hidden=dim_hidden, + dim_output=len(categories), + num_layers=num_layers, + dropout=dropout, + ) + self.class_weights = category_weights + self.valid_auroc = MulticlassAUROC(len(categories)) + self.ground_truth_label = ground_truth_label + self.categories = np.array(categories) + self.train_patients = train_patients + self.valid_patients = valid_patients + + # TODO: Add version check with version 2.2.1, for both MLP and Transformer + + def forward(self, x: Tensor) -> Tensor: + return self.model(x) + + def _step(self, batch, step_name: str): + feats, targets = batch + logits = self.model(feats) + loss = nn.functional.cross_entropy( + logits, + targets.type_as(logits), + weight=self.class_weights.type_as(logits), + ) + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + if step_name == "validation": + self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) + self.log( + f"{step_name}_auroc", + self.valid_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + return loss + + def training_step(self, batch, batch_idx): + return self._step(batch, "training") + + def validation_step(self, batch, batch_idx): + return self._step(batch, "validation") + + def test_step(self, batch, batch_idx): + return self._step(batch, "test") + + def predict_step(self, batch, batch_idx): + feats, _ = batch + return self.model(feats) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py new file mode 100644 index 00000000..7be976bd --- /dev/null +++ b/src/stamp/modeling/registry.py @@ -0,0 +1,34 @@ +from enum import StrEnum +from typing import Sequence, Type, TypedDict + +import lightning + +from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier + + +class ModelName(StrEnum): + """Enum for available model names.""" + + VIT = "vit" + MLP = "mlp" + + +class ModelInfo(TypedDict): + """A dictionary to map a model to supported feature types. For example, + a linear classifier is not compatible with tile-evel feats.""" + + model_class: Type[lightning.LightningModule] + supported_features: Sequence[str] + + +MODEL_REGISTRY: dict[ModelName, ModelInfo] = { + ModelName.VIT: { + "model_class": LitVisionTransformer, + "supported_features": LitVisionTransformer.supported_features, + }, + ModelName.MLP: { + "model_class": LitMLPClassifier, + "supported_features": LitMLPClassifier.supported_features, + }, +} diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index d273315a..ff798030 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -15,20 +15,25 @@ from sklearn.model_selection import train_test_split from torch.utils.data.dataloader import DataLoader +from stamp.modeling.config import AdvancedConfig, TrainConfig from stamp.modeling.data import ( BagDataset, PatientData, - dataloader_from_patient_data, + PatientFeatureDataset, + detect_feature_type, filter_complete_patient_data_, + load_patient_level_data, + patient_feature_dataloader, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, + tile_bag_dataloader, ) from stamp.modeling.lightning_model import ( Bags, BagSizes, EncodedTargets, - LitVisionTransformer, ) +from stamp.modeling.registry import MODEL_REGISTRY, ModelName from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import Category, CoordinatesBatch, GroundTruth, PandasLabel, PatientId @@ -41,116 +46,282 @@ def train_categorical_model_( *, - clini_table: Path, - slide_table: Path, - feature_dir: Path, - output_dir: Path, - patient_label: PandasLabel, - ground_truth_label: PandasLabel, - filename_label: PandasLabel, - categories: Sequence[Category] | None, - # Dataset and -loader parameters - bag_size: int, - num_workers: int, - # Training paramenters - batch_size: int, - max_epochs: int, - patience: int, - accelerator: str | Accelerator, - # Experimental features - use_vary_precision_transform: bool, - use_alibi: bool, + config: TrainConfig, + advanced: AdvancedConfig, ) -> None: - """Trains a model. + """Trains a model based on the feature type.""" + feature_type = detect_feature_type(config.feature_dir) + _logger.info(f"Detected feature type: {feature_type}") - Args: - clini_table: - An excel or csv file to read the clinical information from. - Must at least have the columns specified in the arguments - - `patient_label` (containing a unique patient ID) - and `ground_truth_label` (containing the ground truth to train for). - slide_table: - An excel or csv file to read the patient-slide associations from. - Must at least have the columns specified in the arguments - `patient_label` (containing the patient ID) - and `filename_label` - (containing a filename relative to `feature_dir` - in which some of the patient's features are stored). - feature_dir: - See `slide_table`. - output_dir: - Path into which to output the artifacts (trained model etc.) - generated during training. - patient_label: - See `clini_table`, `slide_table`. - ground_truth_label: - See `clini_table`. - filename_label: - See `slide_table`. - categories: - Categories of the ground truth. - Set to `None` to automatically infer. - """ - # Read and parse data from out clini and slide table - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, - patient_label=patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, - ) - - # Clean data (remove slides without ground truth, missing features, etc.) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) + if feature_type == "tile": + if config.slide_table is None: + raise ValueError("A slide table is required for tile-level modeling") + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=config.clini_table, + ground_truth_label=config.ground_truth_label, + patient_label=config.patient_label, + ) + slide_to_patient = slide_to_patient_from_slide_table_( + slide_table_path=config.slide_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + filename_label=config.filename_label, + ) + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=True, + ) + elif feature_type == "patient": + # Patient-level: ignore slide_table + if config.slide_table is not None: + _logger.warning("slide_table is ignored for patient-level features.") + patient_to_data = load_patient_level_data( + clini_table=config.clini_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ) + elif feature_type == "slide": + raise RuntimeError( + "Slide-level features are not supported for training." + "Please rerun the encoding step with patient-level encoding." + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") - # Train the model + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, - categories=categories, - bag_size=bag_size, - batch_size=batch_size, - num_workers=num_workers, - ground_truth_label=ground_truth_label, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, + categories=config.categories, + advanced=advanced, + ground_truth_label=config.ground_truth_label, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, train_transform=( VaryPrecisionTransform(min_fraction_bits=1) - if use_vary_precision_transform + if config.use_vary_precision_transform else None ), - use_alibi=use_alibi, + feature_type=feature_type, ) train_model_( - output_dir=output_dir, + output_dir=config.output_dir, model=model, train_dl=train_dl, valid_dl=valid_dl, - max_epochs=max_epochs, - patience=patience, - accelerator=accelerator, + max_epochs=advanced.max_epochs, + patience=advanced.patience, + accelerator=advanced.accelerator, + ) + + +def setup_model_for_training( + *, + patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + categories: Sequence[Category] | None, + train_transform: Callable[[torch.Tensor], torch.Tensor] | None, + feature_type: str, + advanced: AdvancedConfig, + # Metadata, has no effect on model training + ground_truth_label: PandasLabel, + clini_table: Path, + slide_table: Path | None, + feature_dir: Path, +) -> tuple[ + lightning.LightningModule, + DataLoader, + DataLoader, +]: + """Creates a model and dataloaders for training""" + + train_dl, valid_dl, train_categories, dim_feats, train_patients, valid_patients = ( + setup_dataloaders_for_training( + patient_to_data=patient_to_data, + categories=categories, + bag_size=advanced.bag_size, + batch_size=advanced.batch_size, + num_workers=advanced.num_workers, + train_transform=train_transform, + feature_type=feature_type, + ) + ) + + _logger.info( + "Training dataloaders: bag_size=%s, batch_size=%s, num_workers=%s", + advanced.bag_size, + advanced.batch_size, + advanced.num_workers, + ) + + category_weights = _compute_class_weights_and_check_categories( + train_dl=train_dl, + feature_type=feature_type, + train_categories=train_categories, + ) + + # 1. Default to a model if none is specified + if advanced.model_name is None: + advanced.model_name = ModelName.VIT if feature_type == "tile" else ModelName.MLP + _logger.info( + f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" + ) + + # 2. Validate that the chosen model supports the feature type + model_info = MODEL_REGISTRY[advanced.model_name] + if feature_type not in model_info["supported_features"]: + raise ValueError( + f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " + f"Supported types are: {model_info['supported_features']}" + ) + + # 3. Get model-specific hyperparameters + model_specific_params = advanced.model_params.model_dump()[ + advanced.model_name.value + ] + + # 4. Prepare common parameters + common_params = { + "categories": train_categories, + "category_weights": category_weights, + "dim_input": dim_feats, + # Metadata, has no effect on model training + "model_name": advanced.model_name.value, + "ground_truth_label": ground_truth_label, + "train_patients": train_patients, + "valid_patients": valid_patients, + "clini_table": clini_table, + "slide_table": slide_table, + "feature_dir": feature_dir, + } + + # 4. Instantiate the model dynamically + ModelClass = model_info["model_class"] + all_params = {**common_params, **model_specific_params} + _logger.info( + f"Instantiating model '{advanced.model_name.value}' with parameters: {model_specific_params}" + ) + model = ModelClass(**all_params) + + return model, train_dl, valid_dl + + +def setup_dataloaders_for_training( + *, + patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + categories: Sequence[Category] | None, + bag_size: int, + batch_size: int, + num_workers: int, + train_transform: Callable[[torch.Tensor], torch.Tensor] | None, + feature_type: str, +) -> tuple[ + DataLoader, + DataLoader, + Sequence[Category], + int, + Sequence[PatientId], + Sequence[PatientId], +]: + """ + Creates train/val dataloaders for tile-level or patient-level features. + + Returns: + train_dl, valid_dl, categories, feature_dim, train_patients, valid_patients + """ + # Stratified split + ground_truths = [ + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + ] + if len(ground_truths) != len(patient_to_data): + raise ValueError( + "patient_to_data must have a ground truth defined for all targets!" + ) + + train_patients, valid_patients = cast( + tuple[Sequence[PatientId], Sequence[PatientId]], + train_test_split( + list(patient_to_data), stratify=ground_truths, shuffle=True, random_state=0 + ), ) + if feature_type == "tile": + # Use existing BagDataset logic + train_dl, train_categories = tile_bag_dataloader( + patient_data=[patient_to_data[pid] for pid in train_patients], + categories=categories, + bag_size=bag_size, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + transform=train_transform, + ) + valid_dl, _ = tile_bag_dataloader( + patient_data=[patient_to_data[pid] for pid in valid_patients], + bag_size=None, + categories=train_categories, + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + bags, _, _, _ = next(iter(train_dl)) + dim_feats = bags.shape[-1] + return ( + train_dl, + valid_dl, + train_categories, + dim_feats, + train_patients, + valid_patients, + ) + + elif feature_type == "patient": + train_dl, train_categories = patient_feature_dataloader( + patient_data=[patient_to_data[pid] for pid in train_patients], + categories=categories, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + transform=train_transform, + ) + valid_dl, _ = patient_feature_dataloader( + patient_data=[patient_to_data[pid] for pid in valid_patients], + categories=train_categories, + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + feats, _ = next(iter(train_dl)) + dim_feats = feats.shape[-1] + return ( + train_dl, + valid_dl, + train_categories, + dim_feats, + train_patients, + valid_patients, + ) + else: + raise RuntimeError( + f"Unsupported feature type: {feature_type}. Only 'tile' and 'patient' are supported." + ) + def train_model_( *, output_dir: Path, - model: LitVisionTransformer, + model: lightning.LightningModule, train_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], valid_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], max_epochs: int, patience: int, accelerator: str | Accelerator, -) -> LitVisionTransformer: +) -> lightning.LightningModule: """Trains a model. Returns: @@ -188,81 +359,28 @@ def train_model_( ) shutil.copy(model_checkpoint.best_model_path, output_dir / "model.ckpt") - return LitVisionTransformer.load_from_checkpoint(model_checkpoint.best_model_path) + # Reload the best model using the same class as the input model + ModelClass = type(model) + return ModelClass.load_from_checkpoint(model_checkpoint.best_model_path) -def setup_model_for_training( +def _compute_class_weights_and_check_categories( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], - categories: Sequence[Category] | None, - bag_size: int, - batch_size: int, - num_workers: int, - train_transform: Callable[[torch.Tensor], torch.Tensor] | None, - use_alibi: bool, - # Metadata, has no effect on model training - ground_truth_label: PandasLabel, - clini_table: Path, - slide_table: Path, - feature_dir: Path, -) -> tuple[ - LitVisionTransformer, - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], -]: - """Creates a model and dataloaders for training""" - - # Do a stratified train-validation split - ground_truths = [ - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - ] - - if len(ground_truths) != len(patient_to_data): - raise ValueError( - "patient_to_data must have a ground truth defined for all targets!" - ) - - train_patients, valid_patients = cast( - tuple[Sequence[PatientId], Sequence[PatientId]], - train_test_split( - list(patient_to_data), stratify=ground_truths, shuffle=True, random_state=0 - ), - ) - - train_dl, train_categories = dataloader_from_patient_data( - patient_data=[patient_to_data[patient] for patient in train_patients], - categories=categories, - bag_size=bag_size, - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - transform=train_transform, - ) - del categories # Let's not accidentally reuse the original categories - valid_dl, _ = dataloader_from_patient_data( - patient_data=[patient_to_data[patient] for patient in valid_patients], - bag_size=None, # Use all the patient data for validation - categories=train_categories, - batch_size=1, - shuffle=False, - num_workers=num_workers, - transform=None, - ) - if overlap := set(train_patients) & set(valid_patients): - raise RuntimeError( - f"unreachable: unexpected overlap between training and validation set: {overlap}" - ) - - # Sample one bag to infer the input dimensions of the model - bags, coords, bag_sizes, targets = cast( - tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], next(iter(train_dl)) - ) - _, _, dim_feats = bags.shape - - # Weigh classes inversely to their occurrence - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + train_dl: DataLoader, + feature_type: str, + train_categories: Sequence[str], +) -> torch.Tensor: + """ + Computes class weights and checks for category issues. + Logs warnings if there are too few or underpopulated categories. + Returns normalized category weights as a torch.Tensor. + """ + if feature_type == "tile": + category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + else: + category_counts = cast( + PatientFeatureDataset, train_dl.dataset + ).ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() @@ -270,7 +388,7 @@ def setup_model_for_training( raise ValueError(f"not enough categories to train on: {train_categories}") elif any(category_counts < 16): underpopulated_categories = { - category: count + category: int(count) for category, count in zip(train_categories, category_counts, strict=True) if count < 16 } @@ -278,25 +396,4 @@ def setup_model_for_training( f"Some categories do not have enough samples to meaningfully train a model: {underpopulated_categories}. " "You may want to consider removing these categories; the model will likely overfit on the few samples available." ) - - # Train the model - model = LitVisionTransformer( - categories=train_categories, - category_weights=category_weights, - dim_input=dim_feats, - dim_model=512, - dim_feedforward=2048, - n_heads=8, - n_layers=2, - dropout=0.25, - use_alibi=use_alibi, - # Metadata, has no effect on model training - ground_truth_label=ground_truth_label, - train_patients=train_patients, - valid_patients=valid_patients, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, - ) - - return model, train_dl, valid_dl + return category_weights diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index 6e2bdc66..f3eb4085 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -326,6 +326,7 @@ def extract_( h5_fp.attrs["tile_size_um"] = tile_size_um # changed in v2.1.0 h5_fp.attrs["tile_size_px"] = tile_size_px h5_fp.attrs["code_hash"] = code_hash + h5_fp.attrs["feat_type"] = "tile" except Exception: _logger.exception(f"error while writing {feature_output_path}") if tmp_h5_file is not None: diff --git a/tests/random_data.py b/tests/random_data.py index f6777b23..d79a161b 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -16,12 +16,7 @@ import stamp from stamp.preprocessing.config import ExtractorName -from stamp.types import ( - Category, - Microns, - PatientId, - TilePixels, -) +from stamp.types import Category, FeaturePath, Microns, PatientId, TilePixels CliniPath: TypeAlias = Path SlidePath: TypeAlias = Path @@ -99,6 +94,56 @@ def create_random_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_patient_level_dataset( + *, + dir: Path, + n_patients: int, + feat_dim: int, + categories: Sequence[str] | None = None, + n_categories: int | None = None, +) -> tuple[Path, Path, Path, Sequence[str]]: + """ + Creates a random dataset with one .h5 file per patient (patient-level features). + Returns (clini_path, slide_path, feat_dir, categories). + slide_path is a dummy file (not used for patient-level). + """ + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" # Not used, but keep interface consistent + feat_dir = dir / "feats" + feat_dir.mkdir() + + if categories is not None: + if n_categories is not None: + raise ValueError("only one of `categories` and `n_categories` can be set") + else: + if n_categories is None: + raise ValueError( + "either `categories` or `n_categories` has to be specified" + ) + categories = [random_string(8) for _ in range(n_categories)] + + patient_to_ground_truth = {} + for _ in range(n_patients): + patient_id = random_string(16) + patient_to_ground_truth[patient_id] = random.choice(categories) + # Create a single feature vector per patient + create_random_patient_level_feature_file( + tmp_path=feat_dir, + feat_dim=feat_dim, + feat_filename=patient_id, + ) + + pd.DataFrame( + patient_to_ground_truth.items(), + columns=["patient", "ground-truth"], + ).to_csv(clini_path, index=False) + + # slide_path is not used for patient-level, but return a dummy file for API compatibility + pd.DataFrame(columns=["slide_path", "patient"]).to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories + + def create_random_feature_file( *, tmp_path: Path, @@ -110,7 +155,7 @@ def create_random_feature_file( extractor_name: ExtractorName | str = "random-test-generator", feat_filename: str | None = None, coords: np.ndarray | None = None, -) -> Path: +) -> FeaturePath: """Creates a h5 file with random contents. Args: @@ -139,7 +184,38 @@ def create_random_feature_file( h5_file.attrs["unit"] = "um" h5_file.attrs["tile_size_um"] = tile_size_um h5_file.attrs["tile_size_px"] = tile_size_px - return Path(feature_file_path) + return FeaturePath(feature_file_path) + + +def create_random_patient_level_feature_file( + *, + tmp_path: Path, + feat_dim: int, + feat_filename: str | None = None, + encoder: str = "test-encoder", + precision: str = "float32", + feat_type: str = "patient", + code_hash: str = "testhash", + version: str | None = None, +) -> FeaturePath: + """ + Creates a random patient-level feature .h5 file with the correct metadata. + Returns the path to the created file. + """ + if feat_filename is None: + feat_filename = random_string(16) + feature_file_path = tmp_path / f"{feat_filename}.h5" + feats = torch.rand(1, feat_dim) + version = version or stamp.__version__ + with h5py.File(feature_file_path, "w") as h5: + h5["feats"] = feats.numpy() + h5.attrs["version"] = version + h5.attrs["encoder"] = encoder + h5.attrs["precision"] = precision + h5.attrs["stamp_version"] = version + h5.attrs["code_hash"] = code_hash + h5.attrs["feat_type"] = feat_type + return FeaturePath(feature_file_path) def random_patient_preds(*, n_patients: int, categories: list[str]) -> pd.DataFrame: @@ -200,5 +276,32 @@ def make_feature_file( h5.attrs["unit"] = "um" h5.attrs["tile_size_um"] = tile_size_um h5.attrs["tile_size_px"] = tile_size_px + h5.attrs["feat_type"] = "tile" + + return file + +def make_patient_level_feature_file( + *, + feats: torch.Tensor, + encoder: str = "test-encoder", + precision: str = "float32", + code_hash: str = "testhash", + version: str | None = None, +) -> io.BytesIO: + """ + Creates an in-memory patient-level feature .h5 file with the correct metadata. + Returns a BytesIO object. + """ + version = version or stamp.__version__ + file = io.BytesIO() + with h5py.File(file, "w") as h5: + h5["feats"] = feats.numpy() + h5.attrs["version"] = version + h5.attrs["encoder"] = encoder + h5.attrs["precision"] = precision + h5.attrs["stamp_version"] = version + h5.attrs["code_hash"] = code_hash + h5.attrs["feat_type"] = "patient" + file.seek(0) return file diff --git a/tests/test_config.py b/tests/test_config.py index c9a0c4fc..dafdd58c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,15 @@ from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig -from stamp.modeling.config import CrossvalConfig, DeploymentConfig, TrainConfig +from stamp.modeling.config import ( + AdvancedConfig, + CrossvalConfig, + DeploymentConfig, + MlpModelParams, + ModelParams, + TrainConfig, + VitModelParams, +) from stamp.preprocessing.config import ( ExtractorName, Microns, @@ -18,26 +26,18 @@ def test_config_parsing() -> None: config = StampConfig.model_validate( { "crossval": { - "accelerator": "gpu", - "bag_size": 512, - "batch_size": 64, "categories": None, "clini_table": "clini.xlsx", "feature_dir": "CRC", "filename_label": "FILENAME", "ground_truth_label": "isMSIH", - "max_epochs": 64, - "n_splits": 5, - "num_workers": 16, "output_dir": "test-crossval", - "patience": 16, "patient_label": "PATIENT", "slide_table": "slide.csv", - "use_alibi": True, "use_vary_precision_transform": False, + "n_splits": 5, }, "deployment": { - "accelerator": "gpu", "checkpoint_paths": [ "test-crossval/split-0/model.ckpt", "test-crossval/split-1/model.ckpt", @@ -49,7 +49,6 @@ def test_config_parsing() -> None: "feature_dir": "CRC", "filename_label": "FILENAME", "ground_truth_label": "isMSIH", - "num_workers": 16, "output_dir": "test-deploy", "patient_label": "PATIENT", "slide_table": "slide.csv", @@ -91,23 +90,39 @@ def test_config_parsing() -> None: "true_class": "MSIH", }, "training": { - "accelerator": "gpu", - "bag_size": 512, - "batch_size": 64, "categories": None, "clini_table": "clini.xlsx", "feature_dir": "CRC", "filename_label": "FILENAME", "ground_truth_label": "isMSIH", - "max_epochs": 64, - "num_workers": 16, "output_dir": "test-alibi", - "patience": 16, "patient_label": "PATIENT", "slide_table": "slide.csv", - "use_alibi": True, "use_vary_precision_transform": False, }, + "advanced_config": { + "bag_size": 512, + "num_workers": 16, + "batch_size": 64, + "max_epochs": 64, + "patience": 16, + "accelerator": "gpu", + "model_params": { + "vit": { + "dim_model": 512, + "dim_feedforward": 512, + "n_heads": 8, + "n_layers": 2, + "dropout": 0.25, + "use_alibi": True, + }, + "mlp": { + "dim_hidden": 512, + "num_layers": 2, + "dropout": 0.25, + }, + }, + }, } ) @@ -134,14 +149,8 @@ def test_config_parsing() -> None: categories=None, patient_label="PATIENT", filename_label="FILENAME", - bag_size=512, - num_workers=16, - batch_size=64, - max_epochs=64, - patience=16, - accelerator="gpu", + params_path=None, use_vary_precision_transform=False, - use_alibi=True, ), crossval=CrossvalConfig( output_dir=Path("test-crossval"), @@ -152,14 +161,8 @@ def test_config_parsing() -> None: categories=None, patient_label="PATIENT", filename_label="FILENAME", - bag_size=512, - num_workers=16, - batch_size=64, - max_epochs=64, - patience=16, - accelerator="gpu", + params_path=None, use_vary_precision_transform=False, - use_alibi=True, n_splits=5, ), deployment=DeploymentConfig( @@ -177,8 +180,6 @@ def test_config_parsing() -> None: ground_truth_label="isMSIH", patient_label="PATIENT", filename_label="FILENAME", - num_workers=16, - accelerator="gpu", ), statistics=StatsConfig( output_dir=Path("test-stats"), @@ -213,4 +214,27 @@ def test_config_parsing() -> None: bottomk=5, default_slide_mpp=SlideMPP(1.0), ), + advanced_config=AdvancedConfig( + bag_size=512, + num_workers=16, + batch_size=64, + max_epochs=64, + patience=16, + accelerator="gpu", + model_params=ModelParams( + vit=VitModelParams( + dim_model=512, + dim_feedforward=512, + n_heads=8, + n_layers=2, + dropout=0.25, + use_alibi=True, + ), + mlp=MlpModelParams( + dim_hidden=512, + num_layers=2, + dropout=0.25, + ), + ), + ), ) diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 880b074e..342e39fc 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -5,21 +5,29 @@ import numpy as np import pytest import torch -from random_data import create_random_dataset +from random_data import create_random_dataset, create_random_patient_level_dataset +from stamp.modeling.config import ( + AdvancedConfig, + CrossvalConfig, + MlpModelParams, + ModelParams, + VitModelParams, +) from stamp.modeling.crossval import categorical_crossval_ @pytest.mark.slow @pytest.mark.filterwarnings("ignore:No positive samples in targets") +@pytest.mark.parametrize("feature_type", ["tile", "patient"]) def test_crossval_integration( - *, tmp_path: Path, - n_patients: int = 800, + feature_type: str, + n_patients: int = 80, max_slides_per_patient: int = 3, min_tiles_per_slide: int = 8, - max_tiles_per_slide: int = 2**10, - feat_dim: int = 25, + max_tiles_per_slide: int = 32, + feat_dim: int = 8, n_categories: int = 3, use_alibi: bool = False, use_vary_precision_transform: bool = False, @@ -28,27 +36,44 @@ def test_crossval_integration( torch.manual_seed(0) np.random.seed(0) - clini_path, slide_path, feature_dir, categories = create_random_dataset( - dir=tmp_path, - n_categories=n_categories, - n_patients=n_patients, - max_slides_per_patient=max_slides_per_patient, - min_tiles_per_slide=min_tiles_per_slide, - max_tiles_per_slide=max_tiles_per_slide, - feat_dim=feat_dim, - ) + if feature_type == "tile": + clini_path, slide_path, feature_dir, categories = create_random_dataset( + dir=tmp_path, + n_categories=n_categories, + n_patients=n_patients, + max_slides_per_patient=max_slides_per_patient, + min_tiles_per_slide=min_tiles_per_slide, + max_tiles_per_slide=max_tiles_per_slide, + feat_dim=feat_dim, + ) + elif feature_type == "patient": + clini_path, slide_path, feature_dir, categories = ( + create_random_patient_level_dataset( + dir=tmp_path, + n_categories=n_categories, + n_patients=n_patients, + feat_dim=feat_dim, + ) + ) + else: + raise ValueError(f"Unknown feature_type: {feature_type}") output_dir = tmp_path / "output" - categorical_crossval_( + config = CrossvalConfig( clini_table=clini_path, slide_table=slide_path, - feature_dir=feature_dir, output_dir=output_dir, patient_label="patient", ground_truth_label="ground-truth", filename_label="slide_path", categories=categories, + feature_dir=feature_dir, + n_splits=2, + use_vary_precision_transform=use_vary_precision_transform, + ) + + advanced = AdvancedConfig( # Dataset and -loader parameters bag_size=max_tiles_per_slide // 2, num_workers=min(os.cpu_count() or 1, 7), @@ -57,8 +82,16 @@ def test_crossval_integration( max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", - n_splits=2, # Experimental features - use_vary_precision_transform=use_vary_precision_transform, - use_alibi=use_alibi, + model_params=ModelParams( + vit=VitModelParams( + use_alibi=use_alibi, + ), + mlp=MlpModelParams(), + ), + ) + + categorical_crossval_( + config=config, + advanced=advanced, ) diff --git a/tests/test_data.py b/tests/test_data.py index 4e15678f..edd678ee 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -5,13 +5,18 @@ import h5py import pytest import torch -from random_data import make_feature_file, make_old_feature_file +from random_data import ( + create_random_patient_level_feature_file, + make_feature_file, + make_old_feature_file, +) from torch.utils.data import DataLoader from stamp.modeling.data import ( BagDataset, CoordsInfo, PatientData, + PatientFeatureDataset, filter_complete_patient_data_, get_coords, ) @@ -80,7 +85,7 @@ def test_get_cohort_df(tmp_path: Path) -> None: "feature_file_creator", [make_feature_file, make_old_feature_file], ) -def test_dataset( +def test_bag_dataset( feature_file_creator, bag_size: BagSize = BagSize(5), dim_feats: int = 34, @@ -125,6 +130,32 @@ def test_dataset( assert (bag_sizes <= bag_size).all() +def test_patient_feature_dataset( + tmp_path: Path, dim_feats: int = 16, batch_size: int = 2 +) -> None: + # Create 3 random patient-level feature files on disk + files = [ + create_random_patient_level_feature_file(tmp_path=tmp_path, feat_dim=dim_feats) + for _ in range(3) + ] + # One-hot encoded labels for 3 samples, 4 categories + labels = torch.eye(4)[:3] + + ds = PatientFeatureDataset(files, labels, transform=None) + assert len(ds) == 3 + + # Test single dataset item + feats, label = ds[0] + assert feats.shape == (dim_feats,) + assert torch.allclose(label, labels[0]) + + # Test batching + dl = DataLoader(ds, batch_size=batch_size, shuffle=False) + feats_batch, labels_batch = next(iter(dl)) + assert feats_batch.shape == (batch_size, dim_feats) + assert labels_batch.shape == (batch_size, 4) + + def test_get_coords_with_mpp() -> None: # Test new feature file with valid mpp calculation file_bytes = make_feature_file( diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 2848bdfa..8bb0d6ec 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,18 +1,27 @@ +from pathlib import Path + import numpy as np -import numpy.typing as npt import pytest import torch -from random_data import make_old_feature_file +from random_data import create_random_patient_level_feature_file, make_old_feature_file -from stamp.modeling.data import PatientData +from stamp.modeling.data import ( + PatientData, + patient_feature_dataloader, + tile_bag_dataloader, +) from stamp.modeling.deploy import _predict, _to_prediction_df from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.types import GroundTruth, PatientId @pytest.mark.filterwarnings("ignore:GPU available but not used") +@pytest.mark.filterwarnings( + "ignore:The 'predict_dataloader' does not have many workers which may be a bottleneck" +) def test_predict( - categories: npt.NDArray = np.array(["foo", "bar", "baz"]), + categories: list[str] = ["foo", "bar", "baz"], n_heads: int = 7, dim_input: int = 12, ) -> None: @@ -42,10 +51,20 @@ def test_predict( ) } + test_dl, _ = tile_bag_dataloader( + patient_data=list(patient_to_data.values()), + bag_size=None, + categories=list(model.categories), + batch_size=1, + shuffle=False, + num_workers=2, + transform=None, + ) + predictions = _predict( model=model, - patient_to_data=patient_to_data, - num_workers=2, + test_dl=test_dl, + patient_ids=list(patient_to_data.keys()), accelerator="cpu", ) @@ -75,10 +94,20 @@ def test_predict( ), } + more_test_dl, _ = tile_bag_dataloader( + patient_data=list(more_patients_to_data.values()), + bag_size=None, + categories=list(model.categories), + batch_size=1, + shuffle=False, + num_workers=2, + transform=None, + ) + more_predictions = _predict( model=model, - patient_to_data=more_patients_to_data, - num_workers=2, + test_dl=more_test_dl, + patient_ids=list(more_patients_to_data.keys()), accelerator="cpu", ) @@ -91,6 +120,105 @@ def test_predict( ), "the same inputs should repeatedly yield the same results" +def test_predict_patient_level( + tmp_path: Path, categories: list[str] = ["foo", "bar", "baz"], dim_feats: int = 12 +): + model = LitMLPClassifier( + categories=categories, + category_weights=torch.rand(len(categories)), + dim_input=dim_feats, + dim_hidden=32, + num_layers=2, + dropout=0.2, + ground_truth_label="test", + train_patients=["pat1", "pat2"], + valid_patients=["pat3", "pat4"], + ) + + # Create 3 random patient-level feature files on disk + patient_ids = [PatientId(f"pat{i}") for i in range(5, 8)] + labels = ["foo", "bar", "baz"] + files = [ + create_random_patient_level_feature_file( + tmp_path=tmp_path, feat_dim=dim_feats, feat_filename=str(pid) + ) + for pid in patient_ids + ] + patient_to_data = { + pid: PatientData( + ground_truth=label, + feature_files={file}, + ) + for pid, label, file in zip(patient_ids, labels, files) + } + + test_dl, _ = patient_feature_dataloader( + patient_data=list(patient_to_data.values()), + categories=categories, + batch_size=1, + shuffle=False, + num_workers=1, + transform=None, + ) + + predictions = _predict( + model=model, + test_dl=test_dl, + patient_ids=patient_ids, + accelerator="cpu", + ) + + assert len(predictions) == len(patient_to_data) + for pid in patient_ids: + assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + + # Check if scores are consistent between runs and different for different patients + more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] + more_labels = ["foo", "bar", "baz"] + more_files = [ + create_random_patient_level_feature_file( + tmp_path=tmp_path, feat_dim=dim_feats, feat_filename=str(pid) + ) + for pid in more_patient_ids + ] + more_patient_to_data = { + pid: PatientData( + ground_truth=label, + feature_files={file}, + ) + for pid, label, file in zip(more_patient_ids, more_labels, more_files) + } + # Add the original patient for repeatability check + all_patient_ids = more_patient_ids + [patient_ids[0]] + + more_test_dl, _ = patient_feature_dataloader( + patient_data=[more_patient_to_data[pid] for pid in more_patient_ids] + + [patient_to_data[patient_ids[0]]], + categories=categories, + batch_size=1, + shuffle=False, + num_workers=1, + transform=None, + ) + + more_predictions = _predict( + model=model, + test_dl=more_test_dl, + patient_ids=all_patient_ids, + accelerator="cpu", + ) + + assert len(more_predictions) == len(all_patient_ids) + # Different patients should give different results + assert not torch.allclose( + more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + ), "different inputs should give different results" + # The same patient should yield the same result + assert torch.allclose( + predictions[patient_ids[0]], more_predictions[patient_ids[0]] + ), "the same inputs should repeatedly yield the same results" + + def test_to_prediction_df() -> None: n_heads = 7 model = LitVisionTransformer( diff --git a/tests/test_deployment_backward_compatability.py b/tests/test_deployment_backward_compatibility.py similarity index 64% rename from tests/test_deployment_backward_compatability.py rename to tests/test_deployment_backward_compatibility.py index 53a97829..0594d450 100644 --- a/tests/test_deployment_backward_compatability.py +++ b/tests/test_deployment_backward_compatibility.py @@ -2,16 +2,16 @@ import torch from stamp.cache import download_file -from stamp.modeling.data import PatientData +from stamp.modeling.data import PatientData, tile_bag_dataloader from stamp.modeling.deploy import _predict from stamp.modeling.lightning_model import LitVisionTransformer -from stamp.types import FeaturePath +from stamp.types import FeaturePath, PatientId @pytest.mark.filterwarnings( "ignore:The 'predict_dataloader' does not have many workers" ) -def test_backwards_compatability() -> None: +def test_backwards_compatibility() -> None: example_checkpoint_path = download_file( url="https://github.com/KatherLab/STAMP/releases/download/2.0.0-dev8/example-model.ckpt", file_name="example-model.ckpt", @@ -25,15 +25,28 @@ def test_backwards_compatability() -> None: model = LitVisionTransformer.load_from_checkpoint(example_checkpoint_path) + # Prepare PatientData and DataLoader for the test patient + patient_id = PatientId("TestPatient") + patient_to_data = { + patient_id: PatientData( + ground_truth=None, + feature_files=[FeaturePath(example_feature_path)], + ) + } + test_dl, _ = tile_bag_dataloader( + patient_data=list(patient_to_data.values()), + bag_size=None, + categories=list(model.categories), + batch_size=1, + shuffle=False, + num_workers=1, + transform=None, + ) + predictions = _predict( model=model, - patient_to_data={ - "TestPatient": PatientData( - ground_truth=None, - feature_files=[FeaturePath(example_feature_path)], - ) - }, - num_workers=1, + test_dl=test_dl, + patient_ids=[patient_id], accelerator="gpu" if torch.cuda.is_available() else "cpu", ) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 43e3e1f8..21d87a02 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -39,7 +39,7 @@ EncoderName.GIGAPATH: ExtractorName.GIGAPATH, EncoderName.MADELEINE: ExtractorName.CONCH, EncoderName.TITAN: ExtractorName.CONCH1_5, - # EncoderName.PRISM: ExtractorName.VIRCHOW_FULL, + EncoderName.PRISM: ExtractorName.VIRCHOW_FULL, } diff --git a/tests/test_model.py b/tests/test_model.py index af30aa40..ccdd4aa5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.modeling.vision_transformer import VisionTransformer @@ -69,3 +70,52 @@ def test_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_mlp_classifier_dims( + num_classes: int = 3, + batch_size: int = 6, + input_dim: int = 32, + dim_hidden: int = 64, + num_layers: int = 2, +) -> None: + model = LitMLPClassifier( + categories=[str(i) for i in range(num_classes)], + category_weights=torch.ones(num_classes), + dim_input=input_dim, + dim_hidden=dim_hidden, + num_layers=num_layers, + dropout=0.1, + ground_truth_label="test", + train_patients=["pat1", "pat2"], + valid_patients=["pat3", "pat4"], + ) + feats = torch.rand((batch_size, input_dim)) + logits = model.forward(feats) + assert logits.shape == (batch_size, num_classes) + + +def test_mlp_inference_reproducibility( + num_classes: int = 4, + batch_size: int = 7, + input_dim: int = 33, + dim_hidden: int = 64, + num_layers: int = 3, +) -> None: + model = LitMLPClassifier( + categories=[str(i) for i in range(num_classes)], + category_weights=torch.ones(num_classes), + dim_input=input_dim, + dim_hidden=dim_hidden, + num_layers=num_layers, + dropout=0.1, + ground_truth_label="test", + train_patients=["pat1", "pat2"], + valid_patients=["pat3", "pat4"], + ) + model = model.eval() + feats = torch.rand((batch_size, input_dim)) + with torch.inference_mode(): + logits1 = model.forward(feats) + logits2 = model.forward(feats) + assert torch.allclose(logits1, logits2) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 62ad9717..03d48c48 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -5,8 +5,15 @@ import numpy as np import pytest import torch -from random_data import create_random_dataset +from random_data import create_random_dataset, create_random_patient_level_dataset +from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + TrainConfig, + VitModelParams, +) from stamp.modeling.deploy import deploy_categorical_model_ from stamp.modeling.train import train_categorical_model_ @@ -56,7 +63,7 @@ def test_train_deploy_integration( feat_dim=feat_dim, ) - train_categorical_model_( + config = TrainConfig( clini_table=train_clini_path, slide_table=train_slide_path, feature_dir=train_feature_dir, @@ -65,6 +72,10 @@ def test_train_deploy_integration( ground_truth_label="ground-truth", filename_label="slide_path", categories=categories, + use_vary_precision_transform=use_vary_precision_transform, + ) + + advanced = AdvancedConfig( # Dataset and -loader parameters bag_size=500, num_workers=min(os.cpu_count() or 1, 16), @@ -73,11 +84,13 @@ def test_train_deploy_integration( max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", - # Experimental features - use_vary_precision_transform=use_vary_precision_transform, - use_alibi=use_alibi, + model_params=ModelParams( + vit=VitModelParams(use_alibi=use_alibi), mlp=MlpModelParams() + ), ) + train_categorical_model_(config=config, advanced=advanced) + deploy_categorical_model_( output_dir=tmp_path / "deploy_output", checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], @@ -90,3 +103,89 @@ def test_train_deploy_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +@pytest.mark.parametrize( + "use_alibi,use_vary_precision_transform", + [ + pytest.param(False, False, id="no experimental features"), + pytest.param(True, False, id="use alibi"), + pytest.param(False, True, id="use vary_precision_transform"), + ], +) +def test_train_deploy_patient_level_integration( + *, + tmp_path: Path, + feat_dim: int = 25, + use_alibi: bool, + use_vary_precision_transform: bool, +) -> None: + random.seed(0) + torch.manual_seed(0) + np.random.seed(0) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + train_clini_path, train_slide_path, train_feature_dir, categories = ( + create_random_patient_level_dataset( + dir=tmp_path / "train", + n_categories=3, + n_patients=400, + feat_dim=feat_dim, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_patient_level_dataset( + dir=tmp_path / "deploy", + categories=categories, + n_patients=50, + feat_dim=feat_dim, + ) + ) + + config = TrainConfig( + clini_table=train_clini_path, + slide_table=None, # Not needed for patient-level + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label="ground-truth", + filename_label="slide_path", # Not used for patient-level + categories=categories, + use_vary_precision_transform=use_vary_precision_transform, + ) + + advanced = AdvancedConfig( + # Dataset and -loader parameters + bag_size=1, # Not used for patient-level, but required by signature + num_workers=min(os.cpu_count() or 1, 16), + # Training paramenters + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + model_params=ModelParams( + vit=VitModelParams(use_alibi=use_alibi), mlp=MlpModelParams() + ), + ) + + train_categorical_model_( + config=config, + advanced=advanced, + ) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=None, # Not needed for patient-level + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label="ground-truth", + filename_label="slide_path", # Not used for patient-level + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + )