From d0f6cdb5f840c1c80fae1efe9ad30b0fd6421f39 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Tue, 8 Jul 2025 13:19:45 +0100 Subject: [PATCH 01/18] add encoded feature dataset --- src/stamp/modeling/data.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index d9935c35..a8f68210 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -185,6 +185,51 @@ def __getitem__( ) +class SingleFeatureDataset(Dataset): + """ + Dataset for single feature vector per sample (e.g. slide-level or patient-level). + Each item is a (feature_vector, label_onehot) tuple. + """ + + def __init__( + self, + feature_files: Sequence[FeaturePath | BinaryIO], + ground_truths: Tensor, # shape: [num_samples, num_classes] + transform: Callable[[Tensor], Tensor] | None, + ): + if len(feature_files) != len(ground_truths): + raise ValueError("Number of feature files and ground truths must match.") + self.feature_files = feature_files + self.ground_truths = ground_truths + self.transform = transform + + def __len__(self): + return len(self.feature_files) + + def __getitem__(self, idx: int): + feature_file = self.feature_files[idx] + with h5py.File(feature_file, "r") as h5: + feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] + # Accept [V] or [1, V] + if feats.ndim == 2 and feats.shape[0] == 1: + feats = feats[0] + elif feats.ndim == 1: + pass + else: + raise RuntimeError( + f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}" + ) + if self.transform is not None: + feats = self.transform(feats) + label = self.ground_truths[idx] + return feats, label + + +# Aliases for clarity +PatientDataset = SingleFeatureDataset +SlideDataset = SingleFeatureDataset + + @dataclass class CoordsInfo: coords_um: np.ndarray From 25b04840b54ed8c47cd38f38cf381e2b0205ea67 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Tue, 8 Jul 2025 13:20:26 +0100 Subject: [PATCH 02/18] add MLP for encoded features --- src/stamp/modeling/mlp_classifier.py | 124 +++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 src/stamp/modeling/mlp_classifier.py diff --git a/src/stamp/modeling/mlp_classifier.py b/src/stamp/modeling/mlp_classifier.py new file mode 100644 index 00000000..52d575a4 --- /dev/null +++ b/src/stamp/modeling/mlp_classifier.py @@ -0,0 +1,124 @@ +from typing import Iterable, Sequence + +import lightning +import numpy as np +import torch +from packaging.version import Version +from torch import Tensor, nn +from torchmetrics.classification import MulticlassAUROC + +import stamp +from stamp.types import Category, PandasLabel, PatientId + + +class MLPClassifier(nn.Module): + """ + Simple MLP for classification from a single feature vector. + """ + + def __init__( + self, + dim_input: int, + dim_hidden: int, + dim_output: int, + num_layers: int, + dropout: float, + ): + super().__init__() + layers = [] + in_dim = dim_input + for i in range(num_layers - 1): + layers.append(nn.Linear(in_dim, dim_hidden)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(dropout)) + in_dim = dim_hidden + layers.append(nn.Linear(in_dim, dim_output)) + self.mlp = nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.mlp(x) + + +class LitMLPClassifier(lightning.LightningModule): + """ + PyTorch Lightning wrapper for MLPClassifier. + """ + + def __init__( + self, + *, + categories: Sequence[Category], + category_weights: torch.Tensor, + dim_input: int, + dim_hidden: int, + num_layers: int, + dropout: float, + ground_truth_label: PandasLabel, + train_patients: Iterable[PatientId], + valid_patients: Iterable[PatientId], + stamp_version: Version = Version(stamp.__version__), + **metadata, + ): + super().__init__() + self.save_hyperparameters(ignore=["category_weights"]) + self.model = MLPClassifier( + dim_input=dim_input, + dim_hidden=dim_hidden, + dim_output=len(categories), + num_layers=num_layers, + dropout=dropout, + ) + self.class_weights = category_weights + self.valid_auroc = MulticlassAUROC(len(categories)) + self.ground_truth_label = ground_truth_label + self.categories = np.array(categories) + self.train_patients = train_patients + self.valid_patients = valid_patients + + # TODO: Add version check with version 2.2.1, for both MLP and Transformer + + def forward(self, x: Tensor) -> Tensor: + return self.model(x) + + def _step(self, batch, step_name: str): + feats, targets = batch + logits = self.model(feats) + loss = nn.functional.cross_entropy( + logits, + targets.type_as(logits), + weight=self.class_weights.type_as(logits), + ) + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + if step_name == "validation": + self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) + self.log( + f"{step_name}_auroc", + self.valid_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + return loss + + def training_step(self, batch, batch_idx): + return self._step(batch, "training") + + def validation_step(self, batch, batch_idx): + return self._step(batch, "validation") + + def test_step(self, batch, batch_idx): + return self._step(batch, "test") + + def predict_step(self, batch, batch_idx): + feats, _ = batch + return self.model(feats) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) From d2b95a57ca2278b41b323433c9079e9caac805f9 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Wed, 9 Jul 2025 16:32:19 +0100 Subject: [PATCH 03/18] add patient level feature training --- src/stamp/modeling/data.py | 85 ++++++++++- src/stamp/modeling/deploy.py | 135 +++++++++++------ src/stamp/modeling/train.py | 281 +++++++++++++++++++++++++---------- 3 files changed, 373 insertions(+), 128 deletions(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index a8f68210..3ce6f9f6 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -35,6 +35,7 @@ ) _logger = logging.getLogger("stamp") +_logged_stamp_v1_warning = False __author__ = "Marko van Treeck" @@ -115,6 +116,81 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def detect_feature_type(feature_dir: Path) -> str: + """ + Detects feature type by inspecting all .h5 files in feature_dir. + + Returns: + "tile" if all files are tile-level, "patient" if all are patient-level. + If files have mixed types, raises an error. + If no .h5 files are found, raises an error. + """ + feature_types = set() + files_checked = 0 + + for file in feature_dir.glob("*.h5"): + files_checked += 1 + with h5py.File(file, "r") as h5: + feat_type = h5.attrs.get("feat_type") + if feat_type is not None: + feature_types.add(str(feat_type)) + else: + # If feat_type is missing, always treat as tile-level feature + feature_types.add("tile") + + if files_checked == 0: + raise RuntimeError("No .h5 feature files found in feature_dir.") + + if len(feature_types) > 1: + raise RuntimeError( + f"Multiple feature types detected in {feature_dir}: {feature_types}. " + "All feature files must have the same type." + ) + + return feature_types.pop() + + +def load_patient_level_data( + *, + clini_table: Path, + feature_dir: Path, + patient_label: PandasLabel, + ground_truth_label: PandasLabel, + feature_ext: str = ".h5", +) -> dict[PatientId, PatientData]: + """ + Loads PatientData for patient-level features, matching patients in the clinical table + to feature files in feature_dir named {patient_id}.h5. + """ + + clini_df = _read_table( + clini_table, + usecols=[patient_label, ground_truth_label], + dtype=str, + ).dropna() + + patient_to_data: dict[PatientId, PatientData] = {} + missing_features = [] + for _, row in clini_df.iterrows(): + patient_id = PatientId(str(row[patient_label])) + ground_truth = row[ground_truth_label] + feature_file = feature_dir / f"{patient_id}{feature_ext}" + if feature_file.exists(): + patient_to_data[patient_id] = PatientData( + ground_truth=ground_truth, + feature_files=[FeaturePath(feature_file)], + ) + else: + missing_features.append(patient_id) + + if missing_features: + _logger.warning( + f"Some patients have no feature file in {feature_dir}: {missing_features}" + ) + + return patient_to_data + + @dataclass class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): """A dataset of bags of instances.""" @@ -269,9 +345,12 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: == 224 ): # Historic STAMP format - _logger.info( - f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" - ) + global _logged_stamp_v1_warning + if not _logged_stamp_v1_warning: + _logger.info( + f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" + ) + _logged_stamp_v1_warning = True tile_size_um = Microns(256.0) tile_size_px = TilePixels(224) coords_um = coords / 224 * 256 diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 0441cd53..0641ca0f 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -12,12 +12,16 @@ from stamp.modeling.data import ( PatientData, + PatientDataset, dataloader_from_patient_data, + detect_feature_type, filter_complete_patient_data_, + load_patient_level_data, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.types import GroundTruth, PandasLabel, PatientId __all__ = ["deploy_categorical_model_"] @@ -44,10 +48,21 @@ def deploy_categorical_model_( num_workers: int, accelerator: str | Accelerator, ) -> None: + # --- Detect feature type and load correct model --- + feature_type = detect_feature_type(feature_dir) + _logger.info(f"Detected feature type: {feature_type}") + + if feature_type == "tile": + ModelClass = LitVisionTransformer + elif feature_type == "patient": + ModelClass = LitMLPClassifier + else: + raise RuntimeError( + f"Unsupported feature type for deployment: {feature_type}. Only 'tile' and 'patient' are supported." + ) + models = [ - LitVisionTransformer.load_from_checkpoint( - checkpoint_path=checkpoint_path - ).eval() + ModelClass.load_from_checkpoint(checkpoint_path=checkpoint_path).eval() for checkpoint_path in checkpoint_paths ] @@ -78,37 +93,77 @@ def deploy_categorical_model_( output_dir.mkdir(exist_ok=True, parents=True) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, - ) - - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] - if clini_table is not None: - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, + # --- Data loading logic --- + if feature_type == "tile": + slide_to_patient = slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, patient_label=patient_label, + filename_label=filename_label, ) - else: + if clini_table is not None: + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + else: + patient_to_ground_truth = { + patient_id: None for patient_id in set(slide_to_patient.values()) + } + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=False, + ) + test_dl, _ = dataloader_from_patient_data( + patient_data=list(patient_to_data.values()), + bag_size=None, + categories=list(models[0].categories), + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + patient_ids = list(patient_to_data.keys()) + elif feature_type == "patient": + if slide_table is not None: + _logger.warning( + "slide_table is ignored for patient-level features during deployment." + ) + if clini_table is None: + raise ValueError( + "clini_table is required for patient-level feature deployment." + ) + patient_to_data = load_patient_level_data( + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + ) + feature_files = [ + next(iter(pd.feature_files)) for pd in patient_to_data.values() + ] + labels = [pd.ground_truth for pd in patient_to_data.values()] + categories = list(models[0].categories) + onehot = torch.tensor(np.array(labels).reshape(-1, 1) == categories) + test_ds = PatientDataset(feature_files, onehot, transform=None) + test_dl = torch.utils.data.DataLoader( + test_ds, batch_size=1, shuffle=False, num_workers=num_workers + ) + patient_ids = list(patient_to_data.keys()) patient_to_ground_truth = { - patient_id: None for patient_id in set(slide_to_patient.values()) + pid: pd.ground_truth for pid, pd in patient_to_data.items() } - - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=False, - ) + else: + raise RuntimeError(f"Unsupported feature type: {feature_type}") all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 for model_i, model in enumerate(models): predictions = _predict( model=model, - patient_to_data=patient_to_data, - num_workers=num_workers, + test_dl=test_dl, + patient_ids=patient_ids, accelerator=accelerator, ) all_predictions.append(predictions) @@ -130,7 +185,7 @@ def deploy_categorical_model_( patient_id: torch.stack( [predictions[patient_id] for predictions in all_predictions] ).mean(dim=0) - for patient_id in patient_to_data.keys() + for patient_id in patient_ids }, patient_label=patient_label, ground_truth_label=ground_truth_label, @@ -139,33 +194,23 @@ def deploy_categorical_model_( def _predict( *, - model: LitVisionTransformer, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None]], - num_workers: int, + model: lightning.LightningModule, + test_dl: torch.utils.data.DataLoader, + patient_ids: Sequence[PatientId], accelerator: str | Accelerator, ) -> Mapping[PatientId, Float[torch.Tensor, "category"]]: # noqa: F821 model = model.eval() torch.set_float32_matmul_precision("medium") - patients_used_for_training: set[PatientId] = set(model.train_patients) | set( - model.valid_patients - ) - if overlap := patients_used_for_training & set(patient_to_data.keys()): + # Check for data leakage + patients_used_for_training: set[PatientId] = set( + getattr(model, "train_patients", []) + ) | set(getattr(model, "valid_patients", [])) + if overlap := patients_used_for_training & set(patient_ids): raise ValueError( f"some of the patients in the validation set were used during training: {overlap}" ) - test_dl, _ = dataloader_from_patient_data( - patient_data=list(patient_to_data.values()), - bag_size=None, # Use all the tiles for deployment - # Use same encoding scheme as during training - categories=list(model.categories), - batch_size=1, - shuffle=False, - num_workers=num_workers, - transform=None, - ) - trainer = lightning.Trainer( accelerator=accelerator, devices=1, # Needs to be 1, otherwise half the predictions are missing for some reason @@ -181,7 +226,7 @@ def _predict( dim=1, ) - return dict(zip(patient_to_data, predictions, strict=True)) + return dict(zip(patient_ids, predictions, strict=True)) def _to_prediction_df( diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index d273315a..8d74a99b 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -8,6 +8,7 @@ import lightning.pytorch import lightning.pytorch.accelerators import lightning.pytorch.accelerators.accelerator +import numpy as np import torch from lightning.pytorch.accelerators.accelerator import Accelerator from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint @@ -18,8 +19,11 @@ from stamp.modeling.data import ( BagDataset, PatientData, + PatientDataset, dataloader_from_patient_data, + detect_feature_type, filter_complete_patient_data_, + load_patient_level_data, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) @@ -29,6 +33,7 @@ EncodedTargets, LitVisionTransformer, ) +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import Category, CoordinatesBatch, GroundTruth, PandasLabel, PatientId @@ -92,27 +97,46 @@ def train_categorical_model_( Categories of the ground truth. Set to `None` to automatically infer. """ - # Read and parse data from out clini and slide table - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, - patient_label=patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, - ) + feature_type = detect_feature_type(feature_dir) + _logger.info(f"Detected feature type: {feature_type}") - # Clean data (remove slides without ground truth, missing features, etc.) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) + if feature_type == "tile": + # Tile-level: use slide_table + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + slide_to_patient = slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=True, + ) + elif feature_type == "patient": + # Patient-level: ignore slide_table + if slide_table is not None: + _logger.warning("slide_table is ignored for patient-level features.") + patient_to_data = load_patient_level_data( + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + ) + elif feature_type == "slide": + raise RuntimeError( + "Slide-level features are not supported for training. " + "Please rerun the encoding step with patient-level encoding." + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") - # Train the model + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, categories=categories, @@ -129,6 +153,7 @@ def train_categorical_model_( else None ), use_alibi=use_alibi, + feature_type=feature_type, ) train_model_( output_dir=output_dir, @@ -144,13 +169,13 @@ def train_categorical_model_( def train_model_( *, output_dir: Path, - model: LitVisionTransformer, + model: lightning.LightningModule, train_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], valid_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], max_epochs: int, patience: int, accelerator: str | Accelerator, -) -> LitVisionTransformer: +) -> lightning.LightningModule: """Trains a model. Returns: @@ -191,7 +216,7 @@ def train_model_( return LitVisionTransformer.load_from_checkpoint(model_checkpoint.best_model_path) -def setup_model_for_training( +def setup_dataloaders_for_training( *, patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], categories: Sequence[Category] | None, @@ -199,26 +224,27 @@ def setup_model_for_training( batch_size: int, num_workers: int, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, - use_alibi: bool, - # Metadata, has no effect on model training - ground_truth_label: PandasLabel, - clini_table: Path, - slide_table: Path, - feature_dir: Path, + feature_type: str, ) -> tuple[ - LitVisionTransformer, - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + DataLoader, + DataLoader, + Sequence[Category], + int, + Sequence[PatientId], + Sequence[PatientId], ]: - """Creates a model and dataloaders for training""" + """ + Creates train/val dataloaders for tile-level or patient-level features. - # Do a stratified train-validation split + Returns: + train_dl, valid_dl, categories, feature_dim, train_patients, valid_patients + """ + # Stratified split ground_truths = [ patient_data.ground_truth for patient_data in patient_to_data.values() if patient_data.ground_truth is not None ] - if len(ground_truths) != len(patient_to_data): raise ValueError( "patient_to_data must have a ground truth defined for all targets!" @@ -231,38 +257,118 @@ def setup_model_for_training( ), ) - train_dl, train_categories = dataloader_from_patient_data( - patient_data=[patient_to_data[patient] for patient in train_patients], - categories=categories, - bag_size=bag_size, - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - transform=train_transform, - ) - del categories # Let's not accidentally reuse the original categories - valid_dl, _ = dataloader_from_patient_data( - patient_data=[patient_to_data[patient] for patient in valid_patients], - bag_size=None, # Use all the patient data for validation - categories=train_categories, - batch_size=1, - shuffle=False, - num_workers=num_workers, - transform=None, - ) - if overlap := set(train_patients) & set(valid_patients): + if feature_type == "tile": + # Use existing BagDataset logic + train_dl, train_categories = dataloader_from_patient_data( + patient_data=[patient_to_data[pid] for pid in train_patients], + categories=categories, + bag_size=bag_size, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + transform=train_transform, + ) + valid_dl, _ = dataloader_from_patient_data( + patient_data=[patient_to_data[pid] for pid in valid_patients], + bag_size=None, + categories=train_categories, + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + bags, _, _, _ = next(iter(train_dl)) + dim_feats = bags.shape[-1] + return ( + train_dl, + valid_dl, + train_categories, + dim_feats, + train_patients, + valid_patients, + ) + + elif feature_type == "patient": + # Patient-level: one feature file per patient + train_feature_files = [ + next(iter(patient_to_data[pid].feature_files)) for pid in train_patients + ] + train_labels = [patient_to_data[pid].ground_truth for pid in train_patients] + valid_feature_files = [ + next(iter(patient_to_data[pid].feature_files)) for pid in valid_patients + ] + valid_labels = [patient_to_data[pid].ground_truth for pid in valid_patients] + + all_labels = train_labels + valid_labels + categories = ( + categories if categories is not None else list(sorted(set(all_labels))) + ) + train_onehot = torch.tensor(np.array(train_labels).reshape(-1, 1) == categories) + valid_onehot = torch.tensor(np.array(valid_labels).reshape(-1, 1) == categories) + + train_ds = PatientDataset( + train_feature_files, train_onehot, transform=train_transform + ) + valid_ds = PatientDataset(valid_feature_files, valid_onehot, transform=None) + train_dl = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + valid_dl = DataLoader( + valid_ds, batch_size=1, shuffle=False, num_workers=num_workers + ) + + feats, _ = next(iter(train_dl)) + dim_feats = feats.shape[-1] + return train_dl, valid_dl, categories, dim_feats, train_patients, valid_patients + + else: raise RuntimeError( - f"unreachable: unexpected overlap between training and validation set: {overlap}" + f"Unsupported feature type: {feature_type}. Only 'tile' and 'patient' are supported." ) - # Sample one bag to infer the input dimensions of the model - bags, coords, bag_sizes, targets = cast( - tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], next(iter(train_dl)) + +def setup_model_for_training( + *, + patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + categories: Sequence[Category] | None, + bag_size: int, + batch_size: int, + num_workers: int, + train_transform: Callable[[torch.Tensor], torch.Tensor] | None, + use_alibi: bool, + # Metadata, has no effect on model training + ground_truth_label: PandasLabel, + clini_table: Path, + slide_table: Path, + feature_dir: Path, + feature_type: str, +) -> tuple[ + lightning.LightningModule, + DataLoader, + DataLoader, +]: + """Creates a model and dataloaders for training""" + + train_dl, valid_dl, train_categories, dim_feats, train_patients, valid_patients = ( + setup_dataloaders_for_training( + patient_to_data=patient_to_data, + categories=categories, + bag_size=bag_size, + batch_size=batch_size, + num_workers=num_workers, + train_transform=train_transform, + feature_type=feature_type, + ) ) - _, _, dim_feats = bags.shape - # Weigh classes inversely to their occurrence - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + # Compute class weights + if feature_type == "tile": + category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + else: + # For slide/patient, count from one-hot labels + category_counts = cast(PatientDataset, train_dl.dataset).ground_truths.sum( + dim=0 + ) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() @@ -279,24 +385,39 @@ def setup_model_for_training( "You may want to consider removing these categories; the model will likely overfit on the few samples available." ) - # Train the model - model = LitVisionTransformer( - categories=train_categories, - category_weights=category_weights, - dim_input=dim_feats, - dim_model=512, - dim_feedforward=2048, - n_heads=8, - n_layers=2, - dropout=0.25, - use_alibi=use_alibi, - # Metadata, has no effect on model training - ground_truth_label=ground_truth_label, - train_patients=train_patients, - valid_patients=valid_patients, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, - ) + # Model selection + if feature_type == "tile": + model = LitVisionTransformer( + categories=train_categories, + category_weights=category_weights, + dim_input=dim_feats, + dim_model=512, + dim_feedforward=2048, + n_heads=8, + n_layers=2, + dropout=0.25, + use_alibi=use_alibi, + ground_truth_label=ground_truth_label, + train_patients=train_patients, + valid_patients=valid_patients, + clini_table=clini_table, + slide_table=slide_table, + feature_dir=feature_dir, + ) + else: + model = LitMLPClassifier( + categories=train_categories, + category_weights=category_weights, + dim_input=dim_feats, + dim_hidden=512, + num_layers=2, + dropout=0.25, + ground_truth_label=ground_truth_label, + train_patients=train_patients, + valid_patients=valid_patients, + clini_table=clini_table, + slide_table=slide_table, + feature_dir=feature_dir, + ) return model, train_dl, valid_dl From 45abb7dd68925914a2bbdd493709049962abee00 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Wed, 9 Jul 2025 16:32:38 +0100 Subject: [PATCH 04/18] add feature type metadata --- src/stamp/encoding/encoder/__init__.py | 13 ++++++++++--- src/stamp/encoding/encoder/chief.py | 9 +++++++-- src/stamp/encoding/encoder/eagle.py | 8 ++++++-- src/stamp/encoding/encoder/gigapath.py | 4 +++- src/stamp/encoding/encoder/titan.py | 4 +++- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 94192812..2696af00 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -83,7 +83,9 @@ def encode_slides_( slide_embedding = self._generate_slide_embedding( feats, device, coords=coords ) - self._save_features_(output_path=output_path, feats=slide_embedding) + self._save_features_( + output_path=output_path, feats=slide_embedding, feat_type="slide" + ) def encode_patients_( self, @@ -142,7 +144,9 @@ def encode_patients_( patient_embedding = self._generate_patient_embedding( feats_list, device, **kwargs ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) @abstractmethod def _generate_slide_embedding( @@ -192,7 +196,9 @@ def _read_h5( ) return feats, coords, extractor - def _save_features_(self, output_path: Path, feats: np.ndarray) -> None: + def _save_features_( + self, output_path: Path, feats: np.ndarray, feat_type: str + ) -> None: with ( NamedTemporaryFile(dir=output_path.parent, delete=False) as tmp_h5_file, h5py.File(tmp_h5_file, "w") as f, @@ -204,6 +210,7 @@ def _save_features_(self, output_path: Path, feats: np.ndarray) -> None: f.attrs["precision"] = str(self.precision) f.attrs["stamp_version"] = stamp.__version__ f.attrs["code_hash"] = get_processing_code_hash(Path(__file__))[:8] + f.attrs["feat_type"] = feat_type # TODO: Add more metadata like tile-level extractor name # and maybe tile size in pixels and microns except Exception: diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2c33a5ca..f174d42a 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -115,7 +115,10 @@ def __init__(self) -> None: model=model, identifier=EncoderName.CHIEF, precision=torch.float32, - required_extractors=[ExtractorName.CHIEF_CTRANSPATH], + required_extractors=[ + ExtractorName.CHIEF_CTRANSPATH, + ExtractorName.CTRANSPATH, + ], ) def _generate_slide_embedding( @@ -192,7 +195,9 @@ def encode_patients_( .cpu() .numpy() ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) def initialize_weights(module): diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index 85c0ae2d..b2fb293d 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -164,7 +164,9 @@ def encode_slides_( continue slide_embedding = self._generate_slide_embedding(feats, device, agg_feats) - self._save_features_(output_path=output_path, feats=slide_embedding) + self._save_features_( + output_path=output_path, feats=slide_embedding, feat_type="slide" + ) # TODO: Add @override decorator on each encoder once it is added to python def encode_patients_( @@ -233,4 +235,6 @@ def encode_patients_( patient_embedding = self._generate_patient_embedding( feats_list, device, agg_feats_list ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index c7931400..e2fb0ebb 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -183,7 +183,9 @@ def encode_patients_( all_feats_list, device, coords_list=all_coords_list ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) def _generate_patient_embedding( self, feats_list, device, coords_list: list | None = None, **kwargs diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 0307d68f..41dd19f1 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -166,4 +166,6 @@ def encode_patients_( patient_embedding = self._generate_patient_embedding( all_feats_list, device, all_coords_list ) - self._save_features_(output_path=output_path, feats=patient_embedding) + self._save_features_( + output_path=output_path, feats=patient_embedding, feat_type="patient" + ) From ebc14b371b21088ba7d09efabf4ea01021cab94c Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 10 Jul 2025 10:15:27 +0100 Subject: [PATCH 05/18] add adaptable checkpoint saver --- src/stamp/modeling/mlp_classifier.py | 2 +- src/stamp/modeling/train.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/stamp/modeling/mlp_classifier.py b/src/stamp/modeling/mlp_classifier.py index 52d575a4..b21142fc 100644 --- a/src/stamp/modeling/mlp_classifier.py +++ b/src/stamp/modeling/mlp_classifier.py @@ -60,7 +60,7 @@ def __init__( **metadata, ): super().__init__() - self.save_hyperparameters(ignore=["category_weights"]) + self.save_hyperparameters() self.model = MLPClassifier( dim_input=dim_input, dim_hidden=dim_hidden, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 8d74a99b..14ef6436 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -213,7 +213,9 @@ def train_model_( ) shutil.copy(model_checkpoint.best_model_path, output_dir / "model.ckpt") - return LitVisionTransformer.load_from_checkpoint(model_checkpoint.best_model_path) + # Reload the best model using the same class as the input model + ModelClass = type(model) + return ModelClass.load_from_checkpoint(model_checkpoint.best_model_path) def setup_dataloaders_for_training( From 373683dcd42fcb6f99e329661d1ab636e1182893 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 10 Jul 2025 13:30:29 +0100 Subject: [PATCH 06/18] refactor model and dataloader setup --- src/stamp/modeling/data.py | 40 ++++- src/stamp/modeling/deploy.py | 25 ++- src/stamp/modeling/train.py | 319 ++++++++++++++++++----------------- 3 files changed, 208 insertions(+), 176 deletions(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 3ce6f9f6..523477ba 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -57,7 +57,7 @@ class PatientData(Generic[GroundTruthType]): feature_files: Iterable[FeaturePath | BinaryIO] -def dataloader_from_patient_data( +def tile_bag_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], bag_size: int | None, @@ -70,7 +70,7 @@ def dataloader_from_patient_data( DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], Sequence[Category], ]: - """Creates a dataloader from patient data, encoding the ground truths. + """Creates a dataloader from patient data for tile-level (bagged) features. Args: categories: @@ -116,6 +116,29 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def patient_feature_dataloader( + *, + patient_data: Sequence[PatientData[GroundTruth | None]], + categories: Sequence[Category] | None = None, + batch_size: int, + shuffle: bool, + num_workers: int, + transform: Callable[[Tensor], Tensor] | None, +) -> tuple[DataLoader, Sequence[Category]]: + """ + Creates a dataloader for patient-level features (one feature vector per patient). + """ + feature_files = [next(iter(p.feature_files)) for p in patient_data] + raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) + categories = ( + categories if categories is not None else list(np.unique(raw_ground_truths)) + ) + one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) + ds = PatientFeatureDataset(feature_files, one_hot, transform=transform) + dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + return dl, categories + + def detect_feature_type(feature_dir: Path) -> str: """ Detects feature type by inspecting all .h5 files in feature_dir. @@ -162,6 +185,9 @@ def load_patient_level_data( Loads PatientData for patient-level features, matching patients in the clinical table to feature files in feature_dir named {patient_id}.h5. """ + # TODO: I'm not proud at all of this. Any other alternative for mapping + # clinical data to the patient-level feature paths that avoids + # creating another slide table for encoded featuress is welcome :P. clini_df = _read_table( clini_table, @@ -261,7 +287,7 @@ def __getitem__( ) -class SingleFeatureDataset(Dataset): +class PatientFeatureDataset(Dataset): """ Dataset for single feature vector per sample (e.g. slide-level or patient-level). Each item is a (feature_vector, label_onehot) tuple. @@ -293,7 +319,8 @@ def __getitem__(self, idx: int): pass else: raise RuntimeError( - f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}" + f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." + "Check that the features are patient-level." ) if self.transform is not None: feats = self.transform(feats) @@ -301,11 +328,6 @@ def __getitem__(self, idx: int): return feats, label -# Aliases for clarity -PatientDataset = SingleFeatureDataset -SlideDataset = SingleFeatureDataset - - @dataclass class CoordsInfo: coords_um: np.ndarray diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 0641ca0f..64a16dac 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -12,13 +12,14 @@ from stamp.modeling.data import ( PatientData, - PatientDataset, - dataloader_from_patient_data, + PatientFeatureDataset, detect_feature_type, filter_complete_patient_data_, load_patient_level_data, + patient_feature_dataloader, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, + tile_bag_dataloader, ) from stamp.modeling.lightning_model import LitVisionTransformer from stamp.modeling.mlp_classifier import LitMLPClassifier @@ -116,9 +117,9 @@ def deploy_categorical_model_( slide_to_patient=slide_to_patient, drop_patients_with_missing_ground_truth=False, ) - test_dl, _ = dataloader_from_patient_data( + test_dl, _ = tile_bag_dataloader( patient_data=list(patient_to_data.values()), - bag_size=None, + bag_size=None, # We want all tiles to be seen by the model categories=list(models[0].categories), batch_size=1, shuffle=False, @@ -141,15 +142,13 @@ def deploy_categorical_model_( patient_label=patient_label, ground_truth_label=ground_truth_label, ) - feature_files = [ - next(iter(pd.feature_files)) for pd in patient_to_data.values() - ] - labels = [pd.ground_truth for pd in patient_to_data.values()] - categories = list(models[0].categories) - onehot = torch.tensor(np.array(labels).reshape(-1, 1) == categories) - test_ds = PatientDataset(feature_files, onehot, transform=None) - test_dl = torch.utils.data.DataLoader( - test_ds, batch_size=1, shuffle=False, num_workers=num_workers + test_dl, _ = patient_feature_dataloader( + patient_data=list(patient_to_data.values()), + categories=list(models[0].categories), + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, ) patient_ids = list(patient_to_data.keys()) patient_to_ground_truth = { diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 14ef6436..85146882 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -8,7 +8,6 @@ import lightning.pytorch import lightning.pytorch.accelerators import lightning.pytorch.accelerators.accelerator -import numpy as np import torch from lightning.pytorch.accelerators.accelerator import Accelerator from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint @@ -19,13 +18,14 @@ from stamp.modeling.data import ( BagDataset, PatientData, - PatientDataset, - dataloader_from_patient_data, + PatientFeatureDataset, detect_feature_type, filter_complete_patient_data_, load_patient_level_data, + patient_feature_dataloader, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, + tile_bag_dataloader, ) from stamp.modeling.lightning_model import ( Bags, @@ -66,7 +66,7 @@ def train_categorical_model_( use_vary_precision_transform: bool, use_alibi: bool, ) -> None: - """Trains a model. + """Trains a model based on the feature type. Args: clini_table: @@ -130,7 +130,7 @@ def train_categorical_model_( ) elif feature_type == "slide": raise RuntimeError( - "Slide-level features are not supported for training. " + "Slide-level features are not supported for training." "Please rerun the encoding step with patient-level encoding." ) else: @@ -166,56 +166,82 @@ def train_categorical_model_( ) -def train_model_( +def setup_model_for_training( *, - output_dir: Path, - model: lightning.LightningModule, - train_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - valid_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - max_epochs: int, - patience: int, - accelerator: str | Accelerator, -) -> lightning.LightningModule: - """Trains a model. - - Returns: - The model with the best validation loss during training. - """ - torch.set_float32_matmul_precision("high") + patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + categories: Sequence[Category] | None, + bag_size: int, + batch_size: int, + num_workers: int, + train_transform: Callable[[torch.Tensor], torch.Tensor] | None, + use_alibi: bool, + # Metadata, has no effect on model training + ground_truth_label: PandasLabel, + clini_table: Path, + slide_table: Path, + feature_dir: Path, + feature_type: str, +) -> tuple[ + lightning.LightningModule, + DataLoader, + DataLoader, +]: + """Creates a model and dataloaders for training""" - model_checkpoint = ModelCheckpoint( - monitor="validation_loss", - mode="min", - filename="checkpoint-{epoch:02d}-{validation_loss:0.3f}", - ) - trainer = lightning.Trainer( - default_root_dir=output_dir, - callbacks=[ - EarlyStopping(monitor="validation_loss", mode="min", patience=patience), - model_checkpoint, - ], - max_epochs=max_epochs, - # FIXME The number of accelerators is currently fixed to one for the - # following reasons: - # 1. `trainer.predict()` does not return any predictions if used with - # the default strategy no multiple GPUs - # 2. `barspoon.model.SafeMulticlassAUROC` breaks on multiple GPUs - accelerator=accelerator, - devices=1, - gradient_clip_val=0.5, - logger=CSVLogger(save_dir=output_dir), - log_every_n_steps=len(train_dl), + train_dl, valid_dl, train_categories, dim_feats, train_patients, valid_patients = ( + setup_dataloaders_for_training( + patient_to_data=patient_to_data, + categories=categories, + bag_size=bag_size, + batch_size=batch_size, + num_workers=num_workers, + train_transform=train_transform, + feature_type=feature_type, + ) ) - trainer.fit( - model=model, - train_dataloaders=train_dl, - val_dataloaders=valid_dl, + + category_weights = _compute_class_weights_and_check_categories( + train_dl=train_dl, + feature_type=feature_type, + train_categories=train_categories, ) - shutil.copy(model_checkpoint.best_model_path, output_dir / "model.ckpt") - # Reload the best model using the same class as the input model - ModelClass = type(model) - return ModelClass.load_from_checkpoint(model_checkpoint.best_model_path) + # Model selection + if feature_type == "tile": + model = LitVisionTransformer( + categories=train_categories, + category_weights=category_weights, + dim_input=dim_feats, + dim_model=512, + dim_feedforward=2048, + n_heads=8, + n_layers=2, + dropout=0.25, + use_alibi=use_alibi, + ground_truth_label=ground_truth_label, + train_patients=train_patients, + valid_patients=valid_patients, + clini_table=clini_table, + slide_table=slide_table, + feature_dir=feature_dir, + ) + else: + model = LitMLPClassifier( + categories=train_categories, + category_weights=category_weights, + dim_input=dim_feats, + dim_hidden=512, + num_layers=2, + dropout=0.25, + ground_truth_label=ground_truth_label, + train_patients=train_patients, + valid_patients=valid_patients, + clini_table=clini_table, + slide_table=slide_table, + feature_dir=feature_dir, + ) + + return model, train_dl, valid_dl def setup_dataloaders_for_training( @@ -261,7 +287,7 @@ def setup_dataloaders_for_training( if feature_type == "tile": # Use existing BagDataset logic - train_dl, train_categories = dataloader_from_patient_data( + train_dl, train_categories = tile_bag_dataloader( patient_data=[patient_to_data[pid] for pid in train_patients], categories=categories, bag_size=bag_size, @@ -270,7 +296,7 @@ def setup_dataloaders_for_training( num_workers=num_workers, transform=train_transform, ) - valid_dl, _ = dataloader_from_patient_data( + valid_dl, _ = tile_bag_dataloader( patient_data=[patient_to_data[pid] for pid in valid_patients], bag_size=None, categories=train_categories, @@ -291,86 +317,107 @@ def setup_dataloaders_for_training( ) elif feature_type == "patient": - # Patient-level: one feature file per patient - train_feature_files = [ - next(iter(patient_to_data[pid].feature_files)) for pid in train_patients - ] - train_labels = [patient_to_data[pid].ground_truth for pid in train_patients] - valid_feature_files = [ - next(iter(patient_to_data[pid].feature_files)) for pid in valid_patients - ] - valid_labels = [patient_to_data[pid].ground_truth for pid in valid_patients] - - all_labels = train_labels + valid_labels - categories = ( - categories if categories is not None else list(sorted(set(all_labels))) - ) - train_onehot = torch.tensor(np.array(train_labels).reshape(-1, 1) == categories) - valid_onehot = torch.tensor(np.array(valid_labels).reshape(-1, 1) == categories) - - train_ds = PatientDataset( - train_feature_files, train_onehot, transform=train_transform - ) - valid_ds = PatientDataset(valid_feature_files, valid_onehot, transform=None) - train_dl = DataLoader( - train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers + train_dl, train_categories = patient_feature_dataloader( + patient_data=[patient_to_data[pid] for pid in train_patients], + categories=categories, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + transform=train_transform, ) - valid_dl = DataLoader( - valid_ds, batch_size=1, shuffle=False, num_workers=num_workers + valid_dl, _ = patient_feature_dataloader( + patient_data=[patient_to_data[pid] for pid in valid_patients], + categories=train_categories, + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, ) - feats, _ = next(iter(train_dl)) dim_feats = feats.shape[-1] - return train_dl, valid_dl, categories, dim_feats, train_patients, valid_patients - + return ( + train_dl, + valid_dl, + train_categories, + dim_feats, + train_patients, + valid_patients, + ) else: raise RuntimeError( f"Unsupported feature type: {feature_type}. Only 'tile' and 'patient' are supported." ) + - -def setup_model_for_training( +def train_model_( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], - categories: Sequence[Category] | None, - bag_size: int, - batch_size: int, - num_workers: int, - train_transform: Callable[[torch.Tensor], torch.Tensor] | None, - use_alibi: bool, - # Metadata, has no effect on model training - ground_truth_label: PandasLabel, - clini_table: Path, - slide_table: Path, - feature_dir: Path, - feature_type: str, -) -> tuple[ - lightning.LightningModule, - DataLoader, - DataLoader, -]: - """Creates a model and dataloaders for training""" + output_dir: Path, + model: lightning.LightningModule, + train_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + valid_dl: DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + max_epochs: int, + patience: int, + accelerator: str | Accelerator, +) -> lightning.LightningModule: + """Trains a model. - train_dl, valid_dl, train_categories, dim_feats, train_patients, valid_patients = ( - setup_dataloaders_for_training( - patient_to_data=patient_to_data, - categories=categories, - bag_size=bag_size, - batch_size=batch_size, - num_workers=num_workers, - train_transform=train_transform, - feature_type=feature_type, - ) + Returns: + The model with the best validation loss during training. + """ + torch.set_float32_matmul_precision("high") + + model_checkpoint = ModelCheckpoint( + monitor="validation_loss", + mode="min", + filename="checkpoint-{epoch:02d}-{validation_loss:0.3f}", ) + trainer = lightning.Trainer( + default_root_dir=output_dir, + callbacks=[ + EarlyStopping(monitor="validation_loss", mode="min", patience=patience), + model_checkpoint, + ], + max_epochs=max_epochs, + # FIXME The number of accelerators is currently fixed to one for the + # following reasons: + # 1. `trainer.predict()` does not return any predictions if used with + # the default strategy no multiple GPUs + # 2. `barspoon.model.SafeMulticlassAUROC` breaks on multiple GPUs + accelerator=accelerator, + devices=1, + gradient_clip_val=0.5, + logger=CSVLogger(save_dir=output_dir), + log_every_n_steps=len(train_dl), + ) + trainer.fit( + model=model, + train_dataloaders=train_dl, + val_dataloaders=valid_dl, + ) + shutil.copy(model_checkpoint.best_model_path, output_dir / "model.ckpt") + + # Reload the best model using the same class as the input model + ModelClass = type(model) + return ModelClass.load_from_checkpoint(model_checkpoint.best_model_path) + - # Compute class weights +def _compute_class_weights_and_check_categories( + *, + train_dl: DataLoader, + feature_type: str, + train_categories: Sequence[str], +) -> torch.Tensor: + """ + Computes class weights and checks for category issues. + Logs warnings if there are too few or underpopulated categories. + Returns normalized category weights as a torch.Tensor. + """ if feature_type == "tile": category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) else: - # For slide/patient, count from one-hot labels - category_counts = cast(PatientDataset, train_dl.dataset).ground_truths.sum( - dim=0 - ) + category_counts = cast( + PatientFeatureDataset, train_dl.dataset + ).ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() @@ -378,7 +425,7 @@ def setup_model_for_training( raise ValueError(f"not enough categories to train on: {train_categories}") elif any(category_counts < 16): underpopulated_categories = { - category: count + category: int(count) for category, count in zip(train_categories, category_counts, strict=True) if count < 16 } @@ -386,40 +433,4 @@ def setup_model_for_training( f"Some categories do not have enough samples to meaningfully train a model: {underpopulated_categories}. " "You may want to consider removing these categories; the model will likely overfit on the few samples available." ) - - # Model selection - if feature_type == "tile": - model = LitVisionTransformer( - categories=train_categories, - category_weights=category_weights, - dim_input=dim_feats, - dim_model=512, - dim_feedforward=2048, - n_heads=8, - n_layers=2, - dropout=0.25, - use_alibi=use_alibi, - ground_truth_label=ground_truth_label, - train_patients=train_patients, - valid_patients=valid_patients, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, - ) - else: - model = LitMLPClassifier( - categories=train_categories, - category_weights=category_weights, - dim_input=dim_feats, - dim_hidden=512, - num_layers=2, - dropout=0.25, - ground_truth_label=ground_truth_label, - train_patients=train_patients, - valid_patients=valid_patients, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, - ) - - return model, train_dl, valid_dl + return category_weights From dee2a5c6d5847c8a1dc08e02c08813ca63fd7b38 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 10 Jul 2025 15:53:02 +0100 Subject: [PATCH 07/18] add patient-level crossvalidation --- src/stamp/modeling/crossval.py | 94 ++++++++++++++++++++++++---------- src/stamp/modeling/data.py | 1 + 2 files changed, 69 insertions(+), 26 deletions(-) diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index aaa21b49..8f6a4ac1 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -10,9 +10,13 @@ from stamp.modeling.data import ( PatientData, + detect_feature_type, filter_complete_patient_data_, + load_patient_level_data, + patient_feature_dataloader, patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, + tile_bag_dataloader, ) from stamp.modeling.deploy import _predict, _to_prediction_df from stamp.modeling.lightning_model import LitVisionTransformer @@ -64,30 +68,44 @@ def categorical_crossval_( use_vary_precision_transform: bool, use_alibi: bool, ) -> None: - patient_to_ground_truth: Final[dict[PatientId, GroundTruth]] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, - patient_label=patient_label, + feature_type = detect_feature_type(feature_dir) + _logger.info(f"Detected feature type: {feature_type}") + + if feature_type == "tile": + patient_to_ground_truth: dict[PatientId, GroundTruth] = ( + patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=slide_table, + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + ) + patient_to_data: Mapping[PatientId, PatientData] = ( + filter_complete_patient_data_( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=True, + ) + ) + elif feature_type == "patient": + patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( + clini_table=clini_table, feature_dir=feature_dir, patient_label=patient_label, - filename_label=filename_label, - ) - ) - - # Clean data (remove slides without ground truth, missing features, etc.) - patient_to_data: Final[Mapping[Category, PatientData]] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, + ground_truth_label=ground_truth_label, ) - ) + patient_to_ground_truth: dict[PatientId, GroundTruth] = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + else: + raise RuntimeError(f"Unsupported feature type: {feature_type}") output_dir.mkdir(parents=True, exist_ok=True) splits_file = output_dir / "splits.json" @@ -169,6 +187,7 @@ def categorical_crossval_( else None ), use_alibi=use_alibi, + feature_type=feature_type, ) model = train_model_( output_dir=split_dir, @@ -184,14 +203,37 @@ def categorical_crossval_( # Deploy on test set if not (split_dir / "patient-preds.csv").exists(): + # Prepare test dataloader + test_patients = [ + pid for pid in split.test_patients if pid in patient_to_data + ] + test_patient_data = [patient_to_data[pid] for pid in test_patients] + if feature_type == "tile": + test_dl, _ = tile_bag_dataloader( + patient_data=test_patient_data, + bag_size=None, + categories=categories, + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + elif feature_type == "patient": + test_dl, _ = patient_feature_dataloader( + patient_data=test_patient_data, + categories=categories, + batch_size=1, + shuffle=False, + num_workers=num_workers, + transform=None, + ) + else: + raise RuntimeError(f"Unsupported feature type: {feature_type}") + predictions = _predict( model=model, - patient_to_data={ - patient_id: patient_data - for patient_id, patient_data in patient_to_data.items() - if patient_id in split.test_patients - }, - num_workers=num_workers, + test_dl=test_dl, + patient_ids=test_patients, accelerator=accelerator, ) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 523477ba..fc3545a5 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -367,6 +367,7 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: == 224 ): # Historic STAMP format + # TODO: find a better way to get this warning just once global _logged_stamp_v1_warning if not _logged_stamp_v1_warning: _logger.info( From e2fc987a9c69e4939333037481c77a60e6b590ae Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 10 Jul 2025 15:54:13 +0100 Subject: [PATCH 08/18] change transformer feedforward dimension to 512 --- src/stamp/modeling/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 85146882..93f09a64 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -213,7 +213,7 @@ def setup_model_for_training( category_weights=category_weights, dim_input=dim_feats, dim_model=512, - dim_feedforward=2048, + dim_feedforward=512, n_heads=8, n_layers=2, dropout=0.25, From be9ae84bd4d51f00c0f20058f0c6955acefc0b43 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Fri, 11 Jul 2025 09:48:44 +0100 Subject: [PATCH 09/18] make slide_table optional --- src/stamp/modeling/config.py | 4 +++- src/stamp/modeling/crossval.py | 4 +++- src/stamp/modeling/train.py | 10 ++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 96edd3db..9a5b0e61 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -12,7 +12,9 @@ class TrainConfig(BaseModel): output_dir: Path = Field(description="The directory to save the results to") - clini_table: Path = Field(description="Excel or CSV to read clinical data from") + clini_table: Path | None = Field( + description="Excel or CSV to read clinical data from" + ) slide_table: Path = Field( description="Excel or CSV to read patient-slide associations from" ) diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 8f6a4ac1..040a199c 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -48,7 +48,7 @@ class _Splits(BaseModel): def categorical_crossval_( clini_table: Path, - slide_table: Path, + slide_table: Path | None, feature_dir: Path, output_dir: Path, patient_label: PandasLabel, @@ -72,6 +72,8 @@ def categorical_crossval_( _logger.info(f"Detected feature type: {feature_type}") if feature_type == "tile": + if slide_table is None: + raise ValueError("A slide table is required for tile-level modeling") patient_to_ground_truth: dict[PatientId, GroundTruth] = ( patient_to_ground_truth_from_clini_table_( clini_table_path=clini_table, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 93f09a64..08acb36b 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -47,7 +47,7 @@ def train_categorical_model_( *, clini_table: Path, - slide_table: Path, + slide_table: Path | None, feature_dir: Path, output_dir: Path, patient_label: PandasLabel, @@ -101,7 +101,8 @@ def train_categorical_model_( _logger.info(f"Detected feature type: {feature_type}") if feature_type == "tile": - # Tile-level: use slide_table + if slide_table is None: + raise ValueError("A slide table is required for tile-level modeling") patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( clini_table_path=clini_table, ground_truth_label=ground_truth_label, @@ -178,7 +179,7 @@ def setup_model_for_training( # Metadata, has no effect on model training ground_truth_label: PandasLabel, clini_table: Path, - slide_table: Path, + slide_table: Path | None, feature_dir: Path, feature_type: str, ) -> tuple[ @@ -218,6 +219,7 @@ def setup_model_for_training( n_layers=2, dropout=0.25, use_alibi=use_alibi, + # Metadata, has no effect on model training ground_truth_label=ground_truth_label, train_patients=train_patients, valid_patients=valid_patients, @@ -233,11 +235,11 @@ def setup_model_for_training( dim_hidden=512, num_layers=2, dropout=0.25, + # Metadata, has no effect on model training ground_truth_label=ground_truth_label, train_patients=train_patients, valid_patients=valid_patients, clini_table=clini_table, - slide_table=slide_table, feature_dir=feature_dir, ) From df6a59fc9c9d04c0ae0c9049e773d9b2ace5c1ee Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 14 Jul 2025 13:13:42 +0100 Subject: [PATCH 10/18] add tests and mayor reformat --- src/stamp/modeling/config.py | 6 +- src/stamp/modeling/data.py | 4 +- src/stamp/modeling/deploy.py | 6 +- src/stamp/modeling/mlp_classifier.py | 2 +- src/stamp/modeling/train.py | 2 +- tests/random_data.py | 119 ++++++++++++++- tests/test_crossval.py | 41 +++-- tests/test_data.py | 35 ++++- tests/test_deployment.py | 144 +++++++++++++++++- ...test_deployment_backward_compatibility.py} | 33 ++-- tests/test_model.py | 50 ++++++ tests/test_train_deploy.py | 79 +++++++++- uv.lock | 6 +- 13 files changed, 472 insertions(+), 55 deletions(-) rename tests/{test_deployment_backward_compatability.py => test_deployment_backward_compatibility.py} (64%) diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 9a5b0e61..0ce184ab 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -12,10 +12,8 @@ class TrainConfig(BaseModel): output_dir: Path = Field(description="The directory to save the results to") - clini_table: Path | None = Field( - description="Excel or CSV to read clinical data from" - ) - slide_table: Path = Field( + clini_table: Path = Field(description="Excel or CSV to read clinical data from") + slide_table: Path | None = Field( description="Excel or CSV to read patient-slide associations from" ) feature_dir: Path = Field(description="Directory containing feature files") diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index fc3545a5..545534fc 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -155,7 +155,9 @@ def detect_feature_type(feature_dir: Path) -> str: files_checked += 1 with h5py.File(file, "r") as h5: feat_type = h5.attrs.get("feat_type") - if feat_type is not None: + encoder = h5.attrs.get("encoder") + + if feat_type is not None or encoder is not None: feature_types.add(str(feat_type)) else: # If feat_type is missing, always treat as tile-level feature diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 64a16dac..144dd2ce 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -11,8 +11,6 @@ from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( - PatientData, - PatientFeatureDataset, detect_feature_type, filter_complete_patient_data_, load_patient_level_data, @@ -41,7 +39,7 @@ def deploy_categorical_model_( output_dir: Path, checkpoint_paths: Sequence[Path], clini_table: Path | None, - slide_table: Path, + slide_table: Path | None, feature_dir: Path, ground_truth_label: PandasLabel | None, patient_label: PandasLabel, @@ -96,6 +94,8 @@ def deploy_categorical_model_( # --- Data loading logic --- if feature_type == "tile": + if slide_table is None: + raise ValueError("A slide table is required for tile-level modeling") slide_to_patient = slide_to_patient_from_slide_table_( slide_table_path=slide_table, feature_dir=feature_dir, diff --git a/src/stamp/modeling/mlp_classifier.py b/src/stamp/modeling/mlp_classifier.py index b21142fc..0a85f191 100644 --- a/src/stamp/modeling/mlp_classifier.py +++ b/src/stamp/modeling/mlp_classifier.py @@ -1,4 +1,4 @@ -from typing import Iterable, Sequence +from collections.abc import Iterable, Sequence import lightning import numpy as np diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 08acb36b..585914c3 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -349,7 +349,7 @@ def setup_dataloaders_for_training( raise RuntimeError( f"Unsupported feature type: {feature_type}. Only 'tile' and 'patient' are supported." ) - + def train_model_( *, diff --git a/tests/random_data.py b/tests/random_data.py index f6777b23..d79a161b 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -16,12 +16,7 @@ import stamp from stamp.preprocessing.config import ExtractorName -from stamp.types import ( - Category, - Microns, - PatientId, - TilePixels, -) +from stamp.types import Category, FeaturePath, Microns, PatientId, TilePixels CliniPath: TypeAlias = Path SlidePath: TypeAlias = Path @@ -99,6 +94,56 @@ def create_random_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_patient_level_dataset( + *, + dir: Path, + n_patients: int, + feat_dim: int, + categories: Sequence[str] | None = None, + n_categories: int | None = None, +) -> tuple[Path, Path, Path, Sequence[str]]: + """ + Creates a random dataset with one .h5 file per patient (patient-level features). + Returns (clini_path, slide_path, feat_dir, categories). + slide_path is a dummy file (not used for patient-level). + """ + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" # Not used, but keep interface consistent + feat_dir = dir / "feats" + feat_dir.mkdir() + + if categories is not None: + if n_categories is not None: + raise ValueError("only one of `categories` and `n_categories` can be set") + else: + if n_categories is None: + raise ValueError( + "either `categories` or `n_categories` has to be specified" + ) + categories = [random_string(8) for _ in range(n_categories)] + + patient_to_ground_truth = {} + for _ in range(n_patients): + patient_id = random_string(16) + patient_to_ground_truth[patient_id] = random.choice(categories) + # Create a single feature vector per patient + create_random_patient_level_feature_file( + tmp_path=feat_dir, + feat_dim=feat_dim, + feat_filename=patient_id, + ) + + pd.DataFrame( + patient_to_ground_truth.items(), + columns=["patient", "ground-truth"], + ).to_csv(clini_path, index=False) + + # slide_path is not used for patient-level, but return a dummy file for API compatibility + pd.DataFrame(columns=["slide_path", "patient"]).to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories + + def create_random_feature_file( *, tmp_path: Path, @@ -110,7 +155,7 @@ def create_random_feature_file( extractor_name: ExtractorName | str = "random-test-generator", feat_filename: str | None = None, coords: np.ndarray | None = None, -) -> Path: +) -> FeaturePath: """Creates a h5 file with random contents. Args: @@ -139,7 +184,38 @@ def create_random_feature_file( h5_file.attrs["unit"] = "um" h5_file.attrs["tile_size_um"] = tile_size_um h5_file.attrs["tile_size_px"] = tile_size_px - return Path(feature_file_path) + return FeaturePath(feature_file_path) + + +def create_random_patient_level_feature_file( + *, + tmp_path: Path, + feat_dim: int, + feat_filename: str | None = None, + encoder: str = "test-encoder", + precision: str = "float32", + feat_type: str = "patient", + code_hash: str = "testhash", + version: str | None = None, +) -> FeaturePath: + """ + Creates a random patient-level feature .h5 file with the correct metadata. + Returns the path to the created file. + """ + if feat_filename is None: + feat_filename = random_string(16) + feature_file_path = tmp_path / f"{feat_filename}.h5" + feats = torch.rand(1, feat_dim) + version = version or stamp.__version__ + with h5py.File(feature_file_path, "w") as h5: + h5["feats"] = feats.numpy() + h5.attrs["version"] = version + h5.attrs["encoder"] = encoder + h5.attrs["precision"] = precision + h5.attrs["stamp_version"] = version + h5.attrs["code_hash"] = code_hash + h5.attrs["feat_type"] = feat_type + return FeaturePath(feature_file_path) def random_patient_preds(*, n_patients: int, categories: list[str]) -> pd.DataFrame: @@ -200,5 +276,32 @@ def make_feature_file( h5.attrs["unit"] = "um" h5.attrs["tile_size_um"] = tile_size_um h5.attrs["tile_size_px"] = tile_size_px + h5.attrs["feat_type"] = "tile" + + return file + +def make_patient_level_feature_file( + *, + feats: torch.Tensor, + encoder: str = "test-encoder", + precision: str = "float32", + code_hash: str = "testhash", + version: str | None = None, +) -> io.BytesIO: + """ + Creates an in-memory patient-level feature .h5 file with the correct metadata. + Returns a BytesIO object. + """ + version = version or stamp.__version__ + file = io.BytesIO() + with h5py.File(file, "w") as h5: + h5["feats"] = feats.numpy() + h5.attrs["version"] = version + h5.attrs["encoder"] = encoder + h5.attrs["precision"] = precision + h5.attrs["stamp_version"] = version + h5.attrs["code_hash"] = code_hash + h5.attrs["feat_type"] = "patient" + file.seek(0) return file diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 880b074e..394afd9a 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -5,21 +5,22 @@ import numpy as np import pytest import torch -from random_data import create_random_dataset +from random_data import create_random_dataset, create_random_patient_level_dataset from stamp.modeling.crossval import categorical_crossval_ @pytest.mark.slow @pytest.mark.filterwarnings("ignore:No positive samples in targets") +@pytest.mark.parametrize("feature_type", ["tile", "patient"]) def test_crossval_integration( - *, tmp_path: Path, - n_patients: int = 800, + feature_type: str, + n_patients: int = 80, max_slides_per_patient: int = 3, min_tiles_per_slide: int = 8, - max_tiles_per_slide: int = 2**10, - feat_dim: int = 25, + max_tiles_per_slide: int = 32, + feat_dim: int = 8, n_categories: int = 3, use_alibi: bool = False, use_vary_precision_transform: bool = False, @@ -28,15 +29,27 @@ def test_crossval_integration( torch.manual_seed(0) np.random.seed(0) - clini_path, slide_path, feature_dir, categories = create_random_dataset( - dir=tmp_path, - n_categories=n_categories, - n_patients=n_patients, - max_slides_per_patient=max_slides_per_patient, - min_tiles_per_slide=min_tiles_per_slide, - max_tiles_per_slide=max_tiles_per_slide, - feat_dim=feat_dim, - ) + if feature_type == "tile": + clini_path, slide_path, feature_dir, categories = create_random_dataset( + dir=tmp_path, + n_categories=n_categories, + n_patients=n_patients, + max_slides_per_patient=max_slides_per_patient, + min_tiles_per_slide=min_tiles_per_slide, + max_tiles_per_slide=max_tiles_per_slide, + feat_dim=feat_dim, + ) + elif feature_type == "patient": + clini_path, slide_path, feature_dir, categories = ( + create_random_patient_level_dataset( + dir=tmp_path, + n_categories=n_categories, + n_patients=n_patients, + feat_dim=feat_dim, + ) + ) + else: + raise ValueError(f"Unknown feature_type: {feature_type}") output_dir = tmp_path / "output" diff --git a/tests/test_data.py b/tests/test_data.py index 4e15678f..edd678ee 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -5,13 +5,18 @@ import h5py import pytest import torch -from random_data import make_feature_file, make_old_feature_file +from random_data import ( + create_random_patient_level_feature_file, + make_feature_file, + make_old_feature_file, +) from torch.utils.data import DataLoader from stamp.modeling.data import ( BagDataset, CoordsInfo, PatientData, + PatientFeatureDataset, filter_complete_patient_data_, get_coords, ) @@ -80,7 +85,7 @@ def test_get_cohort_df(tmp_path: Path) -> None: "feature_file_creator", [make_feature_file, make_old_feature_file], ) -def test_dataset( +def test_bag_dataset( feature_file_creator, bag_size: BagSize = BagSize(5), dim_feats: int = 34, @@ -125,6 +130,32 @@ def test_dataset( assert (bag_sizes <= bag_size).all() +def test_patient_feature_dataset( + tmp_path: Path, dim_feats: int = 16, batch_size: int = 2 +) -> None: + # Create 3 random patient-level feature files on disk + files = [ + create_random_patient_level_feature_file(tmp_path=tmp_path, feat_dim=dim_feats) + for _ in range(3) + ] + # One-hot encoded labels for 3 samples, 4 categories + labels = torch.eye(4)[:3] + + ds = PatientFeatureDataset(files, labels, transform=None) + assert len(ds) == 3 + + # Test single dataset item + feats, label = ds[0] + assert feats.shape == (dim_feats,) + assert torch.allclose(label, labels[0]) + + # Test batching + dl = DataLoader(ds, batch_size=batch_size, shuffle=False) + feats_batch, labels_batch = next(iter(dl)) + assert feats_batch.shape == (batch_size, dim_feats) + assert labels_batch.shape == (batch_size, 4) + + def test_get_coords_with_mpp() -> None: # Test new feature file with valid mpp calculation file_bytes = make_feature_file( diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 2848bdfa..8bb0d6ec 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,18 +1,27 @@ +from pathlib import Path + import numpy as np -import numpy.typing as npt import pytest import torch -from random_data import make_old_feature_file +from random_data import create_random_patient_level_feature_file, make_old_feature_file -from stamp.modeling.data import PatientData +from stamp.modeling.data import ( + PatientData, + patient_feature_dataloader, + tile_bag_dataloader, +) from stamp.modeling.deploy import _predict, _to_prediction_df from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.types import GroundTruth, PatientId @pytest.mark.filterwarnings("ignore:GPU available but not used") +@pytest.mark.filterwarnings( + "ignore:The 'predict_dataloader' does not have many workers which may be a bottleneck" +) def test_predict( - categories: npt.NDArray = np.array(["foo", "bar", "baz"]), + categories: list[str] = ["foo", "bar", "baz"], n_heads: int = 7, dim_input: int = 12, ) -> None: @@ -42,10 +51,20 @@ def test_predict( ) } + test_dl, _ = tile_bag_dataloader( + patient_data=list(patient_to_data.values()), + bag_size=None, + categories=list(model.categories), + batch_size=1, + shuffle=False, + num_workers=2, + transform=None, + ) + predictions = _predict( model=model, - patient_to_data=patient_to_data, - num_workers=2, + test_dl=test_dl, + patient_ids=list(patient_to_data.keys()), accelerator="cpu", ) @@ -75,10 +94,20 @@ def test_predict( ), } + more_test_dl, _ = tile_bag_dataloader( + patient_data=list(more_patients_to_data.values()), + bag_size=None, + categories=list(model.categories), + batch_size=1, + shuffle=False, + num_workers=2, + transform=None, + ) + more_predictions = _predict( model=model, - patient_to_data=more_patients_to_data, - num_workers=2, + test_dl=more_test_dl, + patient_ids=list(more_patients_to_data.keys()), accelerator="cpu", ) @@ -91,6 +120,105 @@ def test_predict( ), "the same inputs should repeatedly yield the same results" +def test_predict_patient_level( + tmp_path: Path, categories: list[str] = ["foo", "bar", "baz"], dim_feats: int = 12 +): + model = LitMLPClassifier( + categories=categories, + category_weights=torch.rand(len(categories)), + dim_input=dim_feats, + dim_hidden=32, + num_layers=2, + dropout=0.2, + ground_truth_label="test", + train_patients=["pat1", "pat2"], + valid_patients=["pat3", "pat4"], + ) + + # Create 3 random patient-level feature files on disk + patient_ids = [PatientId(f"pat{i}") for i in range(5, 8)] + labels = ["foo", "bar", "baz"] + files = [ + create_random_patient_level_feature_file( + tmp_path=tmp_path, feat_dim=dim_feats, feat_filename=str(pid) + ) + for pid in patient_ids + ] + patient_to_data = { + pid: PatientData( + ground_truth=label, + feature_files={file}, + ) + for pid, label, file in zip(patient_ids, labels, files) + } + + test_dl, _ = patient_feature_dataloader( + patient_data=list(patient_to_data.values()), + categories=categories, + batch_size=1, + shuffle=False, + num_workers=1, + transform=None, + ) + + predictions = _predict( + model=model, + test_dl=test_dl, + patient_ids=patient_ids, + accelerator="cpu", + ) + + assert len(predictions) == len(patient_to_data) + for pid in patient_ids: + assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + + # Check if scores are consistent between runs and different for different patients + more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] + more_labels = ["foo", "bar", "baz"] + more_files = [ + create_random_patient_level_feature_file( + tmp_path=tmp_path, feat_dim=dim_feats, feat_filename=str(pid) + ) + for pid in more_patient_ids + ] + more_patient_to_data = { + pid: PatientData( + ground_truth=label, + feature_files={file}, + ) + for pid, label, file in zip(more_patient_ids, more_labels, more_files) + } + # Add the original patient for repeatability check + all_patient_ids = more_patient_ids + [patient_ids[0]] + + more_test_dl, _ = patient_feature_dataloader( + patient_data=[more_patient_to_data[pid] for pid in more_patient_ids] + + [patient_to_data[patient_ids[0]]], + categories=categories, + batch_size=1, + shuffle=False, + num_workers=1, + transform=None, + ) + + more_predictions = _predict( + model=model, + test_dl=more_test_dl, + patient_ids=all_patient_ids, + accelerator="cpu", + ) + + assert len(more_predictions) == len(all_patient_ids) + # Different patients should give different results + assert not torch.allclose( + more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + ), "different inputs should give different results" + # The same patient should yield the same result + assert torch.allclose( + predictions[patient_ids[0]], more_predictions[patient_ids[0]] + ), "the same inputs should repeatedly yield the same results" + + def test_to_prediction_df() -> None: n_heads = 7 model = LitVisionTransformer( diff --git a/tests/test_deployment_backward_compatability.py b/tests/test_deployment_backward_compatibility.py similarity index 64% rename from tests/test_deployment_backward_compatability.py rename to tests/test_deployment_backward_compatibility.py index 53a97829..0594d450 100644 --- a/tests/test_deployment_backward_compatability.py +++ b/tests/test_deployment_backward_compatibility.py @@ -2,16 +2,16 @@ import torch from stamp.cache import download_file -from stamp.modeling.data import PatientData +from stamp.modeling.data import PatientData, tile_bag_dataloader from stamp.modeling.deploy import _predict from stamp.modeling.lightning_model import LitVisionTransformer -from stamp.types import FeaturePath +from stamp.types import FeaturePath, PatientId @pytest.mark.filterwarnings( "ignore:The 'predict_dataloader' does not have many workers" ) -def test_backwards_compatability() -> None: +def test_backwards_compatibility() -> None: example_checkpoint_path = download_file( url="https://github.com/KatherLab/STAMP/releases/download/2.0.0-dev8/example-model.ckpt", file_name="example-model.ckpt", @@ -25,15 +25,28 @@ def test_backwards_compatability() -> None: model = LitVisionTransformer.load_from_checkpoint(example_checkpoint_path) + # Prepare PatientData and DataLoader for the test patient + patient_id = PatientId("TestPatient") + patient_to_data = { + patient_id: PatientData( + ground_truth=None, + feature_files=[FeaturePath(example_feature_path)], + ) + } + test_dl, _ = tile_bag_dataloader( + patient_data=list(patient_to_data.values()), + bag_size=None, + categories=list(model.categories), + batch_size=1, + shuffle=False, + num_workers=1, + transform=None, + ) + predictions = _predict( model=model, - patient_to_data={ - "TestPatient": PatientData( - ground_truth=None, - feature_files=[FeaturePath(example_feature_path)], - ) - }, - num_workers=1, + test_dl=test_dl, + patient_ids=[patient_id], accelerator="gpu" if torch.cuda.is_available() else "cpu", ) diff --git a/tests/test_model.py b/tests/test_model.py index af30aa40..ccdd4aa5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.modeling.vision_transformer import VisionTransformer @@ -69,3 +70,52 @@ def test_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_mlp_classifier_dims( + num_classes: int = 3, + batch_size: int = 6, + input_dim: int = 32, + dim_hidden: int = 64, + num_layers: int = 2, +) -> None: + model = LitMLPClassifier( + categories=[str(i) for i in range(num_classes)], + category_weights=torch.ones(num_classes), + dim_input=input_dim, + dim_hidden=dim_hidden, + num_layers=num_layers, + dropout=0.1, + ground_truth_label="test", + train_patients=["pat1", "pat2"], + valid_patients=["pat3", "pat4"], + ) + feats = torch.rand((batch_size, input_dim)) + logits = model.forward(feats) + assert logits.shape == (batch_size, num_classes) + + +def test_mlp_inference_reproducibility( + num_classes: int = 4, + batch_size: int = 7, + input_dim: int = 33, + dim_hidden: int = 64, + num_layers: int = 3, +) -> None: + model = LitMLPClassifier( + categories=[str(i) for i in range(num_classes)], + category_weights=torch.ones(num_classes), + dim_input=input_dim, + dim_hidden=dim_hidden, + num_layers=num_layers, + dropout=0.1, + ground_truth_label="test", + train_patients=["pat1", "pat2"], + valid_patients=["pat3", "pat4"], + ) + model = model.eval() + feats = torch.rand((batch_size, input_dim)) + with torch.inference_mode(): + logits1 = model.forward(feats) + logits2 = model.forward(feats) + assert torch.allclose(logits1, logits2) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 62ad9717..2e415856 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -5,7 +5,7 @@ import numpy as np import pytest import torch -from random_data import create_random_dataset +from random_data import create_random_dataset, create_random_patient_level_dataset from stamp.modeling.deploy import deploy_categorical_model_ from stamp.modeling.train import train_categorical_model_ @@ -90,3 +90,80 @@ def test_train_deploy_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +@pytest.mark.parametrize( + "use_alibi,use_vary_precision_transform", + [ + pytest.param(False, False, id="no experimental features"), + pytest.param(True, False, id="use alibi"), + pytest.param(False, True, id="use vary_precision_transform"), + ], +) +def test_train_deploy_patient_level_integration( + *, + tmp_path: Path, + feat_dim: int = 25, + use_alibi: bool, + use_vary_precision_transform: bool, +) -> None: + random.seed(0) + torch.manual_seed(0) + np.random.seed(0) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + train_clini_path, train_slide_path, train_feature_dir, categories = ( + create_random_patient_level_dataset( + dir=tmp_path / "train", + n_categories=3, + n_patients=400, + feat_dim=feat_dim, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_patient_level_dataset( + dir=tmp_path / "deploy", + categories=categories, + n_patients=50, + feat_dim=feat_dim, + ) + ) + + train_categorical_model_( + clini_table=train_clini_path, + slide_table=None, # Not needed for patient-level + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label="ground-truth", + filename_label="slide_path", # Not used for patient-level + categories=categories, + # Dataset and -loader parameters + bag_size=1, # Not used for patient-level, but required by signature + num_workers=min(os.cpu_count() or 1, 16), + # Training paramenters + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + # Experimental features + use_vary_precision_transform=use_vary_precision_transform, + use_alibi=use_alibi, + ) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=None, # Not needed for patient-level + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label="ground-truth", + filename_label="slide_path", # Not used for patient-level + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + ) diff --git a/uv.lock b/uv.lock index d29d37c7..a3983d5d 100644 --- a/uv.lock +++ b/uv.lock @@ -1640,6 +1640,9 @@ wheels = [ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -2962,7 +2965,6 @@ all = [ { name = "einops-exts" }, { name = "environs" }, { name = "gdown" }, - { name = "gigapath" }, { name = "huggingface-hub" }, { name = "madeleine" }, { name = "musk" }, @@ -3058,7 +3060,7 @@ requires-dist = [ { name = "sacremoses", marker = "extra == 'prism'", specifier = "==0.1.1" }, { name = "scikit-learn", specifier = ">=1.5.2" }, { name = "scipy", specifier = ">=1.15.1" }, - { name = "stamp", extras = ["conch", "ctranspath", "uni", "virchow2", "chief-ctranspath", "conch1-5", "gigapath", "prism", "madeleine", "musk", "plip"], marker = "extra == 'all'" }, + { name = "stamp", extras = ["conch", "ctranspath", "uni", "virchow2", "chief-ctranspath", "conch1-5", "prism", "madeleine", "musk", "plip"], marker = "extra == 'all'" }, { name = "timm", specifier = ">=0.9.11" }, { name = "torch", specifier = ">=2.5.1" }, { name = "torch", marker = "extra == 'chief-ctranspath'", specifier = ">=2.0.0" }, From 7cbbc78007a897a3e00e5466c2311b8c6eb72b7b Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 14 Jul 2025 13:20:05 +0100 Subject: [PATCH 11/18] reuse read tables function from data.py --- src/stamp/encoding/encoder/__init__.py | 9 ++------- src/stamp/modeling/data.py | 8 ++++---- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 2696af00..d9035f7a 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -6,7 +6,6 @@ import h5py import numpy as np -import pandas as pd import torch from torch import Tensor from tqdm import tqdm @@ -14,7 +13,7 @@ import stamp from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName -from stamp.modeling.data import CoordsInfo, get_coords +from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel @@ -115,7 +114,7 @@ def encode_patients_( if self.precision == torch.float16: self.model.half() - slide_table = self._read_slide_table(slide_table_path) + slide_table = read_table(slide_table_path) patient_groups = slide_table.groupby(patient_label) for patient_id, group in (progress := tqdm(patient_groups)): @@ -165,10 +164,6 @@ def _generate_patient_embedding( """Generate patient embedding. Must be implemented by subclasses.""" pass - @staticmethod - def _read_slide_table(slide_table_path: Path) -> pd.DataFrame: - return pd.read_csv(slide_table_path) - def _validate_and_read_features(self, h5_path: str) -> tuple[Tensor, CoordsInfo]: feats, coords, extractor = self._read_h5(h5_path) if extractor not in self.required_extractors: diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 545534fc..23e3ca08 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -191,7 +191,7 @@ def load_patient_level_data( # clinical data to the patient-level feature paths that avoids # creating another slide table for encoded featuress is welcome :P. - clini_df = _read_table( + clini_df = read_table( clini_table, usecols=[patient_label, ground_truth_label], dtype=str, @@ -435,7 +435,7 @@ def patient_to_ground_truth_from_clini_table_( ground_truth_label: PandasLabel, ) -> dict[PatientId, GroundTruth]: """Loads the patients and their ground truths from a clini table.""" - clini_df = _read_table( + clini_df = read_table( clini_table_path, usecols=[patient_label, ground_truth_label], dtype=str, @@ -469,7 +469,7 @@ def slide_to_patient_from_slide_table_( filename_label: PandasLabel, ) -> dict[FeaturePath, PatientId]: """Creates a slide-to-patient mapping from a slide table.""" - slide_df = _read_table( + slide_df = read_table( slide_table_path, usecols=[patient_label, filename_label], dtype=str, @@ -485,7 +485,7 @@ def slide_to_patient_from_slide_table_( return slide_to_patient -def _read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: +def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: if not isinstance(path, Path): return pd.read_csv(path, **kwargs) elif path.suffix == ".xlsx": From c1949ee4ff50c24ff0e79e12f5466e58665726f8 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 14 Jul 2025 14:12:56 +0100 Subject: [PATCH 12/18] add extractor metadata and update docs --- getting-started.md | 40 +++++++++++++++++++++-------- src/stamp/preprocessing/__init__.py | 1 + 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/getting-started.md b/getting-started.md index 78eaa9a7..af3990b6 100644 --- a/getting-started.md +++ b/getting-started.md @@ -63,15 +63,6 @@ Stamp currently supports the following feature extractors: As some of the above require you to request access to the model on huggingface, we will stick with ctranspath for this example. -In order to use a feature extractor, -you also have to install their respective dependencies. -You can do so by specifying the feature extractor you want to use -when installing stamp: -```sh -# Install stamp including the dependencies for all feature extractors -pip install "git+https://github.com/KatherLab/stamp@v2[all]" -``` - Open the `stamp-test-experiment/config.yaml` we created in the last step and modify the `output_dir`, `wsi_dir` and `cache_dir` entries in the `preprocessing` section @@ -126,6 +117,12 @@ as well as `.jpg`s showing from which parts of the slide features are extracted. Most of the background should be marked in red, meaning ignored that it was ignored during feature extraction. +> In case you want to use a gated model (e.g. Virchow2), you need to login in your console using: +> ``` +>huggingface-cli login +> ``` +> More info about this [here](https://huggingface.co/docs/huggingface_hub/en/guides/cli). + > **If you are using the UNI or CONCH models** > and working in an environment where your home directory storage is limited, > you may want to also specify your huggingface storage directory @@ -367,4 +364,27 @@ patient_encoding: stamp --config stamp-test-experiment/config.yaml encode_patients ``` -The output `.h5` features will have the patient's id as name. \ No newline at end of file +The output `.h5` features will have the patient's id as name. + +## Training with Patient-Level Features + +Once you have patient-level features, +you can train models directly on these features. This is useful because: +- **Efficient with Limited Data**: Patient-level modeling often performs better when data is scarce, since pretrained encoders can extract robust features from each slide as a whole. +- **Faster Training & Reduced Overfitting**: With fewer parameters to train compared to tile-level models, patient-level models train more quickly and are less prone to overfitting. +- **Enables Interpretable Cohort Analysis**: Patient-level features can be used for unsupervised analyses, such as clustering, making it easier to interpret and explore patient subgroups within your cohort. + +> **Note:** Slide-level features are not supported for modeling because the ground truth +> labels in the clinical table are at the patient level. + +To train a model using patient-level features, you can use the same command as before: +```sh +stamp --config stamp-test-experiment/config.yaml crossval +``` + +The key differences for patient-level modeling are: +- The `feature_dir` should contain patient-level `.h5` files (one per patient) +- The `slide_table` is not needed since there's a direct mapping from patient ID to feature file +- STAMP will automatically detect that these are patient-level features and use a MultiLayer Perceptron (MLP) classifier instead of the Vision Transformer + +You can then run statistics as done with tile-level features. \ No newline at end of file diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index 6e2bdc66..f3eb4085 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -326,6 +326,7 @@ def extract_( h5_fp.attrs["tile_size_um"] = tile_size_um # changed in v2.1.0 h5_fp.attrs["tile_size_px"] = tile_size_px h5_fp.attrs["code_hash"] = code_hash + h5_fp.attrs["feat_type"] = "tile" except Exception: _logger.exception(f"error while writing {feature_output_path}") if tmp_h5_file is not None: From e798090e0600128ab6570612b4a9f81e20b70891 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Wed, 16 Jul 2025 16:28:14 +0100 Subject: [PATCH 13/18] add prism --- src/stamp/encoding/__init__.py | 5 +++++ src/stamp/encoding/config.py | 3 +-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 95492230..5e68d43a 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -64,6 +64,11 @@ def init_slide_encoder_( selected_encoder: Encoder = Madeleine() + case EncoderName.PRISM: + from stamp.encoding.encoder.prism import Prism + + selected_encoder: Encoder = Prism() + case Encoder(): selected_encoder = encoder diff --git a/src/stamp/encoding/config.py b/src/stamp/encoding/config.py index 5158b3b9..e743fcfd 100644 --- a/src/stamp/encoding/config.py +++ b/src/stamp/encoding/config.py @@ -13,8 +13,7 @@ class EncoderName(StrEnum): TITAN = "titan" GIGAPATH = "gigapath" MADELEINE = "madeleine" - # PRISM = "paigeai-prism" - # waiting for paige-ai authors to fix it + PRISM = "prism" class SlideEncodingConfig(BaseModel, arbitrary_types_allowed=True): From 8a866be7a2a23ec8e1361b8927a19b19af06cf7f Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Fri, 18 Jul 2025 16:48:38 +0100 Subject: [PATCH 14/18] add advanced config move training hyperparams and dataloading stuff into a separate config section. --- src/stamp/__main__.py | 63 +++----- src/stamp/config.py | 9 +- src/stamp/config.yaml | 50 +++++-- src/stamp/encoding/__init__.py | 5 + src/stamp/modeling/config.py | 59 ++++++-- src/stamp/modeling/crossval.py | 102 ++++++------- src/stamp/modeling/lightning_model.py | 2 + src/stamp/modeling/mlp_classifier.py | 2 + src/stamp/modeling/registry.py | 34 +++++ src/stamp/modeling/train.py | 203 +++++++++++--------------- tests/test_encoders.py | 4 +- 11 files changed, 286 insertions(+), 247 deletions(-) create mode 100644 src/stamp/modeling/registry.py diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 89792122..3806cddb 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -7,6 +7,12 @@ import yaml from stamp.config import StampConfig +from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + VitModelParams, +) STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") @@ -126,32 +132,20 @@ def _run_cli(args: argparse.Namespace) -> None: if config.training is None: raise ValueError("no training configuration supplied") + # use default advanced config in case none is provided + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()) + ) + _add_file_handle_(_logger, output_dir=config.training.output_dir) _logger.info( "using the following configuration:\n" f"{yaml.dump(config.training.model_dump(mode='json'))}" ) - # We pass every parameter explicitly so our type checker can do its work. + train_categorical_model_( - output_dir=config.training.output_dir, - clini_table=config.training.clini_table, - slide_table=config.training.slide_table, - feature_dir=config.training.feature_dir, - patient_label=config.training.patient_label, - ground_truth_label=config.training.ground_truth_label, - filename_label=config.training.filename_label, - categories=config.training.categories, - # Dataset and -loader parameters - bag_size=config.training.bag_size, - num_workers=config.training.num_workers, - # Training paramenters - batch_size=config.training.batch_size, - max_epochs=config.training.max_epochs, - patience=config.training.patience, - accelerator=config.training.accelerator, - # Experimental features - use_vary_precision_transform=config.training.use_vary_precision_transform, - use_alibi=config.training.use_alibi, + config=config.training, advanced=config.advanced_config ) case "deploy": @@ -189,27 +183,16 @@ def _run_cli(args: argparse.Namespace) -> None: "using the following configuration:\n" f"{yaml.dump(config.crossval.model_dump(mode='json'))}" ) + + # use default advanced config in case none is provided + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()) + ) + categorical_crossval_( - output_dir=config.crossval.output_dir, - clini_table=config.crossval.clini_table, - slide_table=config.crossval.slide_table, - feature_dir=config.crossval.feature_dir, - patient_label=config.crossval.patient_label, - ground_truth_label=config.crossval.ground_truth_label, - filename_label=config.crossval.filename_label, - categories=config.crossval.categories, - n_splits=config.crossval.n_splits, - # Dataset and -loader parameters - bag_size=config.crossval.bag_size, - num_workers=config.crossval.num_workers, - # Crossval paramenters - batch_size=config.crossval.batch_size, - max_epochs=config.crossval.max_epochs, - patience=config.crossval.patience, - accelerator=config.crossval.accelerator, - # Experimental Features - use_vary_precision_transform=config.crossval.use_vary_precision_transform, - use_alibi=config.crossval.use_alibi, + config=config.crossval, + advanced=config.advanced_config, ) case "statistics": diff --git a/src/stamp/config.py b/src/stamp/config.py index ca283dba..3d847324 100644 --- a/src/stamp/config.py +++ b/src/stamp/config.py @@ -2,7 +2,12 @@ from stamp.encoding.config import PatientEncodingConfig, SlideEncodingConfig from stamp.heatmaps.config import HeatmapConfig -from stamp.modeling.config import CrossvalConfig, DeploymentConfig, TrainConfig +from stamp.modeling.config import ( + AdvancedConfig, + CrossvalConfig, + DeploymentConfig, + TrainConfig, +) from stamp.preprocessing.config import PreprocessingConfig from stamp.statistics import StatsConfig @@ -23,3 +28,5 @@ class StampConfig(BaseModel): slide_encoding: SlideEncodingConfig | None = None patient_encoding: PatientEncodingConfig | None = None + + advanced_config: AdvancedConfig | None = None diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 3ba0f0d2..83cf887e 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -79,16 +79,14 @@ crossval: # Number of folds to split the data into for cross-validation #n_splits: 5 - # Experimental features: + # Path to a YAML file with advanced training parameters. + #params_path: "path/to/train_params.yaml" - # Please try uncommenting the settings below - # and report if they improve / reduce model performance! + # Experimental features: # Change the precision of features during training #use_vary_precision_transform: true - # Use ALiBi positional embedding - # use_alibi: true training: @@ -126,17 +124,14 @@ training: # If unspecified, they will be inferred from the table itself. #categories: ["mutated", "wild type"] - # Experimental features: + # Path to a YAML file with advanced training parameters. + #params_path: "path/to/model_params.yaml" - # Please try uncommenting the settings below - # and report if they improve / reduce model performance! + # Experimental features: # Change the precision of features during training #use_vary_precision_transform: true - # Use ALiBi positional embedding - # use_alibi: true - deployment: output_dir: "/path/to/save/files/to" @@ -272,3 +267,36 @@ patient_encoding: # Add a hash of the entire preprocessing codebase in the feature folder name. #generate_hash: True + + +advanced_config: + max_epochs: 64 + patience: 16 + batch_size: 64 + # Only for tile-level training. Reducing its amount could affect + # model performance. Reduces memory consumption. Default value works + # fine for most cases. + bag_size: 512 + # Optional parameters + #num_workers: 16 # Default chosen by cpu cores + + # Select a model. Not working yet, added for future support. + # Now it uses a ViT for tile features and a MLP for patient features. + #model_name: "vit" + + model_params: + # Tile-level training models: + vit: # Vision Transformer + dim_model: 512 + dim_feedforward: 512 + n_heads: 8 + n_layers: 2 + dropout: 0.25 + # Experimental feature: Use ALiBi positional embedding + use_alibi: false + + # Patient-level training models: + mlp: # Multilayer Perceptron + dim_hidden: 512 + num_layers: 2 + dropout: 0.25 diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 5e68d43a..3148f635 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -150,6 +150,11 @@ def init_patient_encoder_( selected_encoder: Encoder = Madeleine() + case EncoderName.PRISM: + from stamp.encoding.encoder.prism import Prism + + selected_encoder: Encoder = Prism() + case Encoder(): selected_encoder = encoder diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 0ce184ab..0e1c76fd 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -1,10 +1,12 @@ import os +from collections.abc import Sequence from pathlib import Path import torch from pydantic import BaseModel, ConfigDict, Field -from stamp.types import PandasLabel +from stamp.modeling.registry import ModelName +from stamp.types import Category, PandasLabel class TrainConfig(BaseModel): @@ -21,24 +23,18 @@ class TrainConfig(BaseModel): ground_truth_label: PandasLabel = Field( description="Name of categorical column in clinical table to train on" ) - categories: list[str] | None = None + categories: Sequence[Category] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" - # Dataset and -loader parameters - bag_size: int = 512 - num_workers: int = min(os.cpu_count() or 1, 16) - - # Training paramenters - batch_size: int = 64 - max_epochs: int = 64 - patience: int = 16 - accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + params_path: Path | None = Field( + default=None, + description="Optional: Path to a YAML file with advanced training parameters.", + ) # Experimental features use_vary_precision_transform: bool = False - use_alibi: bool = False class CrossvalConfig(TrainConfig): @@ -61,3 +57,42 @@ class DeploymentConfig(BaseModel): num_workers: int = min(os.cpu_count() or 1, 16) accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + + +class VitModelParams(BaseModel): + model_config = ConfigDict(extra="forbid") + dim_model: int = 512 + dim_feedforward: int = 512 + n_heads: int = 8 + n_layers: int = 2 + dropout: float = 0.25 + # Experimental feature: Use ALiBi positional embedding + use_alibi: bool = False + + +class MlpModelParams(BaseModel): + model_config = ConfigDict(extra="forbid") + dim_hidden: int = 512 + num_layers: int = 2 + dropout: float = 0.25 + + +class ModelParams(BaseModel): + model_config = ConfigDict(extra="forbid") + vit: VitModelParams + mlp: MlpModelParams + + +class AdvancedConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + bag_size: int = 512 + num_workers: int = min(os.cpu_count() or 1, 16) + batch_size: int = 64 + max_epochs: int = 64 + patience: int = 16 + accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + model_name: ModelName | None = Field( + default=None, + description='Optional: "vit" or "mlp". Defaults based on feature type.', + ) + model_params: ModelParams diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 040a199c..37bdf381 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,13 +1,12 @@ import logging from collections.abc import Mapping, Sequence -from pathlib import Path from typing import Any, Final import numpy as np -from lightning.pytorch.accelerators.accelerator import Accelerator from pydantic import BaseModel from sklearn.model_selection import StratifiedKFold +from stamp.modeling.config import AdvancedConfig, CrossvalConfig from stamp.modeling.data import ( PatientData, detect_feature_type, @@ -20,13 +19,12 @@ ) from stamp.modeling.deploy import _predict, _to_prediction_df from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier from stamp.modeling.train import setup_model_for_training, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - Category, FeaturePath, GroundTruth, - PandasLabel, PatientId, ) @@ -47,46 +45,28 @@ class _Splits(BaseModel): def categorical_crossval_( - clini_table: Path, - slide_table: Path | None, - feature_dir: Path, - output_dir: Path, - patient_label: PandasLabel, - ground_truth_label: PandasLabel, - filename_label: PandasLabel, - categories: Sequence[Category] | None, - n_splits: int, - # Dataset and -loader parameters - bag_size: int, - num_workers: int, - # Training paramenters - batch_size: int, - max_epochs: int, - patience: int, - accelerator: str | Accelerator, - # Experimental features - use_vary_precision_transform: bool, - use_alibi: bool, + config: CrossvalConfig, + advanced: AdvancedConfig, ) -> None: - feature_type = detect_feature_type(feature_dir) + feature_type = detect_feature_type(config.feature_dir) _logger.info(f"Detected feature type: {feature_type}") if feature_type == "tile": - if slide_table is None: + if config.slide_table is None: raise ValueError("A slide table is required for tile-level modeling") patient_to_ground_truth: dict[PatientId, GroundTruth] = ( patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, - patient_label=patient_label, + clini_table_path=config.clini_table, + ground_truth_label=config.ground_truth_label, + patient_label=config.patient_label, ) ) slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, + slide_table_path=config.slide_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + filename_label=config.filename_label, ) ) patient_to_data: Mapping[PatientId, PatientData] = ( @@ -98,10 +78,10 @@ def categorical_crossval_( ) elif feature_type == "patient": patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( - clini_table=clini_table, - feature_dir=feature_dir, - patient_label=patient_label, - ground_truth_label=ground_truth_label, + clini_table=config.clini_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, ) patient_to_ground_truth: dict[PatientId, GroundTruth] = { pid: pd.ground_truth for pid, pd in patient_to_data.items() @@ -109,12 +89,12 @@ def categorical_crossval_( else: raise RuntimeError(f"Unsupported feature type: {feature_type}") - output_dir.mkdir(parents=True, exist_ok=True) - splits_file = output_dir / "splits.json" + config.output_dir.mkdir(parents=True, exist_ok=True) + splits_file = config.output_dir / "splits.json" # Generate the splits, or load them from the splits file if they already exist if not splits_file.exists(): - splits = _get_splits(patient_to_data=patient_to_data, n_splits=n_splits) + splits = _get_splits(patient_to_data=patient_to_data, n_splits=config.n_splits) with open(splits_file, "w") as fp: fp.write(splits.model_dump_json(indent=4)) else: @@ -140,7 +120,7 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) - categories = categories or sorted( + categories = config.categories or sorted( { patient_data.ground_truth for patient_data in patient_to_data.values() @@ -149,7 +129,7 @@ def categorical_crossval_( ) for split_i, split in enumerate(splits.splits): - split_dir = output_dir / f"split-{split_i}" + split_dir = config.output_dir / f"split-{split_i}" if (split_dir / "patient-preds.csv").exists(): _logger.info( @@ -161,13 +141,11 @@ def categorical_crossval_( # Train the model if not (split_dir / "model.ckpt").exists(): model, train_dl, valid_dl = setup_model_for_training( - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, - ground_truth_label=ground_truth_label, - bag_size=bag_size, - num_workers=num_workers, - batch_size=batch_size, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, + ground_truth_label=config.ground_truth_label, + advanced=advanced, patient_to_data={ patient_id: patient_data for patient_id, patient_data in patient_to_data.items() @@ -185,10 +163,9 @@ def categorical_crossval_( ), train_transform=( VaryPrecisionTransform(min_fraction_bits=1) - if use_vary_precision_transform + if config.use_vary_precision_transform else None ), - use_alibi=use_alibi, feature_type=feature_type, ) model = train_model_( @@ -196,12 +173,17 @@ def categorical_crossval_( model=model, train_dl=train_dl, valid_dl=valid_dl, - max_epochs=max_epochs, - patience=patience, - accelerator=accelerator, + max_epochs=advanced.max_epochs, + patience=advanced.patience, + accelerator=advanced.accelerator, ) else: - model = LitVisionTransformer.load_from_checkpoint(split_dir / "model.ckpt") + if feature_type == "tile": + model = LitVisionTransformer.load_from_checkpoint( + split_dir / "model.ckpt" + ) + else: + model = LitMLPClassifier.load_from_checkpoint(split_dir / "model.ckpt") # Deploy on test set if not (split_dir / "patient-preds.csv").exists(): @@ -217,7 +199,7 @@ def categorical_crossval_( categories=categories, batch_size=1, shuffle=False, - num_workers=num_workers, + num_workers=advanced.num_workers, transform=None, ) elif feature_type == "patient": @@ -226,7 +208,7 @@ def categorical_crossval_( categories=categories, batch_size=1, shuffle=False, - num_workers=num_workers, + num_workers=advanced.num_workers, transform=None, ) else: @@ -236,15 +218,15 @@ def categorical_crossval_( model=model, test_dl=test_dl, patient_ids=test_patients, - accelerator=accelerator, + accelerator=advanced.accelerator, ) _to_prediction_df( categories=categories, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, - patient_label=patient_label, - ground_truth_label=ground_truth_label, + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) diff --git a/src/stamp/modeling/lightning_model.py b/src/stamp/modeling/lightning_model.py index c30c87ef..4849c4e1 100644 --- a/src/stamp/modeling/lightning_model.py +++ b/src/stamp/modeling/lightning_model.py @@ -57,6 +57,8 @@ class LitVisionTransformer(lightning.LightningModule): **metadata: Additional metadata to store with the model. """ + supported_features = ["tile"] + def __init__( self, *, diff --git a/src/stamp/modeling/mlp_classifier.py b/src/stamp/modeling/mlp_classifier.py index 0a85f191..98650f08 100644 --- a/src/stamp/modeling/mlp_classifier.py +++ b/src/stamp/modeling/mlp_classifier.py @@ -44,6 +44,8 @@ class LitMLPClassifier(lightning.LightningModule): PyTorch Lightning wrapper for MLPClassifier. """ + supported_features = ["patient"] + def __init__( self, *, diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py new file mode 100644 index 00000000..7be976bd --- /dev/null +++ b/src/stamp/modeling/registry.py @@ -0,0 +1,34 @@ +from enum import StrEnum +from typing import Sequence, Type, TypedDict + +import lightning + +from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.modeling.mlp_classifier import LitMLPClassifier + + +class ModelName(StrEnum): + """Enum for available model names.""" + + VIT = "vit" + MLP = "mlp" + + +class ModelInfo(TypedDict): + """A dictionary to map a model to supported feature types. For example, + a linear classifier is not compatible with tile-evel feats.""" + + model_class: Type[lightning.LightningModule] + supported_features: Sequence[str] + + +MODEL_REGISTRY: dict[ModelName, ModelInfo] = { + ModelName.VIT: { + "model_class": LitVisionTransformer, + "supported_features": LitVisionTransformer.supported_features, + }, + ModelName.MLP: { + "model_class": LitMLPClassifier, + "supported_features": LitMLPClassifier.supported_features, + }, +} diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 585914c3..ff798030 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -15,6 +15,7 @@ from sklearn.model_selection import train_test_split from torch.utils.data.dataloader import DataLoader +from stamp.modeling.config import AdvancedConfig, TrainConfig from stamp.modeling.data import ( BagDataset, PatientData, @@ -31,9 +32,8 @@ Bags, BagSizes, EncodedTargets, - LitVisionTransformer, ) -from stamp.modeling.mlp_classifier import LitMLPClassifier +from stamp.modeling.registry import MODEL_REGISTRY, ModelName from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import Category, CoordinatesBatch, GroundTruth, PandasLabel, PatientId @@ -46,73 +46,26 @@ def train_categorical_model_( *, - clini_table: Path, - slide_table: Path | None, - feature_dir: Path, - output_dir: Path, - patient_label: PandasLabel, - ground_truth_label: PandasLabel, - filename_label: PandasLabel, - categories: Sequence[Category] | None, - # Dataset and -loader parameters - bag_size: int, - num_workers: int, - # Training paramenters - batch_size: int, - max_epochs: int, - patience: int, - accelerator: str | Accelerator, - # Experimental features - use_vary_precision_transform: bool, - use_alibi: bool, + config: TrainConfig, + advanced: AdvancedConfig, ) -> None: - """Trains a model based on the feature type. - - Args: - clini_table: - An excel or csv file to read the clinical information from. - Must at least have the columns specified in the arguments - - `patient_label` (containing a unique patient ID) - and `ground_truth_label` (containing the ground truth to train for). - slide_table: - An excel or csv file to read the patient-slide associations from. - Must at least have the columns specified in the arguments - `patient_label` (containing the patient ID) - and `filename_label` - (containing a filename relative to `feature_dir` - in which some of the patient's features are stored). - feature_dir: - See `slide_table`. - output_dir: - Path into which to output the artifacts (trained model etc.) - generated during training. - patient_label: - See `clini_table`, `slide_table`. - ground_truth_label: - See `clini_table`. - filename_label: - See `slide_table`. - categories: - Categories of the ground truth. - Set to `None` to automatically infer. - """ - feature_type = detect_feature_type(feature_dir) + """Trains a model based on the feature type.""" + feature_type = detect_feature_type(config.feature_dir) _logger.info(f"Detected feature type: {feature_type}") if feature_type == "tile": - if slide_table is None: + if config.slide_table is None: raise ValueError("A slide table is required for tile-level modeling") patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=clini_table, - ground_truth_label=ground_truth_label, - patient_label=patient_label, + clini_table_path=config.clini_table, + ground_truth_label=config.ground_truth_label, + patient_label=config.patient_label, ) slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=slide_table, - feature_dir=feature_dir, - patient_label=patient_label, - filename_label=filename_label, + slide_table_path=config.slide_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + filename_label=config.filename_label, ) patient_to_data = filter_complete_patient_data_( patient_to_ground_truth=patient_to_ground_truth, @@ -121,13 +74,13 @@ def train_categorical_model_( ) elif feature_type == "patient": # Patient-level: ignore slide_table - if slide_table is not None: + if config.slide_table is not None: _logger.warning("slide_table is ignored for patient-level features.") patient_to_data = load_patient_level_data( - clini_table=clini_table, - feature_dir=feature_dir, - patient_label=patient_label, - ground_truth_label=ground_truth_label, + clini_table=config.clini_table, + feature_dir=config.feature_dir, + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, ) elif feature_type == "slide": raise RuntimeError( @@ -140,30 +93,27 @@ def train_categorical_model_( # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, - categories=categories, - bag_size=bag_size, - batch_size=batch_size, - num_workers=num_workers, - ground_truth_label=ground_truth_label, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, + categories=config.categories, + advanced=advanced, + ground_truth_label=config.ground_truth_label, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, train_transform=( VaryPrecisionTransform(min_fraction_bits=1) - if use_vary_precision_transform + if config.use_vary_precision_transform else None ), - use_alibi=use_alibi, feature_type=feature_type, ) train_model_( - output_dir=output_dir, + output_dir=config.output_dir, model=model, train_dl=train_dl, valid_dl=valid_dl, - max_epochs=max_epochs, - patience=patience, - accelerator=accelerator, + max_epochs=advanced.max_epochs, + patience=advanced.patience, + accelerator=advanced.accelerator, ) @@ -171,17 +121,14 @@ def setup_model_for_training( *, patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], categories: Sequence[Category] | None, - bag_size: int, - batch_size: int, - num_workers: int, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, - use_alibi: bool, + feature_type: str, + advanced: AdvancedConfig, # Metadata, has no effect on model training ground_truth_label: PandasLabel, clini_table: Path, slide_table: Path | None, feature_dir: Path, - feature_type: str, ) -> tuple[ lightning.LightningModule, DataLoader, @@ -193,56 +140,70 @@ def setup_model_for_training( setup_dataloaders_for_training( patient_to_data=patient_to_data, categories=categories, - bag_size=bag_size, - batch_size=batch_size, - num_workers=num_workers, + bag_size=advanced.bag_size, + batch_size=advanced.batch_size, + num_workers=advanced.num_workers, train_transform=train_transform, feature_type=feature_type, ) ) + _logger.info( + "Training dataloaders: bag_size=%s, batch_size=%s, num_workers=%s", + advanced.bag_size, + advanced.batch_size, + advanced.num_workers, + ) + category_weights = _compute_class_weights_and_check_categories( train_dl=train_dl, feature_type=feature_type, train_categories=train_categories, ) - # Model selection - if feature_type == "tile": - model = LitVisionTransformer( - categories=train_categories, - category_weights=category_weights, - dim_input=dim_feats, - dim_model=512, - dim_feedforward=512, - n_heads=8, - n_layers=2, - dropout=0.25, - use_alibi=use_alibi, - # Metadata, has no effect on model training - ground_truth_label=ground_truth_label, - train_patients=train_patients, - valid_patients=valid_patients, - clini_table=clini_table, - slide_table=slide_table, - feature_dir=feature_dir, + # 1. Default to a model if none is specified + if advanced.model_name is None: + advanced.model_name = ModelName.VIT if feature_type == "tile" else ModelName.MLP + _logger.info( + f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" ) - else: - model = LitMLPClassifier( - categories=train_categories, - category_weights=category_weights, - dim_input=dim_feats, - dim_hidden=512, - num_layers=2, - dropout=0.25, - # Metadata, has no effect on model training - ground_truth_label=ground_truth_label, - train_patients=train_patients, - valid_patients=valid_patients, - clini_table=clini_table, - feature_dir=feature_dir, + + # 2. Validate that the chosen model supports the feature type + model_info = MODEL_REGISTRY[advanced.model_name] + if feature_type not in model_info["supported_features"]: + raise ValueError( + f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " + f"Supported types are: {model_info['supported_features']}" ) + # 3. Get model-specific hyperparameters + model_specific_params = advanced.model_params.model_dump()[ + advanced.model_name.value + ] + + # 4. Prepare common parameters + common_params = { + "categories": train_categories, + "category_weights": category_weights, + "dim_input": dim_feats, + # Metadata, has no effect on model training + "model_name": advanced.model_name.value, + "ground_truth_label": ground_truth_label, + "train_patients": train_patients, + "valid_patients": valid_patients, + "clini_table": clini_table, + "slide_table": slide_table, + "feature_dir": feature_dir, + } + + # 4. Instantiate the model dynamically + ModelClass = model_info["model_class"] + all_params = {**common_params, **model_specific_params} + _logger.info( + f"Instantiating model '{advanced.model_name.value}' with parameters: {model_specific_params}" + ) + model = ModelClass(**all_params) + return model, train_dl, valid_dl diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 43e3e1f8..c740d923 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -39,12 +39,12 @@ EncoderName.GIGAPATH: ExtractorName.GIGAPATH, EncoderName.MADELEINE: ExtractorName.CONCH, EncoderName.TITAN: ExtractorName.CONCH1_5, - # EncoderName.PRISM: ExtractorName.VIRCHOW_FULL, + EncoderName.PRISM: ExtractorName.VIRCHOW_FULL, } @pytest.mark.slow -@pytest.mark.parametrize("encoder", EncoderName) +@pytest.mark.parametrize("encoder", [EncoderName.PRISM]) @pytest.mark.filterwarnings("ignore:Importing from timm.models.layers is deprecated") @pytest.mark.filterwarnings( "ignore:You are using `torch.load` with `weights_only=False`" From 9d2090b4eafd13babed9746cf1444a9e1a64a426 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 21 Jul 2025 14:33:01 +0100 Subject: [PATCH 15/18] adapt tests for new config --- tests/test_config.py | 92 ++++++++++++++++++++++++-------------- tests/test_crossval.py | 30 ++++++++++--- tests/test_train_deploy.py | 38 ++++++++++++---- 3 files changed, 113 insertions(+), 47 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index c9a0c4fc..dafdd58c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,15 @@ from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig -from stamp.modeling.config import CrossvalConfig, DeploymentConfig, TrainConfig +from stamp.modeling.config import ( + AdvancedConfig, + CrossvalConfig, + DeploymentConfig, + MlpModelParams, + ModelParams, + TrainConfig, + VitModelParams, +) from stamp.preprocessing.config import ( ExtractorName, Microns, @@ -18,26 +26,18 @@ def test_config_parsing() -> None: config = StampConfig.model_validate( { "crossval": { - "accelerator": "gpu", - "bag_size": 512, - "batch_size": 64, "categories": None, "clini_table": "clini.xlsx", "feature_dir": "CRC", "filename_label": "FILENAME", "ground_truth_label": "isMSIH", - "max_epochs": 64, - "n_splits": 5, - "num_workers": 16, "output_dir": "test-crossval", - "patience": 16, "patient_label": "PATIENT", "slide_table": "slide.csv", - "use_alibi": True, "use_vary_precision_transform": False, + "n_splits": 5, }, "deployment": { - "accelerator": "gpu", "checkpoint_paths": [ "test-crossval/split-0/model.ckpt", "test-crossval/split-1/model.ckpt", @@ -49,7 +49,6 @@ def test_config_parsing() -> None: "feature_dir": "CRC", "filename_label": "FILENAME", "ground_truth_label": "isMSIH", - "num_workers": 16, "output_dir": "test-deploy", "patient_label": "PATIENT", "slide_table": "slide.csv", @@ -91,23 +90,39 @@ def test_config_parsing() -> None: "true_class": "MSIH", }, "training": { - "accelerator": "gpu", - "bag_size": 512, - "batch_size": 64, "categories": None, "clini_table": "clini.xlsx", "feature_dir": "CRC", "filename_label": "FILENAME", "ground_truth_label": "isMSIH", - "max_epochs": 64, - "num_workers": 16, "output_dir": "test-alibi", - "patience": 16, "patient_label": "PATIENT", "slide_table": "slide.csv", - "use_alibi": True, "use_vary_precision_transform": False, }, + "advanced_config": { + "bag_size": 512, + "num_workers": 16, + "batch_size": 64, + "max_epochs": 64, + "patience": 16, + "accelerator": "gpu", + "model_params": { + "vit": { + "dim_model": 512, + "dim_feedforward": 512, + "n_heads": 8, + "n_layers": 2, + "dropout": 0.25, + "use_alibi": True, + }, + "mlp": { + "dim_hidden": 512, + "num_layers": 2, + "dropout": 0.25, + }, + }, + }, } ) @@ -134,14 +149,8 @@ def test_config_parsing() -> None: categories=None, patient_label="PATIENT", filename_label="FILENAME", - bag_size=512, - num_workers=16, - batch_size=64, - max_epochs=64, - patience=16, - accelerator="gpu", + params_path=None, use_vary_precision_transform=False, - use_alibi=True, ), crossval=CrossvalConfig( output_dir=Path("test-crossval"), @@ -152,14 +161,8 @@ def test_config_parsing() -> None: categories=None, patient_label="PATIENT", filename_label="FILENAME", - bag_size=512, - num_workers=16, - batch_size=64, - max_epochs=64, - patience=16, - accelerator="gpu", + params_path=None, use_vary_precision_transform=False, - use_alibi=True, n_splits=5, ), deployment=DeploymentConfig( @@ -177,8 +180,6 @@ def test_config_parsing() -> None: ground_truth_label="isMSIH", patient_label="PATIENT", filename_label="FILENAME", - num_workers=16, - accelerator="gpu", ), statistics=StatsConfig( output_dir=Path("test-stats"), @@ -213,4 +214,27 @@ def test_config_parsing() -> None: bottomk=5, default_slide_mpp=SlideMPP(1.0), ), + advanced_config=AdvancedConfig( + bag_size=512, + num_workers=16, + batch_size=64, + max_epochs=64, + patience=16, + accelerator="gpu", + model_params=ModelParams( + vit=VitModelParams( + dim_model=512, + dim_feedforward=512, + n_heads=8, + n_layers=2, + dropout=0.25, + use_alibi=True, + ), + mlp=MlpModelParams( + dim_hidden=512, + num_layers=2, + dropout=0.25, + ), + ), + ), ) diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 394afd9a..342e39fc 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -7,6 +7,13 @@ import torch from random_data import create_random_dataset, create_random_patient_level_dataset +from stamp.modeling.config import ( + AdvancedConfig, + CrossvalConfig, + MlpModelParams, + ModelParams, + VitModelParams, +) from stamp.modeling.crossval import categorical_crossval_ @@ -53,15 +60,20 @@ def test_crossval_integration( output_dir = tmp_path / "output" - categorical_crossval_( + config = CrossvalConfig( clini_table=clini_path, slide_table=slide_path, - feature_dir=feature_dir, output_dir=output_dir, patient_label="patient", ground_truth_label="ground-truth", filename_label="slide_path", categories=categories, + feature_dir=feature_dir, + n_splits=2, + use_vary_precision_transform=use_vary_precision_transform, + ) + + advanced = AdvancedConfig( # Dataset and -loader parameters bag_size=max_tiles_per_slide // 2, num_workers=min(os.cpu_count() or 1, 7), @@ -70,8 +82,16 @@ def test_crossval_integration( max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", - n_splits=2, # Experimental features - use_vary_precision_transform=use_vary_precision_transform, - use_alibi=use_alibi, + model_params=ModelParams( + vit=VitModelParams( + use_alibi=use_alibi, + ), + mlp=MlpModelParams(), + ), + ) + + categorical_crossval_( + config=config, + advanced=advanced, ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 2e415856..03d48c48 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -7,6 +7,13 @@ import torch from random_data import create_random_dataset, create_random_patient_level_dataset +from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + TrainConfig, + VitModelParams, +) from stamp.modeling.deploy import deploy_categorical_model_ from stamp.modeling.train import train_categorical_model_ @@ -56,7 +63,7 @@ def test_train_deploy_integration( feat_dim=feat_dim, ) - train_categorical_model_( + config = TrainConfig( clini_table=train_clini_path, slide_table=train_slide_path, feature_dir=train_feature_dir, @@ -65,6 +72,10 @@ def test_train_deploy_integration( ground_truth_label="ground-truth", filename_label="slide_path", categories=categories, + use_vary_precision_transform=use_vary_precision_transform, + ) + + advanced = AdvancedConfig( # Dataset and -loader parameters bag_size=500, num_workers=min(os.cpu_count() or 1, 16), @@ -73,11 +84,13 @@ def test_train_deploy_integration( max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", - # Experimental features - use_vary_precision_transform=use_vary_precision_transform, - use_alibi=use_alibi, + model_params=ModelParams( + vit=VitModelParams(use_alibi=use_alibi), mlp=MlpModelParams() + ), ) + train_categorical_model_(config=config, advanced=advanced) + deploy_categorical_model_( output_dir=tmp_path / "deploy_output", checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], @@ -133,7 +146,7 @@ def test_train_deploy_patient_level_integration( ) ) - train_categorical_model_( + config = TrainConfig( clini_table=train_clini_path, slide_table=None, # Not needed for patient-level feature_dir=train_feature_dir, @@ -142,6 +155,10 @@ def test_train_deploy_patient_level_integration( ground_truth_label="ground-truth", filename_label="slide_path", # Not used for patient-level categories=categories, + use_vary_precision_transform=use_vary_precision_transform, + ) + + advanced = AdvancedConfig( # Dataset and -loader parameters bag_size=1, # Not used for patient-level, but required by signature num_workers=min(os.cpu_count() or 1, 16), @@ -150,9 +167,14 @@ def test_train_deploy_patient_level_integration( max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", - # Experimental features - use_vary_precision_transform=use_vary_precision_transform, - use_alibi=use_alibi, + model_params=ModelParams( + vit=VitModelParams(use_alibi=use_alibi), mlp=MlpModelParams() + ), + ) + + train_categorical_model_( + config=config, + advanced=advanced, ) deploy_categorical_model_( From 919cd29bcbed2833cbe1eed3f83ae877e86f0623 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 21 Jul 2025 15:18:22 +0100 Subject: [PATCH 16/18] remove test type --- tests/test_encoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index c740d923..21d87a02 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -44,7 +44,7 @@ @pytest.mark.slow -@pytest.mark.parametrize("encoder", [EncoderName.PRISM]) +@pytest.mark.parametrize("encoder", EncoderName) @pytest.mark.filterwarnings("ignore:Importing from timm.models.layers is deprecated") @pytest.mark.filterwarnings( "ignore:You are using `torch.load` with `weights_only=False`" From 8d7f1a0f78fb950547ce559594bc3602b25f2602 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 21 Jul 2025 15:34:59 +0100 Subject: [PATCH 17/18] remove ctranspath as chief supported extractor ctranspath is not explicitly declared in chief's paper if it can be used for tile-level feature extraction. cheif-ctranspath is now the default feature extractor so people can run all the pipeline, including encoding, without requesting any model access. --- src/stamp/config.yaml | 2 +- src/stamp/encoding/encoder/chief.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 83cf887e..e2e8eda2 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -6,7 +6,7 @@ preprocessing: # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", # "virchow-full", "musk", "mstar", "plip" # Some of them require requesting access to the respective authors beforehand. - extractor: "ctranspath" + extractor: "chief-ctranspath" # Device to run feature extraction on ("cpu", "cuda", "cuda:0", etc.) device: "cuda" diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index f174d42a..eaab9750 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -117,7 +117,6 @@ def __init__(self) -> None: precision=torch.float32, required_extractors=[ ExtractorName.CHIEF_CTRANSPATH, - ExtractorName.CTRANSPATH, ], ) From 961479a995ce18a8f2adcf1721fc003da92fe8b8 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 21 Jul 2025 15:48:21 +0100 Subject: [PATCH 18/18] update docs --- getting-started.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/getting-started.md b/getting-started.md index eeec0099..e12429da 100644 --- a/getting-started.md +++ b/getting-started.md @@ -157,6 +157,7 @@ meaning ignored that it was ignored during feature extraction. [COBRA2]: https://huggingface.co/KatherLab/COBRA [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine +[PRISM]: https://huggingface.co/paige-ai/Prism @@ -272,6 +273,7 @@ STAMP currently supports the following encoders: - [COBRA2] - [EAGLE] - [MADELEINE] +- [PRISM] Slide encoders take as input the already extracted tile-level features in the preprocessing step. Each encoder accepts only certain extractors and most @@ -279,12 +281,13 @@ work only on CUDA devices: | Encoder | Required Extractor | Compatible Devices | |--|--|--| -| CHIEF | CTRANSPATH, CHIEF-CTRANSPATH | CUDA only | +| CHIEF | CHIEF-CTRANSPATH | CUDA only | | TITAN | CONCH1.5 | CUDA, cpu, mps | GIGAPATH | GIGAPATH | CUDA only | COBRA2 | CONCH, UNI, VIRCHOW2 or H-OPTIMUS-0 | CUDA only | EAGLE | CTRANSPATH, CHIEF-CTRANSPATH | CUDA only | MADELEINE | CONCH | CUDA only +| PRISM | VIRCHOW_FULL | CUDA only As with feature extractors, most of these models require you to request @@ -388,8 +391,8 @@ stamp --config stamp-test-experiment/config.yaml crossval ``` The key differences for patient-level modeling are: -- The `feature_dir` should contain patient-level `.h5` files (one per patient) -- The `slide_table` is not needed since there's a direct mapping from patient ID to feature file -- STAMP will automatically detect that these are patient-level features and use a MultiLayer Perceptron (MLP) classifier instead of the Vision Transformer +- The `feature_dir` should contain patient-level `.h5` files (one per patient). +- The `slide_table` is not needed since there's a direct mapping from patient ID to feature file. +- STAMP will automatically detect that these are patient-level features and use a MultiLayer Perceptron (MLP) classifier instead of the Vision Transformer. You can then run statistics as done with tile-level features. \ No newline at end of file