From 3ef9098b8d5b384e53d80d6846b05373487ab588 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 27 Oct 2025 14:14:53 +0100 Subject: [PATCH] fix warning raster format mismatch ome-zarr --- src/spatialdata/_io/format.py | 8 +++++++- src/spatialdata/_io/io_zarr.py | 31 ++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 9d639a0a5..cce6654ef 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -91,7 +91,9 @@ def validate_coordinate_transformations( import json json0 = [json.dumps(t) for t in transformations] - from spatialdata.transformations.ngff.ngff_transformations import NgffBaseTransformation + from spatialdata.transformations.ngff.ngff_transformations import ( + NgffBaseTransformation, + ) parsed = [NgffBaseTransformation.from_dict(t) for t in transformations] json1 = [json.dumps(p.to_dict()) for p in parsed] @@ -298,6 +300,10 @@ def spatialdata_format_version(self) -> str: "0.4-dev-spatialdata": FormatV04(), "0.5-dev-spatialdata": FormatV05(), } +sdata_zarr_version_to_raster_format: dict[str, FormatV04 | FormatV05] = { + fmt.version: fmt # type: ignore[attr-defined] + for fmt in [RasterFormatV01(), RasterFormatV02(), RasterFormatV03()] +} RasterFormats: dict[str, RasterFormatType] = { "0.1": RasterFormatV01(), "0.2": RasterFormatV02(), diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 6a6569f62..ea459b953 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -35,7 +35,7 @@ def _read_zarr_group_spatialdata_element( read_func: Callable[..., Any], group_name: Literal["images", "labels", "shapes", "points", "tables"], element_type: Literal["image", "labels", "shapes", "points", "tables"], - element_container: dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], + element_container: (dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData]), on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], ) -> None: with handle_read_errors( @@ -80,7 +80,11 @@ def _read_zarr_group_spatialdata_element( logger.debug(f"Found {count} elements in {group}") -def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format: +def get_raster_format_for_read( + group: zarr.Group, + sdata_version: Literal["0.1", "0.2"], + return_ome_zarr_format: bool = False, +) -> Format: """Get raster format of stored raster data. This checks the image or label element zarr group metadata to retrieve the format that is used by @@ -92,18 +96,27 @@ def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", The zarr group of the raster element to be read. sdata_version The version of the SpatialData zarr store retrieved from the spatialdata attributes. + return_ome_zarr_format + Whether to return the ome-zarr Format or a SpatialData raster format class (which is a subclass of Format) Returns ------- The ome-zarr format to use for reading the raster element. """ - from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format + from spatialdata._io.format import ( + sdata_zarr_version_to_ome_zarr_format, + sdata_zarr_version_to_raster_format, + ) if sdata_version == "0.1": group_version = group.metadata.attributes["multiscales"][0]["version"] - if sdata_version == "0.2": + elif sdata_version == "0.2": group_version = group.metadata.attributes["ome"]["version"] - return sdata_zarr_version_to_ome_zarr_format[group_version] + else: + raise ValueError(f"Unknown SpatialData zarr version {sdata_version}") + if return_ome_zarr_format: + return sdata_zarr_version_to_ome_zarr_format[group_version] + return sdata_zarr_version_to_raster_format[group_version] def read_zarr( @@ -146,7 +159,7 @@ def read_zarr( if sdata_version == "0.1": warnings.warn( "SpatialData is not stored in the most current format. If you want to use Zarr v3" - ", please write the store to a new location.", + ", please write the store to a new location using `sdata.write()`.", UserWarning, stacklevel=2, ) @@ -177,7 +190,11 @@ def read_zarr( "shapes": (_read_shapes, "shapes", shapes), "tables": (_read_table, "tables", tables), } - for group_name, (read_func, element_type, element_container) in group_readers.items(): + for group_name, ( + read_func, + element_type, + element_container, + ) in group_readers.items(): _read_zarr_group_spatialdata_element( root_group=root_group, root_store_path=root_store_path,