From b907554d48326d0705289d326fa53a24487c2851 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 14:16:06 +0100 Subject: [PATCH 01/32] feat: `AnnData.can_write` based on `AnnData.fold` --- docs/api.md | 7 +++++++ src/anndata/_core/anndata.py | 39 ++++++++++++++++++++++++++++++++++++ src/anndata/types.py | 6 ++++++ tests/test_readwrite.py | 23 ++++++++++++++++++++- 4 files changed, 74 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 279070e50..0f42a381d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -264,6 +264,13 @@ Types used by the former: abc.CSCDataset ``` +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + types.FoldFunc +``` + ```{eval-rst} diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 03ac68dad..699c9832a 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -63,6 +63,9 @@ from zarr.storage import StoreLike + from anndata.types import FoldFunc + from anndata.typing import RWAble + from ..acc import AdRef, Array, MapAcc, RefAcc from ..compat import XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType @@ -1446,6 +1449,42 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) + def fold[T](self, func: FoldFunc[T], *, init: T) -> T: + acc = init + for attr_name in [ + "X", + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + "uns", + ]: + attr = getattr(self, attr_name) + if attr_name != "X": + for elem_name in attr: + acc = func(attr[elem_name], acc=acc) + return acc + + def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + from anndata._io.specs.registry import _REGISTRY + + writeable_elems = { + src_type + for (dest_type, src_type, __) in _REGISTRY.write + if store_type is None or store_type in dest_type.__module__ + } + + def predicate(x: RWAble, *, acc: bool): + if isinstance(x, pd.Series): + # matches behavior in methods.py + x = x._values + return acc and type(x) in writeable_elems + + return self.fold(predicate, init=True) + @deprecated( deprecation_msg( *("AnnData.concatenate", "anndata.concat"), diff --git a/src/anndata/types.py b/src/anndata/types.py index aa23d10f2..add66c1c6 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,6 +8,8 @@ from array_api.latest import ArrayNamespace + from anndata.typing import RWAble + from ._core.anndata import AnnData @@ -48,3 +50,7 @@ def __dlpack__( copy: bool | None = None, ) -> Any: ... def __dlpack_device__(self) -> tuple[int, int]: ... + + +class FoldFunc[T](Protocol): + def __call__(self, elem: RWAble, *, acc: T | None) -> T | None: ... diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 3359b2ff8..c6468e86b 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -99,7 +99,7 @@ def dataset_kwargs(request): @pytest.fixture -def rw(backing_h5ad): +def rw(backing_h5ad) -> tuple[ad.AnnData, ad.AnnData]: M, N = 100, 101 orig = gen_adata((M, N), **GEN_ADATA_NO_XARRAY_ARGS) orig.write(backing_h5ad) @@ -126,6 +126,27 @@ def dtype(request): # ------------------------------------------------------------------------------ +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_write( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + adata, _ = rw + assert adata.can_write(store_type=store_type) + + +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_not_write_with_custom_array( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + import pyarrow as pa + + adata, _ = rw + adata.obs["arrow_array"] = pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}] * adata.shape[0]) + ) + assert not adata.can_write(store_type=store_type) + + @pytest.mark.parametrize("typ", ARRAY_TYPES) def test_readwrite_roundtrip(typ, tmp_path, diskfmt, diskfmt2): pth1 = tmp_path / f"first.{diskfmt}" From 19daed554b9d6f4b22801f023dbee553caad8857 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 14:42:22 +0100 Subject: [PATCH 02/32] chore: docs --- docs/release-notes/2327.feat.md | 1 + src/anndata/_core/anndata.py | 35 ++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 5 deletions(-) create mode 100644 docs/release-notes/2327.feat.md diff --git a/docs/release-notes/2327.feat.md b/docs/release-notes/2327.feat.md new file mode 100644 index 000000000..a66fef550 --- /dev/null +++ b/docs/release-notes/2327.feat.md @@ -0,0 +1 @@ +New {meth}`AnnData.fold` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 699c9832a..6f4765c16 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1450,7 +1450,21 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: return read_h5ad(filename, backed=mode) def fold[T](self, func: FoldFunc[T], *, init: T) -> T: - acc = init + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + + Parameters + ---------- + func + The function that performs the accumulation + init + The starting value + + + Returns + ------- + An accumulated value + """ + accumulate = init for attr_name in [ "X", "obs", @@ -1465,10 +1479,21 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: attr = getattr(self, attr_name) if attr_name != "X": for elem_name in attr: - acc = func(attr[elem_name], acc=acc) - return acc + accumulate = func(attr[elem_name], accumulate=accumulate) + return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + """Whether or not an `AnnData` object can be written to disk for a given store type. + + Parameters + ---------- + store_type + Which backing store - `None` indicates that it can be writeable to either. + + Returns + ------- + Whether or not this object is writable. + """ from anndata._io.specs.registry import _REGISTRY writeable_elems = { @@ -1477,11 +1502,11 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate(x: RWAble, *, acc: bool): + def predicate(x: RWAble, *, accumulate: bool): if isinstance(x, pd.Series): # matches behavior in methods.py x = x._values - return acc and type(x) in writeable_elems + return accumulate and type(x) in writeable_elems return self.fold(predicate, init=True) From 4125375612ec6b702a0032526e8b5b217b1921da Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 15:38:40 +0100 Subject: [PATCH 03/32] refactor: use accessors --- src/anndata/_core/anndata.py | 84 ++++++++++++++++++++---------------- src/anndata/types.py | 5 ++- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 6f4765c16..6c46d5d1e 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Mapping, MutableMapping, Sequence from copy import copy, deepcopy from functools import partial, singledispatchmethod @@ -22,6 +22,7 @@ from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning +from anndata.acc import A, AdRef, GraphAcc, LayerAcc, MultiAcc from .. import utils from .._settings import settings @@ -66,7 +67,7 @@ from anndata.types import FoldFunc from anndata.typing import RWAble - from ..acc import AdRef, Array, MapAcc, RefAcc + from ..acc import Array, MapAcc, RefAcc from ..compat import XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -516,36 +517,39 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size(X) -> int: + def get_size[R: dict[RefAcc | None, int]]( + X: RWAble, + *, + accumulate: R, + ref_acc: RefAcc | AdRef | None, + ) -> R: def cs_to_bytes(X) -> int: return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) - if isinstance(X, h5py.Dataset) and with_disk: - return int(np.array(X.shape).prod() * X.dtype.itemsize) - elif isinstance(X, BaseCompressedSparseDataset) and with_disk: - return cs_to_bytes(X._to_backed()) - elif issparse(X): - return cs_to_bytes(X) - else: - return X.__sizeof__() - - sizes = {} - attrs = ["X", "_obs", "_var"] - attrs_multi = ["_uns", "_obsm", "_varm", "varp", "_obsp", "_layers"] - for attr in attrs + attrs_multi: - if attr in attrs_multi: - keys = getattr(self, attr).keys() - s = sum(get_size(getattr(self, attr)[k]) for k in keys) - else: - s = get_size(getattr(self, attr)) - if s > 0 and show_stratified: - from tqdm import tqdm + if is_elem := ( + (is_ad_ref := isinstance(ref_acc, AdRef)) + or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) + ) or (ref_acc is None and X is not self.uns): + key = ref_acc.acc if is_ad_ref else ref_acc + if isinstance(X, h5py.Dataset) and with_disk: + accumulate[key] += int(np.array(X.shape).prod() * X.dtype.itemsize) + elif isinstance(X, BaseCompressedSparseDataset) and with_disk: + accumulate[key] += cs_to_bytes(X._to_backed()) + elif issparse(X): + accumulate[key] += cs_to_bytes(X) + else: + accumulate[key] += X.__sizeof__() + if not is_elem or ref_acc is A.X: + s = accumulate[ref_acc] + if s > 0 and show_stratified: + from tqdm import tqdm + + print( + f"Size of {repr(ref_acc).replace('A.', '') if ref_acc is not None else 'uns'}: {tqdm.format_sizeof(s, 'B')}" + ) + return accumulate - print( - f"Size of {attr.replace('_', '.'):<7}: {tqdm.format_sizeof(s, 'B')}" - ) - sizes[attr] = s - return sum(sizes.values()) + return sum(self.fold(get_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1450,7 +1454,7 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: return read_h5ad(filename, backed=mode) def fold[T](self, func: FoldFunc[T], *, init: T) -> T: - """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object in DFS order. Parameters ---------- @@ -1474,12 +1478,18 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: "obsp", "varp", "layers", - "uns", ]: attr = getattr(self, attr_name) + acc = getattr(A, attr_name) if attr_name != "X": for elem_name in attr: - accumulate = func(attr[elem_name], accumulate=accumulate) + ref = acc[elem_name] + accumulate = func( + attr[elem_name], accumulate=accumulate, ref_acc=ref + ) + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) + for elem in self.uns: + accumulate = func(elem, accumulate=accumulate, ref_acc=None) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: @@ -1502,11 +1512,13 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate(x: RWAble, *, accumulate: bool): - if isinstance(x, pd.Series): - # matches behavior in methods.py - x = x._values - return accumulate and type(x) in writeable_elems + def predicate(x: RWAble, *, accumulate: bool, ref_acc: AdRef | RefAcc | None): + if isinstance(ref_acc, AdRef) or ref_acc is None: + if isinstance(x, pd.Series): + # matches behavior in methods.py + x = x._values + return accumulate and type(x) in writeable_elems + return accumulate return self.fold(predicate, init=True) diff --git a/src/anndata/types.py b/src/anndata/types.py index add66c1c6..06a80a784 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,6 +8,7 @@ from array_api.latest import ArrayNamespace + from anndata.acc import AdRef, RefAcc from anndata.typing import RWAble from ._core.anndata import AnnData @@ -53,4 +54,6 @@ def __dlpack_device__(self) -> tuple[int, int]: ... class FoldFunc[T](Protocol): - def __call__(self, elem: RWAble, *, acc: T | None) -> T | None: ... + def __call__( + self, elem: RWAble, *, accumulate: T, ref_acc: RefAcc | AdRef | None + ) -> T | None: ... From 8be5ba2849f09d62818fa6fcd795227de6dcce09 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 16:27:45 +0100 Subject: [PATCH 04/32] fix: DFS order + fixes --- src/anndata/_core/anndata.py | 64 +++++++++++++++++++++++++++--------- src/anndata/acc/__init__.py | 21 ++++++++++++ src/anndata/types.py | 10 ++++-- 3 files changed, 76 insertions(+), 19 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 6c46d5d1e..a6ac363ff 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -22,7 +22,7 @@ from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning -from anndata.acc import A, AdRef, GraphAcc, LayerAcc, MultiAcc +from anndata.acc import A, AdAcc, AdRef, GraphAcc, LayerAcc, MultiAcc from .. import utils from .._settings import settings @@ -517,20 +517,34 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size[R: dict[RefAcc | None, int]]( + def get_size[R: dict[type[RefAcc | MapAcc | AdAcc] | None, int]]( X: RWAble, *, accumulate: R, - ref_acc: RefAcc | AdRef | None, + ref_acc: RefAcc | AdRef | MapAcc | None, ) -> R: def cs_to_bytes(X) -> int: return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) if is_elem := ( - (is_ad_ref := isinstance(ref_acc, AdRef)) - or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) - ) or (ref_acc is None and X is not self.uns): - key = ref_acc.acc if is_ad_ref else ref_acc + ( + # an array of some sort i.e., from AdRef (from obs/var) or a reference to one + (is_ad_ref := isinstance(ref_acc, AdRef)) + or ( + is_ref_acc := isinstance( + ref_acc, LayerAcc | MultiAcc | GraphAcc + ) + ) + ) + # an element of uns + or (ref_acc is None and X is not self.uns) + ): + if is_ad_ref: + key = type(ref_acc.acc) + elif is_ref_acc: + key = ref_acc.parent_type + else: + key = None if isinstance(X, h5py.Dataset) and with_disk: accumulate[key] += int(np.array(X.shape).prod() * X.dtype.itemsize) elif isinstance(X, BaseCompressedSparseDataset) and with_disk: @@ -539,8 +553,12 @@ def cs_to_bytes(X) -> int: accumulate[key] += cs_to_bytes(X) else: accumulate[key] += X.__sizeof__() - if not is_elem or ref_acc is A.X: - s = accumulate[ref_acc] + # if this is X or a parent elem maybe print it out. + if (is_x := ref_acc is A.X) or not is_elem: + if ref_acc is not None: + s = accumulate[AdAcc if is_x else type(ref_acc)] + else: + s = accumulate[None] if s > 0 and show_stratified: from tqdm import tqdm @@ -1453,8 +1471,14 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - def fold[T](self, func: FoldFunc[T], *, init: T) -> T: - """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object in DFS order. + def fold[T]( + self, + func: FoldFunc[T], + *, + init: T, + order: Literal["DFS-pre", "DFS-post"] = "DFS-post", + ) -> T: + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. Parameters ---------- @@ -1462,6 +1486,12 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: The function that performs the accumulation init The starting value + order + How to visit the items in the fold. + "DFS-pre" indicates that parent-elements like uns, obs, and varp get visited first. + "DFS-post" means they get visited afterwards. + The `AnnData` itself is not visited. + Returns @@ -1478,18 +1508,20 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: "obsp", "varp", "layers", + "uns", ]: attr = getattr(self, attr_name) - acc = getattr(A, attr_name) + acc = getattr(A, attr_name) if attr_name != "uns" else None + if order == "DFS-pre": + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) if attr_name != "X": for elem_name in attr: - ref = acc[elem_name] + ref = acc[elem_name] if acc is not None else None accumulate = func( attr[elem_name], accumulate=accumulate, ref_acc=ref ) - accumulate = func(attr, accumulate=accumulate, ref_acc=acc) - for elem in self.uns: - accumulate = func(elem, accumulate=accumulate, ref_acc=None) + if order == "DFS-post": + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: diff --git a/src/anndata/acc/__init__.py b/src/anndata/acc/__init__.py index ff6b80b6d..86215827b 100644 --- a/src/anndata/acc/__init__.py +++ b/src/anndata/acc/__init__.py @@ -192,6 +192,11 @@ def _maybe_flatten(self, idx: I, a: Array) -> Array: return a.__array_namespace__().reshape(a, (a.size,)) return a.ravel() + @property + @abc.abstractmethod + def parent_type(self) -> type[MapAcc | AdAcc]: + """Get the parent to this reference accessor""" + @dataclass(frozen=True) class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): @@ -209,6 +214,10 @@ class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str | None """Key this accessor refers to, e.g. `A.layers['counts'].k == 'counts'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return LayerMapAcc if self.k is not None else AdAcc + @overload def __getitem__(self, idx: Idx2D, /) -> R: ... @overload @@ -298,6 +307,10 @@ class MetaAcc[R: AdRef[str | None]](RefAcc[R, str | None]): dim: Literal["obs", "var"] """Axis this accessor refers to, e.g. `A.obs.dim == 'obs'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return AdAcc + @property def index(self) -> R: """Index :class:`AdRef`, i.e. `A.obs.index` or `A.var.index`.""" @@ -380,6 +393,10 @@ class MultiAcc[R: AdRef[int]](RefAcc[R, int]): k: str """Key this accessor refers to, e.g. `A.varm['x'].k == 'x'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return MultiMapAcc + @staticmethod def process_idx(i: object, /) -> int | list[int] | pd.Index[int]: if isinstance(i, tuple): @@ -463,6 +480,10 @@ class GraphAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str """Key this accessor refers to, e.g. `A.obsp['x'].k == 'x'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return GraphMapAcc + def process_idx(self, idx: Idx2D, /) -> Idx2D: if not all(isinstance(i, str | slice) for i in idx): msg = f"Unsupported index {idx!r}" diff --git a/src/anndata/types.py b/src/anndata/types.py index 06a80a784..126eb1d02 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,7 +8,7 @@ from array_api.latest import ArrayNamespace - from anndata.acc import AdRef, RefAcc + from anndata.acc import AdAcc, AdRef, MapAcc, RefAcc from anndata.typing import RWAble from ._core.anndata import AnnData @@ -55,5 +55,9 @@ def __dlpack_device__(self) -> tuple[int, int]: ... class FoldFunc[T](Protocol): def __call__( - self, elem: RWAble, *, accumulate: T, ref_acc: RefAcc | AdRef | None - ) -> T | None: ... + self, + elem: RWAble, + *, + accumulate: T, + ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + ) -> T: ... From 0f4d1b0417b73e0e74609641a1e0c990bc5eeb06 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 16:30:04 +0100 Subject: [PATCH 05/32] chore: add test for `uns` --- tests/test_readwrite.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index c6468e86b..429a093d0 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -135,13 +135,16 @@ def test_can_write( @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +@pytest.mark.parametrize("parent_elem", ["obs", "uns"]) def test_can_not_write_with_custom_array( - rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None + rw: tuple[ad.AnnData, ad.AnnData], + store_type: Literal["h5", "zarr"] | None, + parent_elem: Literal["obs", "uns"], ): import pyarrow as pa adata, _ = rw - adata.obs["arrow_array"] = pd.arrays.ArrowExtensionArray( + getattr(adata, parent_elem)["arrow_array"] = pd.arrays.ArrowExtensionArray( pa.array([{"x": 1, "y": True}] * adata.shape[0]) ) assert not adata.can_write(store_type=store_type) From 69daf90bb0ae54a567ef3139ccd5d64d414e728d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:07:59 +0100 Subject: [PATCH 06/32] feat: `raw` + `uns` traversal --- src/anndata/_core/anndata.py | 85 ++++++++++++++++++++++-------------- tests/test_base.py | 23 ++++++++++ tests/test_readwrite.py | 13 ++++-- 3 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 59a802b69..929e08edb 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -68,7 +68,7 @@ from anndata.typing import RWAble from ..acc import Array, MapAcc, RefAcc - from ..compat import XDataset + from ..compat import CSMatrix, XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -517,27 +517,37 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size[R: dict[type[RefAcc | MapAcc | AdAcc] | None, int]]( + def cs_to_bytes(X: CSArray | CSMatrix) -> int: + return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) + + def get_size(X: RWAble) -> int: + if isinstance(X, h5py.Dataset) and with_disk: + return int(np.array(X.shape).prod() * X.dtype.itemsize) + elif isinstance(X, BaseCompressedSparseDataset) and with_disk: + return cs_to_bytes(X._to_backed()) + elif issparse(X): + return cs_to_bytes(X) + else: + return X.__sizeof__() + + def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( X: RWAble, *, accumulate: R, ref_acc: RefAcc | AdRef | MapAcc | None, ) -> R: - def cs_to_bytes(X) -> int: - return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) - + if isinstance(X, Raw): + ref_acc = X # type: ignore[assignment] + accumulate[Raw] += get_size(X.X) + accumulate[Raw] += get_size(X.var) + for key in X.varm: + accumulate[Raw] += get_size(X.varm[key]) + elif ref_acc is None: # "None but not Raw" is uns + accumulate[None] = sum(get_size(v) for v in self.uns.values()) if is_elem := ( - ( - # an array of some sort i.e., from AdRef (from obs/var) or a reference to one - (is_ad_ref := isinstance(ref_acc, AdRef)) - or ( - is_ref_acc := isinstance( - ref_acc, LayerAcc | MultiAcc | GraphAcc - ) - ) - ) - # an element of uns - or (ref_acc is None and X is not self.uns) + # an array of some sort i.e., from AdRef (from obs/var) or a reference to one + (is_ad_ref := isinstance(ref_acc, AdRef)) + or (is_ref_acc := isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc)) ): if is_ad_ref: key = type(ref_acc.acc) @@ -545,18 +555,11 @@ def cs_to_bytes(X) -> int: key = ref_acc.parent_type else: key = None - if isinstance(X, h5py.Dataset) and with_disk: - accumulate[key] += int(np.array(X.shape).prod() * X.dtype.itemsize) - elif isinstance(X, BaseCompressedSparseDataset) and with_disk: - accumulate[key] += cs_to_bytes(X._to_backed()) - elif issparse(X): - accumulate[key] += cs_to_bytes(X) - else: - accumulate[key] += X.__sizeof__() + accumulate[key] += get_size(X) # if this is X or a parent elem maybe print it out. if (is_x := ref_acc is A.X) or not is_elem: if ref_acc is not None: - s = accumulate[AdAcc if is_x else type(ref_acc)] + s = accumulate[AdAcc if is_x else type(ref_acc)] # type: ignore[assignment] else: s = accumulate[None] if s > 0 and show_stratified: @@ -567,7 +570,7 @@ def cs_to_bytes(X) -> int: ) return accumulate - return sum(self.fold(get_size, init=defaultdict(int)).values()) + return sum(self.fold(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1480,10 +1483,9 @@ def fold[T]( "obsp", "varp", "layers", - "uns", ]: attr = getattr(self, attr_name) - acc = getattr(A, attr_name) if attr_name != "uns" else None + acc = getattr(A, attr_name) if order == "DFS-pre": accumulate = func(attr, accumulate=accumulate, ref_acc=acc) if attr_name != "X": @@ -1494,6 +1496,8 @@ def fold[T]( ) if order == "DFS-post": accumulate = func(attr, accumulate=accumulate, ref_acc=acc) + accumulate = func(self.uns, accumulate=accumulate, ref_acc=None) + accumulate = func(self.raw, accumulate=accumulate, ref_acc=None) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: @@ -1516,12 +1520,29 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate(x: RWAble, *, accumulate: bool, ref_acc: AdRef | RefAcc | None): + def predicate( + elem: RWAble, + *, + accumulate: bool, + ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + ): + if isinstance(elem, Raw): + accumulate = accumulate and type(elem.X) in writeable_elems + return accumulate and all( + type(e[attr]) in writeable_elems + for e in [elem.var, elem.varm] + for attr in e + ) + if ref_acc is None and isinstance(elem, dict): + return accumulate and all( + predicate(e, accumulate=accumulate, ref_acc=None) + for e in elem.values() + ) if isinstance(ref_acc, AdRef) or ref_acc is None: - if isinstance(x, pd.Series): + if isinstance(elem, pd.Series): # matches behavior in methods.py - x = x._values - return accumulate and type(x) in writeable_elems + elem = elem._values + return accumulate and type(elem) in writeable_elems return accumulate return self.fold(predicate, init=True) diff --git a/tests/test_base.py b/tests/test_base.py index 254e483b8..b24197956 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -186,6 +186,29 @@ def test_df_warnings(): adata.X = df +@pytest.mark.parametrize("use_raw", [True, False], ids=["raw", "no_raw"]) +@pytest.mark.parametrize("use_uns", [True, False], ids=["uns", "no_uns"]) +def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): + adata = gen_adata((10, 20)) + if use_uns: + adata.uns = {"foo": np.arange(10)} + if use_raw: + adata.raw = adata.copy() + adata.__sizeof__(show_stratified=True) + captured = capsys.readouterr() + for attr in [ + "X", + "layers", + "obsm", + "varm", + "obsp", + "varp", + *(["uns"] if use_uns else []), + *(["raw"] if use_raw else []), + ]: + assert attr in captured.out + + @pytest.mark.parametrize("attr", ["X", "layers", "obsm", "varm", "obsp", "varp"]) @pytest.mark.parametrize("when", ["init", "assign"]) def test_convert_matrix(attr, when): diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 429a093d0..43694da77 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -135,17 +135,22 @@ def test_can_write( @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) -@pytest.mark.parametrize("parent_elem", ["obs", "uns"]) +@pytest.mark.parametrize("parent_elem", ["var", "uns", "raw"]) def test_can_not_write_with_custom_array( rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None, - parent_elem: Literal["obs", "uns"], + parent_elem: Literal["obs", "uns", "raw"], ): import pyarrow as pa adata, _ = rw - getattr(adata, parent_elem)["arrow_array"] = pd.arrays.ArrowExtensionArray( - pa.array([{"x": 1, "y": True}] * adata.shape[0]) + if parent_elem == "raw": + adata.raw = adata.copy() + getter = lambda: getattr(adata, parent_elem).var + else: + getter = lambda: getattr(adata, parent_elem) + getter()["arrow_array"] = pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}] * adata.shape[1]) ) assert not adata.can_write(store_type=store_type) From 932d766db3800c1b418857124f991d2ebe6e04e7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:08:20 +0100 Subject: [PATCH 07/32] fix: `fold` -> `reduce` --- src/anndata/_core/anndata.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 929e08edb..d692e242f 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -570,7 +570,7 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( ) return accumulate - return sum(self.fold(fold_size, init=defaultdict(int)).values()) + return sum(self.reduce(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1446,7 +1446,7 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - def fold[T]( + def reduce[T]( self, func: FoldFunc[T], *, @@ -1462,7 +1462,7 @@ def fold[T]( init The starting value order - How to visit the items in the fold. + How to visit the items in the reduce. "DFS-pre" indicates that parent-elements like uns, obs, and varp get visited first. "DFS-post" means they get visited afterwards. The `AnnData` itself is not visited. @@ -1545,7 +1545,7 @@ def predicate( return accumulate and type(elem) in writeable_elems return accumulate - return self.fold(predicate, init=True) + return self.reduce(predicate, init=True) @deprecated( deprecation_msg( From e0f3ee24da7f93d5ca6b8d8edfa521f29965a70c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:42:19 +0100 Subject: [PATCH 08/32] chore: docs --- docs/api.md | 2 +- src/anndata/_core/anndata.py | 15 +++++++++++---- src/anndata/types.py | 20 ++++++++++++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/docs/api.md b/docs/api.md index 0f42a381d..5bdac918e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -268,7 +268,7 @@ Types used by the former: .. autosummary:: :toctree: generated/ - types.FoldFunc + types.ReduceFunc ``` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index d692e242f..f968a0b33 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -64,7 +64,7 @@ from zarr.storage import StoreLike - from anndata.types import FoldFunc + from anndata.types import ReduceFunc from anndata.typing import RWAble from ..acc import Array, MapAcc, RefAcc @@ -1448,22 +1448,29 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: def reduce[T]( self, - func: FoldFunc[T], + func: ReduceFunc[T], *, init: T, order: Literal["DFS-pre", "DFS-post"] = "DFS-post", ) -> T: """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + All visits inside the user-defined `func` are distinguishable via the `ref_acc` + `elem` args. + Visits to {attr}`~AnnData.raw` pass `ref_acc is None` and `isinstance(elem, Raw)` to the :func:`types.ReduceFunc`. + Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. + Furthermore, neither element is descended into. + This behavior could change where a new `ref_acc` type will be available, in which case we could start descending in these cases. + All other elements will have a non-`None` `ref_acc` argument. + Parameters ---------- func - The function that performs the accumulation + The function that performs the accumulation. init The starting value order How to visit the items in the reduce. - "DFS-pre" indicates that parent-elements like uns, obs, and varp get visited first. + "DFS-pre" indicates that parent-elements like layers, obs, and varp get visited first. "DFS-post" means they get visited afterwards. The `AnnData` itself is not visited. diff --git a/src/anndata/types.py b/src/anndata/types.py index 126eb1d02..d168d3134 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -53,11 +53,27 @@ def __dlpack__( def __dlpack_device__(self) -> tuple[int, int]: ... -class FoldFunc[T](Protocol): +class ReduceFunc[T](Protocol): def __call__( self, elem: RWAble, *, accumulate: T, ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, - ) -> T: ... + ) -> T: + """Function to be called on each visit within :func:`AnnData.reduce`. + + Parameters + ---------- + elem + The current element. + accumulate + The value being accumulated. + ref_acc + A reference to help uses distinguish where they are in the `AnnData` object. + + Returns + ------- + An accumulated value + """ + ... From ee04741f5a5d40e86fa9726dfabd93e2584c5e1d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:46:15 +0100 Subject: [PATCH 09/32] fix: `meth` not `func` --- src/anndata/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/types.py b/src/anndata/types.py index d168d3134..4262c639c 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -61,7 +61,7 @@ def __call__( accumulate: T, ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, ) -> T: - """Function to be called on each visit within :func:`AnnData.reduce`. + """Function to be called on each visit within :meth:`AnnData.reduce`. Parameters ---------- From 6d6f4548c19c44d03099672e676a7264f5ce0d41 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:46:58 +0100 Subject: [PATCH 10/32] fix: `fold` not `reduce` in relnote --- docs/release-notes/2327.feat.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/2327.feat.md b/docs/release-notes/2327.feat.md index a66fef550..0a0feef86 100644 --- a/docs/release-notes/2327.feat.md +++ b/docs/release-notes/2327.feat.md @@ -1 +1 @@ -New {meth}`AnnData.fold` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` +New {meth}`AnnData.reduce` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` From 1f77a4c719a374984f6490c238c988a9cb8d5fce Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:50:50 +0100 Subject: [PATCH 11/32] fix: nested --- src/anndata/_core/anndata.py | 6 ++++-- tests/test_base.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 2a135eb95..ec9f2e789 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -67,7 +67,7 @@ from anndata.typing import RWAble from ..acc import Array, MapAcc, RefAcc - from ..compat import CSMatrix, XDataset + from ..compat import CSArray, CSMatrix, XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -526,6 +526,8 @@ def get_size(X: RWAble) -> int: return cs_to_bytes(X._to_backed()) elif issparse(X): return cs_to_bytes(X) + elif isinstance(X, dict): + return sum(get_size(v) for v in X.values()) else: return X.__sizeof__() @@ -542,7 +544,7 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( for key in X.varm: accumulate[Raw] += get_size(X.varm[key]) elif ref_acc is None: # "None but not Raw" is uns - accumulate[None] = sum(get_size(v) for v in self.uns.values()) + accumulate[None] = get_size(self.uns) if is_elem := ( # an array of some sort i.e., from AdRef (from obs/var) or a reference to one (is_ad_ref := isinstance(ref_acc, AdRef)) diff --git a/tests/test_base.py b/tests/test_base.py index b24197956..fe6c66ba3 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -191,7 +191,7 @@ def test_df_warnings(): def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): adata = gen_adata((10, 20)) if use_uns: - adata.uns = {"foo": np.arange(10)} + adata.uns = {"foo": np.arange(10), "nested": {"here": np.arange(10)}} if use_raw: adata.raw = adata.copy() adata.__sizeof__(show_stratified=True) From 91adffe29383a4a81165be9abd34c2dec7385a63 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:53:04 +0100 Subject: [PATCH 12/32] chore: more `func` clarification --- src/anndata/_core/anndata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index ec9f2e789..90d2fcb8b 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1461,7 +1461,7 @@ def reduce[T]( Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. Furthermore, neither element is descended into. This behavior could change where a new `ref_acc` type will be available, in which case we could start descending in these cases. - All other elements will have a non-`None` `ref_acc` argument. + All other elements will have a non-`None` `ref_acc` argument indicating the path at which `elem` was created in the `AnnData`. Parameters ---------- From 928b72af17297e9a4d532a17d5d71160c98761e8 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:53:36 +0100 Subject: [PATCH 13/32] fix: link --- src/anndata/_core/anndata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 90d2fcb8b..c1ae4fdb5 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1456,7 +1456,7 @@ def reduce[T]( ) -> T: """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. - All visits inside the user-defined `func` are distinguishable via the `ref_acc` + `elem` args. + All visits inside the user-defined `func` (see :func:`types.ReduceFunc`) are distinguishable via the `ref_acc` + `elem` args. Visits to {attr}`~AnnData.raw` pass `ref_acc is None` and `isinstance(elem, Raw)` to the :func:`types.ReduceFunc`. Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. Furthermore, neither element is descended into. From 19a915d5f3bd8f201e66b91b0dac65ee8c64200e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:58:00 +0100 Subject: [PATCH 14/32] fix: link --- src/anndata/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/types.py b/src/anndata/types.py index 4262c639c..712c09d51 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -61,7 +61,7 @@ def __call__( accumulate: T, ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, ) -> T: - """Function to be called on each visit within :meth:`AnnData.reduce`. + """Function to be called on each visit within :meth:`anndata.AnnData.reduce`. Parameters ---------- From c0886fef44448ec5cb2ffb2cf2c2e70ef17ab9ae Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:59:54 +0100 Subject: [PATCH 15/32] refactor: simpler --- src/anndata/_core/anndata.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index c1ae4fdb5..e059bb7fc 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -548,14 +548,9 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( if is_elem := ( # an array of some sort i.e., from AdRef (from obs/var) or a reference to one (is_ad_ref := isinstance(ref_acc, AdRef)) - or (is_ref_acc := isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc)) + or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) ): - if is_ad_ref: - key = type(ref_acc.acc) - elif is_ref_acc: - key = ref_acc.parent_type - else: - key = None + key = type(ref_acc.acc) if is_ad_ref else ref_acc.parent_type accumulate[key] += get_size(X) # if this is X or a parent elem maybe print it out. if (is_x := ref_acc is A.X) or not is_elem: From 6cffc05c480914e414dd9000e51e5b51ec757195 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 17:02:45 +0100 Subject: [PATCH 16/32] fix: relnote number --- docs/release-notes/{2327.feat.md => 2372.feat.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/release-notes/{2327.feat.md => 2372.feat.md} (100%) diff --git a/docs/release-notes/2327.feat.md b/docs/release-notes/2372.feat.md similarity index 100% rename from docs/release-notes/2327.feat.md rename to docs/release-notes/2372.feat.md From 39800aadd7cb0084492683e3c4894de5e4fcc75b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:13:58 +0200 Subject: [PATCH 17/32] refactor: use `iter` --- src/anndata/_core/anndata.py | 189 ++++++++++++++----------------- src/anndata/_io/h5ad.py | 38 ++++--- src/anndata/_io/specs/methods.py | 20 ++-- src/anndata/types.py | 29 ----- tests/test_base.py | 6 +- 5 files changed, 117 insertions(+), 165 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index e059bb7fc..24288dd2e 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -21,13 +21,13 @@ from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning -from anndata.acc import A, AdAcc, AdRef, GraphAcc, LayerAcc, MultiAcc from .. import utils from .._settings import settings from ..compat import ( DaskArray, IndexManager, + XDataset, ZarrArray, _move_adj_mtx, has_xp, @@ -56,7 +56,7 @@ from .xarray import Dataset2D if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Generator, Iterable from os import PathLike from typing import Any, ClassVar, Literal @@ -66,9 +66,10 @@ from anndata.types import ReduceFunc from anndata.typing import RWAble - from ..acc import Array, MapAcc, RefAcc - from ..compat import CSArray, CSMatrix, XDataset - from ..typing import Index, Index1D, _Index1DNorm, _XDataType + from .._types import AnnDataElem + from ..acc import AdRef, Array, MapAcc, RefAcc + from ..compat import CSArray, CSMatrix + from ..typing import AxisStorable, Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -526,44 +527,32 @@ def get_size(X: RWAble) -> int: return cs_to_bytes(X._to_backed()) elif issparse(X): return cs_to_bytes(X) - elif isinstance(X, dict): + elif isinstance(X, dict | MutableMapping): return sum(get_size(v) for v in X.values()) else: return X.__sizeof__() - def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( - X: RWAble, + def fold_size( + elem: _XDataType | AxisStorable | pd.DataFrame | XDataset, *, - accumulate: R, - ref_acc: RefAcc | AdRef | MapAcc | None, - ) -> R: - if isinstance(X, Raw): - ref_acc = X # type: ignore[assignment] - accumulate[Raw] += get_size(X.X) - accumulate[Raw] += get_size(X.var) - for key in X.varm: - accumulate[Raw] += get_size(X.varm[key]) - elif ref_acc is None: # "None but not Raw" is uns - accumulate[None] = get_size(self.uns) - if is_elem := ( - # an array of some sort i.e., from AdRef (from obs/var) or a reference to one - (is_ad_ref := isinstance(ref_acc, AdRef)) - or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) - ): - key = type(ref_acc.acc) if is_ad_ref else ref_acc.parent_type - accumulate[key] += get_size(X) - # if this is X or a parent elem maybe print it out. - if (is_x := ref_acc is A.X) or not is_elem: - if ref_acc is not None: - s = accumulate[AdAcc if is_x else type(ref_acc)] # type: ignore[assignment] - else: - s = accumulate[None] - if s > 0 and show_stratified: - from tqdm import tqdm + accumulate: dict[str, int], + attr_name: str | None, # TODO: type + ): + if elem is None: + size = 0 + elif elem is self.raw: + size = ( + get_size(elem.X) + + get_size(elem.var) + + sum(get_size(v) for v in elem.varm.values()) + ) + else: + size = get_size(elem) + accumulate[attr_name] = size + if size > 0 and show_stratified: + from tqdm import tqdm - print( - f"Size of {repr(ref_acc).replace('A.', '') if ref_acc is not None else 'uns'}: {tqdm.format_sizeof(s, 'B')}" - ) + print(f"Size of {attr_name}: {tqdm.format_sizeof(size, 'B')}") return accumulate return sum(self.reduce(fold_size, init=defaultdict(int)).values()) @@ -571,19 +560,11 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}" - for attr in [ - "obs", - "var", - "uns", - "obsm", - "varm", - "layers", - "obsp", - "varp", - ]: - keys = getattr(self, attr).keys() - if len(keys) > 0: - descr += f"\n {attr}: {str(list(keys))[1:-1]}" + for attr_name, elem in self.iter(): + if attr_name not in {"raw", "X"}: + keys = elem.keys() + if len(keys) > 0: + descr += f"\n {attr_name}: {str(list(keys))[1:-1]}" return descr def __repr__(self) -> str: @@ -1389,27 +1370,16 @@ def to_memory(self, *, copy: bool = False) -> AnnData: mem = backed[backed.obs["cluster"] == "a", :].to_memory() """ new = {} - for attr_name in [ - "X", - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - "uns", - ]: - attr = getattr(self, attr_name, None) + for attr_name, attr in self.iter(): if attr is not None: - new[attr_name] = to_memory(attr, copy=copy) - - if self.raw is not None: - new["raw"] = { - "X": to_memory(self.raw.X, copy=copy), - "var": to_memory(self.raw.var, copy=copy), - "varm": to_memory(self.raw.varm, copy=copy), - } + if attr is self.raw: + new["raw"] = { + "X": to_memory(self.raw.X, copy=copy), + "var": to_memory(self.raw.var, copy=copy), + "varm": to_memory(self.raw.varm, copy=copy), + } + else: + new[attr_name] = to_memory(attr, copy=copy) if self.isbacked: self.file.close() @@ -1442,6 +1412,28 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) + def iter( + self, + ) -> Generator[ + tuple[AnnDataElem, AxisStorable | _XDataType | Dataset2D | pd.DataFrame] + ]: + for attr_name in [ + "X", + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + "uns", + "raw", + ]: + was_closed = self.isbacked and not self.file.is_open + yield (attr_name, getattr(self, attr_name)) + if was_closed: + self.file.close() + def reduce[T]( self, func: ReduceFunc[T], @@ -1477,30 +1469,8 @@ def reduce[T]( An accumulated value """ accumulate = init - for attr_name in [ - "X", - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - ]: - attr = getattr(self, attr_name) - acc = getattr(A, attr_name) - if order == "DFS-pre": - accumulate = func(attr, accumulate=accumulate, ref_acc=acc) - if attr_name != "X": - for elem_name in attr: - ref = acc[elem_name] if acc is not None else None - accumulate = func( - attr[elem_name], accumulate=accumulate, ref_acc=ref - ) - if order == "DFS-post": - accumulate = func(attr, accumulate=accumulate, ref_acc=acc) - accumulate = func(self.uns, accumulate=accumulate, ref_acc=None) - accumulate = func(self.raw, accumulate=accumulate, ref_acc=None) + for attr_name, attr in self.iter(): + accumulate = func(attr, accumulate=accumulate, attr_name=attr_name) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: @@ -1515,6 +1485,7 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: ------- Whether or not this object is writable. """ + from anndata._io.specs.registry import _REGISTRY writeable_elems = { @@ -1527,26 +1498,34 @@ def predicate( elem: RWAble, *, accumulate: bool, - ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + attr_name: str | None = None, # TODO: type ): - if isinstance(elem, Raw): + if elem is None: + return accumulate + if attr_name == "raw": accumulate = accumulate and type(elem.X) in writeable_elems return accumulate and all( type(e[attr]) in writeable_elems for e in [elem.var, elem.varm] for attr in e ) - if ref_acc is None and isinstance(elem, dict): + if attr_name in ( + "obs", + "obsm", + "varm", + "var", + "layers", + "varp", + "obsp", + "uns", + ) or isinstance(elem, pd.DataFrame | XDataset | MutableMapping): return accumulate and all( - predicate(e, accumulate=accumulate, ref_acc=None) - for e in elem.values() + predicate(elem[k], accumulate=accumulate) for k in elem ) - if isinstance(ref_acc, AdRef) or ref_acc is None: - if isinstance(elem, pd.Series): - # matches behavior in methods.py - elem = elem._values - return accumulate and type(elem) in writeable_elems - return accumulate + if isinstance(elem, pd.Series): + # matches behavior in methods.py + elem = elem._values + return accumulate and type(elem) in writeable_elems return self.reduce(predicate, init=True) diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index d0540de1c..cc8948e28 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from collections.abc import MutableMapping from functools import partial from pathlib import Path from types import MappingProxyType @@ -84,23 +85,26 @@ def write_h5ad( f = cast("h5py.Group", f["/"]) f.attrs.setdefault("encoding-type", "anndata") f.attrs.setdefault("encoding-version", "0.1.0") - - _write_x( - f, - adata, # accessing adata.X reopens adata.file if it’s backed - is_backed=adata.isbacked and adata.filename == filepath, - as_dense=as_dense, - dataset_kwargs=dataset_kwargs, - ) - _write_raw(f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs) - write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs) - write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) + for k, elem in adata.iter(): + if k == "X": + _write_x( + f, + adata, # accessing adata.X reopens adata.file if it’s backed + is_backed=adata.isbacked and adata.filename == filepath, + as_dense=as_dense, + dataset_kwargs=dataset_kwargs, + ) + elif k == "raw": + _write_raw( + f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs + ) + else: + write_elem( + f, + k, + dict(elem) if isinstance(elem, MutableMapping) else elem, + dataset_kwargs=dataset_kwargs, + ) def _write_x( diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 43b084a00..f2cfff554 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from copy import copy from functools import partial from itertools import product @@ -286,17 +286,13 @@ def write_anndata( dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): g = f.require_group(k) - if adata.X is not None: - _writer.write_elem(g, "X", adata.X, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "var", adata.var, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) + for sub_key, elem in adata.iter(): + _writer.write_elem( + g, + sub_key, + dict(elem) if isinstance(elem, MutableMapping) else elem, + dataset_kwargs=dataset_kwargs, + ) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) diff --git a/src/anndata/types.py b/src/anndata/types.py index 712c09d51..aa23d10f2 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,9 +8,6 @@ from array_api.latest import ArrayNamespace - from anndata.acc import AdAcc, AdRef, MapAcc, RefAcc - from anndata.typing import RWAble - from ._core.anndata import AnnData @@ -51,29 +48,3 @@ def __dlpack__( copy: bool | None = None, ) -> Any: ... def __dlpack_device__(self) -> tuple[int, int]: ... - - -class ReduceFunc[T](Protocol): - def __call__( - self, - elem: RWAble, - *, - accumulate: T, - ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, - ) -> T: - """Function to be called on each visit within :meth:`anndata.AnnData.reduce`. - - Parameters - ---------- - elem - The current element. - accumulate - The value being accumulated. - ref_acc - A reference to help uses distinguish where they are in the `AnnData` object. - - Returns - ------- - An accumulated value - """ - ... diff --git a/tests/test_base.py b/tests/test_base.py index fe6c66ba3..c5895c8bc 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -192,6 +192,8 @@ def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): adata = gen_adata((10, 20)) if use_uns: adata.uns = {"foo": np.arange(10), "nested": {"here": np.arange(10)}} + else: + adata.uns = {} if use_raw: adata.raw = adata.copy() adata.__sizeof__(show_stratified=True) @@ -203,10 +205,10 @@ def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): "varm", "obsp", "varp", - *(["uns"] if use_uns else []), - *(["raw"] if use_raw else []), ]: assert attr in captured.out + assert use_uns == ("uns" in captured.out) + assert use_raw == ("raw" in captured.out) @pytest.mark.parametrize("attr", ["X", "layers", "obsm", "varm", "obsp", "varp"]) From 6cb401b45d4f60d0224639bf0fe17b1e4d445f11 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:29:59 +0200 Subject: [PATCH 18/32] fix: oops --- docs/concatenation.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/concatenation.rst b/docs/concatenation.rst index fe4046d75..b22b68d7a 100644 --- a/docs/concatenation.rst +++ b/docs/concatenation.rst @@ -26,7 +26,6 @@ Let's start off with an example: AnnData object with n_obs × n_vars = 700 × 765 obs: 'bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', 'G2M_score', 'phase', 'louvain' var: 'n_counts', 'means', 'dispersions', 'dispersions_norm', 'highly_variable' - uns: 'bulk_labels_colors', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups' obsm: 'X_pca', 'X_umap' varm: 'PCs' obsp: ... @@ -165,9 +164,9 @@ First, our example case: >>> blobs AnnData object with n_obs × n_vars = 640 × 30 obs: 'blobs' - uns: 'pca' obsm: 'X_pca' varm: 'PCs' + uns: 'pca' Now we will split this object by the categorical `"blobs"` and recombine it to illustrate different merge strategies. @@ -181,9 +180,9 @@ Now we will split this object by the categorical `"blobs"` and recombine it to i >>> adatas[0] AnnData object with n_obs × n_vars = 128 × 30 obs: 'blobs' - uns: 'pca' obsm: 'X_pca', 'qc' varm: 'PCs', '0_qc' + uns: 'pca' `adatas` is now a list of datasets with disjoint sets of observations and a common set of variables. Each object has had QC metrics computed, with observation-wise metrics stored under `"qc"` in `.obsm`, and variable-wise metrics stored with a unique key for each subset. From 1dfdd96ebfa53a03eec0fa8e64b95147a91b8c8d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:37:37 +0200 Subject: [PATCH 19/32] fix: why was this deleted? --- src/anndata/types.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/anndata/types.py b/src/anndata/types.py index aa23d10f2..166f0897b 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -7,8 +7,13 @@ from typing import Any, Literal from array_api.latest import ArrayNamespace + from pandas import DataFrame + + from anndata.typing import AxisStorable, _XDataType from ._core.anndata import AnnData + from ._types import AnnDataElem + from .compat import XDataset @runtime_checkable @@ -48,3 +53,29 @@ def __dlpack__( copy: bool | None = None, ) -> Any: ... def __dlpack_device__(self) -> tuple[int, int]: ... + + +class ReduceFunc[T](Protocol): + def __call__( + self, + elem: _XDataType | AxisStorable | DataFrame | XDataset, + *, + accumulate: T, + attr_name: AnnDataElem | None, + ) -> T: + """Function to be called on each visit within :meth:`anndata.AnnData.reduce`. + + Parameters + ---------- + elem + The current element. + accumulate + The value being accumulated. + ref_acc + A reference to help uses distinguish where they are in the `AnnData` object. + + Returns + ------- + An accumulated value + """ + ... From 9ad937f7c3ba71931dea9c72aa002d459ce4e74b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:39:39 +0200 Subject: [PATCH 20/32] fix: doc string --- src/anndata/_core/anndata.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 24288dd2e..aa8d98775 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1439,16 +1439,8 @@ def reduce[T]( func: ReduceFunc[T], *, init: T, - order: Literal["DFS-pre", "DFS-post"] = "DFS-post", ) -> T: - """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. - - All visits inside the user-defined `func` (see :func:`types.ReduceFunc`) are distinguishable via the `ref_acc` + `elem` args. - Visits to {attr}`~AnnData.raw` pass `ref_acc is None` and `isinstance(elem, Raw)` to the :func:`types.ReduceFunc`. - Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. - Furthermore, neither element is descended into. - This behavior could change where a new `ref_acc` type will be available, in which case we could start descending in these cases. - All other elements will have a non-`None` `ref_acc` argument indicating the path at which `elem` was created in the `AnnData`. + """Accumulate a value starting from init by iterating over the parent "elems"of the AnnData object i.e., raw, obs, varp etc. Parameters ---------- @@ -1456,13 +1448,6 @@ def reduce[T]( The function that performs the accumulation. init The starting value - order - How to visit the items in the reduce. - "DFS-pre" indicates that parent-elements like layers, obs, and varp get visited first. - "DFS-post" means they get visited afterwards. - The `AnnData` itself is not visited. - - Returns ------- From 0c03ffb4634cecda10af86ece7645a784a9754e7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:40:13 +0200 Subject: [PATCH 21/32] fix: docs --- src/anndata/_core/anndata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index aa8d98775..ace4aa356 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1417,6 +1417,7 @@ def iter( ) -> Generator[ tuple[AnnDataElem, AxisStorable | _XDataType | Dataset2D | pd.DataFrame] ]: + """Iterate over key-value pairs of the parent "elems" like aw, obs, varp etc""" for attr_name in [ "X", "obs", From f00db891b65e7d53f8c6ef9e3a53dd35e3e5de51 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:42:03 +0200 Subject: [PATCH 22/32] fix: remove `parent_type` --- src/anndata/acc/__init__.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/anndata/acc/__init__.py b/src/anndata/acc/__init__.py index 86215827b..ff6b80b6d 100644 --- a/src/anndata/acc/__init__.py +++ b/src/anndata/acc/__init__.py @@ -192,11 +192,6 @@ def _maybe_flatten(self, idx: I, a: Array) -> Array: return a.__array_namespace__().reshape(a, (a.size,)) return a.ravel() - @property - @abc.abstractmethod - def parent_type(self) -> type[MapAcc | AdAcc]: - """Get the parent to this reference accessor""" - @dataclass(frozen=True) class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): @@ -214,10 +209,6 @@ class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str | None """Key this accessor refers to, e.g. `A.layers['counts'].k == 'counts'`.""" - @property - def parent_type(self) -> type[MapAcc | AdAcc]: - return LayerMapAcc if self.k is not None else AdAcc - @overload def __getitem__(self, idx: Idx2D, /) -> R: ... @overload @@ -307,10 +298,6 @@ class MetaAcc[R: AdRef[str | None]](RefAcc[R, str | None]): dim: Literal["obs", "var"] """Axis this accessor refers to, e.g. `A.obs.dim == 'obs'`.""" - @property - def parent_type(self) -> type[MapAcc | AdAcc]: - return AdAcc - @property def index(self) -> R: """Index :class:`AdRef`, i.e. `A.obs.index` or `A.var.index`.""" @@ -393,10 +380,6 @@ class MultiAcc[R: AdRef[int]](RefAcc[R, int]): k: str """Key this accessor refers to, e.g. `A.varm['x'].k == 'x'`.""" - @property - def parent_type(self) -> type[MapAcc | AdAcc]: - return MultiMapAcc - @staticmethod def process_idx(i: object, /) -> int | list[int] | pd.Index[int]: if isinstance(i, tuple): @@ -480,10 +463,6 @@ class GraphAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str """Key this accessor refers to, e.g. `A.obsp['x'].k == 'x'`.""" - @property - def parent_type(self) -> type[MapAcc | AdAcc]: - return GraphMapAcc - def process_idx(self, idx: Idx2D, /) -> Idx2D: if not all(isinstance(i, str | slice) for i in idx): msg = f"Unsupported index {idx!r}" From 9fa978aaf2cfee4596ac06599502345b52f63178 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 10 Apr 2026 12:28:05 +0200 Subject: [PATCH 23/32] fix: writing none --- src/anndata/_io/specs/methods.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index f2cfff554..c3a6fea38 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -287,12 +287,13 @@ def write_anndata( ): g = f.require_group(k) for sub_key, elem in adata.iter(): - _writer.write_elem( - g, - sub_key, - dict(elem) if isinstance(elem, MutableMapping) else elem, - dataset_kwargs=dataset_kwargs, - ) + if not (sub_key == "X" and elem is None): + _writer.write_elem( + g, + sub_key, + dict(elem) if isinstance(elem, MutableMapping) else elem, + dataset_kwargs=dataset_kwargs, + ) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) From e7b201f5dfb264d0f42b6ec26437f6a1feca3ae4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 10 Apr 2026 12:45:03 +0200 Subject: [PATCH 24/32] fix: API changes --- docs/api.md | 1 + docs/release-notes/2372.feat.md | 2 +- src/anndata/_core/anndata.py | 39 +++++++------------------------- src/anndata/_io/h5ad.py | 4 ++-- src/anndata/_io/specs/methods.py | 4 ++-- src/anndata/utils.py | 30 +++++++++++++++++++++++- 6 files changed, 43 insertions(+), 37 deletions(-) diff --git a/docs/api.md b/docs/api.md index 5bdac918e..88d4abfaa 100644 --- a/docs/api.md +++ b/docs/api.md @@ -92,6 +92,7 @@ Writing a complete {class}`AnnData` object to disk in anndata’s native formats AnnData.write_h5ad AnnData.write_zarr + AnnData.can_write .. diff --git a/docs/release-notes/2372.feat.md b/docs/release-notes/2372.feat.md index 0a0feef86..ad4b80912 100644 --- a/docs/release-notes/2372.feat.md +++ b/docs/release-notes/2372.feat.md @@ -1 +1 @@ -New {meth}`AnnData.reduce` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` +New {meth}`AnnData.can_write` for checking if an `AnnData` can be written {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index ace4aa356..6af6e9ee3 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -40,6 +40,7 @@ deprecated, deprecation_msg, ensure_df_homogeneous, + iter_outer, raise_value_error_if_multiindex_columns, set_module, warn, @@ -56,7 +57,7 @@ from .xarray import Dataset2D if TYPE_CHECKING: - from collections.abc import Generator, Iterable + from collections.abc import Iterable from os import PathLike from typing import Any, ClassVar, Literal @@ -66,7 +67,6 @@ from anndata.types import ReduceFunc from anndata.typing import RWAble - from .._types import AnnDataElem from ..acc import AdRef, Array, MapAcc, RefAcc from ..compat import CSArray, CSMatrix from ..typing import AxisStorable, Index, Index1D, _Index1DNorm, _XDataType @@ -555,12 +555,12 @@ def fold_size( print(f"Size of {attr_name}: {tqdm.format_sizeof(size, 'B')}") return accumulate - return sum(self.reduce(fold_size, init=defaultdict(int)).values()) + return sum(self._reduce(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}" - for attr_name, elem in self.iter(): + for attr_name, elem in iter_outer(self): if attr_name not in {"raw", "X"}: keys = elem.keys() if len(keys) > 0: @@ -1370,7 +1370,7 @@ def to_memory(self, *, copy: bool = False) -> AnnData: mem = backed[backed.obs["cluster"] == "a", :].to_memory() """ new = {} - for attr_name, attr in self.iter(): + for attr_name, attr in iter_outer(self): if attr is not None: if attr is self.raw: new["raw"] = { @@ -1412,30 +1412,7 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - def iter( - self, - ) -> Generator[ - tuple[AnnDataElem, AxisStorable | _XDataType | Dataset2D | pd.DataFrame] - ]: - """Iterate over key-value pairs of the parent "elems" like aw, obs, varp etc""" - for attr_name in [ - "X", - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - "uns", - "raw", - ]: - was_closed = self.isbacked and not self.file.is_open - yield (attr_name, getattr(self, attr_name)) - if was_closed: - self.file.close() - - def reduce[T]( + def _reduce[T]( self, func: ReduceFunc[T], *, @@ -1455,7 +1432,7 @@ def reduce[T]( An accumulated value """ accumulate = init - for attr_name, attr in self.iter(): + for attr_name, attr in iter_outer(self): accumulate = func(attr, accumulate=accumulate, attr_name=attr_name) return accumulate @@ -1513,7 +1490,7 @@ def predicate( elem = elem._values return accumulate and type(elem) in writeable_elems - return self.reduce(predicate, init=True) + return self._reduce(predicate, init=True) def var_names_make_unique(self, join: str = "-") -> None: # Important to go through the setter so obsm dataframes are updated too diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index cc8948e28..df6272567 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -24,7 +24,7 @@ _from_fixed_length_strings, ) from ..experimental import read_dispatched -from ..utils import warn +from ..utils import iter_outer, warn from .specs import read_elem, write_elem from .specs.registry import IOSpec, write_spec from .utils import ( @@ -85,7 +85,7 @@ def write_h5ad( f = cast("h5py.Group", f["/"]) f.attrs.setdefault("encoding-type", "anndata") f.attrs.setdefault("encoding-version", "0.1.0") - for k, elem in adata.iter(): + for k, elem in iter_outer(adata): if k == "X": _write_x( f, diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index c3a6fea38..1dee0e3e0 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -41,7 +41,7 @@ from ..._settings import settings from ...compat import PANDAS_STRING_ARRAY_TYPES, PANDAS_SUPPORTS_NA_VALUE -from ...utils import warn +from ...utils import iter_outer, warn from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial if TYPE_CHECKING: @@ -286,7 +286,7 @@ def write_anndata( dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): g = f.require_group(k) - for sub_key, elem in adata.iter(): + for sub_key, elem in iter_outer(adata): if not (sub_key == "X" and elem is None): _writer.write_elem( g, diff --git a/src/anndata/utils.py b/src/anndata/utils.py index 9089f90a5..d3c4c9b83 100644 --- a/src/anndata/utils.py +++ b/src/anndata/utils.py @@ -20,9 +20,13 @@ from .logging import get_logger if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping, Sequence + from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import Any, LiteralString + from ._core.xarray import Dataset2D + from ._types import AnnDataElem + from .typing import AxisStorable, _XDataType + logger = get_logger(__name__) @@ -468,3 +472,27 @@ def module_get_attr_redirect( return getattr(mod, new_path) msg = f"module {full_old_module_path} has no attribute {attr_name!r}" raise AttributeError(msg) + + +def iter_outer( + adata, +) -> Generator[ + tuple[AnnDataElem, AxisStorable | _XDataType | Dataset2D | pd.DataFrame] +]: + """Iterate over key-value pairs of the parent "elems" like aw, obs, varp etc""" + for attr_name in [ + "X", + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + "uns", + "raw", + ]: + was_closed = adata.isbacked and not adata.file.is_open + yield (attr_name, getattr(adata, attr_name)) + if was_closed: + adata.file.close() From 95136a211a7524d602c91ea1eb75f09ef7e10b3e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 13 Apr 2026 10:45:36 +0200 Subject: [PATCH 25/32] fix use `set` --- src/anndata/_core/anndata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 6af6e9ee3..bf673814f 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1472,7 +1472,7 @@ def predicate( for e in [elem.var, elem.varm] for attr in e ) - if attr_name in ( + if attr_name in { "obs", "obsm", "varm", @@ -1481,7 +1481,7 @@ def predicate( "varp", "obsp", "uns", - ) or isinstance(elem, pd.DataFrame | XDataset | MutableMapping): + } or isinstance(elem, pd.DataFrame | XDataset | MutableMapping): return accumulate and all( predicate(elem[k], accumulate=accumulate) for k in elem ) From 4eba690e9b8d474505a9b507fe8a2bc99e00722c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 13 Apr 2026 10:48:17 +0200 Subject: [PATCH 26/32] fix: docs --- docs/release-notes/2372.feat.md | 2 +- src/anndata/types.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/release-notes/2372.feat.md b/docs/release-notes/2372.feat.md index ad4b80912..b840039fe 100644 --- a/docs/release-notes/2372.feat.md +++ b/docs/release-notes/2372.feat.md @@ -1 +1 @@ -New {meth}`AnnData.can_write` for checking if an `AnnData` can be written {user}`ilan-gold` +New {meth}`anndata.AnnData.can_write` for checking if an `AnnData` can be written {user}`ilan-gold` diff --git a/src/anndata/types.py b/src/anndata/types.py index 166f0897b..27c3d7a0e 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -63,7 +63,7 @@ def __call__( accumulate: T, attr_name: AnnDataElem | None, ) -> T: - """Function to be called on each visit within :meth:`anndata.AnnData.reduce`. + """Function to be called on each visit within `anndata.AnnData._reduce`. Parameters ---------- From fdd6b7ce0d6c17207bdfe4d6ef4eb6ccbb02ef3d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 13 Apr 2026 10:50:09 +0200 Subject: [PATCH 27/32] fix: remove unused docs / private type --- docs/api.md | 7 ------- src/anndata/_core/anndata.py | 2 +- src/anndata/_types.py | 32 ++++++++++++++++++++++++++++++++ src/anndata/types.py | 31 ------------------------------- 4 files changed, 33 insertions(+), 39 deletions(-) diff --git a/docs/api.md b/docs/api.md index 88d4abfaa..36b094e89 100644 --- a/docs/api.md +++ b/docs/api.md @@ -265,13 +265,6 @@ Types used by the former: abc.CSCDataset ``` -```{eval-rst} -.. autosummary:: - :toctree: generated/ - - types.ReduceFunc -``` - ```{eval-rst} diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index bf673814f..6b884c874 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -64,9 +64,9 @@ from scipy import sparse from zarr.storage import StoreLike - from anndata.types import ReduceFunc from anndata.typing import RWAble + from .._types import ReduceFunc from ..acc import AdRef, Array, MapAcc, RefAcc from ..compat import CSArray, CSMatrix from ..typing import AxisStorable, Index, Index1D, _Index1DNorm, _XDataType diff --git a/src/anndata/_types.py b/src/anndata/_types.py index 6006b31c3..514b8b1e1 100644 --- a/src/anndata/_types.py +++ b/src/anndata/_types.py @@ -14,7 +14,10 @@ from collections.abc import Mapping from typing import Any, TypeAlias + from pandas import DataFrame + from anndata._core.xarray import Dataset2D + from anndata.typing import AxisStorable, _XDataType from ._io.specs.registry import ( IOSpec, @@ -23,6 +26,9 @@ Reader, Writer, ) + from ._types import AnnDataElem + from .compat import XDataset + else: # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 type S = StorageType type RWAble = typing.RWAble @@ -216,3 +222,29 @@ def __call__( ] type Join_T = Literal["inner", "outer"] + + +class ReduceFunc[T](Protocol): + def __call__( + self, + elem: _XDataType | AxisStorable | DataFrame | XDataset, + *, + accumulate: T, + attr_name: AnnDataElem | None, + ) -> T: + """Function to be called on each visit within `anndata.AnnData._reduce`. + + Parameters + ---------- + elem + The current element. + accumulate + The value being accumulated. + ref_acc + A reference to help uses distinguish where they are in the `AnnData` object. + + Returns + ------- + An accumulated value + """ + ... diff --git a/src/anndata/types.py b/src/anndata/types.py index 27c3d7a0e..aa23d10f2 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -7,13 +7,8 @@ from typing import Any, Literal from array_api.latest import ArrayNamespace - from pandas import DataFrame - - from anndata.typing import AxisStorable, _XDataType from ._core.anndata import AnnData - from ._types import AnnDataElem - from .compat import XDataset @runtime_checkable @@ -53,29 +48,3 @@ def __dlpack__( copy: bool | None = None, ) -> Any: ... def __dlpack_device__(self) -> tuple[int, int]: ... - - -class ReduceFunc[T](Protocol): - def __call__( - self, - elem: _XDataType | AxisStorable | DataFrame | XDataset, - *, - accumulate: T, - attr_name: AnnDataElem | None, - ) -> T: - """Function to be called on each visit within `anndata.AnnData._reduce`. - - Parameters - ---------- - elem - The current element. - accumulate - The value being accumulated. - ref_acc - A reference to help uses distinguish where they are in the `AnnData` object. - - Returns - ------- - An accumulated value - """ - ... From 5760cb2ad93522a40935d505b28e4b468d518d2b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 14 Apr 2026 10:55:18 +0200 Subject: [PATCH 28/32] fix: nexting --- src/anndata/_core/anndata.py | 2 ++ tests/test_readwrite.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 6b884c874..885a5a7f1 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1465,6 +1465,8 @@ def predicate( ): if elem is None: return accumulate + if isinstance(elem, AnnData): + return accumulate and elem.can_write(store_type=store_type) if attr_name == "raw": accumulate = accumulate and type(elem.X) in writeable_elems return accumulate and all( diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 43694da77..8a665c778 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -135,22 +135,27 @@ def test_can_write( @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +@pytest.mark.parametrize("should_nest", [True, False], ids=["nest", "no_nest"]) @pytest.mark.parametrize("parent_elem", ["var", "uns", "raw"]) def test_can_not_write_with_custom_array( rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None, parent_elem: Literal["obs", "uns", "raw"], + *, + should_nest: bool, ): import pyarrow as pa adata, _ = rw if parent_elem == "raw": adata.raw = adata.copy() - getter = lambda: getattr(adata, parent_elem).var + getter = lambda adata: getattr(adata, parent_elem).var else: - getter = lambda: getattr(adata, parent_elem) - getter()["arrow_array"] = pd.arrays.ArrowExtensionArray( - pa.array([{"x": 1, "y": True}] * adata.shape[1]) + getter = lambda adata: getattr(adata, parent_elem) + if should_nest: + adata.uns["adata"] = adata.copy() + getter(adata.uns["adata"] if should_nest else adata)["arrow_array"] = ( + pd.arrays.ArrowExtensionArray(pa.array([{"x": 1, "y": True}] * adata.shape[1])) ) assert not adata.can_write(store_type=store_type) From 7382b67b86e90788ba85ec6473ada41e86c01bf3 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 14 Apr 2026 11:44:37 +0200 Subject: [PATCH 29/32] fix: ok --- tests/test_readwrite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 8a665c778..74b8f8221 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -14,7 +14,6 @@ import pandas as pd import pytest import zarr -import zarr.convenience from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix import anndata as ad From 3c747e1eef3aad38b2f960f067b3537d492b5a54 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 14 Apr 2026 11:57:39 +0200 Subject: [PATCH 30/32] fix: handle bad categoricals --- src/anndata/_core/anndata.py | 10 ++++++---- tests/test_readwrite.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 885a5a7f1..ff6e273d3 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1457,7 +1457,7 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate( + def predicate( # noqa: PLR0911 elem: RWAble, *, accumulate: bool, @@ -1467,6 +1467,11 @@ def predicate( return accumulate if isinstance(elem, AnnData): return accumulate and elem.can_write(store_type=store_type) + if isinstance(elem, pd.Categorical): + return accumulate and predicate(elem.categories, accumulate=accumulate) + if isinstance(elem, pd.Series): + # matches behavior in methods.py + return accumulate and predicate(elem._values, accumulate=accumulate) if attr_name == "raw": accumulate = accumulate and type(elem.X) in writeable_elems return accumulate and all( @@ -1487,9 +1492,6 @@ def predicate( return accumulate and all( predicate(elem[k], accumulate=accumulate) for k in elem ) - if isinstance(elem, pd.Series): - # matches behavior in methods.py - elem = elem._values return accumulate and type(elem) in writeable_elems return self._reduce(predicate, init=True) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 74b8f8221..cacc21249 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -133,6 +133,19 @@ def test_can_write( assert adata.can_write(store_type=store_type) +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_not_write_bad_categorical( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + + adata, _ = rw + adata.var["arrow_categorical_array"] = pd.Categorical.from_codes( + [i % 2 for i in range(adata.shape[1])], + categories=pd.arrays.IntervalArray.from_tuples([(0, 10), (20, 30)]), + ) + assert not adata.can_write(store_type=store_type) + + @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) @pytest.mark.parametrize("should_nest", [True, False], ids=["nest", "no_nest"]) @pytest.mark.parametrize("parent_elem", ["var", "uns", "raw"]) From 371d5354ea180af37efcfce304c05331f6388872 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 14 Apr 2026 12:31:51 +0200 Subject: [PATCH 31/32] fix: handle index / awkward --- src/anndata/_core/anndata.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index ff6e273d3..e528a4b0e 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -25,6 +25,7 @@ from .. import utils from .._settings import settings from ..compat import ( + AwkArray, DaskArray, IndexManager, XDataset, @@ -1469,13 +1470,20 @@ def predicate( # noqa: PLR0911 return accumulate and elem.can_write(store_type=store_type) if isinstance(elem, pd.Categorical): return accumulate and predicate(elem.categories, accumulate=accumulate) - if isinstance(elem, pd.Series): + if isinstance(elem, pd.Series | pd.Index): # matches behavior in methods.py return accumulate and predicate(elem._values, accumulate=accumulate) + if isinstance(elem, AwkArray): + import awkward as ak + + container = ak.to_buffers(ak.to_packed(elem)) + return accumulate and all( + predicate(v, accumulate=accumulate) for v in container[2].values() + ) if attr_name == "raw": accumulate = accumulate and type(elem.X) in writeable_elems return accumulate and all( - type(e[attr]) in writeable_elems + predicate(e[attr], accumulate=accumulate) for e in [elem.var, elem.varm] for attr in e ) From 5a0fded9d200340d4ecc2ab39432874316507ef7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 15 Apr 2026 10:20:13 +0200 Subject: [PATCH 32/32] refactor: `can_write` -> `unwriteable` --- docs/api.md | 2 +- docs/release-notes/2372.feat.md | 2 +- src/anndata/_core/anndata.py | 9 ++++++--- tests/test_readwrite.py | 6 +++--- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/api.md b/docs/api.md index 36b094e89..14883f1e7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -92,7 +92,7 @@ Writing a complete {class}`AnnData` object to disk in anndata’s native formats AnnData.write_h5ad AnnData.write_zarr - AnnData.can_write + AnnData.unwriteable .. diff --git a/docs/release-notes/2372.feat.md b/docs/release-notes/2372.feat.md index b840039fe..61e9fa0bc 100644 --- a/docs/release-notes/2372.feat.md +++ b/docs/release-notes/2372.feat.md @@ -1 +1 @@ -New {meth}`anndata.AnnData.can_write` for checking if an `AnnData` can be written {user}`ilan-gold` +New {meth}`anndata.AnnData.unwriteable` for checking if an `AnnData` can be written {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index e528a4b0e..8717ca6af 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1437,7 +1437,7 @@ def _reduce[T]( accumulate = func(attr, accumulate=accumulate, attr_name=attr_name) return accumulate - def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + def unwriteable(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: """Whether or not an `AnnData` object can be written to disk for a given store type. Parameters @@ -1447,7 +1447,10 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: Returns ------- - Whether or not this object is writable. + Whether or not this object is writeable. + While the return type may change to include richer output about which elements cannot be written, + this new type's evaluation as a boolean will not change from the current behavior i.e., + `bool(adata.unwriteable())` will always evaluate the same. """ from anndata._io.specs.registry import _REGISTRY @@ -1467,7 +1470,7 @@ def predicate( # noqa: PLR0911 if elem is None: return accumulate if isinstance(elem, AnnData): - return accumulate and elem.can_write(store_type=store_type) + return accumulate and elem.unwriteable(store_type=store_type) if isinstance(elem, pd.Categorical): return accumulate and predicate(elem.categories, accumulate=accumulate) if isinstance(elem, pd.Series | pd.Index): diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index cacc21249..d9ea1e595 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -130,7 +130,7 @@ def test_can_write( rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None ): adata, _ = rw - assert adata.can_write(store_type=store_type) + assert adata.unwriteable(store_type=store_type) @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) @@ -143,7 +143,7 @@ def test_can_not_write_bad_categorical( [i % 2 for i in range(adata.shape[1])], categories=pd.arrays.IntervalArray.from_tuples([(0, 10), (20, 30)]), ) - assert not adata.can_write(store_type=store_type) + assert not adata.unwriteable(store_type=store_type) @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) @@ -169,7 +169,7 @@ def test_can_not_write_with_custom_array( getter(adata.uns["adata"] if should_nest else adata)["arrow_array"] = ( pd.arrays.ArrowExtensionArray(pa.array([{"x": 1, "y": True}] * adata.shape[1])) ) - assert not adata.can_write(store_type=store_type) + assert not adata.unwriteable(store_type=store_type) @pytest.mark.parametrize("typ", ARRAY_TYPES)