diff --git a/.gitignore b/.gitignore index b24c1b5..f815d0e 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ cython_debug/ # custom output/ +outputs/ archive/ diff --git a/slide2vec/__init__.py b/slide2vec/__init__.py index 8c0d5d5..1e97492 100644 --- a/slide2vec/__init__.py +++ b/slide2vec/__init__.py @@ -1 +1,6 @@ __version__ = "2.0.0" + +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "hs2p")) diff --git a/slide2vec/configs/default.yaml b/slide2vec/configs/default.yaml index 0b772d2..82fab7e 100644 --- a/slide2vec/configs/default.yaml +++ b/slide2vec/configs/default.yaml @@ -28,7 +28,7 @@ tiling: use_otsu: false # use otsu's method instead of simple binary thresholding tissue_pixel_value: 1 # value of tissue pixel in pre-computed segmentation masks filter_params: - ref_tile_size: 16 # reference tile size at spacing tiling.spacing + ref_tile_size: ${tiling.params.tile_size} # reference tile size at spacing tiling.spacing a_t: 4 # area filter threshold for tissue (positive integer, the minimum size of detected foreground contours to consider, relative to the reference tile size ref_tile_size, e.g. a value 10 means only detected foreground contours of size greater than 10 [ref_tile_size, ref_tile_size] tiles at spacing tiling.spacing will be kept) a_h: 2 # area filter threshold for holes (positive integer, the minimum size of detected holes/cavities in foreground contours to avoid, once again relative to the reference tile size ref_tile_size) max_n_holes: 8 # maximum of holes to consider per detected foreground contours (positive integer, higher values lead to more accurate patching but increase computational cost ; keeps the biggest holes) @@ -43,6 +43,7 @@ model: pretrained_weights: # path to the pretrained weights when using a custom model batch_size: 256 tile_size: ${tiling.params.tile_size} + restrict_to_tissue: false # whether to restrict tile content to tissue pixels only when feeding tile through encoder patch_size: 256 # if level is "region", size used to unroll the region into patches save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide" save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') diff --git a/slide2vec/data/dataset.py b/slide2vec/data/dataset.py index 3a2f034..6e2b741 100644 --- a/slide2vec/data/dataset.py +++ b/slide2vec/data/dataset.py @@ -1,3 +1,4 @@ +import cv2 import torch import numpy as np import wholeslidedata as wsd @@ -5,27 +6,67 @@ from transformers.image_processing_utils import BaseImageProcessor from PIL import Image from pathlib import Path +from typing import Callable + +from slide2vec.hs2p.hs2p.wsi import WholeSlideImage, SegmentationParameters, SamplingParameters, FilterParameters +from slide2vec.hs2p.hs2p.wsi.utils import HasEnoughTissue class TileDataset(torch.utils.data.Dataset): - def __init__(self, wsi_path, tile_dir, target_spacing, backend, transforms=None): + def __init__( + self, + wsi_path: Path, + mask_path: Path, + coordinates_dir: Path, + target_spacing: float, + tolerance: float, + backend: str, + segment_params: SegmentationParameters | None = None, + sampling_params: SamplingParameters | None = None, + filter_params: FilterParameters | None = None, + transforms: BaseImageProcessor | Callable | None = None, + restrict_to_tissue: bool = False, + ): self.path = wsi_path + self.mask_path = mask_path self.target_spacing = target_spacing self.backend = backend self.name = wsi_path.stem.replace(" ", "_") - self.load_coordinates(tile_dir) + self.load_coordinates(coordinates_dir) self.transforms = transforms + self.restrict_to_tissue = restrict_to_tissue + + if restrict_to_tissue: + _wsi = WholeSlideImage( + path=self.path, + mask_path=self.mask_path, + backend=self.backend, + segment_params=segment_params, + sampling_params=sampling_params, + ) + contours, holes = _wsi.detect_contours( + target_spacing=target_spacing, + tolerance=tolerance, + filter_params=filter_params, + ) + scale = _wsi.level_downsamples[_wsi.seg_level] + self.contours = _wsi.scaleContourDim(contours, (1.0 / scale[0], 1.0 / scale[1])) + self.holes = _wsi.scaleHolesDim(holes, (1.0 / scale[0], 1.0 / scale[1])) + self.tissue_mask = _wsi.annotation_mask["tissue"] + self.seg_spacing = _wsi.get_level_spacing(_wsi.seg_level) + self.spacing_at_level_0 = _wsi.get_level_spacing(0) - def load_coordinates(self, tile_dir): - coordinates = np.load(Path(tile_dir, f"{self.name}.npy"), allow_pickle=True) + def load_coordinates(self, coordinates_dir): + coordinates = np.load(Path(coordinates_dir, f"{self.name}.npy"), allow_pickle=True) self.x = coordinates["x"] self.y = coordinates["y"] self.coordinates = (np.array([self.x, self.y]).T).astype(int) self.scaled_coordinates = self.scale_coordinates() + self.contour_index = coordinates["contour_index"] + self.target_tile_size = coordinates["target_tile_size"] self.tile_level = coordinates["tile_level"] + self.resize_factor = coordinates["resize_factor"] self.tile_size_resized = coordinates["tile_size_resized"] - resize_factor = coordinates["resize_factor"] - self.tile_size = np.round(self.tile_size_resized / resize_factor).astype(int) self.tile_size_lv0 = coordinates["tile_size_lv0"][0] def scale_coordinates(self): @@ -55,11 +96,30 @@ def __getitem__(self, idx): spacing=tile_spacing, center=False, ) + if self.restrict_to_tissue: + contour_idx = self.contour_index[idx] + contour = self.contours[contour_idx] + holes = self.holes[contour_idx] + tissue_checker = HasEnoughTissue( + contour=contour, + contour_holes=holes, + tissue_mask=self.tissue_mask, + tile_size=self.target_tile_size[idx], + tile_spacing=tile_spacing, + resize_factor=self.resize_factor[idx], + seg_spacing=self.seg_spacing, + spacing_at_level_0=self.spacing_at_level_0, + ) + tissue_mask = tissue_checker.get_tile_mask(self.x[idx], self.y[idx]) + # ensure mask is the same size as the tile + assert tissue_mask.shape[:2] == tile_arr.shape[:2], "Mask and tile shapes do not match" + # apply mask + tile_arr = cv2.bitwise_and(tile_arr, tile_arr, mask=tissue_mask) tile = Image.fromarray(tile_arr).convert("RGB") - if self.tile_size[idx] != self.tile_size_resized[idx]: - tile = tile.resize((self.tile_size[idx], self.tile_size[idx])) + if self.target_tile_size[idx] != self.tile_size_resized[idx]: + tile = tile.resize((self.target_tile_size[idx], self.target_tile_size[idx])) if self.transforms: - if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`) + if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`) tile = self.transforms(tile, return_tensors="pt")["pixel_values"].squeeze(0) else: # general callable such as torchvision transforms tile = self.transforms(tile) diff --git a/slide2vec/embed.py b/slide2vec/embed.py index 06fb022..4c5d951 100644 --- a/slide2vec/embed.py +++ b/slide2vec/embed.py @@ -18,6 +18,7 @@ from slide2vec.utils.config import get_cfg_from_file, setup_distributed from slide2vec.models import ModelFactory from slide2vec.data import TileDataset, RegionUnfolding +from slide2vec.hs2p.hs2p.wsi import SamplingParameters torchvision.disable_beta_transforms_warning() @@ -60,13 +61,31 @@ def create_transforms(cfg, model): raise ValueError(f"Unknown model level: {cfg.model.level}") -def create_dataset(wsi_fp, coordinates_dir, spacing, backend, transforms): +def create_dataset( + wsi_path, + mask_path, + coordinates_dir, + target_spacing, + tolerance, + backend, + segment_params, + sampling_params, + filter_params, + transforms, + restrict_to_tissue: bool, +): return TileDataset( - wsi_fp, - coordinates_dir, - spacing, + wsi_path=wsi_path, + mask_path=mask_path, + coordinates_dir=coordinates_dir, + target_spacing=target_spacing, + tolerance=tolerance, backend=backend, + segment_params=segment_params, + sampling_params=sampling_params, + filter_params=filter_params, transforms=transforms, + restrict_to_tissue=restrict_to_tissue, ) @@ -176,12 +195,30 @@ def main(args): if not run_on_cpu: torch.distributed.barrier() + pixel_mapping = {k: v for e in cfg.tiling.sampling_params.pixel_mapping for k, v in e.items()} + tissue_percentage = {k: v for e in cfg.tiling.sampling_params.tissue_percentage for k, v in e.items()} + if "tissue" not in tissue_percentage: + tissue_percentage["tissue"] = cfg.tiling.params.min_tissue_percentage + if cfg.tiling.sampling_params.color_mapping is not None: + color_mapping = {k: v for e in cfg.tiling.sampling_params.color_mapping for k, v in e.items()} + else: + color_mapping = None + + sampling_params = SamplingParameters( + pixel_mapping=pixel_mapping, + color_mapping=color_mapping, + tissue_percentage=tissue_percentage, + ) + # select slides that were successfully tiled but not yet processed for feature extraction tiled_df = process_df[process_df.tiling_status == "success"] mask = tiled_df["feature_status"] != "success" process_stack = tiled_df[mask] total = len(process_stack) + wsi_paths_to_process = [Path(x) for x in process_stack.wsi_path.values.tolist()] + mask_paths_to_process = [Path(x) for x in process_stack.mask_path.values.tolist()] + combined_paths = zip(wsi_paths_to_process, mask_paths_to_process) features_dir = Path(cfg.output_dir, "features") if distributed.is_main_process(): @@ -201,8 +238,8 @@ def main(args): transforms = create_transforms(cfg, model) print(f"transforms: {transforms}") - for wsi_fp in tqdm.tqdm( - wsi_paths_to_process, + for wsi_fp, mask_fp in tqdm.tqdm( + combined_paths, desc="Inference", unit="slide", total=total, @@ -211,7 +248,19 @@ def main(args): position=1, ): try: - dataset = create_dataset(wsi_fp, coordinates_dir, cfg.tiling.params.spacing, cfg.tiling.backend, transforms) + dataset = create_dataset( + wsi_path=wsi_fp, + mask_path=mask_fp, + coordinates_dir=coordinates_dir, + target_spacing=cfg.tiling.params.spacing, + tolerance=cfg.tiling.params.tolerance, + backend=cfg.tiling.backend, + segment_params=cfg.tiling.seg_params, + sampling_params=sampling_params, + filter_params=cfg.tiling.filter_params, + transforms=transforms, + restrict_to_tissue=cfg.model.restrict_to_tissue, + ) if distributed.is_enabled_and_multiple_gpus(): sampler = torch.utils.data.DistributedSampler( dataset, diff --git a/test/gt/test-wsi.npy b/test/gt/test-wsi.npy index c68ca54..d3eadd4 100644 Binary files a/test/gt/test-wsi.npy and b/test/gt/test-wsi.npy differ