Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ node_modules/
# memray report
*.bin

# speedscope report
profile.speedscope.json

# test datasets (e.g. Xenium ones)
# symlinks
data
Expand Down
230 changes: 126 additions & 104 deletions src/spatialdata_io/readers/xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from geopandas import GeoDataFrame
from shapely import GeometryType, Polygon, from_ragged_array
from spatialdata import SpatialData
from spatialdata._core.query.relational_query import get_element_instances
from spatialdata._logging import logger
from spatialdata.models import (
Image2DModel,
Labels2DModel,
Expand Down Expand Up @@ -203,16 +201,17 @@ def xenium(
# open cells.zarr.zip once and reuse across all functions that need it
cells_zarr: zarr.Group | None = None
need_cells_zarr = (
nucleus_labels
nucleus_boundaries
or nucleus_labels
or cells_boundaries
or cells_labels
or (version is not None and version >= packaging.version.parse("2.0.0") and table is not None)
)
if need_cells_zarr:
cells_zarr_store = zarr.storage.ZipStore(path / XeniumKeys.CELLS_ZARR, read_only=True)
cells_zarr = zarr.open(cells_zarr_store, mode="r")

# pre-compute cell_id strings from the zarr once, to avoid redundant conversion
# in both _get_cells_metadata_table_from_zarr and _get_labels_and_indices_mapping.
# pre-compute cell_id strings from the zarr once, to avoid redundant conversion.
cells_zarr_cell_id_str: np.ndarray | None = None
if cells_zarr is not None and version is not None and version >= packaging.version.parse("1.3.0"):
cell_id_raw = cells_zarr["cell_id"][...]
Expand All @@ -221,7 +220,7 @@ def xenium(

if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
assert cells_zarr is not None
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, specs, cells_zarr_cell_id_str)
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, cells_zarr_cell_id_str)
try:
_assert_arrays_equal_sampled(
cell_summary_table[XeniumKeys.CELL_ID].values, table.obs[XeniumKeys.CELL_ID].values
Expand All @@ -243,36 +242,31 @@ def xenium(
points = {}
images = {}

# From the public release notes here:
# https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/release-notes/release-notes-for-xoa
# we see that for distinguishing between the nuclei of polinucleated cells, the `label_id` column is used.
# This column is currently not found in the preview data, while I think it is needed in order to unambiguously match
# nuclei to cells. Therefore for the moment we only link the table to the cell labels, and not to the nucleus
# labels.
# Build the label_index <-> cell_id mappings from the zarr once, reuse for both labels
# and boundaries. For v2.0+ this is deterministic from the zarr polygon_sets
# (label_id = cell_index + 1). For older versions, use seg_mask_value (cells only).
# For nuclei in v2.0+, this correctly handles multinucleate cells: each nucleus gets its
# own label_index, avoiding the bug of merging multiple nuclei into a single polygon.
# Older versions do not support multinucleate cells, so cell_id-based grouping is correct.
nucleus_indices_mapping: pd.DataFrame | None = None
cell_indices_mapping: pd.DataFrame | None = None
if cells_zarr_cell_id_str is not None and cells_zarr is not None and "polygon_sets" in cells_zarr:
if nucleus_boundaries or nucleus_labels:
nucleus_indices_mapping = _get_indices_mapping_from_zarr(cells_zarr, cells_zarr_cell_id_str, mask_index=0)
if cells_boundaries or cells_labels:
cell_indices_mapping = _get_indices_mapping_from_zarr(cells_zarr, cells_zarr_cell_id_str, mask_index=1)
elif cells_zarr_cell_id_str is not None:
if cells_boundaries or cells_labels:
cell_indices_mapping = _get_indices_mapping_legacy(cells_zarr, cells_zarr_cell_id_str, specs=specs)

if nucleus_labels:
labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
path=path,
specs=specs,
mask_index=0,
labels_name="nucleus_labels",
labels_models_kwargs=labels_models_kwargs,
cells_zarr=cells_zarr,
cell_id_str=None,
)
labels["nucleus_labels"] = _get_labels(cells_zarr, mask_index=0, labels_models_kwargs=labels_models_kwargs)
if cells_labels:
labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
path=path,
specs=specs,
mask_index=1,
labels_name="cell_labels",
labels_models_kwargs=labels_models_kwargs,
cells_zarr=cells_zarr,
cell_id_str=cells_zarr_cell_id_str,
)
if cell_labels_indices_mapping is not None and table is not None:
labels["cell_labels"] = _get_labels(cells_zarr, mask_index=1, labels_models_kwargs=labels_models_kwargs)
if cell_indices_mapping is not None and table is not None:
try:
_assert_arrays_equal_sampled(
cell_labels_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
cell_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
)
except AssertionError:
warnings.warn(
Expand All @@ -283,7 +277,7 @@ def xenium(
stacklevel=2,
)
else:
table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"].values
table.obs["cell_labels"] = cell_indices_mapping["label_index"].values
if not cells_as_circles:
table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"

Expand All @@ -292,15 +286,16 @@ def xenium(
path,
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
specs,
idx=None,
indices_mapping=nucleus_indices_mapping,
is_nucleus=True,
)

if cells_boundaries:
polygons["cell_boundaries"] = _get_polygons(
path,
XeniumKeys.CELL_BOUNDARIES_FILE,
specs,
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
indices_mapping=cell_indices_mapping,
)

if transcripts:
Expand Down Expand Up @@ -455,37 +450,57 @@ def _get_polygons(
path: Path,
file: str,
specs: dict[str, Any],
idx: pd.Series | None = None,
indices_mapping: pd.DataFrame | None = None,
is_nucleus: bool = False,
) -> GeoDataFrame:
# Use PyArrow compute to avoid slow .to_numpy() on Arrow-backed strings in pandas >= 3.0
# The original approach was:
# df = pq.read_table(path / file).to_pandas()
# cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
# which got slow with pandas >= 3.0 (Arrow-backed string .to_numpy() is ~100x slower).
# By doing change detection in Arrow, we avoid allocating Python string objects for all rows.
table = pq.read_table(path / file)
cell_id_col = table.column(str(XeniumKeys.CELL_ID))
"""Parse boundary polygons from a parquet file.

Parameters
----------
indices_mapping
When provided (from ``_get_indices_mapping_from_zarr`` or ``_get_indices_mapping_legacy``),
contains ``cell_id`` and ``label_index`` columns. The parquet ``label_id`` column is used
for fast integer-based change detection (to locate all the vertices of each polygon).
When None, falls back to cell_id-based grouping from the parquet (Xenium < 2.0).
is_nucleus
When True (nucleus boundaries), use ``label_index`` as the GeoDataFrame index and store
``cell_id`` as a column. This gives each nucleus a distinct integer id matching the raster
labels, correctly handling multinucleate cells.
When False (cell boundaries), use ``cell_id`` as the GeoDataFrame index.
"""
# Check whether the parquet has a label_id column (v2.0+). When present, use it for
# fast integer-based change detection. Otherwise fall back to cell_id strings.
parquet_schema = pq.read_schema(path / file)
has_label_id = "label_id" in parquet_schema.names

columns_to_read = [str(XeniumKeys.BOUNDARIES_VERTEX_X), str(XeniumKeys.BOUNDARIES_VERTEX_Y)]
columns_to_read.append("label_id" if has_label_id else str(XeniumKeys.CELL_ID))
table = pq.read_table(path / file, columns=columns_to_read)

x = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_X)).to_numpy()
y = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_Y)).to_numpy()
coords = np.column_stack([x, y])

