Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/spatialdata/_io/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def validate_coordinate_transformations(
import json

json0 = [json.dumps(t) for t in transformations]
from spatialdata.transformations.ngff.ngff_transformations import NgffBaseTransformation
from spatialdata.transformations.ngff.ngff_transformations import (
NgffBaseTransformation,
)

parsed = [NgffBaseTransformation.from_dict(t) for t in transformations]
json1 = [json.dumps(p.to_dict()) for p in parsed]
Expand Down Expand Up @@ -298,6 +300,10 @@ def spatialdata_format_version(self) -> str:
"0.4-dev-spatialdata": FormatV04(),
"0.5-dev-spatialdata": FormatV05(),
}
sdata_zarr_version_to_raster_format: dict[str, FormatV04 | FormatV05] = {
fmt.version: fmt # type: ignore[attr-defined]
for fmt in [RasterFormatV01(), RasterFormatV02(), RasterFormatV03()]
}
RasterFormats: dict[str, RasterFormatType] = {
"0.1": RasterFormatV01(),
"0.2": RasterFormatV02(),
Expand Down
31 changes: 24 additions & 7 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _read_zarr_group_spatialdata_element(
read_func: Callable[..., Any],
group_name: Literal["images", "labels", "shapes", "points", "tables"],
element_type: Literal["image", "labels", "shapes", "points", "tables"],
element_container: dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData],
element_container: (dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData]),
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN],
) -> None:
with handle_read_errors(
Expand Down Expand Up @@ -80,7 +80,11 @@ def _read_zarr_group_spatialdata_element(
logger.debug(f"Found {count} elements in {group}")


def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format:
def get_raster_format_for_read(
group: zarr.Group,
sdata_version: Literal["0.1", "0.2"],
return_ome_zarr_format: bool = False,
) -> Format:
"""Get raster format of stored raster data.

This checks the image or label element zarr group metadata to retrieve the format that is used by
Expand All @@ -92,18 +96,27 @@ def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1",
The zarr group of the raster element to be read.
sdata_version
The version of the SpatialData zarr store retrieved from the spatialdata attributes.
return_ome_zarr_format
Whether to return the ome-zarr Format or a SpatialData raster format class (which is a subclass of Format)

Returns
-------
The ome-zarr format to use for reading the raster element.
"""
from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format
from spatialdata._io.format import (
sdata_zarr_version_to_ome_zarr_format,
sdata_zarr_version_to_raster_format,
)

if sdata_version == "0.1":
group_version = group.metadata.attributes["multiscales"][0]["version"]
if sdata_version == "0.2":
elif sdata_version == "0.2":
group_version = group.metadata.attributes["ome"]["version"]
return sdata_zarr_version_to_ome_zarr_format[group_version]
else:
raise ValueError(f"Unknown SpatialData zarr version {sdata_version}")
if return_ome_zarr_format:
return sdata_zarr_version_to_ome_zarr_format[group_version]
return sdata_zarr_version_to_raster_format[group_version]


def read_zarr(
Expand Down Expand Up @@ -146,7 +159,7 @@ def read_zarr(
if sdata_version == "0.1":
warnings.warn(
"SpatialData is not stored in the most current format. If you want to use Zarr v3"
", please write the store to a new location.",
", please write the store to a new location using `sdata.write()`.",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -177,7 +190,11 @@ def read_zarr(
"shapes": (_read_shapes, "shapes", shapes),
"tables": (_read_table, "tables", tables),
}
for group_name, (read_func, element_type, element_container) in group_readers.items():
for group_name, (
read_func,
element_type,
element_container,
) in group_readers.items():
_read_zarr_group_spatialdata_element(
root_group=root_group,
root_store_path=root_store_path,
Expand Down
Loading