Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions slide2vec/configs/default_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ model:
tile_size: ${tiling.params.tile_size}
restrict_to_tissue: false # whether to restrict tile content to tissue pixels only when feeding tile through encoder
patch_size: 256 # if level is "region", size used to unroll the region into patches
token_size: 16 # size of the tokens used model is a custom pretrained ViT
save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide"
save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism')

Expand Down
2 changes: 1 addition & 1 deletion slide2vec/hs2p
Submodule hs2p updated 1 files
+23 −5 hs2p/tiling.py
40 changes: 21 additions & 19 deletions slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from timm.data.transforms_factory import create_transform

from conch.open_clip_custom import create_model_from_pretrained
from musk import modeling as musk_modeling
from musk import utils as musk_utils

import slide2vec.distributed as distributed
Expand Down Expand Up @@ -70,11 +69,12 @@ def __init__(
pretrained_weights=options.pretrained_weights,
input_size=options.tile_size,
)
elif options.name is None and options.arch:
elif options.name == "dino" and options.arch:
model = DINOViT(
arch=options.arch,
pretrained_weights=options.pretrained_weights,
input_size=options.tile_size,
patch_size=options.token_size,
)
elif options.level == "region":
if options.name == "virchow":
Expand Down Expand Up @@ -259,7 +259,17 @@ def __init__(
def load_weights(self):
if distributed.is_main_process():
print(f"Loading pretrained weights from: {self.pretrained_weights}")
state_dict = torch.load(self.pretrained_weights, map_location="cpu")

# Fix for loading checkpoints saved with numpy 2.0+ in an environment with numpy < 2.0
try:
import numpy._core
except ImportError:
import numpy as np
import sys
sys.modules["numpy._core"] = np.core
sys.modules["numpy._core.multiarray"] = np.core.multiarray

state_dict = torch.load(self.pretrained_weights, map_location="cpu", weights_only=False)
if self.ckpt_key:
state_dict = state_dict[self.ckpt_key]
nn.modules.utils.consume_prefix_in_state_dict_if_present(
Expand All @@ -282,21 +292,13 @@ def build_encoder(self):
return encoder

def get_transforms(self):
if self.input_size > 224:
transform = transforms.Compose(
[
MaybeToTensor(),
transforms.CenterCrop(224),
make_normalize_transform(),
]
)
else:
transforms.Compose(
[
MaybeToTensor(),
make_normalize_transform(),
]
)
transform = transforms.Compose(
[
MaybeToTensor(),
transforms.CenterCrop(self.input_size),
make_normalize_transform(),
]
)
return transform

def forward(self, x):
Expand Down Expand Up @@ -344,7 +346,7 @@ def __init__(
def load_weights(self):
if distributed.is_main_process():
print(f"Loading pretrained weights from: {self.pretrained_weights}")
state_dict = torch.load(self.pretrained_weights, map_location="cpu")
state_dict = torch.load(self.pretrained_weights, map_location="cpu", weights_only=False)
if self.ckpt_key:
state_dict = state_dict[self.ckpt_key]
nn.modules.utils.consume_prefix_in_state_dict_if_present(
Expand Down
1 change: 0 additions & 1 deletion slide2vec/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
initialize_wandb,
fix_random_seeds,
get_sha,
load_csv,
update_state_dict,
)
from .log_utils import setup_logging
15 changes: 0 additions & 15 deletions slide2vec/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,6 @@ def initialize_wandb(
return run


def load_csv(cfg):
df = pd.read_csv(cfg.csv)
if "wsi_path" in df.columns:
wsi_paths = [Path(x) for x in df.wsi_path.values.tolist()]
elif "slide_path" in df.columns:
wsi_paths = [Path(x) for x in df.slide_path.values.tolist()]
if "mask_path" in df.columns:
mask_paths = [Path(x) for x in df.mask_path.values.tolist()]
elif "segmentation_mask_path" in df.columns:
mask_paths = [Path(x) for x in df.segmentation_mask_path.values.tolist()]
else:
mask_paths = [None for _ in wsi_paths]
return wsi_paths, mask_paths


def update_state_dict(
*,
model_dict: dict,
Expand Down
Loading