n = len(cell_id_col)
change_mask = np.empty(n, dtype=bool)
change_mask[0] = True
change_mask[1:] = pc.not_equal(cell_id_col.slice(0, n - 1), cell_id_col.slice(1)).to_numpy(zero_copy_only=False)
group_starts = np.where(change_mask)[0]
group_ends = np.concatenate([group_starts[1:], [n]])
n = len(x)

# sanity check
n_unique_ids = pc.count_distinct(cell_id_col).as_py()
if has_label_id:
id_col = table.column("label_id")
id_arr = id_col.to_numpy()
change_mask = id_arr[1:] != id_arr[:-1]
else:
id_col = table.column(str(XeniumKeys.CELL_ID))
change_mask = pc.not_equal(id_col.slice(0, n - 1), id_col.slice(1)).to_numpy(zero_copy_only=False)
group_starts = np.where(np.concatenate([[True], change_mask]))[0]
n_unique_ids = pc.count_distinct(id_col).as_py()
if len(group_starts) != n_unique_ids:
raise ValueError(
f"In {file}, rows belonging to the same polygon must be contiguous. "
f"Expected {n_unique_ids} group starts, but found {len(group_starts)}. "
f"This indicates non-consecutive polygon rows."
)

group_ends = np.concatenate([group_starts[1:], [n]])

