diff --git a/CHANGELOG.md b/CHANGELOG.md
index fd2100c..10e7a08 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.1.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html
+## [0.3.4] (unreleased)
+
+### Added
+
+- `mudata.register_mudata_namespace()` functionality for adding custom functionality to `MuData` objects.
+
## [0.3.3]
### Fixed
@@ -129,7 +135,8 @@ To copy the annotations explicitly, you will need to use `pull_obs()` and/or `pu
Initial `mudata` release with `MuData`, previously a part of the `muon` framework.
-[0.3.3]: https://github.com/scverse/mudata/compare/v0.3.1...v0.3.3
+[0.3.4]: https://github.com/scverse/mudata/compare/v0.3.3...v0.3.4
+[0.3.3]: https://github.com/scverse/mudata/compare/v0.3.2...v0.3.3
[0.3.2]: https://github.com/scverse/mudata/compare/v0.3.1...v0.3.2
[0.3.1]: https://github.com/scverse/mudata/compare/v0.3.0...v0.3.1
[0.3.0]: https://github.com/scverse/mudata/compare/v0.2.4...v0.3.0
diff --git a/docs/api.md b/docs/api.md
index 0f2084e..acbc025 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -30,3 +30,12 @@
:functions-only:
:toctree: generated
```
+
+## Extensions
+```{eval-rst}
+.. module::mudata
+.. autosummary::
+ :toctree: generated
+
+ register_mudata_namespace
+```
diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py
index fa9ca20..6b4772d 100644
--- a/src/mudata/__init__.py
+++ b/src/mudata/__init__.py
@@ -4,6 +4,7 @@
from ._core import utils
from ._core.config import set_options
+from ._core.extensions import ExtensionNamespace, register_mudata_namespace
from ._core.io import (
read,
read_anndata,
diff --git a/src/mudata/_core/extensions.py b/src/mudata/_core/extensions.py
new file mode 100644
index 0000000..e8c8d7e
--- /dev/null
+++ b/src/mudata/_core/extensions.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import inspect
+import warnings
+from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, get_type_hints, overload, runtime_checkable
+
+from .mudata import MuData
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+
+@runtime_checkable
+class ExtensionNamespace(Protocol):
+ """Protocol for extension namespaces.
+
+ Enforces that the namespace initializer accepts a class with the proper `__init__` method.
+ Protocol's can't enforce that the `__init__` accepts the correct types. See
+ `_check_namespace_signature` for that. This is mainly useful for static type
+ checking with mypy and IDEs.
+ """
+
+ def __init__(self, mdata: MuData) -> None:
+ """Used to enforce the correct signature for extension namespaces."""
+
+
+# Based off of the extension framework in Polars
+# https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py
+
+__all__ = ["register_mudata_namespace", "ExtensionNamespace"]
+
+
+# Reserved namespaces include accessors built into MuData (currently there are none)
+# and all current attributes of MuData
+_reserved_namespaces: set[str] = set(dir(MuData))
+
+NameSpT = TypeVar("NameSpT", bound=ExtensionNamespace)
+T = TypeVar("T")
+
+
+class AccessorNameSpace(ExtensionNamespace, Generic[NameSpT]):
+ """Establish property-like namespace object for user-defined functionality."""
+
+ def __init__(self, name: str, namespace: type[NameSpT]) -> None:
+ self._accessor = name
+ self._ns = namespace
+
+ @overload
+ def __get__(self, instance: None, cls: type[T]) -> type[NameSpT]: ...
+
+ @overload
+ def __get__(self, instance: T, cls: type[T]) -> NameSpT: ...
+
+ def __get__(self, instance: T | None, cls: type[T]) -> NameSpT | type[NameSpT]:
+ if instance is None:
+ return self._ns
+
+ ns_instance = self._ns(instance) # type: ignore[call-arg]
+ setattr(instance, self._accessor, ns_instance)
+ return ns_instance
+
+
+def _check_namespace_signature(ns_class: type) -> None:
+ """Validate the signature of a namespace class for MuData extensions.
+
+ This function ensures that any class intended to be used as an extension namespace
+ has a properly formatted `__init__` method such that:
+
+ 1. Accepts at least two parameters (self and mdata)
+ 2. Has 'mdata' as the name of the second parameter
+ 3. Has the second parameter properly type-annotated as 'MuData' or any equivalent import alias
+
+ The function performs runtime validation of these requirements before a namespace
+ can be registered through the `register_mudata_namespace` decorator.
+
+ Parameters
+ ----------
+ ns_class
+ The namespace class to validate.
+
+ Raises
+ ------
+ TypeError
+ If the `__init__` method has fewer than 2 parameters (missing the MuData parameter).
+ AttributeError
+ If the second parameter of `__init__` lacks a type annotation.
+ TypeError
+ If the second parameter of `__init__` is not named 'mdata'.
+ TypeError
+ If the second parameter of `__init__` is not annotated as the 'MuData' class.
+ TypeError
+ If both the name and type annotation of the second parameter are incorrect.
+
+ """
+ sig = inspect.signature(ns_class.__init__)
+ params = sig.parameters
+
+ # Ensure there are at least two parameters (self and mdata)
+ if len(params) < 2:
+ raise TypeError("Namespace initializer must accept an MuData instance as the second parameter.")
+
+ # Get the second parameter (expected to be 'mdata')
+ param = iter(params.values())
+ next(param)
+ param = next(param)
+ if param.annotation is inspect.Parameter.empty:
+ raise AttributeError(
+ "Namespace initializer's second parameter must be annotated as the 'MuData' class, got empty annotation."
+ )
+
+ name_ok = param.name == "mdata"
+
+ # Resolve the annotation using get_type_hints to handle forward references and aliases.
+ try:
+ type_hints = get_type_hints(ns_class.__init__)
+ resolved_type = type_hints.get(param.name, param.annotation)
+ except NameError as e:
+ raise NameError(f"Namespace initializer's second parameter must be named 'mdata', got '{param.name}'.") from e
+
+ type_ok = resolved_type is MuData
+
+ match (name_ok, type_ok):
+ case (True, True):
+ return # Signature is correct.
+ case (False, True):
+ raise TypeError(f"Namespace initializer's second parameter must be named 'mdata', got {param.name!r}.")
+ case (True, False):
+ type_repr = getattr(resolved_type, "__name__", str(resolved_type))
+ raise TypeError(
+ f"Namespace initializer's second parameter must be annotated as the 'MuData' class, got {type_repr!r}."
+ )
+ case _:
+ type_repr = getattr(resolved_type, "__name__", str(resolved_type))
+ raise TypeError(
+ f"Namespace initializer's second parameter must be named 'mdata', got {param.name!r}. "
+ f"And must be annotated as 'MuData', got {type_repr!r}."
+ )
+
+
+def _create_namespace(name: str, cls: type[MuData]) -> Callable[[type[NameSpT]], type[NameSpT]]:
+ """Register custom namespace against the underlying MuData class."""
+
+ def namespace(ns_class: type[NameSpT]) -> type[NameSpT]:
+ _check_namespace_signature(ns_class) # Perform the runtime signature check
+ if name in _reserved_namespaces:
+ raise AttributeError(f"cannot override reserved attribute {name!r}")
+ elif hasattr(cls, name):
+ warnings.warn(
+ f"Overriding existing custom namespace {name!r} (on {cls.__name__!r})", UserWarning, stacklevel=2
+ )
+ setattr(cls, name, AccessorNameSpace(name, ns_class))
+ return ns_class
+
+ return namespace
+
+
+def register_mudata_namespace(name: str) -> Callable[[type[NameSpT]], type[NameSpT]]:
+ """Decorator for registering custom functionality with an :class:`~mudata.MuData` object.
+
+ This decorator allows you to extend MuData objects with custom methods and properties
+ organized under a namespace. The namespace becomes accessible as an attribute on MuData
+ instances, providing a clean way to you to add domain-specific functionality without modifying
+ the MuData class itself, or extending the class with additional methods as you see fit in your workflow.
+
+ This is equivalent to :func:`anndata.register_anndata_namespace`.
+
+ Parameters
+ ----------
+ name
+ Name under which the accessor should be registered. This will be the attribute name
+ used to access your namespace's functionality on MuData objects (e.g., `mdata.{name}`).
+ Cannot conflict with existing MuData attributes like `obs`, `var`, `mod`, etc. The list of reserved
+ attributes includes everything outputted by `dir(MuData)`.
+
+ Returns
+ -------
+ A decorator that registers the decorated class as a custom namespace.
+
+ Notes
+ -----
+ Implementation requirements:
+
+ 1. The decorated class must have an `__init__` method that accepts exactly one parameter
+ (besides `self`) named `mdata` and annotated with type :class:`~mudata.MuData`.
+ 2. The namespace will be initialized with the MuData object on first access and then
+ cached on the instance.
+ 3. If the namespace name conflicts with an existing namespace, a warning is issued.
+ 4. If the namespace name conflicts with a built-in MuData attribute, an AttributeError is raised.
+ """
+ return _create_namespace(name, MuData)
diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py
index b5a0e11..01bb9f3 100644
--- a/src/mudata/_core/mudata.py
+++ b/src/mudata/_core/mudata.py
@@ -156,13 +156,13 @@ def __init__(
return
# Add all modalities to a MuData object
- self.mod = ModDict()
+ self._mod = ModDict()
if data is None:
# Initialize an empty MuData object
pass
elif isinstance(data, abc.Mapping):
for k, v in data.items():
- self.mod[k] = v
+ self._mod[k] = v
elif isinstance(data, AnnData):
# Get the list of modalities
if "feature_types" in data.var.columns:
@@ -176,9 +176,9 @@ def __init__(
if feature_types_names is not None:
if name in feature_types_names.keys():
alias = feature_types_names[name]
- self.mod[alias] = data[:, data.var.feature_types == name].copy()
+ self._mod[alias] = data[:, data.var.feature_types == name].copy()
else:
- self.mod["data"] = data
+ self._mod["data"] = data
else:
raise TypeError("Expected AnnData object or dictionary with AnnData objects as values")
@@ -272,7 +272,7 @@ def _init_as_view(self, mudata_ref: "MuData", index):
if isinstance(varidx, Integral):
varidx = slice(varidx, varidx + 1)
- self.mod = ModDict()
+ self._mod = ModDict()
for m, a in mudata_ref.mod.items():
cobsidx, cvaridx = mudata_ref.obsmap[m][obsidx], mudata_ref.varmap[m][varidx]
cobsidx, cvaridx = cobsidx[cobsidx > 0] - 1, cvaridx[cvaridx > 0] - 1
@@ -301,11 +301,11 @@ def _init_as_view(self, mudata_ref: "MuData", index):
cvaridx = slice(None)
if a.is_view:
if isinstance(a, MuData):
- self.mod[m] = a._mudata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._mudata_ref)]
+ self._mod[m] = a._mudata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._mudata_ref)]
else:
- self.mod[m] = a._adata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._adata_ref)]
+ self._mod[m] = a._adata_ref[_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._adata_ref)]
else:
- self.mod[m] = a[cobsidx, cvaridx]
+ self._mod[m] = a[cobsidx, cvaridx]
self._obs = DataFrameView(mudata_ref.obs.iloc[obsidx, :], view_args=(self, "obs"))
self._obsm = mudata_ref.obsm._view(self, (obsidx,))
@@ -338,7 +338,7 @@ def _init_as_view(self, mudata_ref: "MuData", index):
def _init_as_actual(self, data: "MuData"):
self._init_common()
- self.mod = data.mod
+ self._mod = data.mod
self._obs = data.obs
self._var = data.var
self._obsm = MuAxisArrays(self, axis=0, store=convert_to_dict(data.obsm))
@@ -389,21 +389,21 @@ def _init_from_dict_(
)
def _check_duplicated_attr_names(self, attr: str):
- if any(not getattr(self.mod[mod_i], attr + "_names").astype(str).is_unique for mod_i in self.mod):
+ if any(not getattr(self._mod[mod_i], attr + "_names").astype(str).is_unique for mod_i in self._mod):
# If there are non-unique attr_names, we can only handle outer joins
# under the condition the duplicated values are restricted to one modality
dups = [
np.unique(
- getattr(self.mod[mod_i], attr + "_names")[
- getattr(self.mod[mod_i], attr + "_names").astype(str).duplicated()
+ getattr(self._mod[mod_i], attr + "_names")[
+ getattr(self._mod[mod_i], attr + "_names").astype(str).duplicated()
]
)
- for mod_i in self.mod
+ for mod_i in self._mod
]
for i, mod_i_dup_attrs in enumerate(dups):
- for j, mod_j in enumerate(self.mod):
+ for j, mod_j in enumerate(self._mod):
if j != i:
- if any(np.isin(mod_i_dup_attrs, getattr(self.mod[mod_j], attr + "_names").values)):
+ if any(np.isin(mod_i_dup_attrs, getattr(self._mod[mod_j], attr + "_names").values)):
warnings.warn(
f"Duplicated {attr}_names should not be present in different modalities due to the ambiguity that leads to.",
stacklevel=3,
@@ -416,9 +416,9 @@ def _check_duplicated_names(self):
self._check_duplicated_attr_names("var")
def _check_intersecting_attr_names(self, attr: str):
- for mod_i, mod_j in combinations(self.mod, 2):
- mod_i_attr_index = getattr(self.mod[mod_i], attr + "_names")
- mod_j_attr_index = getattr(self.mod[mod_j], attr + "_names")
+ for mod_i, mod_j in combinations(self._mod, 2):
+ mod_i_attr_index = getattr(self._mod[mod_i], attr + "_names")
+ mod_j_attr_index = getattr(self._mod[mod_j], attr + "_names")
intersection = mod_i_attr_index.intersection(mod_j_attr_index, sort=False)
if intersection.shape[0] > 0:
# Some of the elements are also in another index
@@ -430,15 +430,15 @@ def _check_changed_attr_names(self, attr: str, columns: bool = False):
attr_names_changed, attr_columns_changed = False, False
if not hasattr(self, attrhash):
attr_names_changed, attr_columns_changed = True, True
- elif len(self.mod) < len(getattr(self, attrhash)):
+ elif len(self._mod) < len(getattr(self, attrhash)):
attr_names_changed, attr_columns_changed = True, None
else:
- for m in self.mod.keys():
+ for m in self._mod.keys():
if m in getattr(self, attrhash):
cached_hash = getattr(self, attrhash)[m]
new_hash = (
- sha1(np.ascontiguousarray(getattr(self.mod[m], attr).index.values)).hexdigest(),
- sha1(np.ascontiguousarray(getattr(self.mod[m], attr).columns.values)).hexdigest(),
+ sha1(np.ascontiguousarray(getattr(self._mod[m], attr).index.values)).hexdigest(),
+ sha1(np.ascontiguousarray(getattr(self._mod[m], attr).columns.values)).hexdigest(),
)
if cached_hash[0] != new_hash[0]:
attr_names_changed = True
@@ -464,7 +464,7 @@ def copy(self, filename: PathLike | None = None) -> "MuData":
"""
if not self.isbacked:
mod = {}
- for k, v in self.mod.items():
+ for k, v in self._mod.items():
mod[k] = v.copy()
return self._init_from_dict_(
mod,
@@ -498,8 +498,8 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None):
# Call the same method on each modality
if df is None:
- for k in self.mod:
- self.mod[k].strings_to_categoricals()
+ for k in self._mod:
+ self._mod[k].strings_to_categoricals()
else:
return df
@@ -508,10 +508,15 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None):
def __getitem__(self, index) -> Union["MuData", AnnData]:
if isinstance(index, str):
- return self.mod[index]
+ return self._mod[index]
else:
return MuData(self, as_view=True, index=index)
+ @property
+ def mod(self) -> Mapping[str, "AnnData | MuData"]:
+ """Dictionary of modalities."""
+ return self._mod
+
@property
def shape(self) -> tuple[int, int]:
"""Shape of data, all variables and observations combined (:attr:`n_obs`, :attr:`n_var`)."""
@@ -533,13 +538,13 @@ def __len__(self) -> int:
def _create_global_attr_index(self, attr: str, axis: int):
if axis == (1 - self._axis):
# Shared indices
- modindices = [getattr(self.mod[m], attr).index for m in self.mod]
+ modindices = [getattr(self._mod[m], attr).index for m in self._mod]
if all(modindices[i].equals(modindices[i + 1]) for i in range(len(modindices) - 1)):
attrindex = modindices[0].copy()
- attrindex = reduce(pd.Index.union, [getattr(self.mod[m], attr).index for m in self.mod]).values
+ attrindex = reduce(pd.Index.union, [getattr(self._mod[m], attr).index for m in self._mod]).values
else:
# Modality-specific indices
- attrindex = np.concatenate([getattr(self.mod[m], attr).index.values for m in self.mod], axis=0)
+ attrindex = np.concatenate([getattr(self._mod[m], attr).index.values for m in self._mod], axis=0)
return attrindex
def _update_attr(
@@ -595,7 +600,7 @@ def _update_attr(
dfs = [
getattr(a, attr).loc[:, []].assign(**{f"{m}:{rowcol}": np.arange(getattr(a, attr).shape[0])})
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
]
index_order = None
@@ -645,7 +650,7 @@ def calc_attrm_update():
data_mod = pd.concat(
dfs, join="outer", axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, sort=False
)
- for mod in self.mod.keys():
+ for mod in self._mod.keys():
fix_attrmap_col(data_mod, mod, rowcol)
data_mod = _make_index_unique(data_mod, force=attr_intersecting)
@@ -671,7 +676,7 @@ def calc_attrm_update():
data_mod = _restore_index(data_mod)
data_mod.index.set_names(rowcol, inplace=True)
data_global.index.set_names(rowcol, inplace=True)
- for mod, amod in self.mod.items():
+ for mod, amod in self._mod.items():
colname = fix_attrmap_col(data_mod, mod, rowcol)
if mod in attrmap:
modmap = attrmap[mod].ravel()
@@ -715,7 +720,7 @@ def calc_attrm_update():
# get adata positions and remove columns from the data frame
mdict = {}
- for m in self.mod.keys():
+ for m in self._mod.keys():
colname = m + ":" + rowcol
mdict[m] = data_mod[colname].to_numpy()
data_mod.drop(colname, axis=1, inplace=True)
@@ -741,7 +746,7 @@ def calc_attrm_update():
if index_order is not None:
if can_update:
for mx_key, mx in attrm.items():
- if mx_key not in self.mod.keys(): # not a modality name
+ if mx_key not in self._mod.keys(): # not a modality name
if isinstance(mx, pd.DataFrame):
mx = mx.iloc[index_order, :]
mx.iloc[index_order == -1, :] = pd.NA
@@ -767,7 +772,7 @@ def calc_attrm_update():
if attr_changed:
if not hasattr(self, _attrhash):
setattr(self, _attrhash, {})
- for m, mod in self.mod.items():
+ for m, mod in self._mod.items():
getattr(self, _attrhash)[m] = (
sha1(np.ascontiguousarray(getattr(mod, attr).index.values)).hexdigest(),
sha1(np.ascontiguousarray(getattr(mod, attr).columns.values)).hexdigest(),
@@ -830,10 +835,10 @@ def _update_attr_legacy(
[
not col.startswith(mod + ":")
or col[col.startswith(mod + ":") and len(mod + ":") :]
- not in getattr(self.mod[mod], attr).columns
+ not in getattr(self._mod[mod], attr).columns
for col in getattr(self, attr).columns
]
- for mod in self.mod
+ for mod in self._mod
],
strict=False,
),
@@ -854,7 +859,7 @@ def _update_attr_legacy(
if join_common:
# If all modalities have a column with the same name, it is not global
columns_common = reduce(
- lambda a, b: a.intersection(b), [getattr(self.mod[mod], attr).columns for mod in self.mod]
+ lambda a, b: a.intersection(b), [getattr(self._mod[mod], attr).columns for mod in self._mod]
)
data_global = data_global.loc[:, [c not in columns_common for c in data_global.columns]]
@@ -877,7 +882,7 @@ def _update_attr_legacy(
.assign(**{rowcol: np.arange(getattr(a, attr).shape[0])})
.add_prefix(m + ":")
)
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
],
join="outer",
axis=1,
@@ -895,14 +900,14 @@ def _update_attr_legacy(
.assign(**{rowcol: np.arange(getattr(a, attr).shape[0])})
.add_prefix(m + ":")
)
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
],
join="outer",
axis=0,
sort=False,
)
data_common = pd.concat(
- [_maybe_coerce_to_boolean(getattr(a, attr)[columns_common]) for m, a in self.mod.items()],
+ [_maybe_coerce_to_boolean(getattr(a, attr)[columns_common]) for m, a in self._mod.items()],
join="outer",
axis=0,
sort=False,
@@ -923,7 +928,7 @@ def _update_attr_legacy(
.assign(**{rowcol: np.arange(getattr(a, attr).shape[0])})
.add_prefix(m + ":")
)
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
],
join="outer",
axis=0,
@@ -931,7 +936,7 @@ def _update_attr_legacy(
)
)
- for mod in self.mod.keys():
+ for mod in self._mod.keys():
colname = mod + ":" + rowcol
# use 0 as special value for missing
# we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write
@@ -977,7 +982,7 @@ def _update_attr_legacy(
force=True,
)
)
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
]
# Here, attr_names are guaranteed to be unique and are safe to be used for joins
@@ -986,7 +991,7 @@ def _update_attr_legacy(
data_common = pd.concat(
[
_maybe_coerce_to_boolean(_make_index_unique(getattr(a, attr)[columns_common], force=True))
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
],
join="outer",
axis=0,
@@ -1001,7 +1006,7 @@ def _update_attr_legacy(
getattr(a, attr).assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}).add_prefix(m + ":"),
force=True,
)
- for m, a in self.mod.items()
+ for m, a in self._mod.items()
]
data_mod = pd.concat(dfs, join="outer", axis=axis, sort=False)
@@ -1029,7 +1034,7 @@ def _update_attr_legacy(
data_mod = _restore_index(data_mod)
data_mod.index.set_names(rowcol, inplace=True)
data_global.index.set_names(rowcol, inplace=True)
- for mod, amod in self.mod.items():
+ for mod, amod in self._mod.items():
colname = mod + ":" + rowcol
# use 0 as special value for missing
# we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write
@@ -1069,7 +1074,7 @@ def _update_attr_legacy(
# get adata positions and remove columns from the data frame
mdict = {}
- for m in self.mod.keys():
+ for m in self._mod.keys():
colname = m + ":" + rowcol
mdict[m] = data_mod[colname].to_numpy()
data_mod.drop(colname, axis=1, inplace=True)
@@ -1130,7 +1135,7 @@ def _update_attr_legacy(
index_order = prev_index.get_indexer(now_index)
for mx_key in attrm.keys():
- if mx_key not in self.mod.keys(): # not a modality name
+ if mx_key not in self._mod.keys(): # not a modality name
attrm[mx_key] = attrm[mx_key][index_order]
attrm[mx_key][index_order == -1] = np.nan
@@ -1156,7 +1161,7 @@ def _update_attr_legacy(
if attr_changed:
if not hasattr(self, _attrhash):
setattr(self, _attrhash, {})
- for m, mod in self.mod.items():
+ for m, mod in self._mod.items():
getattr(self, _attrhash)[m] = (
sha1(np.ascontiguousarray(getattr(mod, attr).index.values)).hexdigest(),
sha1(np.ascontiguousarray(getattr(mod, attr).columns.values)).hexdigest(),
@@ -1169,18 +1174,18 @@ def _shrink_attr(self, attr: str, inplace=True) -> pd.DataFrame:
map(
all,
zip(
- *([not col.startswith(mod + ":") for col in getattr(self, attr).columns] for mod in self.mod),
+ *([not col.startswith(mod + ":") for col in getattr(self, attr).columns] for mod in self._mod),
strict=False,
),
)
)
# Make sure modname-prefix columns exist in modalities,
# keep them in place if they don't
- for mod in self.mod:
+ for mod in self._mod:
for i, col in enumerate(getattr(self, attr).columns):
if col.startswith(mod + ":"):
mcol = col[len(mod) + 1 :]
- if mcol not in getattr(self.mod[mod], attr).columns:
+ if mcol not in getattr(self._mod[mod], attr).columns:
columns_global[i] = True
# Only keep data from global .obs/.var columns
newdf = getattr(self, attr).loc[:, columns_global]
@@ -1197,7 +1202,7 @@ def n_mod(self) -> int:
-------
int: The number of modalities.
"""
- return len(self.mod)
+ return len(self._mod)
@property
def isbacked(self) -> bool:
@@ -1234,7 +1239,7 @@ def filename(self, filename: PathLike | None):
elif filename is not None:
self.write(filename)
self.file.open(filename, "r+")
- for ad in self.mod.values():
+ for ad in self._mod.values():
ad._X = None
@property
@@ -1261,7 +1266,7 @@ def n_obs(self) -> int:
def obs_vector(self, key: str, layer: str | None = None) -> np.ndarray:
"""Return an array of values for the requested key of length n_obs"""
if key not in self.obs.columns:
- for m, a in self.mod.items():
+ for m, a in self._mod.items():
if key in a.obs.columns:
raise KeyError(
f"There is no {key} in MuData .obs but there is one in {m} .obs. Consider running `mu.update_obs()` to update global .obs."
@@ -1281,35 +1286,35 @@ def obs_names_make_unique(self):
If there are obs_names, which are the same for multiple modalities,
append modality name to all obs_names.
"""
- mod_obs_sum = np.sum([a.n_obs for a in self.mod.values()])
+ mod_obs_sum = np.sum([a.n_obs for a in self._mod.values()])
if mod_obs_sum != self.n_obs:
self.update_obs()
- for k in self.mod:
- if isinstance(self.mod[k], AnnData):
- self.mod[k].obs_names_make_unique()
+ for k in self._mod:
+ if isinstance(self._mod[k], AnnData):
+ self._mod[k].obs_names_make_unique()
# Only propagate to individual modalities with shared vars
- elif isinstance(self.mod[k], MuData) and getattr(self.mod[k], "axis", 1) == 1:
- self.mod[k].obs_names_make_unique()
+ elif isinstance(self._mod[k], MuData) and getattr(self._mod[k], "axis", 1) == 1:
+ self._mod[k].obs_names_make_unique()
# Check if there are observations with the same name in different modalities
common_obs = []
- mods = list(self.mod.keys())
- for i in range(len(self.mod) - 1):
+ mods = list(self._mod.keys())
+ for i in range(len(self._mod) - 1):
ki = mods[i]
- for j in range(i + 1, len(self.mod)):
+ for j in range(i + 1, len(self._mod)):
kj = mods[j]
- common_obs.append(self.mod[ki].obs_names.intersection(self.mod[kj].obs_names.values))
+ common_obs.append(self._mod[ki].obs_names.intersection(self._mod[kj].obs_names.values))
if any(len(x) > 0 for x in common_obs):
warnings.warn(
"Modality names will be prepended to obs_names since there are identical obs_names in different modalities.",
stacklevel=1,
)
- for k in self.mod:
- self.mod[k].obs_names = k + ":" + self.mod[k].obs_names.astype(str)
+ for k in self._mod:
+ self._mod[k].obs_names = k + ":" + self._mod[k].obs_names.astype(str)
# Update .obs.index in the MuData
- obs_names = [obs for a in self.mod.values() for obs in a.obs_names.values]
+ obs_names = [obs for a in self._mod.values() for obs in a.obs_names.values]
self._obs.index = obs_names
def _set_names(self, attr: str, axis: int, names: Sequence[str]):
@@ -1324,7 +1329,7 @@ def _set_names(self, attr: str, axis: int, names: Sequence[str]):
if not isinstance(names.name, str | type(None)):
names.name = None
- mod_shape_sum = np.sum([a.shape[axis] for a in self.mod.values()])
+ mod_shape_sum = np.sum([a.shape[axis] for a in self._mod.values()])
if mod_shape_sum != self.shape[axis]:
self._update_attr(attr, axis=1 - axis)
@@ -1338,7 +1343,7 @@ def _set_names(self, attr: str, axis: int, names: Sequence[str]):
getattr(self, attr).index = names
map = getattr(self, f"{attr}map")
- for modname, mod in self.mod.items():
+ for modname, mod in self._mod.items():
newnames = np.empty(mod.shape[axis], dtype=object)
modmap = map[modname].ravel()
mask = modmap > 0
@@ -1389,7 +1394,7 @@ def n_var(self) -> int:
def var_vector(self, key: str, layer: str | None = None) -> np.ndarray:
"""Return an array of values for the requested key of length n_var."""
if key not in self.var.columns:
- for m, a in self.mod.items():
+ for m, a in self._mod.items():
if key in a.var.columns:
raise KeyError(
f"There is no {key} in MuData .var but there is one in {m} .var. Consider running `mu.update_var()` to update global .var."
@@ -1409,35 +1414,35 @@ def var_names_make_unique(self):
If there are var_names, which are the same for multiple modalities,
append modality name to all var_names.
"""
- mod_var_sum = np.sum([a.n_vars for a in self.mod.values()])
+ mod_var_sum = np.sum([a.n_vars for a in self._mod.values()])
if mod_var_sum != self.n_vars:
self.update_var()
- for k in self.mod:
- if isinstance(self.mod[k], AnnData):
- self.mod[k].var_names_make_unique()
+ for k in self._mod:
+ if isinstance(self._mod[k], AnnData):
+ self._mod[k].var_names_make_unique()
# Only propagate to individual modalities with shared obs
- elif isinstance(self.mod[k], MuData) and getattr(self.mod[k], "axis", 0) == 0:
- self.mod[k].var_names_make_unique()
+ elif isinstance(self._mod[k], MuData) and getattr(self._mod[k], "axis", 0) == 0:
+ self._mod[k].var_names_make_unique()
# Check if there are variables with the same name in different modalities
common_vars = []
- mods = list(self.mod.keys())
- for i in range(len(self.mod) - 1):
+ mods = list(self._mod.keys())
+ for i in range(len(self._mod) - 1):
ki = mods[i]
- for j in range(i + 1, len(self.mod)):
+ for j in range(i + 1, len(self._mod)):
kj = mods[j]
- common_vars.append(np.intersect1d(self.mod[ki].var_names.values, self.mod[kj].var_names.values))
+ common_vars.append(np.intersect1d(self._mod[ki].var_names.values, self._mod[kj].var_names.values))
if any(len(x) > 0 for x in common_vars):
warnings.warn(
"Modality names will be prepended to var_names since there are identical var_names in different modalities.",
stacklevel=1,
)
- for k in self.mod:
- self.mod[k].var_names = k + ":" + self.mod[k].var_names.astype(str)
+ for k in self._mod:
+ self._mod[k].var_names = k + ":" + self._mod[k].var_names.astype(str)
# Update .var.index in the MuData
- var_names = [var for a in self.mod.values() for var in a.var_names.values]
+ var_names = [var for a in self._mod.values() for var in a.var_names.values]
self._var.index = var_names
@property
@@ -1588,7 +1593,7 @@ def update(self):
NOTE: From v0.4, it will not pull columns from modalities by default.
"""
- if len(self.mod) > 0:
+ if len(self._mod) > 0:
self.update_var()
self.update_obs()
@@ -1604,7 +1609,7 @@ def mod_names(self) -> list[str]:
This property is read-only.
"""
- return list(self.mod.keys())
+ return list(self._mod.keys())
def _pull_attr(
self,
@@ -1675,7 +1680,7 @@ def _pull_attr(
if mods is not None:
if isinstance(mods, str):
mods = [mods]
- if not all(m in self.mod for m in mods):
+ if not all(m in self._mod for m in mods):
raise ValueError("All mods should be present in mdata.mod")
elif len(mods) == self.n_mod:
mods = None
@@ -1687,12 +1692,12 @@ def _pull_attr(
# get all columns from all modalities and count how many times each column is present
derived_name_counts = Counter()
- for prefix, mod in self.mod.items():
+ for prefix, mod in self._mod.items():
modcols = getattr(mod, attr).columns
ccols = []
for name in modcols:
ccols.append(
- MetadataColumn(allowed_prefixes=self.mod.keys(), prefix=prefix, name=name, strip_prefix=False)
+ MetadataColumn(allowed_prefixes=self._mod.keys(), prefix=prefix, name=name, strip_prefix=False)
)
derived_name_counts[name] += 1
cols[prefix] = ccols
@@ -1770,7 +1775,7 @@ def _pull_attr(
dfs: list[pd.DataFrame] = []
for m, modcols in cols.items():
- mod = self.mod[m]
+ mod = self._mod[m]
mod_map = attrmap[m].ravel()
mask = mod_map > 0
@@ -2003,7 +2008,7 @@ def _push_attr(
if isinstance(mods, str):
mods = [mods]
mods = list(dict.fromkeys(mods))
- if not all(m in self.mod for m in mods):
+ if not all(m in self._mod for m in mods):
raise ValueError("All mods should be present in mdata.mod")
elif len(mods) == self.n_mod:
mods = None
@@ -2012,7 +2017,7 @@ def _push_attr(
drop = True
# get all global columns
- cols = [MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) for name in getattr(self, attr).columns]
+ cols = [MetadataColumn(allowed_prefixes=self._mod.keys(), name=name) for name in getattr(self, attr).columns]
if columns is not None:
for k, v in {"common": common, "prefixed": prefixed}.items():
@@ -2055,7 +2060,7 @@ def _push_attr(
)
attrmap = getattr(self, f"{attr}map")
- for m, mod in self.mod.items():
+ for m, mod in self._mod.items():
if mods is not None and m not in mods:
continue
@@ -2232,7 +2237,7 @@ def _gen_repr(self, n_obs, n_vars, extensive: bool = False, nest_level: int = 0)
zip(
*[
[not col.startswith(mod + mod_sep) for col in getattr(self, attr).keys()]
- for mod in self.mod
+ for mod in self._mod
],
strict=False,
),
@@ -2242,8 +2247,8 @@ def _gen_repr(self, n_obs, n_vars, extensive: bool = False, nest_level: int = 0)
descr += (
f"\n{indent} {attr}:\t{str([keys[i] for i in range(len(keys)) if global_keys[i]])[1:-1]}"
)
- descr += f"\n{indent} {len(self.mod)} modalit{'y' if len(self.mod) == 1 else 'ies'}"
- for k, v in self.mod.items():
+ descr += f"\n{indent} {len(self._mod)} modalit{'y' if len(self._mod) == 1 else 'ies'}"
+ for k, v in self._mod.items():
mod_indent = " " * (nest_level + 1)
if isinstance(v, MuData):
descr += f"\n{mod_indent}{k}:\t" + v._gen_repr(v.n_obs, v.n_vars, extensive, nest_level + 1)
@@ -2282,7 +2287,7 @@ def _repr_html_(self, expand=None):
# General object properties
header = "MuData object {} obs × {} var in {} modalit{}".format(
- self.n_obs, self.n_vars, len(self.mod), "y" if len(self.mod) < 2 else "ies"
+ self.n_obs, self.n_vars, len(self._mod), "y" if len(self._mod) < 2 else "ies"
)
if self.isbacked:
header += f"
↳ backed at {self.file.filename}"
@@ -2299,7 +2304,7 @@ def _repr_html_(self, expand=None):
if self.uns:
mods += details_block_table(self, "uns", "Miscellaneous", expand >> 2)
- for m, dat in self.mod.items():
+ for m, dat in self._mod.items():
mods += "