From 4a524bcbc18d26cbad3bab8a98abd8149ff9793b Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Wed, 5 Jun 2024 16:38:44 +0200 Subject: [PATCH 1/8] feat: add on-the-fly inference --- ahcore/data/dataset.py | 18 ++- ahcore/utils/data.py | 21 +++ ahcore/utils/database_models.py | 25 +++ ahcore/utils/io.py | 3 +- ahcore/utils/manifest.py | 143 +++++++++++++++--- .../utils/on_the_fly_database_generation.py | 110 ++++++++++++++ .../populate_minimal_db_for_inference.py | 22 +++ .../test_in_memory_db/test_images/small_1.svs | Bin 0 -> 2651 bytes .../test_in_memory_db/test_images/small_2.svs | Bin 0 -> 2651 bytes .../test_in_memory_db/test_images/small_3.svs | Bin 0 -> 2651 bytes ...ence_with_on_the_fly_in_memory_database.sh | 61 ++++++++ 11 files changed, 373 insertions(+), 30 deletions(-) create mode 100644 ahcore/utils/on_the_fly_database_generation.py create mode 100644 tests/test_in_memory_db/populate_minimal_db_for_inference.py create mode 100644 tests/test_in_memory_db/test_images/small_1.svs create mode 100644 tests/test_in_memory_db/test_images/small_2.svs create mode 100644 tests/test_in_memory_db/test_images/small_3.svs create mode 100644 tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index e521868..3646fa2 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -14,7 +14,7 @@ from dlup.data.dataset import Dataset, TiledWsiDataset from torch.utils.data import DataLoader, DistributedSampler, Sampler -from ahcore.utils.data import DataDescription, basemodel_to_uuid +from ahcore.utils.data import DataDescription, OnTheFlyDataDescription, basemodel_to_uuid from ahcore.utils.io import fullname, get_cache_dir, get_logger from ahcore.utils.manifest import DataManager, datasets_from_data_description from ahcore.utils.types import DlupDatasetSample, _DlupDataset @@ -87,10 +87,12 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: ... + def __getitem__(self, index: int) -> DlupDatasetSample: + ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: + ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" @@ -109,7 +111,7 @@ class DlupDataModule(pl.LightningDataModule): def __init__( self, - data_description: DataDescription, + data_description: DataDescription | OnTheFlyDataDescription, pre_transform: Callable[[bool], Callable[[DlupDatasetSample], DlupDatasetSample]], batch_size: int = 32, # noqa,pylint: disable=unused-argument validate_batch_size: int | None = None, # noqa,pylint: disable=unused-argument @@ -122,8 +124,8 @@ def __init__( Parameters ---------- - data_description : DataDescription - See `ahcore.utils.data.DataDescription` for more information. + data_description : DataDescription | OnTheFlyDataDescription + See `ahcore.utils.data.DataDescription` and `ahcore.utils.data.DataDescription` for more information. pre_transform : Callable A pre-transform is a callable which is directly applied to the output of the dataset before collation in the dataloader. The transforms typically convert the image in the output to a tensor, convert the @@ -157,9 +159,9 @@ def __init__( ) # save all relevant hyperparams # Data settings - self.data_description: DataDescription = data_description + self.data_description = data_description - self._data_manager = DataManager(database_uri=data_description.manifest_database_uri) + self._data_manager = DataManager(data_description) self._batch_size = self.hparams.batch_size # type: ignore self._validate_batch_size = self.hparams.validate_batch_size # type: ignore diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index 39a5fa9..3c85cd0 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -67,3 +67,24 @@ class DataDescription(BaseModel): convert_mask_to_rois: bool = True use_roi: bool = True apply_color_profile: bool = False + + +class OnTheFlyDataDescription(BaseModel): + # Required + data_dir: Path + glob_pattern: str + num_classes: NonNegativeInt + inference_grid: GridDescription + + # Preset? + convert_mask_to_rois: bool = True + use_roi: bool = True + apply_color_profile: bool = False + + # Explicitly optional + annotations_dir: Optional[Path] = None # May be used to provde a mask. + mask_label: Optional[str] = None + mask_threshold: Optional[float] = None # This is only used for training + roi_name: Optional[str] = None + index_map: Optional[Dict[str, int]] + remap_labels: Optional[Dict[str, str]] = None diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 21538b7..90afbac 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -210,3 +210,28 @@ class Split(Base): split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits") __table_args__ = (UniqueConstraint("split_definition_id", "patient_id", name="uq_patient_split_key"),) + + +class OnTheFlyBase(DeclarativeBase): + """ + Base for creating an in-memory DB on-the-fly for, e.g., segmentation inference on a directory of WSIs. + """ + + pass + + +class MinimalImage(OnTheFlyBase): + """Minimal image table for an in-memory db for instant inference""" + + # TODO Link to annotations or masks + __tablename__ = "image" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True, nullable=False) + relative_filename = Column(String, unique=True, nullable=False) + reader = Column(String) + height = Column(Integer) + width = Column(Integer) + mpp = Column(Float) diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index 246897d..9f08ae1 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -85,6 +85,7 @@ def print_config( config: DictConfig, fields: Sequence[str] = ( "trainer", + "data_description", "model", "experiment", "transforms", @@ -241,7 +242,7 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: return model else: # Load checkpoint weights - lit_ckpt = torch.load(config.ckpt_path) + lit_ckpt = torch.load(config.ckpt_path, map_location="cpu") model.load_state_dict(lit_ckpt["state_dict"], strict=True) return model diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 6760e80..e34ee64 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -23,7 +23,7 @@ from sqlalchemy.sql import exists from ahcore.exceptions import RecordNotFoundError -from ahcore.utils.data import DataDescription +from ahcore.utils.data import DataDescription, OnTheFlyDataDescription from ahcore.utils.database_models import ( Base, CategoryEnum, @@ -31,11 +31,13 @@ ImageAnnotations, Manifest, Mask, + MinimalImage, Patient, Split, SplitDefinitions, ) from ahcore.utils.io import get_enum_key_from_value, get_logger +from ahcore.utils.on_the_fly_database_generation import get_populated_in_memory_db from ahcore.utils.rois import compute_rois from ahcore.utils.types import PositiveFloat, PositiveInt, Rois @@ -156,15 +158,30 @@ def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, st class DataManager: - def __init__(self, database_uri: str) -> None: - self._database_uri = database_uri + def __init__(self, data_description: DataDescription | OnTheFlyDataDescription) -> None: + self._data_description = data_description + self._on_the_fly = isinstance(data_description, OnTheFlyDataDescription) + self._image_table: type[Image] | type[MinimalImage] + if self._on_the_fly: + self._database_engine = get_populated_in_memory_db( + data_description.data_dir, glob_pattern=data_description.glob_pattern # type: ignore + ) + self._image_table = MinimalImage + else: + self._database_uri = data_description.manifest_database_uri # type: ignore + self._image_table = Image + self.__session: Optional[Session] = None self._logger = get_logger(type(self).__name__) @property def _session(self) -> Session: if self.__session is None: - self.__session = open_db(self._database_uri) + if self._on_the_fly: + self.__session = open_db_from_engine(engine=self._database_engine) + else: + self.__session = open_db_from_uri(uri=self._database_uri) + return self.__session @staticmethod @@ -173,13 +190,31 @@ def _ensure_record(record: Any, description: str) -> None: if not record: raise RecordNotFoundError(f"{description} not found.") + def get_all_images( + self, + ) -> Generator[MinimalImage, None, None]: + """ + Queries the db minimal image table for all the provided images. + + This is an ultimately minimal query that is only meaningful for a `MinimalImage` table, since + connections to other information about the images in a fully populated table would be lost + """ + assert ( + self._on_the_fly + ), "This function should only be called when doing on-the-fly inference with a MinimalImage table \ + from an OnTheFlyDataDescription." + images = self._session.query(self._image_table).all() + self._logger.info(f"Found {len(images)} images in {self._data_description.data_dir}") + for image in images: + yield image # type: ignore + def get_records_by_split( self, manifest_name: str, split_version: str, split_category: Optional[str] = None, ) -> Generator[Patient, None, None]: - manifest = self._session.query(Manifest).filter_by(name=manifest_name).first() + manifest = self._session.query(Manifest).filter_by(name=manifest_name).first() # type: ignore try: self._ensure_record(manifest, f"Manifest with name {manifest_name}") except RecordNotFoundError as e: @@ -320,6 +355,76 @@ def close(self) -> None: def datasets_from_data_description( + db_manager: DataManager, + data_description: DataDescription | OnTheFlyDataDescription, + transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, + stage: Optional[str] = None, +) -> Generator[TiledWsiDataset, None, None]: + """ + I think a factory is not necessary here. We simply generate it, and we can use the + class of the datadescription to decide how to do this. + + See the required parameters and typing in + `ahcore.utils.manifest.datasets_from_on_the_fly_data_description` or + `ahcore.utils.manifest.datasets_from_data_description_with_uri` + """ + # This is the same as checking if isinstance(data_description, DataDescription) vs + # isinstance(data_description, OnTheFlyDataDescription) + + if isinstance(data_description, DataDescription): + return datasets_from_data_description_with_uri( + db_manager=db_manager, + data_description=data_description, + transform=transform, + stage=stage, # type: ignore + ) + if isinstance(data_description, OnTheFlyDataDescription): + return datasets_from_on_the_fly_data_description( + db_manager=db_manager, + data_description=data_description, + transform=transform, + ) + else: + raise ValueError( + f"Can't create datasets with current config. engine={db_manager._database_uri}, \ + uri={db_manager._database_engine}" + ) + + +def datasets_from_on_the_fly_data_description( + db_manager: DataManager, + data_description: OnTheFlyDataDescription, + transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, +) -> Generator[TiledWsiDataset, None, None]: + # TODO Add masks for inference of, e.g., feature extraction + image_root = data_description.data_dir + grid_description = data_description.inference_grid + images = db_manager.get_all_images() + for image in images: + mask_threshold = 0.0 # Only set differently for `fit` in original method + dataset = TiledWsiDataset.from_standard_tiling( + path=image_root / image.filename, + mpp=grid_description.mpp, + tile_size=grid_description.tile_size, + tile_overlap=grid_description.tile_overlap, + tile_mode=TilingMode.overflow, + grid_order=GridOrder.C, + crop=False, + mask=None, + mask_threshold=mask_threshold, + output_tile_size=getattr(grid_description, "output_tile_size", None), + labels=None, # type: ignore + transform=transform, + backend=ImageBackend[str(image.reader)], + overwrite_mpp=(image.mpp, image.mpp), + limit_bounds=True, + apply_color_profile=data_description.apply_color_profile, + ) + + yield dataset + + +def datasets_from_data_description_with_uri( db_manager: DataManager, data_description: DataDescription, transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, @@ -389,22 +494,19 @@ class Config: mpp: PositiveFloat -def open_db(database_uri: str, ensure_exists: bool = True) -> Session: - """Open a database connection. +def open_db_from_engine(engine: Engine) -> Session: + SessionLocal = sessionmaker(bind=engine) + return SessionLocal() - Parameters - ---------- - database_uri : str - The URI of the database. - ensure_exists : bool, default=True - Whether to raise an exception of the database does not exist. - Returns - ------- - Session - The database session. - """ - engine = create_engine(database_uri) +def open_db_from_uri( + uri: str, + ensure_exists: bool = True, +) -> Session: + """Open a database connection from a uri""" + + # Set up the engine if no engine is given and uri is given. + engine = create_engine(uri) if not ensure_exists: # Create tables if they don't exist @@ -421,8 +523,7 @@ def open_db(database_uri: str, ensure_exists: bool = True) -> Session: if not result.scalar(): raise RuntimeError("Manifest table is empty. Likely you have set the wrong URI.") - SessionLocal = sessionmaker(bind=engine) - return SessionLocal() + return open_db_from_engine(engine) def create_tables(engine: Engine) -> None: diff --git a/ahcore/utils/on_the_fly_database_generation.py b/ahcore/utils/on_the_fly_database_generation.py new file mode 100644 index 0000000..500d129 --- /dev/null +++ b/ahcore/utils/on_the_fly_database_generation.py @@ -0,0 +1,110 @@ +""" +Functions for generating an in-memory MinimalImage database on-the-fly with only an image root directory and glob +pattern. Used for inference of, e.g., a segmentation model on a directory filled with WSIs, without generating a +database explicitly. +""" + +from pathlib import Path + +import sqlalchemy +from dlup import SlideImage +from dlup.experimental_backends import ImageBackend +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, sessionmaker + +from ahcore.utils.database_models import MinimalImage, OnTheFlyBase + + +def populate_from_directory_and_glob_pattern(session: Session, image_folder: Path | str, glob_pattern: str) -> None: + """ + Populates the MinimalImage database in the in-memory session + with slides found in the image_folder using the + glob_pattern. Population happens in-place, so no return. + + Parameters + ---------- + session : Session + The opened session that connects to the DB engine + image_folder : str + The root directory of the images + glob_pattern : str + The glob pattern to find images within the root directory + + Returns + ------- + None + """ + for wsi in Path(image_folder).glob(glob_pattern): + with SlideImage.from_file_path( + image_folder / wsi, + backend=ImageBackend.PYVIPS, # type: ignore + ) as slide: # type: ignore + mpp = slide.mpp + width, height = slide.size + image = MinimalImage( + filename=str(wsi), + mpp=mpp, + height=height, + width=width, + reader="PYVIPS", + relative_filename=str(wsi.relative_to(image_folder)), + ) + session.add(image) + session.flush() # Flush so that Image ID is populated for future records + session.commit + + +def open_session_for_engine(engine: sqlalchemy.Engine) -> Session: + """ + Creates an engine, populates it with the database model for a minimal image, + generates a session and returns this + + Parameters + ---------- + engine : Engine + The in-memory engine that is to be populated. + Returns + ------- + Session + The session bound to the engine to populate it with. + """ + OnTheFlyBase.metadata.create_all(bind=engine) + session = sessionmaker(bind=engine) + return session() + + +def get_populated_in_memory_db(image_folder: Path, glob_pattern: str) -> Engine: + """ + Callable function to get the populated in-memory DB as an Engine + + Parameters + ---------- + image_folder : str + The root directory of the images + glob_pattern : str + The glob pattern to find images within the root directory + + Returns + ------- + Engine + an in-memory sqlalchemy Engine + """ + assert image_folder.is_dir(), f"image_folder ({image_folder}) does not exist" + + assert ( + len([i for i in image_folder.glob(glob_pattern)]) > 0 + ), f"No images found in {image_folder} with glob pattern {glob_pattern}" + + # An empty URL will create a `:memory:` database + engine = sqlalchemy.create_engine("sqlite://") + + with open_session_for_engine(engine) as session: + # Populate the DB through the session. Happens in-place + populate_from_directory_and_glob_pattern(session, image_folder, glob_pattern) + + # Commit is required before passing the engine back. If not commited, the engine and + # session here will contain the information, But outside of this context it will be lost. + session.commit() + + # Return the engine object that is bound to the session, so we can close the session + return engine diff --git a/tests/test_in_memory_db/populate_minimal_db_for_inference.py b/tests/test_in_memory_db/populate_minimal_db_for_inference.py new file mode 100644 index 0000000..02bb985 --- /dev/null +++ b/tests/test_in_memory_db/populate_minimal_db_for_inference.py @@ -0,0 +1,22 @@ +""" +This is a test for in-memory on-the-fly generation of a minimal ahcore database using a dummy dataset +using tiny .svs files from openslide +""" + +from pathlib import Path + +from sqlalchemy.orm import sessionmaker + +from ahcore.utils.database_models import MinimalImage +from ahcore.utils.on_the_fly_database_generation import get_populated_in_memory_db + +if __name__ == "__main__": + image_folder = Path(__file__) / "test_in_memory_db" + glob_pattern = "**/*.svs" + engine = get_populated_in_memory_db(image_folder=image_folder, glob_pattern=glob_pattern) + + with sessionmaker(bind=engine)() as session: + table = session.query(MinimalImage).all() + + assert len(table) > 0, "The database was not populated" + assert len(table) == 3, f"The database has {len(table)} entires while there are only 3 test images" diff --git a/tests/test_in_memory_db/test_images/small_1.svs b/tests/test_in_memory_db/test_images/small_1.svs new file mode 100644 index 0000000000000000000000000000000000000000..6d113e1d5786e04e3557f1c81d86281df47b5162 GIT binary patch literal 2651 zcmebD)MAieWPpSJH~t@B5aeKRU~ph&5M*E!WMC3x_Z;(R9+&9+yXp#6L2ivk?9M?~zM3L*!i}H@ zBmVzQ1_6eD3@kt&LjfZ*m@NS0G9s~=fNWsEFx&$wWQK~f1Nm%FHZZ&x_!yZOGJuRH zKsBOJagZJvC>vyl6O`=$WIH3VT^J#50NLvbWwQe{L^HB8C<19876t|(CT38)05KE8 z(*}rokeOy|3=AA#=va_il$o#KnVXoNs^F8ERFqg$sZeHUqz8lwyj&(`1_lNd#zvNg zCJNDpmKGL07nQ;UG|NhXFS1_s6{x|XKNX1XR8Nd~%>#%7kf76yg} zsVSz0=1Ga>H35l5xy7j^K;@yqz6wCAN>Ynzd;%qpBv939)bwQa(@#jDq@PH)(A za@VrsGgqu$^8W~fJTRNFGJ*jxC4&JY6Eh1d8#@Ol7dKGBRsjZJ-eYEBVP<7z0cHZm zTA(}wiy*6zqM;+3a9|?4QlW@Z(iUwW$pkka<)WpdpCN3cY31zV>gMj@=@lFj8WtWA8I_!pnwFlCnN?g;T2@|BS=HRq+ScCD*)?hMl&RCE z&zL!D(c&dbmn~nha@D5ITefc7zGLUELx+zXJ$C%W$y1juU%7hi`i+~n9zJ^f}jBP9rNV4wdIB1yvyfKnd+if z_=3&1>gbhcx9&xHuqfQ^53stdlOG(gweOi>qgsXS#=uOgrmR^x<`a7=o0l(SpOxBr z>fWt2uMenC4Bq1PrnOViQR0o`*Okv2&Tm(L!WFVUjeT8KcyG#qItLcPSG-2I&gSwo z8E%z%qW`^as_fJ`Wx56}f-?Pk!#7PbD0~}u>-O61GdeijERQe0wEzE2ev(SZ55STL zU+L%!EWW-0)d)gMxLj7$l2Qy>BCY}^1aYXElR!3d2?{F_tvGR)h@f;tM2W}+F8364 z?H~o9g08}7VK-XX<)9UIu+ol;ot=Z7jf;bWgPV(sn@>W3kC&HER#Z$-LRnr_MM+*s mQBB)iUrp0QQ&Gvl(a^-w+Rn~SRnNuK*~Y`%*3NqPmUaN}3g}V* literal 0 HcmV?d00001 diff --git a/tests/test_in_memory_db/test_images/small_2.svs b/tests/test_in_memory_db/test_images/small_2.svs new file mode 100644 index 0000000000000000000000000000000000000000..6d113e1d5786e04e3557f1c81d86281df47b5162 GIT binary patch literal 2651 zcmebD)MAieWPpSJH~t@B5aeKRU~ph&5M*E!WMC3x_Z;(R9+&9+yXp#6L2ivk?9M?~zM3L*!i}H@ zBmVzQ1_6eD3@kt&LjfZ*m@NS0G9s~=fNWsEFx&$wWQK~f1Nm%FHZZ&x_!yZOGJuRH zKsBOJagZJvC>vyl6O`=$WIH3VT^J#50NLvbWwQe{L^HB8C<19876t|(CT38)05KE8 z(*}rokeOy|3=AA#=va_il$o#KnVXoNs^F8ERFqg$sZeHUqz8lwyj&(`1_lNd#zvNg zCJNDpmKGL07nQ;UG|NhXFS1_s6{x|XKNX1XR8Nd~%>#%7kf76yg} zsVSz0=1Ga>H35l5xy7j^K;@yqz6wCAN>Ynzd;%qpBv939)bwQa(@#jDq@PH)(A za@VrsGgqu$^8W~fJTRNFGJ*jxC4&JY6Eh1d8#@Ol7dKGBRsjZJ-eYEBVP<7z0cHZm zTA(}wiy*6zqM;+3a9|?4QlW@Z(iUwW$pkka<)WpdpCN3cY31zV>gMj@=@lFj8WtWA8I_!pnwFlCnN?g;T2@|BS=HRq+ScCD*)?hMl&RCE z&zL!D(c&dbmn~nha@D5ITefc7zGLUELx+zXJ$C%W$y1juU%7hi`i+~n9zJ^f}jBP9rNV4wdIB1yvyfKnd+if z_=3&1>gbhcx9&xHuqfQ^53stdlOG(gweOi>qgsXS#=uOgrmR^x<`a7=o0l(SpOxBr z>fWt2uMenC4Bq1PrnOViQR0o`*Okv2&Tm(L!WFVUjeT8KcyG#qItLcPSG-2I&gSwo z8E%z%qW`^as_fJ`Wx56}f-?Pk!#7PbD0~}u>-O61GdeijERQe0wEzE2ev(SZ55STL zU+L%!EWW-0)d)gMxLj7$l2Qy>BCY}^1aYXElR!3d2?{F_tvGR)h@f;tM2W}+F8364 z?H~o9g08}7VK-XX<)9UIu+ol;ot=Z7jf;bWgPV(sn@>W3kC&HER#Z$-LRnr_MM+*s mQBB)iUrp0QQ&Gvl(a^-w+Rn~SRnNuK*~Y`%*3NqPmUaN}3g}V* literal 0 HcmV?d00001 diff --git a/tests/test_in_memory_db/test_images/small_3.svs b/tests/test_in_memory_db/test_images/small_3.svs new file mode 100644 index 0000000000000000000000000000000000000000..6d113e1d5786e04e3557f1c81d86281df47b5162 GIT binary patch literal 2651 zcmebD)MAieWPpSJH~t@B5aeKRU~ph&5M*E!WMC3x_Z;(R9+&9+yXp#6L2ivk?9M?~zM3L*!i}H@ zBmVzQ1_6eD3@kt&LjfZ*m@NS0G9s~=fNWsEFx&$wWQK~f1Nm%FHZZ&x_!yZOGJuRH zKsBOJagZJvC>vyl6O`=$WIH3VT^J#50NLvbWwQe{L^HB8C<19876t|(CT38)05KE8 z(*}rokeOy|3=AA#=va_il$o#KnVXoNs^F8ERFqg$sZeHUqz8lwyj&(`1_lNd#zvNg zCJNDpmKGL07nQ;UG|NhXFS1_s6{x|XKNX1XR8Nd~%>#%7kf76yg} zsVSz0=1Ga>H35l5xy7j^K;@yqz6wCAN>Ynzd;%qpBv939)bwQa(@#jDq@PH)(A za@VrsGgqu$^8W~fJTRNFGJ*jxC4&JY6Eh1d8#@Ol7dKGBRsjZJ-eYEBVP<7z0cHZm zTA(}wiy*6zqM;+3a9|?4QlW@Z(iUwW$pkka<)WpdpCN3cY31zV>gMj@=@lFj8WtWA8I_!pnwFlCnN?g;T2@|BS=HRq+ScCD*)?hMl&RCE z&zL!D(c&dbmn~nha@D5ITefc7zGLUELx+zXJ$C%W$y1juU%7hi`i+~n9zJ^f}jBP9rNV4wdIB1yvyfKnd+if z_=3&1>gbhcx9&xHuqfQ^53stdlOG(gweOi>qgsXS#=uOgrmR^x<`a7=o0l(SpOxBr z>fWt2uMenC4Bq1PrnOViQR0o`*Okv2&Tm(L!WFVUjeT8KcyG#qItLcPSG-2I&gSwo z8E%z%qW`^as_fJ`Wx56}f-?Pk!#7PbD0~}u>-O61GdeijERQe0wEzE2ev(SZ55STL zU+L%!EWW-0)d)gMxLj7$l2Qy>BCY}^1aYXElR!3d2?{F_tvGR)h@f;tM2W}+F8364 z?H~o9g08}7VK-XX<)9UIu+ol;ot=Z7jf;bWgPV(sn@>W3kC&HER#Z$-LRnr_MM+*s mQBB)iUrp0QQ&Gvl(a^-w+Rn~SRnNuK*~Y`%*3NqPmUaN}3g}V* literal 0 HcmV?d00001 diff --git a/tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh b/tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh new file mode 100644 index 0000000..c80d0a6 --- /dev/null +++ b/tests/test_run_segmentation_inference_with_on_the_fly_in_memory_database.sh @@ -0,0 +1,61 @@ +# Script to run inference on small test WSIs +# Can easily adjusted by changing +# 1) the data_dir your directory containing WSIs +# 2) the glob_pattern to find the WSIs +# 3) Changing the mpp to 2.0 (which the model is trained on) and to get manageable output. + +# ===== Step 1 ===== +# If on CPU, set `trainer.accelerator=cpu` and set map_location to cpu in utils.io.load_weights +# If on MacOS, set multiprocessing.set_start_method("fork", force=True) before main() in inference.py. mps is not supported. + +# ===== Step 2 ===== +export SCRATCH=/set/to/your/log/dir + +# ===== Step 3 ===== +# Set path to model checkpoint +# Expected to be an attention unet with current configuration. +# ABSOLUTE_PATH_TO_CKPT=/absolute/path/to/checkpoint.ckpt + +# ===== Step 4 ===== +# Create ~/ahcore/additional_config/data_description/on_the_fly_data_description.yaml +# With the following contents + +#~/ahcore/additional_config/data_description/on_the_fly_data_description.yaml +# --------------------------------------------------------------------------------- # +# _target_: ahcore.utils.data.OnTheFlyDataDescription #TODO check if this works +# data_dir: test_in_memory_db # Root dir for images to be found +# glob_pattern: "**/*.svs" # Glob pattern relative to root dir to find images +# use_roi: False +# apply_color_profile: False +# inference_grid: +# mpp: 0.01 # Since the images are VERY small, we need to blow them up to be able to get proper input image size +# tile_size: [512, 512] +# tile_overlap: [0, 0] + +# num_classes: 2 + +# remap_labels: +# specimen: specimen + +# index_map: +# specimen: 1 + +# color_map: +# 0: "yellow" +# 1: "red" +# --------------------------------------------------------------------------------- # + + +# ===== Step 5 ===== +# Run this file from ~/ahcore/tools as CWD, with an installed ahcore environment. +python inference.py \ + callbacks=inference \ + data_description=on_the_fly_data_description \ + pre_transform=segmentation \ + augmentations=segmentation \ + lit_module=monai_segmentation/attention_unet \ + ckpt_path=$ABSOLUTE_PATH_TO_CKPT \ + # trainer.accelerator=cpu # If you're running on CPU + +# ===== Step 6 ===== +# Files should be saved in $SCRATCH / segmentation_inference / $datetime / outputs / AttentionUNet / 0_0 / * From 9361c7da2d0e0c2cd2d2f2564f0ead1dd83cf793 Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Thu, 6 Jun 2024 11:38:12 +0200 Subject: [PATCH 2/8] add tox and pre-commit to dev install --- ahcore/utils/io.py | 11 +++++++++-- setup.py | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index 9f08ae1..b1645f2 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -26,7 +26,7 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.errors import InterpolationKeyError from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_only # type: ignore[attr-defined] +from pytorch_lightning.utilities import rank_zero_only from ahcore.models.jit_model import AhcoreJitModel @@ -242,7 +242,14 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: return model else: # Load checkpoint weights - lit_ckpt = torch.load(config.ckpt_path, map_location="cpu") + accelerator = config.trainer.accelerator + if accelerator == "cpu": + map_location = "cpu" + elif accelerator == "gpu": + map_location = "cuda" + else: + raise ValueError(f"Accelerator must be either cpu or gpu, but config.trainer.accelerator={accelerator}") + lit_ckpt = torch.load(config.ckpt_path, map_location=map_location) model.load_state_dict(lit_ckpt["state_dict"], strict=True) return model diff --git a/setup.py b/setup.py index 7810677..789766d 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,8 @@ "numpydoc", "myst-parser", "sphinx-book-theme", + "pre-commit", + "tox", ], }, license="Apache Software License 2.0", From 7e7e1492e06ce7cb42c6733fd82d7abe6850124f Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Thu, 6 Jun 2024 11:39:26 +0200 Subject: [PATCH 3/8] re-add type ignore since this does not solve mypy issue --- ahcore/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index b1645f2..9ca6283 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -26,7 +26,7 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.errors import InterpolationKeyError from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import rank_zero_only # type: ignore[attr-defined] from ahcore.models.jit_model import AhcoreJitModel From e9b0c0e7a8033746dc6e621b1480dc19aee7f619 Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Thu, 6 Jun 2024 11:50:28 +0200 Subject: [PATCH 4/8] reduce code redundancy for table creation --- ahcore/utils/manifest.py | 7 +++--- .../utils/on_the_fly_database_generation.py | 25 +++---------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index e34ee64..51b7191 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -32,6 +32,7 @@ Manifest, Mask, MinimalImage, + OnTheFlyBase, Patient, Split, SplitDefinitions, @@ -510,7 +511,7 @@ def open_db_from_uri( if not ensure_exists: # Create tables if they don't exist - create_tables(engine) + create_tables(engine, base=Base) else: # Check if the "manifest" table exists inspector = inspect(engine) @@ -526,9 +527,9 @@ def open_db_from_uri( return open_db_from_engine(engine) -def create_tables(engine: Engine) -> None: +def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None: """Create the database tables.""" - Base.metadata.create_all(bind=engine) + base.metadata.create_all(bind=engine) def fetch_image_metadata(image: Image) -> ImageMetadata: diff --git a/ahcore/utils/on_the_fly_database_generation.py b/ahcore/utils/on_the_fly_database_generation.py index 500d129..1a34c26 100644 --- a/ahcore/utils/on_the_fly_database_generation.py +++ b/ahcore/utils/on_the_fly_database_generation.py @@ -10,9 +10,10 @@ from dlup import SlideImage from dlup.experimental_backends import ImageBackend from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from ahcore.utils.database_models import MinimalImage, OnTheFlyBase +from ahcore.utils.manifest import create_tables, open_db_from_engine def populate_from_directory_and_glob_pattern(session: Session, image_folder: Path | str, glob_pattern: str) -> None: @@ -54,25 +55,6 @@ def populate_from_directory_and_glob_pattern(session: Session, image_folder: Pat session.commit -def open_session_for_engine(engine: sqlalchemy.Engine) -> Session: - """ - Creates an engine, populates it with the database model for a minimal image, - generates a session and returns this - - Parameters - ---------- - engine : Engine - The in-memory engine that is to be populated. - Returns - ------- - Session - The session bound to the engine to populate it with. - """ - OnTheFlyBase.metadata.create_all(bind=engine) - session = sessionmaker(bind=engine) - return session() - - def get_populated_in_memory_db(image_folder: Path, glob_pattern: str) -> Engine: """ Callable function to get the populated in-memory DB as an Engine @@ -97,8 +79,9 @@ def get_populated_in_memory_db(image_folder: Path, glob_pattern: str) -> Engine: # An empty URL will create a `:memory:` database engine = sqlalchemy.create_engine("sqlite://") + create_tables(engine=engine, base=OnTheFlyBase) - with open_session_for_engine(engine) as session: + with open_db_from_engine(engine) as session: # Populate the DB through the session. Happens in-place populate_from_directory_and_glob_pattern(session, image_folder, glob_pattern) From 8d5340bf6524af98f7479b1b0089ea471c969b9e Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Thu, 6 Jun 2024 11:51:23 +0200 Subject: [PATCH 5/8] fix docstring type in get_populated_in_memory_db --- ahcore/utils/on_the_fly_database_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ahcore/utils/on_the_fly_database_generation.py b/ahcore/utils/on_the_fly_database_generation.py index 1a34c26..f2d2141 100644 --- a/ahcore/utils/on_the_fly_database_generation.py +++ b/ahcore/utils/on_the_fly_database_generation.py @@ -61,7 +61,7 @@ def get_populated_in_memory_db(image_folder: Path, glob_pattern: str) -> Engine: Parameters ---------- - image_folder : str + image_folder : Path The root directory of the images glob_pattern : str The glob pattern to find images within the root directory From b1de6dad63ce092571b92a39c4d8a18bca0346bd Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Thu, 6 Jun 2024 15:37:37 +0200 Subject: [PATCH 6/8] fix circular import error --- ahcore/utils/manifest.py | 2 +- ahcore/utils/on_the_fly_database_generation.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 51b7191..3031736 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -38,7 +38,7 @@ SplitDefinitions, ) from ahcore.utils.io import get_enum_key_from_value, get_logger -from ahcore.utils.on_the_fly_database_generation import get_populated_in_memory_db +from ahcore.utils.on_the_fly_database_generation import create_tables, get_populated_in_memory_db from ahcore.utils.rois import compute_rois from ahcore.utils.types import PositiveFloat, PositiveInt, Rois diff --git a/ahcore/utils/on_the_fly_database_generation.py b/ahcore/utils/on_the_fly_database_generation.py index f2d2141..6fe029a 100644 --- a/ahcore/utils/on_the_fly_database_generation.py +++ b/ahcore/utils/on_the_fly_database_generation.py @@ -12,8 +12,13 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import Session -from ahcore.utils.database_models import MinimalImage, OnTheFlyBase -from ahcore.utils.manifest import create_tables, open_db_from_engine +from ahcore.utils.database_models import Base, MinimalImage, OnTheFlyBase +from ahcore.utils.manifest import open_db_from_engine + + +def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None: + """Create the database tables.""" + base.metadata.create_all(bind=engine) def populate_from_directory_and_glob_pattern(session: Session, image_folder: Path | str, glob_pattern: str) -> None: From e68b5b2c3c898eb74ea822968842831dd29cf996 Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Fri, 7 Jun 2024 12:43:10 +0200 Subject: [PATCH 7/8] refactor for circular imports --- ahcore/utils/data.py | 42 ++++++++++++++++ ahcore/utils/manifest.py | 49 ++----------------- .../utils/on_the_fly_database_generation.py | 9 +--- 3 files changed, 47 insertions(+), 53 deletions(-) diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index 3c85cd0..8f4265a 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -9,7 +9,12 @@ from typing import Dict, Optional, Tuple from pydantic import BaseModel +from sqlalchemy import create_engine, exists +from sqlalchemy.engine import Engine +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import Session, sessionmaker +from ahcore.utils.database_models import Base, Manifest, OnTheFlyBase from ahcore.utils.types import NonNegativeInt, PositiveFloat, PositiveInt @@ -42,6 +47,43 @@ def basemodel_to_uuid(base_model: BaseModel) -> uuid.UUID: return unique_id +def open_db_from_engine(engine: Engine) -> Session: + SessionLocal = sessionmaker(bind=engine) + return SessionLocal() + + +def open_db_from_uri( + uri: str, + ensure_exists: bool = True, +) -> Session: + """Open a database connection from a uri""" + + # Set up the engine if no engine is given and uri is given. + engine = create_engine(uri) + + if not ensure_exists: + # Create tables if they don't exist + create_tables(engine, base=Base) + else: + # Check if the "manifest" table exists + inspector = inspect(engine) + if "manifest" not in inspector.get_table_names(): + raise RuntimeError("Manifest table does not exist. Likely you have set the wrong URI.") + + # Check if the "manifest" table is not empty + with engine.connect() as connection: + result = connection.execute(exists().where(Manifest.id.isnot(None)).select()) + if not result.scalar(): + raise RuntimeError("Manifest table is empty. Likely you have set the wrong URI.") + + return open_db_from_engine(engine) + + +def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None: + """Create the database tables.""" + base.metadata.create_all(bind=engine) + + class GridDescription(BaseModel): mpp: Optional[PositiveFloat] tile_size: Tuple[PositiveInt, PositiveInt] diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 3031736..102ca34 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -16,29 +16,23 @@ from dlup.experimental_backends import ImageBackend # type: ignore from dlup.tiling import GridOrder, TilingMode from pydantic import BaseModel -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.inspection import inspect -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.sql import exists +from sqlalchemy.orm import Session from ahcore.exceptions import RecordNotFoundError -from ahcore.utils.data import DataDescription, OnTheFlyDataDescription +from ahcore.utils.data import DataDescription, OnTheFlyDataDescription, open_db_from_engine, open_db_from_uri from ahcore.utils.database_models import ( - Base, CategoryEnum, Image, ImageAnnotations, Manifest, Mask, MinimalImage, - OnTheFlyBase, Patient, Split, SplitDefinitions, ) from ahcore.utils.io import get_enum_key_from_value, get_logger -from ahcore.utils.on_the_fly_database_generation import create_tables, get_populated_in_memory_db +from ahcore.utils.on_the_fly_database_generation import get_populated_in_memory_db from ahcore.utils.rois import compute_rois from ahcore.utils.types import PositiveFloat, PositiveInt, Rois @@ -495,43 +489,6 @@ class Config: mpp: PositiveFloat -def open_db_from_engine(engine: Engine) -> Session: - SessionLocal = sessionmaker(bind=engine) - return SessionLocal() - - -def open_db_from_uri( - uri: str, - ensure_exists: bool = True, -) -> Session: - """Open a database connection from a uri""" - - # Set up the engine if no engine is given and uri is given. - engine = create_engine(uri) - - if not ensure_exists: - # Create tables if they don't exist - create_tables(engine, base=Base) - else: - # Check if the "manifest" table exists - inspector = inspect(engine) - if "manifest" not in inspector.get_table_names(): - raise RuntimeError("Manifest table does not exist. Likely you have set the wrong URI.") - - # Check if the "manifest" table is not empty - with engine.connect() as connection: - result = connection.execute(exists().where(Manifest.id.isnot(None)).select()) - if not result.scalar(): - raise RuntimeError("Manifest table is empty. Likely you have set the wrong URI.") - - return open_db_from_engine(engine) - - -def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None: - """Create the database tables.""" - base.metadata.create_all(bind=engine) - - def fetch_image_metadata(image: Image) -> ImageMetadata: """Extract metadata from an Image object.""" return ImageMetadata( diff --git a/ahcore/utils/on_the_fly_database_generation.py b/ahcore/utils/on_the_fly_database_generation.py index 6fe029a..2b4fa98 100644 --- a/ahcore/utils/on_the_fly_database_generation.py +++ b/ahcore/utils/on_the_fly_database_generation.py @@ -12,13 +12,8 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import Session -from ahcore.utils.database_models import Base, MinimalImage, OnTheFlyBase -from ahcore.utils.manifest import open_db_from_engine - - -def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None: - """Create the database tables.""" - base.metadata.create_all(bind=engine) +from ahcore.utils.data import create_tables, open_db_from_engine +from ahcore.utils.database_models import MinimalImage, OnTheFlyBase def populate_from_directory_and_glob_pattern(session: Session, image_folder: Path | str, glob_pattern: str) -> None: From b1f747e0c8e44b52eb88cb819e9bc4e53db97684 Mon Sep 17 00:00:00 2001 From: Yoni Schirris Date: Wed, 12 Jun 2024 12:03:49 +0200 Subject: [PATCH 8/8] reduce setup time of abstractwritercallback --- ahcore/callbacks/abstract_writer_callback.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ahcore/callbacks/abstract_writer_callback.py b/ahcore/callbacks/abstract_writer_callback.py index 3f50a90..d6cb099 100644 --- a/ahcore/callbacks/abstract_writer_callback.py +++ b/ahcore/callbacks/abstract_writer_callback.py @@ -179,8 +179,9 @@ def _on_epoch_start(self, trainer: "pl.Trainer") -> None: current_dataset: TiledWsiDataset assert self._total_dataset for current_dataset in self._total_dataset.datasets: # type: ignore - assert current_dataset.slide_image.identifier - self._dataset_sizes[current_dataset.slide_image.identifier] = len(current_dataset) + curr_filename = current_dataset._path + assert curr_filename + self._dataset_sizes[str(curr_filename)] = len(current_dataset) self._start_callback_workers()