Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def _read_points(
return points


class PointsReader:
def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> DaskDataFrame:
return _read_points(store)


def write_points(
points: DaskDataFrame,
group: zarr.Group,
Expand Down
62 changes: 35 additions & 27 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,6 @@
)


def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]:
"""Get nodes with Multiscales spec from a list of nodes.

The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check
the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have
the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific
metadata though.

Parameters
----------
image_nodes
List of nodes returned from the ome-zarr-py Reader.
nodes
List to append the nodes with the multiscales spec to.

Returns
-------
List of nodes with the multiscales spec.
"""
if len(image_nodes):
for node in image_nodes:
# Labels are now also Multiscales in newer version of ome-zarr-py
if np.any([isinstance(spec, Multiscales) for spec in node.specs]):
nodes.append(node)
return nodes


def _read_multiscale(
store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format
) -> DataArray | DataTree:
Expand Down Expand Up @@ -134,6 +107,7 @@ def _read_multiscale(
msi = DataTree.from_dict(multiscale_image)
_set_transformations(msi, transformations)
return compute_coordinates(msi)

data = node.load(Multiscales).array(resolution=datasets[0])
si = DataArray(
data,
Expand All @@ -145,6 +119,40 @@ def _read_multiscale(
return compute_coordinates(si)


def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]:
"""Get nodes with Multiscales spec from a list of nodes.

The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check
the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have
the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific
metadata though.

Parameters
----------
image_nodes
List of nodes returned from the ome-zarr-py Reader.
nodes
List to append the nodes with the multiscales spec to.

Returns
-------
List of nodes with the multiscales spec.
"""
if len(image_nodes):
for node in image_nodes:
# Labels are now also Multiscales in newer version of ome-zarr-py
if np.any([isinstance(spec, Multiscales) for spec in node.specs]):
nodes.append(node)
return nodes


class MultiscaleReader:
def __call__(
self, path: str | Path, raster_type: Literal["image", "labels"], reader_format: Format
) -> DataArray | DataTree:
return _read_multiscale(path, raster_type, reader_format)


def _write_raster(
raster_type: Literal["image", "labels"],
raster_data: DataArray | DataTree,
Expand Down
7 changes: 6 additions & 1 deletion src/spatialdata/_io/io_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def _read_shapes(
store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg]
store: str | Path | MutableMapping[str, object] | zarr.Group,
) -> GeoDataFrame:
"""Read shapes from a zarr store."""
assert isinstance(store, str | Path)
Expand Down Expand Up @@ -67,6 +67,11 @@ def _read_shapes(
return geo_df


class ShapesReader:
def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> GeoDataFrame:
return _read_shapes(store)


def write_shapes(
shapes: GeoDataFrame,
group: zarr.Group,
Expand Down
14 changes: 12 additions & 2 deletions src/spatialdata/_io/io_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _read_table(
group: zarr.Group,
tables: dict[str, AnnData],
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
) -> dict[str, AnnData]:
) -> None:
"""
Read in tables in the tables Zarr.group of a SpatialData Zarr store.

Expand Down Expand Up @@ -85,7 +85,17 @@ def _read_table(
count += 1

logger.debug(f"Found {count} elements in {group}")
return tables


class TablesReader:
def __call__(
self,
path: str,
group: zarr.Group,
container: dict[str, AnnData],
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN],
) -> None:
return _read_table(path, group, container, on_bad_files)


def write_table(
Expand Down
207 changes: 80 additions & 127 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from json import JSONDecodeError
from pathlib import Path
from typing import Literal
from typing import Literal, cast

import zarr.storage
from anndata import AnnData
Expand All @@ -16,11 +16,59 @@
_resolve_zarr_store,
handle_read_errors,
)
from spatialdata._io.io_points import _read_points
from spatialdata._io.io_raster import _read_multiscale
from spatialdata._io.io_shapes import _read_shapes
from spatialdata._io.io_table import _read_table
from spatialdata._io.io_points import PointsReader
from spatialdata._io.io_raster import MultiscaleReader
from spatialdata._io.io_shapes import ShapesReader
from spatialdata._io.io_table import TablesReader
from spatialdata._logging import logger
from spatialdata.models import SpatialElement

ReadClasses = MultiscaleReader | PointsReader | ShapesReader | TablesReader


def _read_zarr_group_spatialdata_element(
root_group: zarr.Group,
root_store_path: str,
sdata_version: Literal["0.1", "0.2"],
selector: set[str],
read_func: ReadClasses,
group_name: Literal["images", "labels", "shapes", "points", "tables"],
element_type: Literal["image", "labels", "shapes", "points", "tables"],
element_container: dict[str, SpatialElement | AnnData],
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN],
) -> None:
with handle_read_errors(
on_bad_files,
location=group_name,
exc_types=JSONDecodeError,
):
if group_name in selector and group_name in root_group:
group = root_group[group_name]
if isinstance(read_func, TablesReader):
read_func(root_store_path, group, element_container, on_bad_files=on_bad_files)
else:
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
elem_group = group[subgroup_name]
elem_group_path = os.path.join(root_store_path, elem_group.path)
with handle_read_errors(
on_bad_files,
location=f"{group.path}/{subgroup_name}",
exc_types=(KeyError, ArrayNotFoundError, OSError, ArrowInvalid, JSONDecodeError),
):
if isinstance(read_func, MultiscaleReader):
reader_format = get_raster_format_for_read(elem_group, sdata_version)
element = read_func(
elem_group_path, cast(Literal["image", "labels"], element_type), reader_format
)
if isinstance(read_func, PointsReader | ShapesReader):
element = read_func(elem_group_path)
element_container[subgroup_name] = element
count += 1
logger.debug(f"Found {count} elements in {group}")


def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format:
Expand Down Expand Up @@ -87,134 +135,39 @@ def read_zarr(
sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"]
root_store_path = root_group.store.root

images = {}
labels = {}
points = {}
images: dict[str, SpatialElement] = {}
labels: dict[str, SpatialElement] = {}
points: dict[str, SpatialElement] = {}
tables: dict[str, AnnData] = {}
shapes = {}
shapes: dict[str, SpatialElement] = {}

selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or [])
logger.debug(f"Reading selection {selector}")

# We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read.
# related to images / labels.
with handle_read_errors(
on_bad_files,
location="images",
exc_types=JSONDecodeError,
):
if "images" in selector and "images" in root_group:
group = root_group["images"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
elem_group = group[subgroup_name]
reader_format = get_raster_format_for_read(elem_group, sdata_version)
elem_group_path = os.path.join(root_store_path, elem_group.path)
with handle_read_errors(
on_bad_files,
location=f"{group.path}/{subgroup_name}",
exc_types=(
KeyError,
ArrayNotFoundError,
OSError,
),
):
element = _read_multiscale(elem_group_path, raster_type="image", reader_format=reader_format)
images[subgroup_name] = element
count += 1
logger.debug(f"Found {count} elements in {group}")

# read multiscale labels
with handle_read_errors(
on_bad_files,
location="labels",
exc_types=JSONDecodeError,
):
if "labels" in selector and "labels" in root_group:
group = root_group["labels"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
elem_group = group[subgroup_name]
reader_format = get_raster_format_for_read(elem_group, sdata_version)
elem_group_path = root_store_path / elem_group.path
with handle_read_errors(
on_bad_files,
location=f"{group.path}/{subgroup_name}",
exc_types=(
KeyError,
ArrayNotFoundError,
OSError,
),
):
labels[subgroup_name] = _read_multiscale(
elem_group_path, raster_type="labels", reader_format=reader_format
)
count += 1
logger.debug(f"Found {count} elements in {group}")
# now read rest of the data
with handle_read_errors(
on_bad_files,
location="points",
exc_types=JSONDecodeError,
):
if "points" in selector and "points" in root_group:
group = root_group["points"]
count = 0
for subgroup_name in group:
elem_group = group[subgroup_name]
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
elem_group_path = os.path.join(root_store_path, elem_group.path)
with handle_read_errors(
on_bad_files,
location=f"{group.path}/{subgroup_name}",
exc_types=(KeyError, ArrowInvalid, JSONDecodeError),
):
points[subgroup_name] = _read_points(elem_group_path)
count += 1
logger.debug(f"Found {count} elements in {group}")

with handle_read_errors(
on_bad_files,
location="shapes",
exc_types=JSONDecodeError,
):
if "shapes" in selector and "shapes" in root_group:
group = root_group["shapes"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
elem_group = group[subgroup_name]
elem_group_path = os.path.join(root_store_path, elem_group.path)
with handle_read_errors(
on_bad_files,
location=f"{group.path}/{subgroup_name}",
exc_types=(
JSONDecodeError,
KeyError,
ArrayNotFoundError,
),
):
shapes[subgroup_name] = _read_shapes(elem_group_path)
count += 1
logger.debug(f"Found {count} elements in {group}")
if "tables" in selector and "tables" in root_group:
with handle_read_errors(
group_readers: dict[
Literal["images", "labels", "shapes", "points", "tables"],
tuple[
ReadClasses, Literal["image", "labels", "shapes", "points", "tables"], dict[str, SpatialElement | AnnData]
],
] = {
"images": (MultiscaleReader(), "image", images),
"labels": (MultiscaleReader(), "labels", labels),
"points": (PointsReader(), "points", points),
"shapes": (ShapesReader(), "shapes", shapes),
"tables": (TablesReader(), "tables", tables),
}
for group_name, (reader, raster_type, container) in group_readers.items():
_read_zarr_group_spatialdata_element(
root_group,
root_store_path,
sdata_version,
selector,
reader,
group_name,
raster_type,
container,
on_bad_files,
location="tables",
exc_types=JSONDecodeError,
):
group = root_group["tables"]
tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files)
)

# read attrs metadata
attrs = root_group.attrs.asdict()
Expand Down
Loading