Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ as well as `.jpg`s showing from which parts of the slide features are extracted.
Most of the background should be marked in red,
meaning ignored that it was ignored during feature extraction.

> In case you want to use a gated model (e.g. Virchow2), you need to login in your console using:
> ```
>huggingface-cli login
> ```
> More info about this [here](https://huggingface.co/docs/huggingface_hub/en/guides/cli).

> **If you are using the UNI or CONCH models**
> and working in an environment where your home directory storage is limited,
> you may want to also specify your huggingface storage directory
Expand Down Expand Up @@ -151,6 +157,7 @@ meaning ignored that it was ignored during feature extraction.
[COBRA2]: https://huggingface.co/KatherLab/COBRA
[EAGLE]: https://github.com/KatherLab/EAGLE
[MADELEINE]: https://huggingface.co/MahmoodLab/madeleine
[PRISM]: https://huggingface.co/paige-ai/Prism



Expand Down Expand Up @@ -266,19 +273,21 @@ STAMP currently supports the following encoders:
- [COBRA2]
- [EAGLE]
- [MADELEINE]
- [PRISM]

Slide encoders take as input the already extracted tile-level features in the
preprocessing step. Each encoder accepts only certain extractors and most
work only on CUDA devices:

| Encoder | Required Extractor | Compatible Devices |
|--|--|--|
| CHIEF | CTRANSPATH, CHIEF-CTRANSPATH | CUDA only |
| CHIEF | CHIEF-CTRANSPATH | CUDA only |
| TITAN | CONCH1.5 | CUDA, cpu, mps
| GIGAPATH | GIGAPATH | CUDA only
| COBRA2 | CONCH, UNI, VIRCHOW2 or H-OPTIMUS-0 | CUDA only
| EAGLE | CTRANSPATH, CHIEF-CTRANSPATH | CUDA only
| MADELEINE | CONCH | CUDA only
| PRISM | VIRCHOW_FULL | CUDA only


As with feature extractors, most of these models require you to request
Expand Down Expand Up @@ -363,4 +372,27 @@ patient_encoding:
stamp --config stamp-test-experiment/config.yaml encode_patients
```

The output `.h5` features will have the patient's id as name.
The output `.h5` features will have the patient's id as name.

## Training with Patient-Level Features

Once you have patient-level features,
you can train models directly on these features. This is useful because:
- **Efficient with Limited Data**: Patient-level modeling often performs better when data is scarce, since pretrained encoders can extract robust features from each slide as a whole.
- **Faster Training & Reduced Overfitting**: With fewer parameters to train compared to tile-level models, patient-level models train more quickly and are less prone to overfitting.
- **Enables Interpretable Cohort Analysis**: Patient-level features can be used for unsupervised analyses, such as clustering, making it easier to interpret and explore patient subgroups within your cohort.
Comment thread
EzicStar marked this conversation as resolved.

> **Note:** Slide-level features are not supported for modeling because the ground truth
> labels in the clinical table are at the patient level.

To train a model using patient-level features, you can use the same command as before:
```sh
stamp --config stamp-test-experiment/config.yaml crossval
```

The key differences for patient-level modeling are:
- The `feature_dir` should contain patient-level `.h5` files (one per patient).
- The `slide_table` is not needed since there's a direct mapping from patient ID to feature file.
- STAMP will automatically detect that these are patient-level features and use a MultiLayer Perceptron (MLP) classifier instead of the Vision Transformer.

You can then run statistics as done with tile-level features.
63 changes: 23 additions & 40 deletions src/stamp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import yaml

from stamp.config import StampConfig
from stamp.modeling.config import (
AdvancedConfig,
MlpModelParams,
ModelParams,
VitModelParams,
)

STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml")

Expand Down Expand Up @@ -126,32 +132,20 @@ def _run_cli(args: argparse.Namespace) -> None:
if config.training is None:
raise ValueError("no training configuration supplied")

# use default advanced config in case none is provided
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams())
)

_add_file_handle_(_logger, output_dir=config.training.output_dir)
_logger.info(
"using the following configuration:\n"
f"{yaml.dump(config.training.model_dump(mode='json'))}"
)
# We pass every parameter explicitly so our type checker can do its work.

train_categorical_model_(
output_dir=config.training.output_dir,
clini_table=config.training.clini_table,
slide_table=config.training.slide_table,
feature_dir=config.training.feature_dir,
patient_label=config.training.patient_label,
ground_truth_label=config.training.ground_truth_label,
filename_label=config.training.filename_label,
categories=config.training.categories,
# Dataset and -loader parameters
bag_size=config.training.bag_size,
num_workers=config.training.num_workers,
# Training paramenters
batch_size=config.training.batch_size,
max_epochs=config.training.max_epochs,
patience=config.training.patience,
accelerator=config.training.accelerator,
# Experimental features
use_vary_precision_transform=config.training.use_vary_precision_transform,
use_alibi=config.training.use_alibi,
config=config.training, advanced=config.advanced_config
)

case "deploy":
Expand Down Expand Up @@ -189,27 +183,16 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.crossval.model_dump(mode='json'))}"
)

# use default advanced config in case none is provided
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams())
)

categorical_crossval_(
output_dir=config.crossval.output_dir,
clini_table=config.crossval.clini_table,
slide_table=config.crossval.slide_table,
feature_dir=config.crossval.feature_dir,
patient_label=config.crossval.patient_label,
ground_truth_label=config.crossval.ground_truth_label,
filename_label=config.crossval.filename_label,
categories=config.crossval.categories,
n_splits=config.crossval.n_splits,
# Dataset and -loader parameters
bag_size=config.crossval.bag_size,
num_workers=config.crossval.num_workers,
# Crossval paramenters
batch_size=config.crossval.batch_size,
max_epochs=config.crossval.max_epochs,
patience=config.crossval.patience,
accelerator=config.crossval.accelerator,
# Experimental Features
use_vary_precision_transform=config.crossval.use_vary_precision_transform,
use_alibi=config.crossval.use_alibi,
config=config.crossval,
advanced=config.advanced_config,
)

case "statistics":
Expand Down
9 changes: 8 additions & 1 deletion src/stamp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from stamp.encoding.config import PatientEncodingConfig, SlideEncodingConfig
from stamp.heatmaps.config import HeatmapConfig
from stamp.modeling.config import CrossvalConfig, DeploymentConfig, TrainConfig
from stamp.modeling.config import (
AdvancedConfig,
CrossvalConfig,
DeploymentConfig,
TrainConfig,
)
from stamp.preprocessing.config import PreprocessingConfig
from stamp.statistics import StatsConfig

Expand All @@ -23,3 +28,5 @@ class StampConfig(BaseModel):
slide_encoding: SlideEncodingConfig | None = None

patient_encoding: PatientEncodingConfig | None = None

advanced_config: AdvancedConfig | None = None
52 changes: 40 additions & 12 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ preprocessing:
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
# "virchow-full", "musk", "mstar", "plip"
# Some of them require requesting access to the respective authors beforehand.
extractor: "ctranspath"
extractor: "chief-ctranspath"

# Device to run feature extraction on ("cpu", "cuda", "cuda:0", etc.)
device: "cuda"
Expand Down Expand Up @@ -79,16 +79,14 @@ crossval:
# Number of folds to split the data into for cross-validation
#n_splits: 5

# Experimental features:
# Path to a YAML file with advanced training parameters.
#params_path: "path/to/train_params.yaml"

# Please try uncommenting the settings below
# and report if they improve / reduce model performance!
# Experimental features:

# Change the precision of features during training
#use_vary_precision_transform: true

# Use ALiBi positional embedding
# use_alibi: true


training:
Expand Down Expand Up @@ -126,17 +124,14 @@ training:
# If unspecified, they will be inferred from the table itself.
#categories: ["mutated", "wild type"]

# Experimental features:
# Path to a YAML file with advanced training parameters.
#params_path: "path/to/model_params.yaml"

# Please try uncommenting the settings below
# and report if they improve / reduce model performance!
# Experimental features:

# Change the precision of features during training
#use_vary_precision_transform: true

# Use ALiBi positional embedding
# use_alibi: true


deployment:
output_dir: "/path/to/save/files/to"
Expand Down Expand Up @@ -272,3 +267,36 @@ patient_encoding:

# Add a hash of the entire preprocessing codebase in the feature folder name.
#generate_hash: True


advanced_config:
max_epochs: 64
patience: 16
batch_size: 64
# Only for tile-level training. Reducing its amount could affect
# model performance. Reduces memory consumption. Default value works
# fine for most cases.
bag_size: 512
# Optional parameters
#num_workers: 16 # Default chosen by cpu cores

# Select a model. Not working yet, added for future support.
# Now it uses a ViT for tile features and a MLP for patient features.
#model_name: "vit"

model_params:
# Tile-level training models:
vit: # Vision Transformer
dim_model: 512
dim_feedforward: 512
n_heads: 8
n_layers: 2
dropout: 0.25
# Experimental feature: Use ALiBi positional embedding
use_alibi: false

# Patient-level training models:
mlp: # Multilayer Perceptron
dim_hidden: 512
num_layers: 2
dropout: 0.25
10 changes: 10 additions & 0 deletions src/stamp/encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def init_slide_encoder_(

selected_encoder: Encoder = Madeleine()

case EncoderName.PRISM:
from stamp.encoding.encoder.prism import Prism

selected_encoder: Encoder = Prism()

case Encoder():
selected_encoder = encoder

Expand Down Expand Up @@ -145,6 +150,11 @@ def init_patient_encoder_(

selected_encoder: Encoder = Madeleine()

case EncoderName.PRISM:
from stamp.encoding.encoder.prism import Prism

selected_encoder: Encoder = Prism()

case Encoder():
selected_encoder = encoder

Expand Down
3 changes: 1 addition & 2 deletions src/stamp/encoding/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class EncoderName(StrEnum):
TITAN = "titan"
GIGAPATH = "gigapath"
MADELEINE = "madeleine"
# PRISM = "paigeai-prism"
# waiting for paige-ai authors to fix it
PRISM = "prism"


class SlideEncodingConfig(BaseModel, arbitrary_types_allowed=True):
Expand Down
22 changes: 12 additions & 10 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

import h5py
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from tqdm import tqdm

import stamp
from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.modeling.data import CoordsInfo, get_coords
from stamp.modeling.data import CoordsInfo, get_coords, read_table
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel

Expand Down Expand Up @@ -83,7 +82,9 @@ def encode_slides_(
slide_embedding = self._generate_slide_embedding(
feats, device, coords=coords
)
self._save_features_(output_path=output_path, feats=slide_embedding)
self._save_features_(
output_path=output_path, feats=slide_embedding, feat_type="slide"
)

def encode_patients_(
self,
Expand Down Expand Up @@ -113,7 +114,7 @@ def encode_patients_(
if self.precision == torch.float16:
self.model.half()

slide_table = self._read_slide_table(slide_table_path)
slide_table = read_table(slide_table_path)
patient_groups = slide_table.groupby(patient_label)

for patient_id, group in (progress := tqdm(patient_groups)):
Expand Down Expand Up @@ -142,7 +143,9 @@ def encode_patients_(
patient_embedding = self._generate_patient_embedding(
feats_list, device, **kwargs
)
self._save_features_(output_path=output_path, feats=patient_embedding)
self._save_features_(
output_path=output_path, feats=patient_embedding, feat_type="patient"
)

@abstractmethod
def _generate_slide_embedding(
Expand All @@ -161,10 +164,6 @@ def _generate_patient_embedding(
"""Generate patient embedding. Must be implemented by subclasses."""
pass

@staticmethod
def _read_slide_table(slide_table_path: Path) -> pd.DataFrame:
return pd.read_csv(slide_table_path)

def _validate_and_read_features(self, h5_path: str) -> tuple[Tensor, CoordsInfo]:
feats, coords, extractor = self._read_h5(h5_path)
if extractor not in self.required_extractors:
Expand Down Expand Up @@ -192,7 +191,9 @@ def _read_h5(
)
return feats, coords, extractor

def _save_features_(self, output_path: Path, feats: np.ndarray) -> None:
def _save_features_(
self, output_path: Path, feats: np.ndarray, feat_type: str
) -> None:
with (
NamedTemporaryFile(dir=output_path.parent, delete=False) as tmp_h5_file,
h5py.File(tmp_h5_file, "w") as f,
Expand All @@ -204,6 +205,7 @@ def _save_features_(self, output_path: Path, feats: np.ndarray) -> None:
f.attrs["precision"] = str(self.precision)
f.attrs["stamp_version"] = stamp.__version__
f.attrs["code_hash"] = get_processing_code_hash(Path(__file__))[:8]
f.attrs["feat_type"] = feat_type
# TODO: Add more metadata like tile-level extractor name
# and maybe tile size in pixels and microns
except Exception:
Expand Down
Loading