From 3249d5293a79e7bb8247e14e6ecc46094822f19c Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 3 Mar 2026 12:11:23 +0100 Subject: [PATCH] add extension functionality (see #111) This is basically a copy of AnnData's implementation, lightly adapted. --- CHANGELOG.md | 9 +- docs/api.md | 9 ++ src/mudata/__init__.py | 1 + src/mudata/_core/extensions.py | 190 +++++++++++++++++++++++++++ src/mudata/_core/mudata.py | 205 +++++++++++++++-------------- tests/test_extensions.py | 229 +++++++++++++++++++++++++++++++++ 6 files changed, 542 insertions(+), 101 deletions(-) create mode 100644 src/mudata/_core/extensions.py create mode 100644 tests/test_extensions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fd2100c..10e7a08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.1.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.3.4] (unreleased) + +### Added + +- `mudata.register_mudata_namespace()` functionality for adding custom functionality to `MuData` objects. + ## [0.3.3] ### Fixed @@ -129,7 +135,8 @@ To copy the annotations explicitly, you will need to use `pull_obs()` and/or `pu Initial `mudata` release with `MuData`, previously a part of the `muon` framework. -[0.3.3]: https://github.com/scverse/mudata/compare/v0.3.1...v0.3.3 +[0.3.4]: https://github.com/scverse/mudata/compare/v0.3.3...v0.3.4 +[0.3.3]: https://github.com/scverse/mudata/compare/v0.3.2...v0.3.3 [0.3.2]: https://github.com/scverse/mudata/compare/v0.3.1...v0.3.2 [0.3.1]: https://github.com/scverse/mudata/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/scverse/mudata/compare/v0.2.4...v0.3.0 diff --git a/docs/api.md b/docs/api.md index 0f2084e..acbc025 100644 --- a/docs/api.md +++ b/docs/api.md @@ -30,3 +30,12 @@ :functions-only: :toctree: generated ``` + +## Extensions +```{eval-rst} +.. module::mudata +.. autosummary:: + :toctree: generated + + register_mudata_namespace +``` diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py index fa9ca20..6b4772d 100644 --- a/src/mudata/__init__.py +++ b/src/mudata/__init__.py @@ -4,6 +4,7 @@ from ._core import utils from ._core.config import set_options +from ._core.extensions import ExtensionNamespace, register_mudata_namespace from ._core.io import ( read, read_anndata, diff --git a/src/mudata/_core/extensions.py b/src/mudata/_core/extensions.py new file mode 100644 index 0000000..e8c8d7e --- /dev/null +++ b/src/mudata/_core/extensions.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import inspect +import warnings +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, get_type_hints, overload, runtime_checkable + +from .mudata import MuData + +if TYPE_CHECKING: + from collections.abc import Callable + + +@runtime_checkable +class ExtensionNamespace(Protocol): + """Protocol for extension namespaces. + + Enforces that the namespace initializer accepts a class with the proper `__init__` method. + Protocol's can't enforce that the `__init__` accepts the correct types. See + `_check_namespace_signature` for that. This is mainly useful for static type + checking with mypy and IDEs. + """ + + def __init__(self, mdata: MuData) -> None: + """Used to enforce the correct signature for extension namespaces.""" + + +# Based off of the extension framework in Polars +# https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py + +__all__ = ["register_mudata_namespace", "ExtensionNamespace"] + + +# Reserved namespaces include accessors built into MuData (currently there are none) +# and all current attributes of MuData +_reserved_namespaces: set[str] = set(dir(MuData)) + +NameSpT = TypeVar("NameSpT", bound=ExtensionNamespace) +T = TypeVar("T") + + +class AccessorNameSpace(ExtensionNamespace, Generic[NameSpT]): + """Establish property-like namespace object for user-defined functionality.""" + + def __init__(self, name: str, namespace: type[NameSpT]) -> None: + self._accessor = name + self._ns = namespace + + @overload + def __get__(self, instance: None, cls: type[T]) -> type[NameSpT]: ... + + @overload + def __get__(self, instance: T, cls: type[T]) -> NameSpT: ... + + def __get__(self, instance: T | None, cls: type[T]) -> NameSpT | type[NameSpT]: + if instance is None: + return self._ns + + ns_instance = self._ns(instance) # type: ignore[call-arg] + setattr(instance, self._accessor, ns_instance) + return ns_instance + + +def _check_namespace_signature(ns_class: type) -> None: + """Validate the signature of a namespace class for MuData extensions. + + This function ensures that any class intended to be used as an extension namespace + has a properly formatted `__init__` method such that: + + 1. Accepts at least two parameters (self and mdata) + 2. Has 'mdata' as the name of the second parameter + 3. Has the second parameter properly type-annotated as 'MuData' or any equivalent import alias + + The function performs runtime validation of these requirements before a namespace + can be registered through the `register_mudata_namespace` decorator. + + Parameters + ---------- + ns_class + The namespace class to validate. + + Raises + ------ + TypeError + If the `__init__` method has fewer than 2 parameters (missing the MuData parameter). + AttributeError + If the second parameter of `__init__` lacks a type annotation. + TypeError + If the second parameter of `__init__` is not named 'mdata'. + TypeError + If the second parameter of `__init__` is not annotated as the 'MuData' class. + TypeError + If both the name and type annotation of the second parameter are incorrect. + + """ + sig = inspect.signature(ns_class.__init__) + params = sig.parameters + + # Ensure there are at least two parameters (self and mdata) + if len(params) < 2: + raise TypeError("Namespace initializer must accept an MuData instance as the second parameter.") + + # Get the second parameter (expected to be 'mdata') + param = iter(params.values()) + next(param) + param = next(param) + if param.annotation is inspect.Parameter.empty: + raise AttributeError( + "Namespace initializer's second parameter must be annotated as the 'MuData' class, got empty annotation." + ) + + name_ok = param.name == "mdata" + + # Resolve the annotation using get_type_hints to handle forward references and aliases. + try: + type_hints = get_type_hints(ns_class.__init__) + resolved_type = type_hints.get(param.name, param.annotation) + except NameError as e: + raise NameError(f"Namespace initializer's second parameter must be named 'mdata', got '{param.name}'.") from e + + type_ok = resolved_type is MuData + + match (name_ok, type_ok): + case (True, True): + return # Signature is correct. + case (False, True): + raise TypeError(f"Namespace initializer's second parameter must be named 'mdata', got {param.name!r}.") + case (True, False): + type_repr = getattr(resolved_type, "__name__", str(resolved_type)) + raise TypeError( + f"Namespace initializer's second parameter must be annotated as the 'MuData' class, got {type_repr!r}." + ) + case _: + type_repr = getattr(resolved_type, "__name__", str(resolved_type)) + raise TypeError( + f"Namespace initializer's second parameter must be named 'mdata', got {param.name!r}. " + f"And must be annotated as 'MuData', got {type_repr!r}." + ) + + +def _create_namespace(name: str, cls: type[MuData]) -> Callable[[type[NameSpT]], type[NameSpT]]: + """Register custom namespace against the underlying MuData class.""" + + def namespace(ns_class: type[NameSpT]) -> type[NameSpT]: + _check_namespace_signature(ns_class) # Perform the runtime signature check + if name in _reserved_namespaces: + raise AttributeError(f"cannot override reserved attribute {name!r}") + elif hasattr(cls, name): + warnings.warn( + f"Overriding existing custom namespace {name!r} (on {cls.__name__!r})", UserWarning, stacklevel=2 + ) + setattr(cls, name, AccessorNameSpace(name, ns_class)) + return ns_class + + return namespace + + +def register_mudata_namespace(name: str) -> Callable[[type[NameSpT]], type[NameSpT]]: + """Decorator for registering custom functionality with an :class:`~mudata.MuData` object. + + This decorator allows you to extend MuData objects with custom methods and properties + organized under a namespace. The namespace becomes accessible as an attribute on MuData + instances, providing a clean way to you to add domain-specific functionality without modifying + the MuData class itself, or extending the class with additional methods as you see fit in your workflow. + + This is equivalent to :func:`anndata.register_anndata_namespace`. + + Parameters + ---------- + name + Name under which the accessor should be registered. This will be the attribute name + used to access your namespace's functionality on MuData objects (e.g., `mdata.{name}`). + Cannot conflict with existing MuData attributes like `obs`, `var`, `mod`, etc. The list of reserved + attributes includes everything outputted by `dir(MuData)`. + + Returns + ------- + A decorator that registers the decorated class as a custom namespace. + + Notes + ----- + Implementation requirements: + + 1. The decorated class must have an `__init__` method that accepts exactly one parameter + (besides `self`) named `mdata` and annotated with type :class:`~mudata.MuData`. + 2. The namespace will be initialized with the MuData object on first access and then + cached on the instance. + 3. If the namespace name conflicts with an existing namespace, a warning is issued. + 4. If the namespace name conflicts with a built-in MuData attribute, an AttributeError is raised. + """ + return _create_namespace(name, MuData) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index b5a0e11..01bb9f3 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -156,13 +156,13 @@ def __init__( return # Add all modalities to a MuData object - self.mod = ModDict() + self._mod = ModDict() if data is None: # Initialize an empty MuData object pass elif isinstance(data, abc.Mapping): for k, v in data.items(): - self.mod[k] = v + self._mod[k] = v elif isinstance(data, AnnData): # Get the list of modalities if "feature_types" in data.var.columns: @@ -176,9 +176,9 @@ def __init__( if feature_types_names is not None: if name in feature_types_names.keys(): alias = feature_types_names[name] - self.mod[alias] = data[:, data.var.feature_types == name].copy() + self._mod[alias] = data[:, data.var.feature_types == name].copy() else: - self.mod["data"] = data + self._mod["data"] = data else: raise TypeError("Expected AnnData object or dictionary with AnnData objects as values") @@ -272,7 +272,7 @@ def _init_as_view(self, mudata_ref: "MuData", index): if isinstance(varidx, Integral): varidx = slice(varidx, varidx + 1) - self.mod = ModDict() + self._mod = ModDict() for m, a in mudata_ref.mod.items(): cobsidx, cvaridx = mudata_ref.obsmap[m][obsidx], mudata_ref.varmap[m][varidx] cobsidx, cvaridx = cobsidx[cobsidx > 0] - 1, cvaridx[cvaridx > 0] - 1 @@ -301,11 +301,11 @@ def _init_as_view(self, mudata_ref: "MuData", index): cvaridx = slice(None) if a.is_view: if isinstance(a, MuData): - self.mod[m] = a._mudata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._mudata_ref)] + self._mod[m] = a._mudata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._mudata_ref)] else: - self.mod[m] = a._adata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._adata_ref)] + self._mod[m] = a._adata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._adata_ref)] else: - self.mod[m] = a[cobsidx, cvaridx] + self._mod[m] = a[cobsidx, cvaridx] self._obs = DataFrameView(mudata_ref.obs.iloc[obsidx, :], view_args=(self, "obs")) self._obsm = mudata_ref.obsm._view(self, (obsidx,)) @@ -338,7 +338,7 @@ def _init_as_view(self, mudata_ref: "MuData", index): def _init_as_actual(self, data: "MuData"): self._init_common() - self.mod = data.mod + self._mod = data.mod self._obs = data.obs self._var = data.var self._obsm = MuAxisArrays(self, axis=0, store=convert_to_dict(data.obsm)) @@ -389,21 +389,21 @@ def _init_from_dict_( ) def _check_duplicated_attr_names(self, attr: str): - if any(not getattr(self.mod[mod_i], attr + "_names").astype(str).is_unique for mod_i in self.mod): + if any(not getattr(self._mod[mod_i], attr + "_names").astype(str).is_unique for mod_i in self._mod): # If there are non-unique attr_names, we can only handle outer joins # under the condition the duplicated values are restricted to one modality dups = [ np.unique( - getattr(self.mod[mod_i], attr + "_names")[ - getattr(self.mod[mod_i], attr + "_names").astype(str).duplicated() + getattr(self._mod[mod_i], attr + "_names")[ + getattr(self._mod[mod_i], attr + "_names").astype(str).duplicated() ] ) - for mod_i in self.mod + for mod_i in self._mod ] for i, mod_i_dup_attrs in enumerate(dups): - for j, mod_j in enumerate(self.mod): + for j, mod_j in enumerate(self._mod): if j != i: - if any(np.isin(mod_i_dup_attrs, getattr(self.mod[mod_j], attr + "_names").values)): + if any(np.isin(mod_i_dup_attrs, getattr(self._mod[mod_j], attr + "_names").values)): warnings.warn( f"Duplicated {attr}_names should not be present in different modalities due to the ambiguity that leads to.", stacklevel=3, @@ -416,9 +416,9 @@ def _check_duplicated_names(self): self._check_duplicated_attr_names("var") def _check_intersecting_attr_names(self, attr: str): - for mod_i, mod_j in combinations(self.mod, 2): - mod_i_attr_index = getattr(self.mod[mod_i], attr + "_names") - mod_j_attr_index = getattr(self.mod[mod_j], attr + "_names") + for mod_i, mod_j in combinations(self._mod, 2): + mod_i_attr_index = getattr(self._mod[mod_i], attr + "_names") + mod_j_attr_index = getattr(self._mod[mod_j], attr + "_names") intersection = mod_i_attr_index.intersection(mod_j_attr_index, sort=False) if intersection.shape[0] > 0: # Some of the elements are also in another index @@ -430,15 +430,15 @@ def _check_changed_attr_names(self, attr: str, columns: bool = False): attr_names_changed, attr_columns_changed = False, False if not hasattr(self, attrhash): attr_names_changed, attr_columns_changed = True, True - elif len(self.mod) < len(getattr(self, attrhash)): + elif len(self._mod) < len(getattr(self, attrhash)): attr_names_changed, attr_columns_changed = True, None else: - for m in self.mod.keys(): + for m in self._mod.keys(): if m in getattr(self, attrhash): cached_hash = getattr(self, attrhash)[m] new_hash = ( - sha1(np.ascontiguousarray(getattr(self.mod[m], attr).index.values)).hexdigest(), - sha1(np.ascontiguousarray(getattr(self.mod[m], attr).columns.values)).hexdigest(), + sha1(np.ascontiguousarray(getattr(self._mod[m], attr).index.values)).hexdigest(), + sha1(np.ascontiguousarray(getattr(self._mod[m], attr).columns.values)).hexdigest(), ) if cached_hash[0] != new_hash[0]: attr_names_changed = True @@ -464,7 +464,7 @@ def copy(self, filename: PathLike | None = None) -> "MuData": """ if not self.isbacked: mod = {} - for k, v in self.mod.items(): + for k, v in self._mod.items(): mod[k] = v.copy() return self._init_from_dict_( mod, @@ -498,8 +498,8 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None): # Call the same method on each modality if df is None: - for k in self.mod: - self.mod[k].strings_to_categoricals() + for k in self._mod: + self._mod[k].strings_to_categoricals() else: return df @@ -508,10 +508,15 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None): def __getitem__(self, index) -> Union["MuData", AnnData]: if isinstance(index, str): - return self.mod[index] + return self._mod[index] else: return MuData(self, as_view=True, index=index) + @property + def mod(self) -> Mapping[str, "AnnData | MuData"]: + """Dictionary of modalities.""" + return self._mod + @property def shape(self) -> tuple[int, int]: """Shape of data, all variables and observations combined (:attr:`n_obs`, :attr:`n_var`).""" @@ -533,13 +538,13 @@ def __len__(self) -> int: def _create_global_attr_index(self, attr: str, axis: int): if axis == (1 - self._axis): # Shared indices - modindices = [getattr(self.mod[m], attr).index for m in self.mod] + modindices = [getattr(self._mod[m], attr).index for m in self._mod] if all(modindices[i].equals(modindices[i + 1]) for i in range(len(modindices) - 1)): attrindex = modindices[0].copy() - attrindex = reduce(pd.Index.union, [getattr(self.mod[m], attr).index for m in self.mod]).values + attrindex = reduce(pd.Index.union, [getattr(self._mod[m], attr).index for m in self._mod]).values else: # Modality-specific indices - attrindex = np.concatenate([getattr(self.mod[m], attr).index.values for m in self.mod], axis=0) + attrindex = np.concatenate([getattr(self._mod[m], attr).index.values for m in self._mod], axis=0) return attrindex def _update_attr( @@ -595,7 +600,7 @@ def _update_attr( dfs = [ getattr(a, attr).loc[:, []].assign(**{f"{m}:{rowcol}": np.arange(getattr(a, attr).shape[0])}) - for m, a in self.mod.items() + for m, a in self._mod.items() ] index_order = None @@ -645,7 +650,7 @@ def calc_attrm_update(): data_mod = pd.concat( dfs, join="outer", axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, sort=False ) - for mod in self.mod.keys(): + for mod in self._mod.keys(): fix_attrmap_col(data_mod, mod, rowcol) data_mod = _make_index_unique(data_mod, force=attr_intersecting) @@ -671,7 +676,7 @@ def calc_attrm_update(): data_mod = _restore_index(data_mod) data_mod.index.set_names(rowcol, inplace=True) data_global.index.set_names(rowcol, inplace=True) - for mod, amod in self.mod.items(): + for mod, amod in self._mod.items(): colname = fix_attrmap_col(data_mod, mod, rowcol) if mod in attrmap: modmap = attrmap[mod].ravel() @@ -715,7 +720,7 @@ def calc_attrm_update(): # get adata positions and remove columns from the data frame mdict = {} - for m in self.mod.keys(): + for m in self._mod.keys(): colname = m + ":" + rowcol mdict[m] = data_mod[colname].to_numpy() data_mod.drop(colname, axis=1, inplace=True) @@ -741,7 +746,7 @@ def calc_attrm_update(): if index_order is not None: if can_update: for mx_key, mx in attrm.items(): - if mx_key not in self.mod.keys(): # not a modality name + if mx_key not in self._mod.keys(): # not a modality name if isinstance(mx, pd.DataFrame): mx = mx.iloc[index_order, :] mx.iloc[index_order == -1, :] = pd.NA @@ -767,7 +772,7 @@ def calc_attrm_update(): if attr_changed: if not hasattr(self, _attrhash): setattr(self, _attrhash, {}) - for m, mod in self.mod.items(): + for m, mod in self._mod.items(): getattr(self, _attrhash)[m] = ( sha1(np.ascontiguousarray(getattr(mod, attr).index.values)).hexdigest(), sha1(np.ascontiguousarray(getattr(mod, attr).columns.values)).hexdigest(), @@ -830,10 +835,10 @@ def _update_attr_legacy( [ not col.startswith(mod + ":") or col[col.startswith(mod + ":") and len(mod + ":") :] - not in getattr(self.mod[mod], attr).columns + not in getattr(self._mod[mod], attr).columns for col in getattr(self, attr).columns ] - for mod in self.mod + for mod in self._mod ], strict=False, ), @@ -854,7 +859,7 @@ def _update_attr_legacy( if join_common: # If all modalities have a column with the same name, it is not global columns_common = reduce( - lambda a, b: a.intersection(b), [getattr(self.mod[mod], attr).columns for mod in self.mod] + lambda a, b: a.intersection(b), [getattr(self._mod[mod], attr).columns for mod in self._mod] ) data_global = data_global.loc[:, [c not in columns_common for c in data_global.columns]] @@ -877,7 +882,7 @@ def _update_attr_legacy( .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) .add_prefix(m + ":") ) - for m, a in self.mod.items() + for m, a in self._mod.items() ], join="outer", axis=1, @@ -895,14 +900,14 @@ def _update_attr_legacy( .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) .add_prefix(m + ":") ) - for m, a in self.mod.items() + for m, a in self._mod.items() ], join="outer", axis=0, sort=False, ) data_common = pd.concat( - [_maybe_coerce_to_boolean(getattr(a, attr)[columns_common]) for m, a in self.mod.items()], + [_maybe_coerce_to_boolean(getattr(a, attr)[columns_common]) for m, a in self._mod.items()], join="outer", axis=0, sort=False, @@ -923,7 +928,7 @@ def _update_attr_legacy( .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) .add_prefix(m + ":") ) - for m, a in self.mod.items() + for m, a in self._mod.items() ], join="outer", axis=0, @@ -931,7 +936,7 @@ def _update_attr_legacy( ) ) - for mod in self.mod.keys(): + for mod in self._mod.keys(): colname = mod + ":" + rowcol # use 0 as special value for missing # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write @@ -977,7 +982,7 @@ def _update_attr_legacy( force=True, ) ) - for m, a in self.mod.items() + for m, a in self._mod.items() ] # Here, attr_names are guaranteed to be unique and are safe to be used for joins @@ -986,7 +991,7 @@ def _update_attr_legacy( data_common = pd.concat( [ _maybe_coerce_to_boolean(_make_index_unique(getattr(a, attr)[columns_common], force=True)) - for m, a in self.mod.items() + for m, a in self._mod.items() ], join="outer", axis=0, @@ -1001,7 +1006,7 @@ def _update_attr_legacy( getattr(a, attr).assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}).add_prefix(m + ":"), force=True, ) - for m, a in self.mod.items() + for m, a in self._mod.items() ] data_mod = pd.concat(dfs, join="outer", axis=axis, sort=False) @@ -1029,7 +1034,7 @@ def _update_attr_legacy( data_mod = _restore_index(data_mod) data_mod.index.set_names(rowcol, inplace=True) data_global.index.set_names(rowcol, inplace=True) - for mod, amod in self.mod.items(): + for mod, amod in self._mod.items(): colname = mod + ":" + rowcol # use 0 as special value for missing # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write @@ -1069,7 +1074,7 @@ def _update_attr_legacy( # get adata positions and remove columns from the data frame mdict = {} - for m in self.mod.keys(): + for m in self._mod.keys(): colname = m + ":" + rowcol mdict[m] = data_mod[colname].to_numpy() data_mod.drop(colname, axis=1, inplace=True) @@ -1130,7 +1135,7 @@ def _update_attr_legacy( index_order = prev_index.get_indexer(now_index) for mx_key in attrm.keys(): - if mx_key not in self.mod.keys(): # not a modality name + if mx_key not in self._mod.keys(): # not a modality name attrm[mx_key] = attrm[mx_key][index_order] attrm[mx_key][index_order == -1] = np.nan @@ -1156,7 +1161,7 @@ def _update_attr_legacy( if attr_changed: if not hasattr(self, _attrhash): setattr(self, _attrhash, {}) - for m, mod in self.mod.items(): + for m, mod in self._mod.items(): getattr(self, _attrhash)[m] = ( sha1(np.ascontiguousarray(getattr(mod, attr).index.values)).hexdigest(), sha1(np.ascontiguousarray(getattr(mod, attr).columns.values)).hexdigest(), @@ -1169,18 +1174,18 @@ def _shrink_attr(self, attr: str, inplace=True) -> pd.DataFrame: map( all, zip( - *([not col.startswith(mod + ":") for col in getattr(self, attr).columns] for mod in self.mod), + *([not col.startswith(mod + ":") for col in getattr(self, attr).columns] for mod in self._mod), strict=False, ), ) ) # Make sure modname-prefix columns exist in modalities, # keep them in place if they don't - for mod in self.mod: + for mod in self._mod: for i, col in enumerate(getattr(self, attr).columns): if col.startswith(mod + ":"): mcol = col[len(mod) + 1 :] - if mcol not in getattr(self.mod[mod], attr).columns: + if mcol not in getattr(self._mod[mod], attr).columns: columns_global[i] = True # Only keep data from global .obs/.var columns newdf = getattr(self, attr).loc[:, columns_global] @@ -1197,7 +1202,7 @@ def n_mod(self) -> int: ------- int: The number of modalities. """ - return len(self.mod) + return len(self._mod) @property def isbacked(self) -> bool: @@ -1234,7 +1239,7 @@ def filename(self, filename: PathLike | None): elif filename is not None: self.write(filename) self.file.open(filename, "r+") - for ad in self.mod.values(): + for ad in self._mod.values(): ad._X = None @property @@ -1261,7 +1266,7 @@ def n_obs(self) -> int: def obs_vector(self, key: str, layer: str | None = None) -> np.ndarray: """Return an array of values for the requested key of length n_obs""" if key not in self.obs.columns: - for m, a in self.mod.items(): + for m, a in self._mod.items(): if key in a.obs.columns: raise KeyError( f"There is no {key} in MuData .obs but there is one in {m} .obs. Consider running `mu.update_obs()` to update global .obs." @@ -1281,35 +1286,35 @@ def obs_names_make_unique(self): If there are obs_names, which are the same for multiple modalities, append modality name to all obs_names. """ - mod_obs_sum = np.sum([a.n_obs for a in self.mod.values()]) + mod_obs_sum = np.sum([a.n_obs for a in self._mod.values()]) if mod_obs_sum != self.n_obs: self.update_obs() - for k in self.mod: - if isinstance(self.mod[k], AnnData): - self.mod[k].obs_names_make_unique() + for k in self._mod: + if isinstance(self._mod[k], AnnData): + self._mod[k].obs_names_make_unique() # Only propagate to individual modalities with shared vars - elif isinstance(self.mod[k], MuData) and getattr(self.mod[k], "axis", 1) == 1: - self.mod[k].obs_names_make_unique() + elif isinstance(self._mod[k], MuData) and getattr(self._mod[k], "axis", 1) == 1: + self._mod[k].obs_names_make_unique() # Check if there are observations with the same name in different modalities common_obs = [] - mods = list(self.mod.keys()) - for i in range(len(self.mod) - 1): + mods = list(self._mod.keys()) + for i in range(len(self._mod) - 1): ki = mods[i] - for j in range(i + 1, len(self.mod)): + for j in range(i + 1, len(self._mod)): kj = mods[j] - common_obs.append(self.mod[ki].obs_names.intersection(self.mod[kj].obs_names.values)) + common_obs.append(self._mod[ki].obs_names.intersection(self._mod[kj].obs_names.values)) if any(len(x) > 0 for x in common_obs): warnings.warn( "Modality names will be prepended to obs_names since there are identical obs_names in different modalities.", stacklevel=1, ) - for k in self.mod: - self.mod[k].obs_names = k + ":" + self.mod[k].obs_names.astype(str) + for k in self._mod: + self._mod[k].obs_names = k + ":" + self._mod[k].obs_names.astype(str) # Update .obs.index in the MuData - obs_names = [obs for a in self.mod.values() for obs in a.obs_names.values] + obs_names = [obs for a in self._mod.values() for obs in a.obs_names.values] self._obs.index = obs_names def _set_names(self, attr: str, axis: int, names: Sequence[str]): @@ -1324,7 +1329,7 @@ def _set_names(self, attr: str, axis: int, names: Sequence[str]): if not isinstance(names.name, str | type(None)): names.name = None - mod_shape_sum = np.sum([a.shape[axis] for a in self.mod.values()]) + mod_shape_sum = np.sum([a.shape[axis] for a in self._mod.values()]) if mod_shape_sum != self.shape[axis]: self._update_attr(attr, axis=1 - axis) @@ -1338,7 +1343,7 @@ def _set_names(self, attr: str, axis: int, names: Sequence[str]): getattr(self, attr).index = names map = getattr(self, f"{attr}map") - for modname, mod in self.mod.items(): + for modname, mod in self._mod.items(): newnames = np.empty(mod.shape[axis], dtype=object) modmap = map[modname].ravel() mask = modmap > 0 @@ -1389,7 +1394,7 @@ def n_var(self) -> int: def var_vector(self, key: str, layer: str | None = None) -> np.ndarray: """Return an array of values for the requested key of length n_var.""" if key not in self.var.columns: - for m, a in self.mod.items(): + for m, a in self._mod.items(): if key in a.var.columns: raise KeyError( f"There is no {key} in MuData .var but there is one in {m} .var. Consider running `mu.update_var()` to update global .var." @@ -1409,35 +1414,35 @@ def var_names_make_unique(self): If there are var_names, which are the same for multiple modalities, append modality name to all var_names. """ - mod_var_sum = np.sum([a.n_vars for a in self.mod.values()]) + mod_var_sum = np.sum([a.n_vars for a in self._mod.values()]) if mod_var_sum != self.n_vars: self.update_var() - for k in self.mod: - if isinstance(self.mod[k], AnnData): - self.mod[k].var_names_make_unique() + for k in self._mod: + if isinstance(self._mod[k], AnnData): + self._mod[k].var_names_make_unique() # Only propagate to individual modalities with shared obs - elif isinstance(self.mod[k], MuData) and getattr(self.mod[k], "axis", 0) == 0: - self.mod[k].var_names_make_unique() + elif isinstance(self._mod[k], MuData) and getattr(self._mod[k], "axis", 0) == 0: + self._mod[k].var_names_make_unique() # Check if there are variables with the same name in different modalities common_vars = [] - mods = list(self.mod.keys()) - for i in range(len(self.mod) - 1): + mods = list(self._mod.keys()) + for i in range(len(self._mod) - 1): ki = mods[i] - for j in range(i + 1, len(self.mod)): + for j in range(i + 1, len(self._mod)): kj = mods[j] - common_vars.append(np.intersect1d(self.mod[ki].var_names.values, self.mod[kj].var_names.values)) + common_vars.append(np.intersect1d(self._mod[ki].var_names.values, self._mod[kj].var_names.values)) if any(len(x) > 0 for x in common_vars): warnings.warn( "Modality names will be prepended to var_names since there are identical var_names in different modalities.", stacklevel=1, ) - for k in self.mod: - self.mod[k].var_names = k + ":" + self.mod[k].var_names.astype(str) + for k in self._mod: + self._mod[k].var_names = k + ":" + self._mod[k].var_names.astype(str) # Update .var.index in the MuData - var_names = [var for a in self.mod.values() for var in a.var_names.values] + var_names = [var for a in self._mod.values() for var in a.var_names.values] self._var.index = var_names @property @@ -1588,7 +1593,7 @@ def update(self): NOTE: From v0.4, it will not pull columns from modalities by default. """ - if len(self.mod) > 0: + if len(self._mod) > 0: self.update_var() self.update_obs() @@ -1604,7 +1609,7 @@ def mod_names(self) -> list[str]: This property is read-only. """ - return list(self.mod.keys()) + return list(self._mod.keys()) def _pull_attr( self, @@ -1675,7 +1680,7 @@ def _pull_attr( if mods is not None: if isinstance(mods, str): mods = [mods] - if not all(m in self.mod for m in mods): + if not all(m in self._mod for m in mods): raise ValueError("All mods should be present in mdata.mod") elif len(mods) == self.n_mod: mods = None @@ -1687,12 +1692,12 @@ def _pull_attr( # get all columns from all modalities and count how many times each column is present derived_name_counts = Counter() - for prefix, mod in self.mod.items(): + for prefix, mod in self._mod.items(): modcols = getattr(mod, attr).columns ccols = [] for name in modcols: ccols.append( - MetadataColumn(allowed_prefixes=self.mod.keys(), prefix=prefix, name=name, strip_prefix=False) + MetadataColumn(allowed_prefixes=self._mod.keys(), prefix=prefix, name=name, strip_prefix=False) ) derived_name_counts[name] += 1 cols[prefix] = ccols @@ -1770,7 +1775,7 @@ def _pull_attr( dfs: list[pd.DataFrame] = [] for m, modcols in cols.items(): - mod = self.mod[m] + mod = self._mod[m] mod_map = attrmap[m].ravel() mask = mod_map > 0 @@ -2003,7 +2008,7 @@ def _push_attr( if isinstance(mods, str): mods = [mods] mods = list(dict.fromkeys(mods)) - if not all(m in self.mod for m in mods): + if not all(m in self._mod for m in mods): raise ValueError("All mods should be present in mdata.mod") elif len(mods) == self.n_mod: mods = None @@ -2012,7 +2017,7 @@ def _push_attr( drop = True # get all global columns - cols = [MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) for name in getattr(self, attr).columns] + cols = [MetadataColumn(allowed_prefixes=self._mod.keys(), name=name) for name in getattr(self, attr).columns] if columns is not None: for k, v in {"common": common, "prefixed": prefixed}.items(): @@ -2055,7 +2060,7 @@ def _push_attr( ) attrmap = getattr(self, f"{attr}map") - for m, mod in self.mod.items(): + for m, mod in self._mod.items(): if mods is not None and m not in mods: continue @@ -2232,7 +2237,7 @@ def _gen_repr(self, n_obs, n_vars, extensive: bool = False, nest_level: int = 0) zip( *[ [not col.startswith(mod + mod_sep) for col in getattr(self, attr).keys()] - for mod in self.mod + for mod in self._mod ], strict=False, ), @@ -2242,8 +2247,8 @@ def _gen_repr(self, n_obs, n_vars, extensive: bool = False, nest_level: int = 0) descr += ( f"\n{indent} {attr}:\t{str([keys[i] for i in range(len(keys)) if global_keys[i]])[1:-1]}" ) - descr += f"\n{indent} {len(self.mod)} modalit{'y' if len(self.mod) == 1 else 'ies'}" - for k, v in self.mod.items(): + descr += f"\n{indent} {len(self._mod)} modalit{'y' if len(self._mod) == 1 else 'ies'}" + for k, v in self._mod.items(): mod_indent = " " * (nest_level + 1) if isinstance(v, MuData): descr += f"\n{mod_indent}{k}:\t" + v._gen_repr(v.n_obs, v.n_vars, extensive, nest_level + 1) @@ -2282,7 +2287,7 @@ def _repr_html_(self, expand=None): # General object properties header = "MuData object {} obs × {} var in {} modalit{}".format( - self.n_obs, self.n_vars, len(self.mod), "y" if len(self.mod) < 2 else "ies" + self.n_obs, self.n_vars, len(self._mod), "y" if len(self._mod) < 2 else "ies" ) if self.isbacked: header += f"
backed at {self.file.filename}" @@ -2299,7 +2304,7 @@ def _repr_html_(self, expand=None): if self.uns: mods += details_block_table(self, "uns", "Miscellaneous", expand >> 2) - for m, dat in self.mod.items(): + for m, dat in self._mod.items(): mods += "
" mods += "".format(" open" if (expand & 0b010) >> 1 else "") mods += "
{}
{} × {}
".format( @@ -2345,7 +2350,7 @@ def _find_unique_colnames(self, attr: str, ncols: int): for i in range(ncols): finished = False while not finished: - for ad in chain((self,), self.mod.values()): + for ad in chain((self,), self._mod.values()): if colnames[i] in getattr(ad, attr).columns: colnames[i] = "_" + colnames[i] break diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 0000000..b5520d3 --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import anndata as ad +import numpy as np +import pytest + +import mudata as md +from mudata._core import extensions + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture(autouse=True) +def _cleanup_dummy() -> Generator[None, None, None]: + """Automatically cleanup dummy namespace after each test.""" + original = getattr(md.MuData, "dummy", None) + yield + if original is not None: + md.MuData.dummy = original + elif hasattr(md.MuData, "dummy"): + delattr(md.MuData, "dummy") + + +@pytest.fixture +def dummy_namespace() -> type: + """Create a basic dummy namespace class.""" + + @md.register_mudata_namespace("dummy") + class DummyNamespace: + def __init__(self, mdata: md.MuData) -> None: + self._mdata = mdata + + def greet(self) -> str: + return "hello" + + return DummyNamespace + + +@pytest.fixture +def mdata(rng) -> md.MuData: + """Create a basic MuData object for testing.""" + return md.MuData({"test": ad.AnnData(X=rng.poisson(1, size=(10, 10)))}) + + +def test_accessor_namespace() -> None: + """Test the behavior of the AccessorNameSpace descriptor. + + This test verifies that: + - When accessed at the class level (i.e., without an instance), the descriptor + returns the namespace type. + - When accessed via an instance, the descriptor instantiates the namespace, + passing the instance to its constructor. + - The instantiated namespace is then cached on the instance such that subsequent + accesses of the same attribute return the cached namespace instance. + """ + + # Define a dummy namespace class to be used via the descriptor. + class DummyNamespace: + def __init__(self, mdata: md.MuData) -> None: + self._mdata = mdata + + def foo(self) -> str: + return "foo" + + class Dummy: + pass + + descriptor = extensions.AccessorNameSpace("dummy", DummyNamespace) + + # When accessed on the class, it should return the namespace type. + ns_class = descriptor.__get__(None, Dummy) + assert ns_class is DummyNamespace + + # When accessed via an instance, it should instantiate DummyNamespace. + dummy_obj = Dummy() + ns_instance = descriptor.__get__(dummy_obj, Dummy) + assert isinstance(ns_instance, DummyNamespace) + assert ns_instance._mdata is dummy_obj + + # __get__ should cache the namespace instance on the object. + # Subsequent access should return the same cached instance. + assert dummy_obj.dummy is ns_instance + + +def test_descriptor_instance_caching(dummy_namespace: type, mdata: md.MuData) -> None: + """Test that namespace instances are cached on individual MuData objects.""" + # First access creates the instance + ns_instance = mdata.dummy + # Subsequent accesses should return the same instance + assert mdata.dummy is ns_instance + + +def test_register_namespace_basic(dummy_namespace: type, mdata: md.MuData) -> None: + """Test basic namespace registration and access.""" + assert mdata.dummy.greet() == "hello" + + +def test_register_namespace_override(dummy_namespace: type) -> None: + """Test namespace registration and override behavior.""" + assert hasattr(md.MuData, "dummy") + + # Override should warn and update the namespace + with pytest.warns(UserWarning, match="Overriding existing custom namespace 'dummy'"): + + @md.register_mudata_namespace("dummy") + class DummyNamespaceOverride: + def __init__(self, mdata: md.MuData) -> None: + self._mdata = mdata + + def greet(self) -> str: + return "world" + + # Verify the override worked + mdata = md.MuData({"test": ad.AnnData(X=np.random.poisson(1, size=(10, 10)))}) + assert mdata.dummy.greet() == "world" + + +@pytest.mark.parametrize( + "attr", + [ + "mod", + "obs", + "var", + "uns", + "obsm", + "varm", + "copy", + "write", + "obsmap", + "varmap", + "obsp", + "varp", + "update", + "update_obs", + "update_var", + "push_obs", + "push_var", + "pull_obs", + "pull_var", + ], +) +def test_register_existing_attributes(attr: str) -> None: + """ + Test that registering an accessor with a name that is a reserved attribute of MuData raises an attribute error. + + We only test a representative sample of important attributes rather than all of them. + """ + # Test a representative sample of key AnnData attributes + with pytest.raises(AttributeError, match=f"cannot override reserved attribute {attr!r}"): + + @md.register_mudata_namespace(attr) + class DummyNamespace: + def __init__(self, mdata: md.MuData) -> None: + self._mdata = mdata + + +def test_valid_signature() -> None: + """Test that a namespace with valid signature is accepted.""" + + @md.register_mudata_namespace("valid") + class ValidNamespace: + def __init__(self, mdata: md.MuData) -> None: + self.mdata = mdata + + +def test_missing_param() -> None: + """Test that a namespace missing the second parameter is rejected.""" + with pytest.raises( + TypeError, match=r"Namespace initializer must accept an MuData instance as the second parameter\." + ): + + @md.register_mudata_namespace("missing_param") + class MissingParamNamespace: + def __init__(self) -> None: + pass + + +def test_wrong_name() -> None: + """Test that a namespace with wrong parameter name is rejected.""" + with pytest.raises( + TypeError, match=r"Namespace initializer's second parameter must be named 'mdata', got 'notmdata'\." + ): + + @md.register_mudata_namespace("wrong_name") + class WrongNameNamespace: + def __init__(self, notmdata: md.MuData) -> None: + self.notmdata = notmdata + + +def test_wrong_annotation() -> None: + """Test that a namespace with wrong parameter annotation is rejected.""" + with pytest.raises( + TypeError, + match=r"Namespace initializer's second parameter must be annotated as the 'MuData' class, got 'int'\.", + ): + + @md.register_mudata_namespace("wrong_annotation") + class WrongAnnotationNamespace: + def __init__(self, mdata: int) -> None: + self.mdata = mdata + + +def test_missing_annotation() -> None: + """Test that a namespace with missing parameter annotation is rejected.""" + with pytest.raises(AttributeError): + + @md.register_mudata_namespace("missing_annotation") + class MissingAnnotationNamespace: + def __init__(self, mdata) -> None: + self.mdata = mdata + + +def test_both_wrong() -> None: + """Test that a namespace with both wrong name and annotation is rejected.""" + with pytest.raises( + TypeError, + match=( + r"Namespace initializer's second parameter must be named 'mdata', got 'info'\. " + r"And must be annotated as 'MuData', got 'str'\." + ), + ): + + @md.register_mudata_namespace("both_wrong") + class BothWrongNamespace: + def __init__(self, info: str) -> None: + self.info = info