# offsets for ragged array:
# offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
# offsets[1] (geom_offsets): describing to which polygons the rings belong to
Expand All @@ -494,85 +509,92 @@ def _get_polygons(

geoms = from_ragged_array(GeometryType.POLYGON, coords, offsets=(ring_offsets, geom_offsets))

# idx is not None for the cells and None for the nuclei (for xenium(cells_table=False) is None for both
if idx is not None:
# Cell IDs already available from the annotation table
assert len(idx) == len(group_starts), f"Expected {len(group_starts)} cell IDs, got {len(idx)}"
geo_df = GeoDataFrame({"geometry": geoms}, index=idx.values)
if indices_mapping is not None:
assert len(indices_mapping) == len(group_starts), (
f"Expected {len(group_starts)} polygons, but indices_mapping has {len(indices_mapping)} entries."
)
if is_nucleus:
# Use label_index (int) as GeoDataFrame index, cell_id as column.
geo_df = GeoDataFrame(
{"geometry": geoms, str(XeniumKeys.CELL_ID): indices_mapping["cell_id"].values},
index=indices_mapping["label_index"].values,
)
else:
# Use cell_id (str) as GeoDataFrame index.
geo_df = GeoDataFrame({"geometry": geoms}, index=indices_mapping["cell_id"].values)
else:
# Fall back to extracting unique cell IDs from parquet (slow for large_string columns).
unique_ids = cell_id_col.filter(change_mask).to_pylist()
unique_ids = id_col.filter(np.concatenate([[True], change_mask])).to_pylist()
index = _decode_cell_id_column(pd.Series(unique_ids))
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)

scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
return ShapesModel.parse(geo_df, transformations={"global": scale})


