diff --git a/ahcore/callbacks/abstract_writer_callback.py b/ahcore/callbacks/abstract_writer_callback.py index 3f50a90..d6cb099 100644 --- a/ahcore/callbacks/abstract_writer_callback.py +++ b/ahcore/callbacks/abstract_writer_callback.py @@ -179,8 +179,9 @@ def _on_epoch_start(self, trainer: "pl.Trainer") -> None: current_dataset: TiledWsiDataset assert self._total_dataset for current_dataset in self._total_dataset.datasets: # type: ignore - assert current_dataset.slide_image.identifier - self._dataset_sizes[current_dataset.slide_image.identifier] = len(current_dataset) + curr_filename = current_dataset._path + assert curr_filename + self._dataset_sizes[str(curr_filename)] = len(current_dataset) self._start_callback_workers() diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index e521868..3646fa2 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -14,7 +14,7 @@ from dlup.data.dataset import Dataset, TiledWsiDataset from torch.utils.data import DataLoader, DistributedSampler, Sampler -from ahcore.utils.data import DataDescription, basemodel_to_uuid +from ahcore.utils.data import DataDescription, OnTheFlyDataDescription, basemodel_to_uuid from ahcore.utils.io import fullname, get_cache_dir, get_logger from ahcore.utils.manifest import DataManager, datasets_from_data_description from ahcore.utils.types import DlupDatasetSample, _DlupDataset @@ -87,10 +87,12 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: ... + def __getitem__(self, index: int) -> DlupDatasetSample: + ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: + ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" @@ -109,7 +111,7 @@ class DlupDataModule(pl.LightningDataModule): def __init__( self, - data_description: DataDescription, + data_description: DataDescription | OnTheFlyDataDescription, pre_transform: Callable[[bool], Callable[[DlupDatasetSample], DlupDatasetSample]], batch_size: int = 32, # noqa,pylint: disable=unused-argument validate_batch_size: int | None = None, # noqa,pylint: disable=unused-argument @@ -122,8 +124,8 @@ def __init__( Parameters ---------- - data_description : DataDescription - See `ahcore.utils.data.DataDescription` for more information. + data_description : DataDescription | OnTheFlyDataDescription + See `ahcore.utils.data.DataDescription` and `ahcore.utils.data.DataDescription` for more information. pre_transform : Callable A pre-transform is a callable which is directly applied to the output of the dataset before collation in the dataloader. The transforms typically convert the image in the output to a tensor, convert the @@ -157,9 +159,9 @@ def __init__( ) # save all relevant hyperparams # Data settings - self.data_description: DataDescription = data_description + self.data_description = data_description - self._data_manager = DataManager(database_uri=data_description.manifest_database_uri) + self._data_manager = DataManager(data_description) self._batch_size = self.hparams.batch_size # type: ignore self._validate_batch_size = self.hparams.validate_batch_size # type: ignore diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index 39a5fa9..8f4265a 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -9,7 +9,12 @@ from typing import Dict, Optional, Tuple from pydantic import BaseModel +from sqlalchemy import create_engine, exists +from sqlalchemy.engine import Engine +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import Session, sessionmaker +from ahcore.utils.database_models import Base, Manifest, OnTheFlyBase from ahcore.utils.types import NonNegativeInt, PositiveFloat, PositiveInt @@ -42,6 +47,43 @@ def basemodel_to_uuid(base_model: BaseModel) -> uuid.UUID: return unique_id +def open_db_from_engine(engine: Engine) -> Session: + SessionLocal = sessionmaker(bind=engine) + return SessionLocal() + + +def open_db_from_uri( + uri: str, + ensure_exists: bool = True, +) -> Session: + """Open a database connection from a uri""" + + # Set up the engine if no engine is given and uri is given. + engine = create_engine(uri) + + if not ensure_exists: + # Create tables if they don't exist + create_tables(engine, base=Base) + else: + # Check if the "manifest" table exists + inspector = inspect(engine) + if "manifest" not in inspector.get_table_names(): + raise RuntimeError("Manifest table does not exist. Likely you have set the wrong URI.") + + # Check if the "manifest" table is not empty + with engine.connect() as connection: + result = connection.execute(exists().where(Manifest.id.isnot(None)).select()) + if not result.scalar(): + raise RuntimeError("Manifest table is empty. Likely you have set the wrong URI.") + + return open_db_from_engine(engine) + + +def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None: + """Create the database tables.""" + base.metadata.create_all(bind=engine) + + class GridDescription(BaseModel): mpp: Optional[PositiveFloat] tile_size: Tuple[PositiveInt, PositiveInt] @@ -67,3 +109,24 @@ class DataDescription(BaseModel): convert_mask_to_rois: bool = True use_roi: bool = True apply_color_profile: bool = False + + +class OnTheFlyDataDescription(BaseModel): + # Required + data_dir: Path + glob_pattern: str + num_classes: NonNegativeInt + inference_grid: GridDescription + + # Preset? + convert_mask_to_rois: bool = True + use_roi: bool = True + apply_color_profile: bool = False + + # Explicitly optional + annotations_dir: Optional[Path] = None # May be used to provde a mask. + mask_label: Optional[str] = None + mask_threshold: Optional[float] = None # This is only used for training + roi_name: Optional[str] = None + index_map: Optional[Dict[str, int]] + remap_labels: Optional[Dict[str, str]] = None diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 21538b7..90afbac 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -210,3 +210,28 @@ class Split(Base): split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits") __table_args__ = (UniqueConstraint("split_definition_id", "patient_id", name="uq_patient_split_key"),) + + +class OnTheFlyBase(DeclarativeBase): + """ + Base for creating an in-memory DB on-the-fly for, e.g., segmentation inference on a directory of WSIs. + """ + + pass + + +class MinimalImage(OnTheFlyBase): + """Minimal image table for an in-memory db for instant inference""" + + # TODO Link to annotations or masks + __tablename__ = "image" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True, nullable=False) + relative_filename = Column(String, unique=True, nullable=False) + reader = Column(String) + height = Column(Integer) + width = Column(Integer) + mpp = Column(Float) diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index 246897d..9ca6283 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -85,6 +85,7 @@ def print_config( config: DictConfig, fields: Sequence[str] = ( "trainer", + "data_description", "model", "experiment", "transforms", @@ -241,7 +242,14 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: return model else: # Load checkpoint weights - lit_ckpt = torch.load(config.ckpt_path) + accelerator = config.trainer.accelerator + if accelerator == "cpu": + map_location = "cpu" + elif accelerator == "gpu": + map_location = "cuda" + else: + raise ValueError(f"Accelerator must be either cpu or gpu, but config.trainer.accelerator={accelerator}") + lit_ckpt = torch.load(config.ckpt_path, map_location=map_location) model.load_state_dict(lit_ckpt["state_dict"], strict=True) return model diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 6760e80..102ca34 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -16,26 +16,23 @@ from dlup.experimental_backends import ImageBackend # type: ignore from dlup.tiling import GridOrder, TilingMode from pydantic import BaseModel -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.inspection import inspect -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.sql import exists +from sqlalchemy.orm import Session from ahcore.exceptions import RecordNotFoundError -from ahcore.utils.data import DataDescription +from ahcore.utils.data import DataDescription, OnTheFlyDataDescription, open_db_from_engine, open_db_from_uri from ahcore.utils.database_models import ( - Base, CategoryEnum, Image, ImageAnnotations, Manifest, Mask, + MinimalImage, Patient, Split, SplitDefinitions, ) from ahcore.utils.io import get_enum_key_from_value, get_logger +from ahcore.utils.on_the_fly_database_generation import get_populated_in_memory_db from ahcore.utils.rois import compute_rois from ahcore.utils.types import PositiveFloat, PositiveInt, Rois @@ -156,15 +153,30 @@ def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, st class DataManager: - def __init__(self, database_uri: str) -> None: - self._database_uri = database_uri + def __init__(self, data_description: DataDescription | OnTheFlyDataDescription) -> None: + self._data_description = data_description + self._on_the_fly = isinstance(data_description, OnTheFlyDataDescription) + self._image_table: type[Image] | type[MinimalImage] + if self._on_the_fly: + self._database_engine = get_populated_in_memory_db( + data_description.data_dir, glob_pattern=data_description.glob_pattern # type: ignore + ) + self._image_table = MinimalImage + else: + self._database_uri = data_description.manifest_database_uri # type: ignore + self._image_table = Image + self.__session: Optional[Session] = None self._logger = get_logger(type(self).__name__) @property def _session(self) -> Session: if self.__session is None: - self.__session = open_db(self._database_uri) + if self._on_the_fly: + self.__session = open_db_from_engine(engine=self._database_engine) + else: + self.__session = open_db_from_uri(uri=self._database_uri) + return self.__session @staticmethod @@ -173,13 +185,31 @@ def _ensure_record(record: Any, description: str) -> None: if not record: raise RecordNotFoundError(f"{description} not found.") + def get_all_images( + self, + ) -> Generator[MinimalImage, None, None]: + """ + Queries the db minimal image table for all the provided images. + + This is an ultimately minimal query that is only meaningful for a `MinimalImage` table, since + connections to other information about the images in a fully populated table would be lost + """ + assert ( + self._on_the_fly + ), "This function should only be called when doing on-the-fly inference with a MinimalImage table \ + from an OnTheFlyDataDescription." + images = self._session.query(self._image_table).all() + self._logger.info(f"Found {len(images)} images in {self._data_description.data_dir}") + for image in images: + yield image # type: ignore + def get_records_by_split( self, manifest_name: str, split_version: str, split_category: Optional[str] = None, ) -> Generator[Patient, None, None]: - manifest = self._session.query(Manifest).filter_by(name=manifest_name).first() + manifest = self._session.query(Manifest).filter_by(name=manifest_name).first() # type: ignore try: self._ensure_record(manifest, f"Manifest with name {manifest_name}") except RecordNotFoundError as e: @@ -320,6 +350,76 @@ def close(self) -> None: def datasets_from_data_description( + db_manager: DataManager, + data_description: DataDescription | OnTheFlyDataDescription, + transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, + stage: Optional[str] = None, +) -> Generator[TiledWsiDataset, None, None]: + """ + I think a factory is not necessary here. We simply generate it, and we can use the + class of the datadescription to decide how to do this. + + See the required parameters and typing in + `ahcore.utils.manifest.datasets_from_on_the_fly_data_description` or + `ahcore.utils.manifest.datasets_from_data_description_with_uri` + """ + # This is the same as checking if isinstance(data_description, DataDescription) vs + # isinstance(data_description, OnTheFlyDataDescription) + + if isinstance(data_description, DataDescription): + return datasets_from_data_description_with_uri( + db_manager=db_manager, + data_description=data_description, + transform=transform, + stage=stage, # type: ignore + ) + if isinstance(data_description, OnTheFlyDataDescription): + return datasets_from_on_the_fly_data_description( + db_manager=db_manager, + data_description=data_description, + transform=transform, + ) + else: + raise ValueError( + f"Can't create datasets with current config. engine={db_manager._database_uri}, \ + uri={db_manager._database_engine}" + ) + + +def datasets_from_on_the_fly_data_description( + db_manager: DataManager, + data_description: OnTheFlyDataDescription, + transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, +) -> Generator[TiledWsiDataset, None, None]: + # TODO Add masks for inference of, e.g., feature extraction + image_root = data_description.data_dir + grid_description = data_description.inference_grid + images = db_manager.get_all_images() + for image in images: + mask_threshold = 0.0 # Only set differently for `fit` in original method + dataset = TiledWsiDataset.from_standard_tiling( + path=image_root / image.filename, + mpp=grid_description.mpp, + tile_size=grid_description.tile_size, + tile_overlap=grid_description.tile_overlap, + tile_mode=TilingMode.overflow, + grid_order=GridOrder.C, + crop=False, + mask=None, + mask_threshold=mask_threshold, + output_tile_size=getattr(grid_description, "output_tile_size", None), + labels=None, # type: ignore + transform=transform, + backend=ImageBackend[str(image.reader)], + overwrite_mpp=(image.mpp, image.mpp), + limit_bounds=True, + apply_color_profile=data_description.apply_color_profile, + ) + + yield dataset + + +def datasets_from_data_description_with_uri( db_manager: DataManager, data_description: DataDescription, transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, @@ -389,47 +489,6 @@ class Config: mpp: PositiveFloat -def open_db(database_uri: str, ensure_exists: bool = True) -> Session: - """Open a database connection. - - Parameters - ---------- - database_uri : str - The URI of the database. - ensure_exists : bool, default=True - Whether to raise an exception of the database does not exist. - - Returns - ------- - Session - The database session. - """ - engine = create_engine(database_uri) - - if not ensure_exists: - # Create tables if they don't exist - create_tables(engine) - else: - # Check if the "manifest" table exists - inspector = inspect(engine) - if "manifest" not in inspector.get_table_names(): - raise RuntimeError("Manifest table does not exist. Likely you have set the wrong URI.") - - # Check if the "manifest" table is not empty - with engine.connect() as connection: - result = connection.execute(exists().where(Manifest.id.isnot(None)).select()) - if not result.scalar(): - raise RuntimeError("Manifest table is empty. Likely you have set the wrong URI.") - - SessionLocal = sessionmaker(bind=engine) - return SessionLocal() - - -def create_tables(engine: Engine) -> None: - """Create the database tables.""" - Base.metadata.create_all(bind=engine) - - def fetch_image_metadata(image: Image) -> ImageMetadata: """Extract metadata from an Image object.""" return ImageMetadata( diff --git a/ahcore/utils/on_the_fly_database_generation.py b/ahcore/utils/on_the_fly_database_generation.py new file mode 100644 index 0000000..2b4fa98 --- /dev/null +++ b/ahcore/utils/on_the_fly_database_generation.py @@ -0,0 +1,93 @@ +""" +Functions for generating an in-memory MinimalImage database on-the-fly with only an image root directory and glob +pattern. Used for inference of, e.g., a segmentation model on a directory filled with WSIs, without generating a +database explicitly. +""" + +from pathlib import Path + +import sqlalchemy +from dlup import SlideImage +from dlup.experimental_backends import ImageBackend +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from ahcore.utils.data import create_tables, open_db_from_engine +from ahcore.utils.database_models import MinimalImage, OnTheFlyBase + + +def populate_from_directory_and_glob_pattern(session: Session, image_folder: Path | str, glob_pattern: str) -> None: + """ + Populates the MinimalImage database in the in-memory session + with slides found in the image_folder using the + glob_pattern. Population happens in-place, so no return. + + Parameters + ---------- + session : Session + The opened session that connects to the DB engine + image_folder : str + The root directory of the images + glob_pattern : str + The glob pattern to find images within the root directory + + Returns + ------- + None + """ + for wsi in Path(image_folder).glob(glob_pattern): + with SlideImage.from_file_path( + image_folder / wsi, + backend=ImageBackend.PYVIPS, # type: ignore + ) as slide: # type: ignore + mpp = slide.mpp + width, height = slide.size + image = MinimalImage( + filename=str(wsi), + mpp=mpp, + height=height, + width=width, + reader="PYVIPS", + relative_filename=str(wsi.relative_to(image_folder)), + ) + session.add(image) + session.flush() # Flush so that Image ID is populated for future records + session.commit + + +def get_populated_in_memory_db(image_folder: Path, glob_pattern: str) -> Engine: + """ + Callable function to get the populated in-memory DB as an Engine + + Parameters + ---------- + image_folder : Path + The root directory of the images + glob_pattern : str + The glob pattern to find images within the root directory + + Returns + ------- + Engine + an in-memory sqlalchemy Engine + """ + assert image_folder.is_dir(), f"image_folder ({image_folder}) does not exist" + + assert ( + len([i for i in image_folder.glob(glob_pattern)]) > 0 + ), f"No images found in {image_folder} with glob pattern {glob_pattern}" + + # An empty URL will create a `:memory:` database + engine = sqlalchemy.create_engine("sqlite://") + create_tables(engine=engine, base=OnTheFlyBase) + + with open_db_from_engine(engine) as session: + # Populate the DB through the session. Happens in-place + populate_from_directory_and_glob_pattern(session, image_folder, glob_pattern) + + # Commit is required before passing the engine back. If not commited, the engine and + # session here will contain the information, But outside of this context it will be lost. + session.commit() + + # Return the engine object that is bound to the session, so we can close the session + return engine diff --git a/setup.py b/setup.py index 7810677..789766d 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,8 @@ "numpydoc", "myst-parser", "sphinx-book-theme", + "pre-commit", + "tox", ], }, license="Apache Software License 2.0", diff --git a/tests/test_in_memory_db/populate_minimal_db_for_inference.py b/tests/test_in_memory_db/populate_minimal_db_for_inference.py new file mode 100644 index 0000000..02bb985 --- /dev/null +++ b/tests/test_in_memory_db/populate_minimal_db_for_inference.py @@ -0,0 +1,22 @@ +""" +This is a test for in-memory on-the-fly generation of a minimal ahcore database using a dummy dataset +using tiny .svs files from openslide +""" + +from pathlib import Path + +from sqlalchemy.orm import sessionmaker + +from ahcore.utils.database_models import MinimalImage +from ahcore.utils.on_the_fly_database_generation import get_populated_in_memory_db + +if __name__ == "__main__": + image_folder = Path(__file__) / "test_in_memory_db" + glob_pattern = "**/*.svs" + engine = get_populated_in_memory_db(image_folder=image_folder, glob_pattern=glob_pattern) + + with sessionmaker(bind=engine)() as session: + table = session.query(MinimalImage).all() + + assert len(table) > 0, "The database was not populated" + assert len(table) == 3, f"The database has {len(table)} entires while there are only 3 test images" diff --git a/tests/test_in_memory_db/test_images/small_1.svs b/tests/test_in_memory_db/test_images/small_1.svs new file mode 100644 index 0000000..6d113e1 Binary files /dev/null and b/tests/test_in_memory_db/test_images/small_1.svs differ diff --git a/tests/test_in_memory_db/test_images/small_2.svs b/tests/test_in_memory_db/test_images/small_2.svs new file mode 100644 index 0000000..6d113e1 Binary files /dev/null and b/tests/test_in_memory_db/test_images/small_2.svs differ diff --git a/tests/test_in_memory_db/test_images/small_3.svs b/tests/test_in_memory_db/test_images/small_3.svs new file mode 100644 index 0000000..6d113e1 Binary files /dev/null and b/tests/test_in_memory_db/test_images/small_3.svs differ diff --git a/tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh b/tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh new file mode 100644 index 0000000..c80d0a6 --- /dev/null +++ b/tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh @@ -0,0 +1,61 @@ +# Script to run inference on small test WSIs +# Can easily adjusted by changing +# 1) the data_dir your directory containing WSIs +# 2) the glob_pattern to find the WSIs +# 3) Changing the mpp to 2.0 (which the model is trained on) and to get manageable output. + +# ===== Step 1 ===== +# If on CPU, set `trainer.accelerator=cpu` and set map_location to cpu in utils.io.load_weights +# If on MacOS, set multiprocessing.set_start_method("fork", force=True) before main() in inference.py. mps is not supported. + +# ===== Step 2 ===== +export SCRATCH=/set/to/your/log/dir + +# ===== Step 3 ===== +# Set path to model checkpoint +# Expected to be an attention unet with current configuration. +# ABSOLUTE_PATH_TO_CKPT=/absolute/path/to/checkpoint.ckpt + +# ===== Step 4 ===== +# Create ~/ahcore/additional_config/data_description/on_the_fly_data_description.yaml +# With the following contents + +#~/ahcore/additional_config/data_description/on_the_fly_data_description.yaml +# --------------------------------------------------------------------------------- # +# _target_: ahcore.utils.data.OnTheFlyDataDescription #TODO check if this works +# data_dir: test_in_memory_db # Root dir for images to be found +# glob_pattern: "**/*.svs" # Glob pattern relative to root dir to find images +# use_roi: False +# apply_color_profile: False +# inference_grid: +# mpp: 0.01 # Since the images are VERY small, we need to blow them up to be able to get proper input image size +# tile_size: [512, 512] +# tile_overlap: [0, 0] + +# num_classes: 2 + +# remap_labels: +# specimen: specimen + +# index_map: +# specimen: 1 + +# color_map: +# 0: "yellow" +# 1: "red" +# --------------------------------------------------------------------------------- # + + +# ===== Step 5 ===== +# Run this file from ~/ahcore/tools as CWD, with an installed ahcore environment. +python inference.py \ + callbacks=inference \ + data_description=on_the_fly_data_description \ + pre_transform=segmentation \ + augmentations=segmentation \ + lit_module=monai_segmentation/attention_unet \ + ckpt_path=$ABSOLUTE_PATH_TO_CKPT \ + # trainer.accelerator=cpu # If you're running on CPU + +# ===== Step 6 ===== +# Files should be saved in $SCRATCH / segmentation_inference / $datetime / outputs / AttentionUNet / 0_0 / *