def _get_labels_and_indices_mapping(
path: Path,
specs: dict[str, Any],
mask_index: int,
labels_name: str,
def _get_labels(
cells_zarr: zarr.Group,
cell_id_str: ArrayLike,
mask_index: int,
labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> tuple[GeoDataFrame, pd.DataFrame | None]:
) -> DataArray:
"""Read the labels raster from cells.zarr.zip masks/{mask_index}."""
if mask_index not in [0, 1]:
raise ValueError(f"mask_index must be 0 or 1, found {mask_index}.")

# get the labels
masks = da.from_array(cells_zarr["masks"][f"{mask_index}"])
labels = Labels2DModel.parse(masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs)
return Labels2DModel.parse(masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs)

# build the matching table
version = _parse_version_of_xenium_analyzer(specs)
if mask_index == 0:
# nuclei currently not supported
return labels, None
if version is None or version is not None and version < packaging.version.parse("1.3.0"):
# supported in version 1.3.0 and not supported in version 1.0.2; conservatively, let's assume it is not
# supported in versions < 1.3.0
return labels, None

if version < packaging.version.parse("2.0.0"):
label_index = cells_zarr["seg_mask_value"][...]
else:
# For v >= 2.0.0, seg_mask_value is no longer available in the zarr;
# read label_id from the corresponding parquet boundary file instead
boundaries_file = XeniumKeys.NUCLEUS_BOUNDARIES_FILE if mask_index == 0 else XeniumKeys.CELL_BOUNDARIES_FILE
boundary_columns = pq.read_schema(path / boundaries_file).names
if "label_id" in boundary_columns:
boundary_df = pq.read_table(path / boundaries_file, columns=[XeniumKeys.CELL_ID, "label_id"]).to_pandas()
unique_pairs = boundary_df.drop_duplicates(subset=[XeniumKeys.CELL_ID, "label_id"]).copy()
unique_pairs[XeniumKeys.CELL_ID] = _decode_cell_id_column(unique_pairs[XeniumKeys.CELL_ID])
cell_id_to_label_id = unique_pairs.set_index(XeniumKeys.CELL_ID)["label_id"]
label_index = cell_id_to_label_id.loc[cell_id_str].values
else:
# fallback for dev versions around 2.0.0 that lack both seg_mask_value and label_id
logger.warn(
f"Could not find the labels ids from the metadata for version {version}. Using a fallback (slower) implementation."
)
label_index = get_element_instances(labels).values

if label_index[0] == 0:
label_index = label_index[1:]
def _get_indices_mapping_from_zarr(
cells_zarr: zarr.Group,
cells_zarr_cell_id_str: np.ndarray,
mask_index: int,
) -> pd.DataFrame:
"""Build the label_index <-> cell_id mapping from the zarr polygon_sets.

# labels_index is an uint32, so let's cast to np.int64 to avoid the risk of overflow on some systems
indices_mapping = pd.DataFrame(
From the 10x Genomics docs: "the label ID is equal to the cell index + 1",
where cell_index is polygon_sets/{mask_index}/cell_index. This is deterministic
and avoids reading the slow parquet boundary files.

For cells (mask_index=1): cell_index is 0..N-1 (1:1 with cells), so
label_index = arange(1, N+1).
For nuclei (mask_index=0): cell_index maps each nucleus to its parent cell,
so label_index = arange(1, M+1) and cell_id = cell_id_str[cell_index[i]].
"""
cell_index = cells_zarr[f"polygon_sets/{mask_index}/cell_index"][...]
label_index = np.arange(1, len(cell_index) + 1, dtype=np.int64)
cell_id = cells_zarr_cell_id_str[cell_index]
return pd.DataFrame(
{
"cell_id": cell_id,
"label_index": label_index,
}
)


def _get_indices_mapping_legacy(
cells_zarr: zarr.Group,
cell_id_str: ArrayLike,
specs: dict[str, Any],
) -> pd.DataFrame | None:
"""Build the label_index <-> cell_id mapping for versions < 2.0.0.

Uses seg_mask_value from the zarr (available in v1.3.0+).
"""
version = _parse_version_of_xenium_analyzer(specs)
if version is None or version < packaging.version.parse("1.3.0"):
return None
label_index = cells_zarr["seg_mask_value"][...]
return pd.DataFrame(
{
"region": labels_name,
"cell_id": cell_id_str,
"label_index": label_index.astype(np.int64),
}
)
return labels, indices_mapping


@inject_docs(xx=XeniumKeys)
def _get_cells_metadata_table_from_zarr(
cells_zarr: zarr.Group,
specs: dict[str, Any],
cell_id_str: ArrayLike,
) -> AnnData:
"""Read cells metadata from ``{xx.CELLS_ZARR}``.
Expand Down
9 changes: 6 additions & 3 deletions tests/test_xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def test_example_data_index_integrity(dataset: str) -> None:
assert sdata["nucleus_labels"]["scale0"]["image"].sel(y=3515.5, x=4618.5).data.compute() == 6392
assert np.allclose(sdata['transcripts'].compute().loc[[0, 10000, 1113949]]['x'], [2.608911, 194.917831, 1227.499268])
assert np.isclose(sdata['cell_boundaries'].loc['oipggjko-1'].geometry.centroid.x,736.4864931162789)
assert np.isclose(sdata['nucleus_boundaries'].loc['oipggjko-1'].geometry.centroid.x,736.4931256878282)
index = sdata['nucleus_boundaries']['cell_id'].index[sdata['nucleus_boundaries']['cell_id'].eq('oipggjko-1')][0]
assert np.isclose(sdata['nucleus_boundaries'].loc[index].geometry.centroid.x,736.4931256878282)
assert np.array_equal(sdata['table'].X.indices[:3], [1, 3, 34])
# fmt: on

Expand All @@ -138,7 +139,8 @@ def test_example_data_index_integrity(dataset: str) -> None:
assert sdata["nucleus_labels"]["scale0"]["image"].sel(y=18.5, x=3015.5).data.compute() == 2764
assert np.allclose(sdata['transcripts'].compute().loc[[0, 10000, 20000]]['x'], [174.258392, 12.210024, 214.759186])
assert np.isclose(sdata['cell_boundaries'].loc['aaanbaof-1'].geometry.centroid.x, 43.96894317275074)
assert np.isclose(sdata['nucleus_boundaries'].loc['aaanbaof-1'].geometry.centroid.x,43.31874577809517)
index = sdata['nucleus_boundaries']['cell_id'].index[sdata['nucleus_boundaries']['cell_id'].eq('aaanbaof-1')][0]
assert np.isclose(sdata['nucleus_boundaries'].loc[index].geometry.centroid.x,43.31874577809517)
assert np.array_equal(sdata['table'].X.indices[:3], [1, 8, 19])
# fmt: on

Expand All @@ -164,7 +166,8 @@ def test_example_data_index_integrity(dataset: str) -> None:
assert sdata["nucleus_labels"]["scale0"]["image"].sel(y=4039.5, x=93.5).data.compute() == 274
assert np.allclose(sdata['transcripts'].compute().loc[[0, 10000, 20000]]['x'], [43.296875, 62.484375, 93.125])
assert np.isclose(sdata['cell_boundaries'].loc['aadmbfof-1'].geometry.centroid.x, 64.54541104696033)
assert np.isclose(sdata['nucleus_boundaries'].loc['aadmbfof-1'].geometry.centroid.x, 65.43305896114295)
index = sdata['nucleus_boundaries']['cell_id'].index[sdata['nucleus_boundaries']['cell_id'].eq('aadmbfof-1')][0]
assert np.isclose(sdata['nucleus_boundaries'].loc[index].geometry.centroid.x, 65.43305896114295)
assert np.array_equal(sdata['table'].X.indices[:3], [3, 49, 53])
# fmt: on

Expand Down
Loading