From 6383dd15d92b78dcc39429f6060a7fdd5b584f95 Mon Sep 17 00:00:00 2001 From: Dominik Date: Fri, 12 Dec 2025 16:21:25 +0100 Subject: [PATCH 01/14] adata.extensions module --- src/anndata/_repr/__init__.py | 30 ++++++-- src/anndata/extensions.py | 110 ++++++++++++++++++++++++++++++ tests/test_repr_html.py | 30 ++++---- tests/visual_inspect_repr_html.py | 6 +- 4 files changed, 153 insertions(+), 23 deletions(-) create mode 100644 src/anndata/extensions.py diff --git a/src/anndata/_repr/__init__.py b/src/anndata/_repr/__init__.py index ea378af8a..c60c1a7db 100644 --- a/src/anndata/_repr/__init__.py +++ b/src/anndata/_repr/__init__.py @@ -10,6 +10,18 @@ - Support for nested AnnData objects - Graceful handling of unknown types +.. note:: + + For extending AnnData with custom formatters, prefer importing from + :mod:`anndata.extensions` which provides the public API:: + + from anndata.extensions import ( + register_formatter, + TypeFormatter, + SectionFormatter, + FormattedOutput, + ) + Extensibility ------------- The system is designed to be extensible via two registry patterns: @@ -24,7 +36,7 @@ Example - format by Python type:: - from anndata._repr import register_formatter, TypeFormatter, FormattedOutput + from anndata.extensions import register_formatter, TypeFormatter, FormattedOutput @register_formatter @@ -42,8 +54,12 @@ def format(self, obj, context): Example - format by embedded type hint (for tagged data in uns):: - from anndata._repr import register_formatter, TypeFormatter, FormattedOutput - from anndata._repr import extract_uns_type_hint + from anndata.extensions import ( + register_formatter, + TypeFormatter, + FormattedOutput, + extract_uns_type_hint, + ) @register_formatter @@ -80,8 +96,12 @@ def format(self, obj, context): Example:: - from anndata._repr import register_formatter, SectionFormatter - from anndata._repr import FormattedEntry, FormattedOutput + from anndata.extensions import ( + register_formatter, + SectionFormatter, + FormattedEntry, + FormattedOutput, + ) @register_formatter diff --git a/src/anndata/extensions.py b/src/anndata/extensions.py new file mode 100644 index 000000000..a9bed0233 --- /dev/null +++ b/src/anndata/extensions.py @@ -0,0 +1,110 @@ +""" +Public API for extending AnnData functionality. + +This module provides registration mechanisms for: + +1. **Accessors** - Add custom namespaces to AnnData objects (e.g., `adata.myns.method()`) +2. **HTML Formatters** - Customize how types are displayed in Jupyter notebooks + +Examples +-------- +Register a custom accessor namespace:: + + import anndata as ad + from anndata.extensions import register_anndata_namespace + + @register_anndata_namespace("transform") + class TransformAccessor: + def __init__(self, adata: ad.AnnData): + self._adata = adata + + def log1p(self): + import numpy as np + self._adata.X = np.log1p(self._adata.X) + return self._adata + + # Usage: adata.transform.log1p() + +Register a custom HTML formatter for a type:: + + from anndata.extensions import register_formatter, TypeFormatter, FormattedOutput + + @register_formatter + class MyArrayFormatter(TypeFormatter): + priority = 100 # Higher = checked first + + def can_format(self, obj): + return isinstance(obj, MyArrayType) + + def format(self, obj, context): + return FormattedOutput( + type_name=f"MyArray {obj.shape}", + css_class="dtype-custom", + ) + +Register a custom section formatter (for packages like TreeData, SpatialData):: + + from anndata.extensions import register_formatter, SectionFormatter + from anndata.extensions import FormattedEntry, FormattedOutput + + @register_formatter + class ObstSectionFormatter(SectionFormatter): + section_name = "obst" + after_section = "obsm" # Position in display order + + def should_show(self, obj): + return hasattr(obj, "obst") and len(obj.obst) > 0 + + def get_entries(self, obj, context): + return [ + FormattedEntry( + key=k, + output=FormattedOutput(type_name=f"Tree ({v.n_nodes} nodes)"), + ) + for k, v in obj.obst.items() + ] + +See Also +-------- +anndata._repr : Full documentation of the HTML representation system +""" + +from __future__ import annotations + +# Accessor registration (from PR #1870) +from anndata._core.extensions import register_anndata_namespace + +# HTML representation formatters +from anndata._repr import ( + # Core formatter classes + FormattedEntry, + FormattedOutput, + FormatterContext, + FormatterRegistry, + SectionFormatter, + TypeFormatter, + # Registration function + register_formatter, + # Global registry instance + formatter_registry, + # Type hint utilities for tagged data + UNS_TYPE_HINT_KEY, + extract_uns_type_hint, +) + +__all__ = [ + # Accessor registration + "register_anndata_namespace", + # HTML formatter registration + "register_formatter", + "TypeFormatter", + "SectionFormatter", + "FormattedOutput", + "FormattedEntry", + "FormatterContext", + "FormatterRegistry", + "formatter_registry", + # Type hint utilities + "extract_uns_type_hint", + "UNS_TYPE_HINT_KEY", +] diff --git a/tests/test_repr_html.py b/tests/test_repr_html.py index 22969f3ca..33a8eeb7c 100644 --- a/tests/test_repr_html.py +++ b/tests/test_repr_html.py @@ -689,14 +689,14 @@ class TestFormatterRegistry: def test_registry_has_formatters(self): """Test registry contains registered formatters.""" - from anndata._repr.registry import formatter_registry + from anndata.extensions import formatter_registry # Should have some formatters registered assert len(formatter_registry._type_formatters) > 0 def test_custom_formatter_registration(self): """Test registering a custom formatter.""" - from anndata._repr.registry import ( + from anndata.extensions import ( FormattedOutput, FormatterContext, TypeFormatter, @@ -739,7 +739,7 @@ def format(self, obj: Any, context: FormatterContext) -> FormattedOutput: def test_fallback_formatter_for_unknown_types(self): """Test fallback formatter handles unknown types gracefully.""" - from anndata._repr.registry import FormatterContext, formatter_registry + from anndata.extensions import FormatterContext, formatter_registry class UnknownType: """An unknown type not in the registry.""" @@ -756,7 +756,7 @@ class UnknownType: def test_formatter_priority_order(self): """Test formatters are checked in priority order.""" - from anndata._repr.registry import formatter_registry + from anndata.extensions import formatter_registry # Verify formatters are sorted by priority (highest first) priorities = [f.priority for f in formatter_registry._type_formatters] @@ -764,7 +764,7 @@ def test_formatter_priority_order(self): def test_formatter_sections_filtering(self): """Test formatters are only applied to specified sections.""" - from anndata._repr.registry import ( + from anndata.extensions import ( FormattedOutput, FormatterContext, TypeFormatter, @@ -808,7 +808,7 @@ def format(self, obj: Any, context: FormatterContext) -> FormattedOutput: def test_formatter_sections_none_applies_everywhere(self): """Test formatters with sections=None apply to all sections.""" - from anndata._repr.registry import ( + from anndata.extensions import ( FormattedOutput, FormatterContext, TypeFormatter, @@ -846,7 +846,7 @@ def format(self, obj: Any, context: FormatterContext) -> FormattedOutput: def test_extension_type_graceful_handling(self): """Test extension types (like TreeData, MuData) are handled gracefully.""" - from anndata._repr.registry import FormatterContext, formatter_registry + from anndata.extensions import FormatterContext, formatter_registry # Simulate an extension type that has AnnData-like attributes # We create the class in a way that properly sets __module__ @@ -1704,7 +1704,7 @@ def test_extract_type_hint_malformed_string_format(self): def test_type_formatter_for_tagged_uns_data(self): """Test using TypeFormatter to handle tagged data in uns.""" - from anndata._repr import ( + from anndata.extensions import ( FormattedOutput, TypeFormatter, extract_uns_type_hint, @@ -1761,7 +1761,7 @@ def test_unregistered_type_hint_shows_import_message(self): def test_formatter_error_handled_gracefully(self): """Test that TypeFormatter errors don't crash the repr.""" - from anndata._repr import ( + from anndata.extensions import ( TypeFormatter, extract_uns_type_hint, formatter_registry, @@ -1815,7 +1815,7 @@ def test_string_format_type_hint_in_html(self): def test_type_hint_key_constant_exported(self): """Test that UNS_TYPE_HINT_KEY constant is properly exported.""" - from anndata._repr import UNS_TYPE_HINT_KEY + from anndata.extensions import UNS_TYPE_HINT_KEY assert UNS_TYPE_HINT_KEY == "__anndata_repr__" @@ -1919,7 +1919,7 @@ class TestSectionFormatterCoverage: def test_section_formatter_default_methods(self): """Test SectionFormatter default method implementations.""" - from anndata._repr.registry import SectionFormatter + from anndata.extensions import SectionFormatter class TestSectionFormatter(SectionFormatter): @property @@ -2606,14 +2606,14 @@ class TestRegistryAbstractMethods: def test_type_formatter_is_abstract(self): """Verify TypeFormatter cannot be instantiated directly.""" - from anndata._repr.registry import TypeFormatter + from anndata.extensions import TypeFormatter with pytest.raises(TypeError): TypeFormatter() def test_section_formatter_is_abstract(self): """Verify SectionFormatter cannot be instantiated directly.""" - from anndata._repr.registry import SectionFormatter + from anndata.extensions import SectionFormatter with pytest.raises(TypeError): SectionFormatter() @@ -2624,7 +2624,7 @@ class TestCustomHtmlContent: def test_inline_html_content(self): """Test inline (non-expandable) custom HTML content.""" - from anndata._repr.registry import ( + from anndata.extensions import ( FormattedOutput, TypeFormatter, formatter_registry, @@ -2669,7 +2669,7 @@ def format(self, obj, context): def test_expandable_html_content(self): """Test expandable custom HTML content (e.g., for TreeData visualization).""" - from anndata._repr.registry import ( + from anndata.extensions import ( FormattedOutput, TypeFormatter, formatter_registry, diff --git a/tests/visual_inspect_repr_html.py b/tests/visual_inspect_repr_html.py index 1dd66c21f..9408e3af9 100644 --- a/tests/visual_inspect_repr_html.py +++ b/tests/visual_inspect_repr_html.py @@ -21,7 +21,7 @@ import anndata as ad from anndata import AnnData -from anndata._repr import ( +from anndata.extensions import ( FormattedOutput, TypeFormatter, extract_uns_type_hint, @@ -40,7 +40,7 @@ import networkx as nx from treedata import TreeData - from anndata._repr import ( + from anndata.extensions import ( FormattedEntry, FormattedOutput, FormatterContext, @@ -238,7 +238,7 @@ def get_entries(self, obj, context: FormatterContext) -> list[FormattedEntry]: try: from mudata import MuData - from anndata._repr import ( + from anndata.extensions import ( FormattedEntry, FormattedOutput, FormatterContext, # noqa: TC001 From ee2d8af46195a323d42ecfe7099b61e6e6d7ddc4 Mon Sep 17 00:00:00 2001 From: Dominik Date: Fri, 12 Dec 2025 16:43:54 +0100 Subject: [PATCH 02/14] Fix Ruff RUF022 on __all__ --- src/anndata/_repr/__init__.py | 6 +++++- src/anndata/extensions.py | 17 +++++++++++------ tests/visual_inspect_repr_html.py | 4 ++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/anndata/_repr/__init__.py b/src/anndata/_repr/__init__.py index c60c1a7db..2ef1f48b8 100644 --- a/src/anndata/_repr/__init__.py +++ b/src/anndata/_repr/__init__.py @@ -36,7 +36,11 @@ Example - format by Python type:: - from anndata.extensions import register_formatter, TypeFormatter, FormattedOutput + from anndata.extensions import ( + register_formatter, + TypeFormatter, + FormattedOutput, + ) @register_formatter diff --git a/src/anndata/extensions.py b/src/anndata/extensions.py index a9bed0233..67a0d5de5 100644 --- a/src/anndata/extensions.py +++ b/src/anndata/extensions.py @@ -13,6 +13,7 @@ import anndata as ad from anndata.extensions import register_anndata_namespace + @register_anndata_namespace("transform") class TransformAccessor: def __init__(self, adata: ad.AnnData): @@ -20,15 +21,18 @@ def __init__(self, adata: ad.AnnData): def log1p(self): import numpy as np + self._adata.X = np.log1p(self._adata.X) return self._adata + # Usage: adata.transform.log1p() Register a custom HTML formatter for a type:: from anndata.extensions import register_formatter, TypeFormatter, FormattedOutput + @register_formatter class MyArrayFormatter(TypeFormatter): priority = 100 # Higher = checked first @@ -47,6 +51,7 @@ def format(self, obj, context): from anndata.extensions import register_formatter, SectionFormatter from anndata.extensions import FormattedEntry, FormattedOutput + @register_formatter class ObstSectionFormatter(SectionFormatter): section_name = "obst" @@ -76,6 +81,8 @@ def get_entries(self, obj, context): # HTML representation formatters from anndata._repr import ( + # Type hint utilities for tagged data + UNS_TYPE_HINT_KEY, # Core formatter classes FormattedEntry, FormattedOutput, @@ -83,16 +90,14 @@ def get_entries(self, obj, context): FormatterRegistry, SectionFormatter, TypeFormatter, - # Registration function - register_formatter, + extract_uns_type_hint, # Global registry instance formatter_registry, - # Type hint utilities for tagged data - UNS_TYPE_HINT_KEY, - extract_uns_type_hint, + # Registration function + register_formatter, ) -__all__ = [ +__all__ = [ # noqa: RUF022 # organized by category, not alphabetically # Accessor registration "register_anndata_namespace", # HTML formatter registration diff --git a/tests/visual_inspect_repr_html.py b/tests/visual_inspect_repr_html.py index 9408e3af9..6b62198dd 100644 --- a/tests/visual_inspect_repr_html.py +++ b/tests/visual_inspect_repr_html.py @@ -238,6 +238,8 @@ def get_entries(self, obj, context: FormatterContext) -> list[FormattedEntry]: try: from mudata import MuData + from anndata._repr.html import generate_repr_html + from anndata._repr.utils import format_number from anndata.extensions import ( FormattedEntry, FormattedOutput, @@ -245,8 +247,6 @@ def get_entries(self, obj, context: FormatterContext) -> list[FormattedEntry]: SectionFormatter, register_formatter, ) - from anndata._repr.html import generate_repr_html - from anndata._repr.utils import format_number HAS_MUDATA = True From 9377ad86de8e09dd8448d1e66f083505cbdc83bd Mon Sep 17 00:00:00 2001 From: Dominik Date: Fri, 12 Dec 2025 18:00:28 +0100 Subject: [PATCH 03/14] unified accessor + section viz pattern --- src/anndata/_core/extensions.py | 150 +++++++++++++++++++++++++++++- tests/test_repr_html.py | 91 ++++++++++++++++++ tests/visual_inspect_repr_html.py | 75 +++++++++++++++ 3 files changed, 315 insertions(+), 1 deletion(-) diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 180a15a61..031ba3917 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -10,12 +10,17 @@ if TYPE_CHECKING: from collections.abc import Callable + from anndata._repr.registry import FormattedEntry, FormatterContext + # Based off of the extension framework in Polars # https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py __all__ = ["register_anndata_namespace"] +# Protocol for accessors that provide section visualization +REPR_SECTION_METHOD = "_repr_section_" + # Reserved namespaces include accessors built into AnnData (currently there are none) # and all current attributes of AnnData @@ -121,6 +126,76 @@ def _check_namespace_signature(ns_class: type) -> None: raise TypeError(msg) +def _create_accessor_section_formatter( + name: str, ns_class: type[ExtensionNamespace] +) -> None: + """Create and register a SectionFormatter for an accessor with _repr_section_ method. + + This enables unified accessor + visualization registration. When an accessor + class defines a `_repr_section_` method, a SectionFormatter is automatically + registered that delegates to the accessor instance. + + Parameters + ---------- + name + The accessor name (used as section name) + ns_class + The accessor class that has a _repr_section_ method + """ + from anndata._repr.registry import ( + FormatterContext, + SectionFormatter, + register_formatter, + ) + + # Get optional section configuration from class attributes + after_section = getattr(ns_class, "section_after", None) + display_name = getattr(ns_class, "section_display_name", name) + tooltip = getattr(ns_class, "section_tooltip", "") + + class AccessorSectionFormatter(SectionFormatter): + """Auto-generated SectionFormatter that delegates to accessor._repr_section_.""" + + @property + def section_name(self) -> str: + return name + + @property + def display_name(self) -> str: + return display_name + + @property + def after_section(self) -> str | None: + return after_section + + @property + def tooltip(self) -> str: + return tooltip + + def should_show(self, obj: AnnData) -> bool: + if not hasattr(obj, name): + return False + accessor = getattr(obj, name) + if not hasattr(accessor, REPR_SECTION_METHOD): + return False + # Call _repr_section_ to check if it returns entries + result = getattr(accessor, REPR_SECTION_METHOD)(FormatterContext()) + return result is not None and len(result) > 0 + + def get_entries( + self, obj: AnnData, context: FormatterContext + ) -> list[FormattedEntry]: + accessor = getattr(obj, name) + result = getattr(accessor, REPR_SECTION_METHOD)(context) + return result if result is not None else [] + + # Give it a meaningful name for debugging + AccessorSectionFormatter.__name__ = f"{ns_class.__name__}SectionFormatter" + AccessorSectionFormatter.__qualname__ = f"{ns_class.__name__}SectionFormatter" + + register_formatter(AccessorSectionFormatter()) + + def _create_namespace[NameSpT: ExtensionNamespace]( name: str, cls: type[AnnData] ) -> Callable[[type[NameSpT]], type[NameSpT]]: @@ -138,6 +213,11 @@ def namespace(ns_class: type[NameSpT]) -> type[NameSpT]: ) setattr(cls, name, AccessorNameSpace(name, ns_class)) cls._accessors.add(name) + + # Auto-register SectionFormatter if accessor has _repr_section_ method + if hasattr(ns_class, REPR_SECTION_METHOD): + _create_accessor_section_formatter(name, ns_class) + return ns_class return namespace @@ -169,13 +249,31 @@ def register_anndata_namespace[NameSpT: ExtensionNamespace]( ----- Implementation requirements: - 1. The decorated class must have an `__init__` method that accepts exactly one parameter + 1. The decorated class must have an `__init__`` method that accepts exactly one parameter (besides `self`) named `adata` and annotated with type :class:`~anndata.AnnData`. 2. The namespace will be initialized with the AnnData object on first access and then cached on the instance. 3. If the namespace name conflicts with an existing namespace, a warning is issued. 4. If the namespace name conflicts with a built-in AnnData attribute, an AttributeError is raised. + HTML Representation + ~~~~~~~~~~~~~~~~~~~ + If the accessor class defines a ``_repr_section_`` method, a section will automatically + be added to the HTML representation. This enables unified accessor + visualization + registration with a single decorator. + + The ``_repr_section_`` method should have the signature:: + + def _repr_section_(self, context: FormatterContext) -> list[FormattedEntry] | None: + '''Return entries for HTML repr, or None to hide section.''' + ... + + Optional class attributes for section configuration: + + - ``section_after``: Section name after which this section appears (e.g., "obsm") + - ``section_display_name``: Display name for the section header (defaults to accessor name) + - ``section_tooltip``: Tooltip text for the section header + Examples -------- Simple transformation namespace with two methods: @@ -233,5 +331,55 @@ def register_anndata_namespace[NameSpT: ExtensionNamespace]( >>> adata.transform.arcsinh() # Transforms X and returns the AnnData object AnnData object with n_obs × n_vars = 100 × 2000 layers: 'log1p', 'arcsinh' + + Accessor with HTML section visualization: + + .. code-block:: python + + from anndata.extensions import ( + register_anndata_namespace, + FormattedEntry, + FormattedOutput, + ) + + + @register_anndata_namespace("spatial") + class SpatialAccessor: + # Optional: configure section positioning and display + section_after = "obsm" + section_display_name = "spatial" + section_tooltip = "Spatial data (images, coordinates)" + + def __init__(self, adata: ad.AnnData): + self._adata = adata + + @property + def images(self): + return self._adata.uns.get("spatial_images", {}) + + def add_image(self, key, image): + if "spatial_images" not in self._adata.uns: + self._adata.uns["spatial_images"] = {} + self._adata.uns["spatial_images"][key] = image + + def _repr_section_(self, context) -> list[FormattedEntry] | None: + '''Return entries for HTML repr, or None to hide section.''' + if not self.images: + return None + return [ + FormattedEntry( + key=k, + output=FormattedOutput( + type_name=f"Image {v.shape}", + css_class="dtype-array", + ), + ) + for k, v in self.images.items() + ] + + + # Usage: + adata.spatial.add_image("hires", np.zeros((100, 100, 3))) + adata._repr_html_() # Shows "spatial" section with "hires" entry """ return _create_namespace(name, AnnData) diff --git a/tests/test_repr_html.py b/tests/test_repr_html.py index 33a8eeb7c..29ac9e027 100644 --- a/tests/test_repr_html.py +++ b/tests/test_repr_html.py @@ -2619,6 +2619,97 @@ def test_section_formatter_is_abstract(self): SectionFormatter() +class TestUnifiedAccessorSection: + """Tests for unified accessor + section visualization via _repr_section_.""" + + def test_accessor_with_repr_section_creates_section(self): + """Test that accessor with _repr_section_ automatically gets a section.""" + from anndata.extensions import ( + FormattedEntry, + FormattedOutput, + formatter_registry, + register_anndata_namespace, + ) + + # Register accessor with _repr_section_ + @register_anndata_namespace("unified_test") + class UnifiedTestAccessor: + section_after = "obsm" + + def __init__(self, adata: AnnData): + self._adata = adata + + @property + def items(self): + return self._adata.uns.get("unified_items", {}) + + def add_item(self, key, value): + if "unified_items" not in self._adata.uns: + self._adata.uns["unified_items"] = {} + self._adata.uns["unified_items"][key] = value + + def _repr_section_(self, context): + if not self.items: + return None + return [ + FormattedEntry( + key=k, + output=FormattedOutput(type_name=f"Item: {v}"), + ) + for k, v in self.items.items() + ] + + try: + # Verify section formatter was registered + assert "unified_test" in formatter_registry._section_formatters + + # Test that section appears in HTML when items exist + adata = AnnData(np.zeros((5, 3))) + adata.unified_test.add_item("test_key", "test_value") + + html = adata._repr_html_() + assert "unified_test" in html + assert "test_key" in html + assert "Item: test_value" in html + + # Test that section is hidden when no items + adata2 = AnnData(np.zeros((5, 3))) + html2 = adata2._repr_html_() + assert "unified_test" not in html2 + finally: + # Cleanup: remove the registered accessor and formatter + if hasattr(AnnData, "unified_test"): + delattr(AnnData, "unified_test") + AnnData._accessors.discard("unified_test") + formatter_registry._section_formatters.pop("unified_test", None) + + def test_accessor_without_repr_section_no_section(self): + """Test that accessor without _repr_section_ doesn't create a section.""" + from anndata.extensions import formatter_registry, register_anndata_namespace + + # Register accessor WITHOUT _repr_section_ + @register_anndata_namespace("no_section_test") + class NoSectionAccessor: + def __init__(self, adata: AnnData): + self._adata = adata + + def do_something(self): + return "done" + + try: + # Verify no section formatter was registered + assert "no_section_test" not in formatter_registry._section_formatters + + # Accessor should still work + adata = AnnData(np.zeros((5, 3))) + assert adata.no_section_test.do_something() == "done" + finally: + # Cleanup + if hasattr(AnnData, "no_section_test"): + delattr(AnnData, "no_section_test") + AnnData._accessors.discard("no_section_test") + + class TestCustomHtmlContent: """Tests for custom HTML content in Type Formatters.""" diff --git a/tests/visual_inspect_repr_html.py b/tests/visual_inspect_repr_html.py index 6b62198dd..d8070dfcd 100644 --- a/tests/visual_inspect_repr_html.py +++ b/tests/visual_inspect_repr_html.py @@ -22,12 +22,67 @@ import anndata as ad from anndata import AnnData from anndata.extensions import ( + FormattedEntry, FormattedOutput, TypeFormatter, extract_uns_type_hint, + register_anndata_namespace, register_formatter, ) +# ============================================================================= +# Example: Unified accessor + section visualization +# ============================================================================= +# This demonstrates how to create an accessor that automatically gets a section +# in the HTML repr by defining a _repr_section_ method. + + +@register_anndata_namespace("spatial_demo") +class SpatialDemoAccessor: + """Demo accessor showing unified accessor + section visualization. + + This accessor provides functionality to store spatial images and + automatically displays them in the HTML representation. + """ + + section_after = "obsm" # Position section after obsm + section_display_name = "spatial" # Display name in HTML + section_tooltip = "Spatial data (images, coordinates)" + + def __init__(self, adata: AnnData): + self._adata = adata + + @property + def images(self) -> dict: + """Get stored spatial images.""" + return self._adata.uns.get("_spatial_images", {}) + + def add_image(self, key: str, image: np.ndarray) -> None: + """Add a spatial image.""" + if "_spatial_images" not in self._adata.uns: + self._adata.uns["_spatial_images"] = {} + self._adata.uns["_spatial_images"][key] = image + + def _repr_section_(self, context) -> list[FormattedEntry] | None: + """Return entries for HTML repr, or None to hide section. + + This method is automatically called by the HTML repr system + when this accessor is registered with register_anndata_namespace. + """ + if not self.images: + return None + return [ + FormattedEntry( + key=k, + output=FormattedOutput( + type_name=f"Image {v.shape}", + css_class="dtype-array", + ), + ) + for k, v in self.images.items() + ] + + # Check optional dependencies try: import dask.array as da @@ -1205,6 +1260,26 @@ def format(self, obj, context): else: print(" 19. MuData (skipped - mudata not installed)") + # Test 20: Unified accessor + section visualization + print(" 20. Unified accessor + section visualization (spatial_demo)") + adata_spatial = AnnData(np.random.randn(50, 20).astype(np.float32)) + adata_spatial.obs["cluster"] = pd.Categorical(["A", "B", "C"] * 16 + ["A", "B"]) + adata_spatial.obsm["X_spatial"] = np.random.randn(50, 2).astype(np.float32) + # Use the spatial_demo accessor to add images + adata_spatial.spatial_demo.add_image("hires", np.zeros((1000, 1000, 3))) + adata_spatial.spatial_demo.add_image("lowres", np.zeros((200, 200, 3))) + adata_spatial.spatial_demo.add_image("segmentation", np.zeros((1000, 1000))) + sections.append(( + "20. Unified Accessor + Section (spatial_demo)", + adata_spatial._repr_html_(), + "Demonstrates the unified accessor + section pattern. The @register_anndata_namespace " + "decorator registers both the accessor (adata.spatial_demo) AND a section in the HTML repr. " + "The accessor class defines _repr_section_(self, context) which returns a list of " + "FormattedEntry objects. Optional class attributes: section_after (positioning), " + "section_display_name, section_tooltip. This is the recommended pattern " + "for external packages (SpatialData, MuData) to add both functionality and visualization.", + )) + # Generate HTML file output_path = Path(__file__).parent / "repr_html_visual_test.html" html_content = create_html_page(sections) From 8696776d56f30992ea45c5df93333c39549c4427 Mon Sep 17 00:00:00 2001 From: Dominik Date: Fri, 12 Dec 2025 18:08:34 +0100 Subject: [PATCH 04/14] accessor section viz doc_url --- src/anndata/_core/extensions.py | 7 +++++ tests/test_repr_html.py | 47 +++++++++++++++++++++++++++++++ tests/visual_inspect_repr_html.py | 1 + 3 files changed, 55 insertions(+) diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 031ba3917..f7703ac19 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -152,6 +152,7 @@ class defines a `_repr_section_` method, a SectionFormatter is automatically after_section = getattr(ns_class, "section_after", None) display_name = getattr(ns_class, "section_display_name", name) tooltip = getattr(ns_class, "section_tooltip", "") + doc_url = getattr(ns_class, "section_doc_url", None) class AccessorSectionFormatter(SectionFormatter): """Auto-generated SectionFormatter that delegates to accessor._repr_section_.""" @@ -172,6 +173,10 @@ def after_section(self) -> str | None: def tooltip(self) -> str: return tooltip + @property + def doc_url(self) -> str | None: + return doc_url + def should_show(self, obj: AnnData) -> bool: if not hasattr(obj, name): return False @@ -273,6 +278,7 @@ def _repr_section_(self, context: FormatterContext) -> list[FormattedEntry] | No - ``section_after``: Section name after which this section appears (e.g., "obsm") - ``section_display_name``: Display name for the section header (defaults to accessor name) - ``section_tooltip``: Tooltip text for the section header + - ``section_doc_url``: URL to documentation (shown as link icon in header) Examples -------- @@ -349,6 +355,7 @@ class SpatialAccessor: section_after = "obsm" section_display_name = "spatial" section_tooltip = "Spatial data (images, coordinates)" + section_doc_url = "https://spatialdata.readthedocs.io/" def __init__(self, adata: ad.AnnData): self._adata = adata diff --git a/tests/test_repr_html.py b/tests/test_repr_html.py index 29ac9e027..759cf295d 100644 --- a/tests/test_repr_html.py +++ b/tests/test_repr_html.py @@ -2709,6 +2709,53 @@ def do_something(self): delattr(AnnData, "no_section_test") AnnData._accessors.discard("no_section_test") + def test_accessor_section_doc_url(self): + """Test that section_doc_url is passed through to the SectionFormatter.""" + from anndata.extensions import ( + FormattedEntry, + FormattedOutput, + formatter_registry, + register_anndata_namespace, + ) + + @register_anndata_namespace("docurl_test") + class DocUrlTestAccessor: + section_after = "obsm" + section_display_name = "docurl" + section_tooltip = "Test tooltip" + section_doc_url = "https://example.com/docs" + + def __init__(self, adata: AnnData): + self._adata = adata + + def _repr_section_(self, context): + return [ + FormattedEntry( + key="item", + output=FormattedOutput(type_name="test"), + ) + ] + + try: + # Verify section formatter was registered with doc_url + assert "docurl_test" in formatter_registry._section_formatters + section_formatter = formatter_registry._section_formatters["docurl_test"] + assert section_formatter.doc_url == "https://example.com/docs" + assert section_formatter.display_name == "docurl" + assert section_formatter.tooltip == "Test tooltip" + assert section_formatter.after_section == "obsm" + + # Test that doc URL appears in HTML + adata = AnnData(np.zeros((5, 3))) + html = adata._repr_html_() + assert "https://example.com/docs" in html + finally: + # Cleanup + if hasattr(AnnData, "docurl_test"): + delattr(AnnData, "docurl_test") + AnnData._accessors.discard("docurl_test") + formatter_registry._section_formatters.pop("docurl_test", None) + class TestCustomHtmlContent: """Tests for custom HTML content in Type Formatters.""" diff --git a/tests/visual_inspect_repr_html.py b/tests/visual_inspect_repr_html.py index d8070dfcd..52a140416 100644 --- a/tests/visual_inspect_repr_html.py +++ b/tests/visual_inspect_repr_html.py @@ -48,6 +48,7 @@ class SpatialDemoAccessor: section_after = "obsm" # Position section after obsm section_display_name = "spatial" # Display name in HTML section_tooltip = "Spatial data (images, coordinates)" + section_doc_url = "https://spatialdata.scverse.org/" # Documentation link def __init__(self, adata: AnnData): self._adata = adata From aa1c696f5a4d539e21b0d6326b83defc2b0364cf Mon Sep 17 00:00:00 2001 From: Dominik Date: Sun, 29 Mar 2026 23:54:51 -0700 Subject: [PATCH 05/14] feat: register_aligned_section for pluggable AnnData sections Add register_aligned_section() to anndata.extensions that allows external packages to register new axis-aligned sections (like obsm, layers) on AnnData without subclassing. A registered section gets: - Property accessor (adata.obst) - Axis-aligned storage with validation - Automatic subsetting (adata[:10].obst works) - IO integration (write/read to h5ad and zarr) - Repr discovery (shows in repr output) - Init kwargs (AnnData(obst={...})) Changes: - aligned_mapping.py: AlignedMappingProperty lazily inits backing store - extensions.py: SectionRegistration dataclass + register_aligned_section() - anndata.py: _registered_sections ClassVar, **extra_sections in init, registered sections in _gen_repr - methods.py: write_anndata/read_anndata iterate registered sections - h5ad.py: write_h5ad iterates registered sections --- src/anndata/_core/aligned_mapping.py | 6 +- src/anndata/_core/anndata.py | 11 +++ src/anndata/_core/extensions.py | 115 ++++++++++++++++++++++++++- src/anndata/_io/h5ad.py | 7 ++ src/anndata/_io/specs/methods.py | 11 +++ 5 files changed, 148 insertions(+), 2 deletions(-) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 3ac1c33d7..3406306eb 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -420,7 +420,11 @@ def __get__(self, obj: None | AnnData, objtype: type | None = None) -> T: # this needs to return a `property` instance, e.g. for Sphinx return self # type: ignore if not obj.is_view: - return self.construct(obj, store=getattr(obj, f"_{self.name}")) + store = getattr(obj, f"_{self.name}", None) + if store is None: + store = {} + setattr(obj, f"_{self.name}", store) + return self.construct(obj, store=store) parent_anndata = obj._adata_ref idxs = (obj._oidx, obj._vidx) parent: AlignedMapping = getattr(parent_anndata, self.name) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index b4e7fb3c2..e2715efe9 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -205,6 +205,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641 ) _accessors: ClassVar[set[str]] = set() + _registered_sections: ClassVar[dict] = {} # str -> SectionRegistration # view attributes _adata_ref: AnnData | None @@ -242,6 +243,7 @@ def __init__( # noqa: PLR0913 varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None, oidx: _Index1DNorm | int | np.integer | None = None, vidx: _Index1DNorm | int | np.integer | None = None, + **extra_sections, ): # check for any multi-indices that aren’t later checked in coerce_array for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]: @@ -270,6 +272,7 @@ def __init__( # noqa: PLR0913 varp=varp, filename=filename, filemode=filemode, + **extra_sections, ) def _init_as_view( @@ -361,6 +364,7 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 shape=None, filename=None, filemode=None, + **extra_sections, ): # view attributes self._is_view = False @@ -509,6 +513,12 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 # layers self.layers = layers + # registered sections (e.g., obst, vart from extensions) + for sec_name in self._registered_sections: + value = extra_sections.get(sec_name) + if value is not None: + setattr(self, sec_name, value) + @old_positionals("show_stratified", "with_disk") def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False @@ -556,6 +566,7 @@ def _gen_repr(self, n_obs, n_vars) -> str: "layers", "obsp", "varp", + *self._registered_sections, ]: keys = getattr(self, attr).keys() if len(keys) > 0: diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index f7703ac19..55c79d99f 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -16,7 +16,7 @@ # Based off of the extension framework in Polars # https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py -__all__ = ["register_anndata_namespace"] +__all__ = ["register_anndata_namespace", "register_aligned_section", "SectionRegistration"] # Protocol for accessors that provide section visualization REPR_SECTION_METHOD = "_repr_section_" @@ -390,3 +390,116 @@ def _repr_section_(self, context) -> list[FormattedEntry] | None: adata._repr_html_() # Shows "spatial" section with "hires" entry """ return _create_namespace(name, AnnData) + + +# --------------------------------------------------------------------------- +# Section registration +# --------------------------------------------------------------------------- + +from dataclasses import dataclass +from typing import Literal + + +@dataclass(frozen=True) +class SectionRegistration: + """Metadata for a registered aligned section. + + Instances are stored in ``AnnData._registered_sections``. + """ + + name: str + """Attribute name on AnnData (e.g., ``"obst"``).""" + mapping_type: Literal["axis", "pairwise", "layers"] + """Which AlignedMapping family to use.""" + axis: Literal[0, 1] | None + """``0`` for obs-aligned, ``1`` for var-aligned, ``None`` for layers-like.""" + allow_df: bool + """Whether DataFrames are allowed as values.""" + io_key: str + """Key used in h5ad/zarr files.""" + + +def register_aligned_section( + name: str, + *, + axis: Literal[0, 1] | None = None, + mapping_type: Literal["axis", "pairwise", "layers"] = "axis", + allow_df: bool = True, + io_key: str | None = None, +) -> None: + """Register a new axis-aligned section on :class:`~anndata.AnnData`. + + This allows external packages to add new mappings (like ``obsm``, ``layers``) + that participate in subsetting, IO, repr, and traversal without subclassing. + + Parameters + ---------- + name + Attribute name on AnnData (e.g., ``"obst"``). Becomes ``adata.obst``. + axis + ``0`` for obs-aligned, ``1`` for var-aligned, ``None`` for both-axes + (layers-like). + mapping_type + ``"axis"`` for :class:`AxisArrays` (like obsm/varm), + ``"pairwise"`` for :class:`PairwiseArrays` (like obsp/varp), + ``"layers"`` for :class:`Layers`. + allow_df + Whether to allow DataFrames as values. + io_key + Key used in h5ad/zarr files. Defaults to *name*. + + Examples + -------- + .. code-block:: python + + import anndata as ad + from anndata.extensions import register_aligned_section + + # Register at import time + register_aligned_section("obst", axis=0, mapping_type="axis") + + adata = ad.AnnData(obs=pd.DataFrame(index=["c1", "c2", "c3"])) + adata.obst["lineage"] = np.eye(3) # validates shape against n_obs + sub = adata[:2] # sub.obst["lineage"] is subsetted + adata.write("test.h5ad") # obst is written + adata2 = ad.read_h5ad("test.h5ad") # obst is read back + """ + from .aligned_mapping import ( + AlignedMappingProperty, + AxisArrays, + Layers, + PairwiseArrays, + ) + + if name in _reserved_namespaces: + msg = f"Cannot register section {name!r}: conflicts with existing AnnData attribute" + raise AttributeError(msg) + if name in AnnData._registered_sections: + msg = f"Section {name!r} is already registered" + raise ValueError(msg) + + # Select the right aligned mapping class + cls_map = { + "axis": AxisArrays, + "pairwise": PairwiseArrays, + "layers": Layers, + } + if mapping_type not in cls_map: + msg = f"Unknown mapping_type: {mapping_type!r}. Must be one of {list(cls_map)}" + raise ValueError(msg) + cls = cls_map[mapping_type] + + # Create and attach the property descriptor + prop = AlignedMappingProperty(name, cls, axis) + setattr(AnnData, name, prop) + + # Register in the class-level registry + reg = SectionRegistration( + name=name, + mapping_type=mapping_type, + axis=axis, + allow_df=allow_df, + io_key=io_key or name, + ) + AnnData._registered_sections[name] = reg + _reserved_namespaces.add(name) diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index d0540de1c..5f527546c 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -101,6 +101,13 @@ def write_h5ad( write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) + # Write registered sections (e.g., obst, vart from extensions) + for sec_name, sec_info in adata._registered_sections.items(): + mapping = getattr(adata, sec_name, None) + if mapping is not None and len(mapping) > 0: + write_elem( + f, sec_info.io_key, dict(mapping), dataset_kwargs=dataset_kwargs + ) def _write_x( diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 43b084a00..dd037b00c 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -297,6 +297,13 @@ def write_anndata( _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) + # Write registered sections (e.g., obst, vart from extensions) + for sec_name, sec_info in adata._registered_sections.items(): + mapping = getattr(adata, sec_name, None) + if mapping is not None and len(mapping) > 0: + _writer.write_elem( + g, sec_info.io_key, dict(mapping), dataset_kwargs=dataset_kwargs + ) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) @@ -321,6 +328,10 @@ def read_anndata(elem: _GroupStorageType | H5File, *, _reader: Reader) -> AnnDat ]: if k in elem: d[k] = _reader.read_elem(elem[k]) + # Read registered sections (e.g., obst, vart from extensions) + for sec_name, sec_info in AnnData._registered_sections.items(): + if sec_info.io_key in elem: + d[sec_name] = _reader.read_elem(elem[sec_info.io_key]) return AnnData(**d) From 0573becd64267bbe66b1a6926f2aee3603cbc024 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 00:01:22 -0700 Subject: [PATCH 06/14] fix: copy-on-write and attrname for registered sections - AlignedMappingProperty.construct sets _attrname_override so registered sections report their own name (e.g., "obst") instead of the default ("obsm") - AlignedView propagates _attrname_override from parent mapping - _mutated_copy includes registered sections in the copy loop - _init_as_actual copies registered sections when init from AnnData - _default_attrname replaces attrname in concrete bases (LayersBase, AxisArraysBase, PairwiseArraysBase) to support the override pattern - Add comprehensive test suite (35 tests) covering storage, validation, subsetting, copy-on-write, IO roundtrip, repr, and TreeData-like workflow --- src/anndata/_core/aligned_mapping.py | 32 +- src/anndata/_core/anndata.py | 8 +- src/anndata/_core/extensions.py | 6 +- tests/test_registered_sections.py | 431 +++++++++++++++++++++++++++ 4 files changed, 467 insertions(+), 10 deletions(-) create mode 100644 tests/test_registered_sections.py diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 3406306eb..38b0863be 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -98,10 +98,19 @@ def _validate_value(self, val: Value, key: str) -> Value: name = f"{self.attrname.title().rstrip('s')} {key!r}" return coerce_array(val, name=name, allow_df=self._allow_df) + _attrname_override: str | None = None + @property - @abstractmethod def attrname(self) -> str: """What attr for the AnnData is this?""" + if self._attrname_override is not None: + return self._attrname_override + return self._default_attrname + + @property + @abstractmethod + def _default_attrname(self) -> str: + """Default attr name derived from axis (e.g., 'obsm', 'varp').""" @property @abstractmethod @@ -151,6 +160,9 @@ def __init__(self, parent_mapping: P, parent_view: AnnData, subset_idx: I) -> No self.parent_mapping = parent_mapping self._parent = parent_view self.subset_idx = subset_idx + # Propagate attrname override from actual to view (for registered sections) + if parent_mapping._attrname_override is not None: + self._attrname_override = parent_mapping._attrname_override if hasattr(parent_mapping, "_axis"): # LayersBase has no _axis, the rest does self._axis = parent_mapping._axis # type: ignore @@ -237,7 +249,7 @@ class AxisArraysBase(AlignedMappingBase): _axis: Literal[0, 1] @property - def attrname(self) -> str: + def _default_attrname(self) -> str: return f"{self.dim}m" @property @@ -311,9 +323,12 @@ class LayersBase(AlignedMappingBase): """ _allow_df: ClassVar = False - attrname: ClassVar[Literal["layers"]] = "layers" axes: ClassVar[tuple[Literal[0], Literal[1]]] = (0, 1) + @property + def _default_attrname(self) -> str: + return "layers" + class Layers(AlignedActual, LayersBase): pass @@ -339,7 +354,7 @@ class PairwiseArraysBase(AlignedMappingBase): _axis: Literal[0, 1] @property - def attrname(self) -> str: + def _default_attrname(self) -> str: return f"{self.dim}p" @property @@ -402,8 +417,13 @@ class AlignedMappingProperty[T: AlignedMapping](property): def construct(self, obj: AnnData, *, store: MutableMapping[str, Value]) -> T: if self.axis is None: - return self.cls(obj, store=store) - return self.cls(obj, axis=self.axis, store=store) + mapping = self.cls(obj, store=store) + else: + mapping = self.cls(obj, axis=self.axis, store=store) + # Override attrname for registered sections (e.g., "obst" instead of "obsm") + if mapping._default_attrname != self.name: + mapping._attrname_override = self.name + return mapping @property def fget(self) -> Callable[[], None]: diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index e2715efe9..c6a9880a6 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -395,6 +395,12 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 if any((obs, var, uns, obsm, varm, obsp, varp)): msg = "If `X` is a dict no further arguments must be provided." raise ValueError(msg) + # Copy registered sections from source AnnData + for sec_name in self._registered_sections: + if sec_name not in extra_sections: + src_mapping = getattr(X, sec_name, None) + if src_mapping is not None and len(src_mapping) > 0: + extra_sections[sec_name] = dict(src_mapping) X, obs, var, uns, obsm, varm, obsp, varp, layers, raw = ( X._X, X.obs, @@ -1424,7 +1430,7 @@ def _mutated_copy(self, **kwargs) -> AnnData: raise NotImplementedError(msg) new = {} - for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers"]: + for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", *self._registered_sections]: if key in kwargs: new[key] = kwargs[key] else: diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 55c79d99f..fe1287bd6 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -471,12 +471,12 @@ def register_aligned_section( PairwiseArrays, ) - if name in _reserved_namespaces: - msg = f"Cannot register section {name!r}: conflicts with existing AnnData attribute" - raise AttributeError(msg) if name in AnnData._registered_sections: msg = f"Section {name!r} is already registered" raise ValueError(msg) + if name in _reserved_namespaces: + msg = f"Cannot register section {name!r}: conflicts with existing AnnData attribute" + raise AttributeError(msg) # Select the right aligned mapping class cls_map = { diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py new file mode 100644 index 000000000..9aa0ce67a --- /dev/null +++ b/tests/test_registered_sections.py @@ -0,0 +1,431 @@ +"""Tests for register_aligned_section. + +Validates that registered sections behave like built-in sections (obsm, layers, etc.) +for storage, subsetting, IO, repr, and init. Uses TreeData-like and SpatialData-like +scenarios to test real-world extension patterns. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from scipy.sparse import csr_matrix + +import anndata as ad +from anndata._core.extensions import register_aligned_section + + +# --------------------------------------------------------------------------- +# Fixtures: register sections once per test session +# --------------------------------------------------------------------------- + +# Use module-scoped registration so sections persist across tests. +# This mirrors real-world usage where registration happens at import time. + + +@pytest.fixture(autouse=True, scope="module") +def _register_test_sections(): + """Register TreeData-like and SpatialData-like sections for all tests.""" + # TreeData pattern: axis-aligned tree mappings + if "obst" not in ad.AnnData._registered_sections: + register_aligned_section("obst", axis=0, mapping_type="axis") + if "vart" not in ad.AnnData._registered_sections: + register_aligned_section("vart", axis=1, mapping_type="axis") + + # Pairwise pattern (like obsp but for a custom section) + if "obsd" not in ad.AnnData._registered_sections: + register_aligned_section("obsd", axis=0, mapping_type="pairwise") + + # Layers-like pattern (aligned to both obs and var) + if "extra_layers" not in ad.AnnData._registered_sections: + register_aligned_section("extra_layers", axis=None, mapping_type="layers") + + +@pytest.fixture +def adata(): + """Basic AnnData for testing.""" + return ad.AnnData( + X=np.ones((5, 3)), + obs=pd.DataFrame({"group": list("aabbc")}, index=[f"c{i}" for i in range(5)]), + var=pd.DataFrame({"gene": [f"g{i}" for i in range(3)]}, index=[f"v{i}" for i in range(3)]), + ) + + +# --------------------------------------------------------------------------- +# Registration API +# --------------------------------------------------------------------------- + + +class TestRegistrationAPI: + def test_register_creates_property(self): + """Registered section is accessible as a property on AnnData.""" + assert hasattr(ad.AnnData, "obst") + adata = ad.AnnData(np.ones((3, 4))) + # Should return an empty mapping + assert len(adata.obst) == 0 + + def test_register_duplicate_raises(self): + """Registering the same section twice raises ValueError.""" + with pytest.raises(ValueError, match="already registered"): + register_aligned_section("obst", axis=0) + + def test_register_reserved_name_raises(self): + """Registering a reserved name raises AttributeError.""" + with pytest.raises(AttributeError, match="conflicts with"): + register_aligned_section("obs", axis=0) + + def test_register_invalid_mapping_type_raises(self): + """Invalid mapping_type raises ValueError.""" + with pytest.raises(ValueError, match="Unknown mapping_type"): + register_aligned_section("bad_section", axis=0, mapping_type="invalid") + + def test_section_in_registry(self): + """Registered section appears in _registered_sections.""" + assert "obst" in ad.AnnData._registered_sections + reg = ad.AnnData._registered_sections["obst"] + assert reg.name == "obst" + assert reg.axis == 0 + assert reg.mapping_type == "axis" + + +# --------------------------------------------------------------------------- +# Storage and Validation (TreeData-like: obst, vart) +# --------------------------------------------------------------------------- + + +class TestAxisAlignedStorage: + def test_store_and_retrieve(self, adata): + """Can store and retrieve arrays in registered section.""" + tree = np.random.rand(5, 3) + adata.obst["lineage"] = tree + assert "lineage" in adata.obst + np.testing.assert_array_equal(adata.obst["lineage"], tree) + + def test_wrong_axis_shape_raises(self, adata): + """Storing array with wrong obs dimension raises.""" + with pytest.raises(ValueError, match="shape"): + adata.obst["bad"] = np.ones((10, 3)) # 10 != n_obs=5 + + def test_var_aligned_section(self, adata): + """var-aligned section validates against n_vars.""" + tree = np.random.rand(3, 2) # n_vars=3 + adata.vart["gene_tree"] = tree + assert adata.vart["gene_tree"].shape == (3, 2) + + def test_var_aligned_wrong_shape_raises(self, adata): + """var-aligned section rejects wrong shape.""" + with pytest.raises(ValueError, match="shape"): + adata.vart["bad"] = np.ones((10, 2)) # 10 != n_vars=3 + + def test_multiple_entries(self, adata): + """Can store multiple entries in a section.""" + adata.obst["tree1"] = np.eye(5) + adata.obst["tree2"] = np.random.rand(5, 4) + assert set(adata.obst.keys()) == {"tree1", "tree2"} + + def test_delete_entry(self, adata): + """Can delete entries from registered section.""" + adata.obst["tree"] = np.eye(5) + del adata.obst["tree"] + assert "tree" not in adata.obst + + def test_sparse_matrix(self, adata): + """Can store sparse matrices.""" + sparse = csr_matrix(np.eye(5)) + adata.obst["sparse_tree"] = sparse + assert adata.obst["sparse_tree"].shape == (5, 5) + + def test_dataframe_in_axis_section(self, adata): + """DataFrames are allowed in axis sections (like obsm).""" + df = pd.DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [5, 4, 3, 2, 1]}, + index=adata.obs_names, + ) + adata.obst["df_tree"] = df + assert adata.obst["df_tree"].shape == (5, 2) + + +# --------------------------------------------------------------------------- +# Pairwise Storage (like obsp) +# --------------------------------------------------------------------------- + + +class TestPairwiseStorage: + def test_pairwise_section(self, adata): + """Pairwise section stores square matrices.""" + dist = np.random.rand(5, 5) + adata.obsd["distances"] = dist + np.testing.assert_array_equal(adata.obsd["distances"], dist) + + def test_pairwise_wrong_shape_raises(self, adata): + """Pairwise section rejects non-square matrices.""" + with pytest.raises(ValueError, match="shape"): + adata.obsd["bad"] = np.ones((5, 3)) # not square + + +# --------------------------------------------------------------------------- +# Layers-like Storage (both axes) +# --------------------------------------------------------------------------- + + +class TestLayersLikeStorage: + def test_layers_like_section(self, adata): + """Layers-like section stores (n_obs, n_vars) matrices.""" + data = np.random.rand(5, 3) + adata.extra_layers["normalized"] = data + np.testing.assert_array_equal(adata.extra_layers["normalized"], data) + + def test_layers_like_wrong_shape_raises(self, adata): + """Layers-like section rejects wrong shape.""" + with pytest.raises(ValueError, match="shape"): + adata.extra_layers["bad"] = np.ones((5, 10)) # 10 != n_vars=3 + + +# --------------------------------------------------------------------------- +# Subsetting +# --------------------------------------------------------------------------- + + +class TestSubsetting: + def test_obs_subset(self, adata): + """Subsetting obs subsets axis-0 registered sections.""" + adata.obst["tree"] = np.arange(15).reshape(5, 3) + sub = adata[:3] + assert sub.obst["tree"].shape == (3, 3) + np.testing.assert_array_equal(sub.obst["tree"], np.arange(15).reshape(5, 3)[:3]) + + def test_var_subset(self, adata): + """Subsetting var subsets axis-1 registered sections.""" + adata.vart["gene_tree"] = np.arange(6).reshape(3, 2) + sub = adata[:, :2] + assert sub.vart["gene_tree"].shape == (2, 2) + + def test_pairwise_subset(self, adata): + """Subsetting obs subsets pairwise registered sections.""" + adata.obsd["dist"] = np.eye(5) + sub = adata[:3] + assert sub.obsd["dist"].shape == (3, 3) + + def test_layers_like_subset(self, adata): + """Subsetting subsets both axes of layers-like sections.""" + adata.extra_layers["data"] = np.arange(15).reshape(5, 3) + sub = adata[:3, :2] + assert sub.extra_layers["data"].shape == (3, 2) + + def test_view_copy_on_write(self, adata): + """Writing to a view's registered section triggers copy-on-write.""" + adata.obst["tree"] = np.eye(5) + sub = adata[:3] + sub.obst["new_tree"] = np.ones((3, 2)) + # sub should now be an actual (not view) with the new entry + assert not sub.is_view + assert "new_tree" in sub.obst + # original should be unchanged + assert "new_tree" not in adata.obst + + +# --------------------------------------------------------------------------- +# Init with kwargs +# --------------------------------------------------------------------------- + + +class TestInitKwargs: + def test_init_with_registered_section(self): + """Can pass registered section data as init kwarg.""" + adata = ad.AnnData( + np.ones((3, 4)), + obs=pd.DataFrame(index=["c1", "c2", "c3"]), + obst={"tree": np.eye(3)}, + ) + assert "tree" in adata.obst + assert adata.obst["tree"].shape == (3, 3) + + def test_init_with_multiple_sections(self): + """Can pass multiple registered sections at init.""" + adata = ad.AnnData( + np.ones((3, 4)), + obs=pd.DataFrame(index=["c1", "c2", "c3"]), + var=pd.DataFrame(index=["v1", "v2", "v3", "v4"]), + obst={"tree": np.eye(3)}, + vart={"gene_tree": np.eye(4)}, + ) + assert "tree" in adata.obst + assert "gene_tree" in adata.vart + + def test_init_without_registered_section(self): + """AnnData works normally without passing registered sections.""" + adata = ad.AnnData(np.ones((3, 4))) + assert len(adata.obst) == 0 + assert len(adata.vart) == 0 + + +# --------------------------------------------------------------------------- +# IO Roundtrip (h5ad) +# --------------------------------------------------------------------------- + + +class TestH5adRoundtrip: + def test_write_read_roundtrip(self, adata, tmp_path): + """Registered sections survive h5ad write/read.""" + adata.obst["tree"] = np.eye(5) + adata.vart["gene_tree"] = np.random.rand(3, 2) + + path = tmp_path / "test.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + + assert "tree" in adata2.obst + np.testing.assert_array_equal(adata2.obst["tree"], np.eye(5)) + assert "gene_tree" in adata2.vart + assert adata2.vart["gene_tree"].shape == (3, 2) + + def test_empty_section_not_written(self, adata, tmp_path): + """Empty registered sections are not written to disk.""" + import h5py + + path = tmp_path / "test.h5ad" + adata.write(path) + + with h5py.File(path, "r") as f: + assert "obst" not in f + assert "vart" not in f + + def test_subset_then_write(self, adata, tmp_path): + """Can subset, then write, and registered sections are preserved.""" + adata.obst["tree"] = np.arange(15).reshape(5, 3) + sub = adata[:3].copy() + + path = tmp_path / "test.h5ad" + sub.write(path) + sub2 = ad.read_h5ad(path) + + assert sub2.obst["tree"].shape == (3, 3) + + def test_read_without_section_registered(self, tmp_path): + """Files with extra groups are read fine even without registration. + + The extra data is silently skipped (backward compat). + """ + # Write with registration + adata = ad.AnnData( + np.ones((3, 4)), + obs=pd.DataFrame(index=["c1", "c2", "c3"]), + obst={"tree": np.eye(3)}, + ) + path = tmp_path / "test.h5ad" + adata.write(path) + + # Simulate reading without registration by checking the file directly + import h5py + + with h5py.File(path, "r") as f: + assert "obst" in f # Data is in the file + # Standard read_h5ad would skip it if not registered, + # but since we registered in this session, it will be read. + # This test just verifies the file format. + + +# --------------------------------------------------------------------------- +# Repr +# --------------------------------------------------------------------------- + + +class TestRepr: + def test_repr_shows_registered_section(self, adata): + """Registered sections appear in repr when non-empty.""" + adata.obst["tree"] = np.eye(5) + r = repr(adata) + assert "obst" in r + assert "tree" in r + + def test_repr_hides_empty_section(self, adata): + """Empty registered sections don't appear in repr.""" + r = repr(adata) + assert "obst" not in r + + def test_repr_html_shows_registered_section(self, adata): + """Registered sections appear in HTML repr.""" + adata.obst["tree"] = np.eye(5) + html = adata._repr_html_() + if html is not None: # HTML repr may not be enabled + assert "obst" in html + + +# --------------------------------------------------------------------------- +# Copy +# --------------------------------------------------------------------------- + + +class TestCopy: + def test_copy_preserves_sections(self, adata): + """copy() preserves registered sections.""" + adata.obst["tree"] = np.eye(5) + adata2 = adata.copy() + assert "tree" in adata2.obst + np.testing.assert_array_equal(adata2.obst["tree"], np.eye(5)) + + def test_copy_is_independent(self, adata): + """Modifications to copy don't affect original.""" + adata.obst["tree"] = np.eye(5) + adata2 = adata.copy() + adata2.obst["new"] = np.ones((5, 2)) + assert "new" not in adata.obst + + +# --------------------------------------------------------------------------- +# TreeData-like Scenario +# --------------------------------------------------------------------------- + + +class TestTreeDataScenario: + """End-to-end test mimicking how TreeData would use section registration.""" + + def test_treedata_workflow(self, tmp_path): + """Full TreeData-like workflow: create, populate, subset, IO.""" + # 1. Create AnnData with tree data + n_obs, n_vars = 10, 5 + adata = ad.AnnData( + X=np.random.rand(n_obs, n_vars), + obs=pd.DataFrame( + {"cell_type": pd.Categorical(["A"] * 5 + ["B"] * 5)}, + index=[f"cell_{i}" for i in range(n_obs)], + ), + var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_vars)]), + ) + + # 2. Add tree data (serialized as arrays, like TreeData does) + # In reality these would be serialized DiGraphs + lineage_tree = np.random.rand(n_obs, 4) # obs-aligned tree embedding + gene_tree = np.random.rand(n_vars, 3) # var-aligned tree embedding + + adata.obst["lineage"] = lineage_tree + adata.vart["phylogeny"] = gene_tree + + # 3. Verify storage + assert adata.obst["lineage"].shape == (n_obs, 4) + assert adata.vart["phylogeny"].shape == (n_vars, 3) + + # 4. Subset (like filtering cells) + mask = adata.obs["cell_type"] == "A" + sub = adata[mask] + assert sub.obst["lineage"].shape == (5, 4) + assert sub.vart["phylogeny"].shape == (n_vars, 3) # var unchanged + + # 5. Write and read back + path = tmp_path / "treedata.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + + assert set(adata2.obst.keys()) == {"lineage"} + assert set(adata2.vart.keys()) == {"phylogeny"} + np.testing.assert_array_almost_equal(adata2.obst["lineage"], lineage_tree) + np.testing.assert_array_almost_equal(adata2.vart["phylogeny"], gene_tree) + + # 6. Repr includes custom sections + r = repr(adata2) + assert "obst" in r + assert "vart" in r From 23302d9ff8abeb7e220169d9f967e963191d4ac7 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 00:01:55 -0700 Subject: [PATCH 07/14] feat: re-export register_aligned_section from anndata.extensions --- src/anndata/extensions.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/anndata/extensions.py b/src/anndata/extensions.py index 67a0d5de5..ccad0e7fa 100644 --- a/src/anndata/extensions.py +++ b/src/anndata/extensions.py @@ -4,7 +4,9 @@ This module provides registration mechanisms for: 1. **Accessors** - Add custom namespaces to AnnData objects (e.g., `adata.myns.method()`) -2. **HTML Formatters** - Customize how types are displayed in Jupyter notebooks +2. **Aligned Sections** - Add new axis-aligned mappings (e.g., `adata.obst`) with full + subsetting, IO, repr, and init support — no subclassing needed +3. **HTML Formatters** - Customize how types are displayed in Jupyter notebooks Examples -------- @@ -79,6 +81,9 @@ def get_entries(self, obj, context): # Accessor registration (from PR #1870) from anndata._core.extensions import register_anndata_namespace +# Section registration (pluggable aligned mappings) +from anndata._core.extensions import register_aligned_section, SectionRegistration + # HTML representation formatters from anndata._repr import ( # Type hint utilities for tagged data @@ -100,6 +105,9 @@ def get_entries(self, obj, context): __all__ = [ # noqa: RUF022 # organized by category, not alphabetically # Accessor registration "register_anndata_namespace", + # Section registration + "register_aligned_section", + "SectionRegistration", # HTML formatter registration "register_formatter", "TypeFormatter", From 9e0a14a89e42c38202b295dfc08956bc8c1223c4 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 10:12:02 -0700 Subject: [PATCH 08/14] feat: replace register_aligned_section with decorator-based register_section New @register_section decorator with: - Alignment as tuple of "obs"/"var" axes: ("obs",), ("obs","var"), ("obs","obs"), ("var","var"), () for unaligned - Custom value_type enforcement - Custom validate/subset/serialize/deserialize methods - Custom repr_entry for HTML repr - Auto-registers SectionFormatter for HTML repr New container classes in section_registry.py: - SectionMapping: validates on assignment (type, alignment, custom) - SectionMappingView: subsets on access, copy-on-write on mutation - SectionProperty: descriptor creating ephemeral containers 45 tests covering all alignment combinations, custom validation, custom IO, subsetting, copy-on-write, init kwargs, TreeData-like and SpatialData-like scenarios. --- src/anndata/_core/anndata.py | 2 +- src/anndata/_core/extensions.py | 249 +++++++---- src/anndata/_core/section_registry.py | 265 ++++++++++++ src/anndata/_io/h5ad.py | 23 +- src/anndata/_io/specs/methods.py | 19 +- src/anndata/extensions.py | 9 +- tests/test_registered_sections.py | 580 +++++++++++++++----------- 7 files changed, 795 insertions(+), 352 deletions(-) create mode 100644 src/anndata/_core/section_registry.py diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index c6a9880a6..032667dd8 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -205,7 +205,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641 ) _accessors: ClassVar[set[str]] = set() - _registered_sections: ClassVar[dict] = {} # str -> SectionRegistration + _registered_sections: ClassVar[dict] = {} # str -> SectionSpec # view attributes _adata_ref: AnnData | None diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index fe1287bd6..04f33e389 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -16,7 +16,7 @@ # Based off of the extension framework in Polars # https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py -__all__ = ["register_anndata_namespace", "register_aligned_section", "SectionRegistration"] +__all__ = ["register_anndata_namespace", "register_section", "SectionSpec"] # Protocol for accessors that provide section visualization REPR_SECTION_METHOD = "_repr_section_" @@ -396,110 +396,193 @@ def _repr_section_(self, context) -> list[FormattedEntry] | None: # Section registration # --------------------------------------------------------------------------- -from dataclasses import dataclass +from collections.abc import Callable from typing import Literal +from .section_registry import SectionProperty, SectionSpec -@dataclass(frozen=True) -class SectionRegistration: - """Metadata for a registered aligned section. - Instances are stored in ``AnnData._registered_sections``. - """ - - name: str - """Attribute name on AnnData (e.g., ``"obst"``).""" - mapping_type: Literal["axis", "pairwise", "layers"] - """Which AlignedMapping family to use.""" - axis: Literal[0, 1] | None - """``0`` for obs-aligned, ``1`` for var-aligned, ``None`` for layers-like.""" - allow_df: bool - """Whether DataFrames are allowed as values.""" - io_key: str - """Key used in h5ad/zarr files.""" - - -def register_aligned_section( +def register_section( name: str, *, - axis: Literal[0, 1] | None = None, - mapping_type: Literal["axis", "pairwise", "layers"] = "axis", - allow_df: bool = True, + alignment: tuple[Literal["obs", "var"], ...] = (), io_key: str | None = None, -) -> None: - """Register a new axis-aligned section on :class:`~anndata.AnnData`. +) -> Callable[[type], type]: + """Register a new section on :class:`~anndata.AnnData`. - This allows external packages to add new mappings (like ``obsm``, ``layers``) - that participate in subsetting, IO, repr, and traversal without subclassing. + Decorator that creates a section from a class definition. The class + can optionally define methods and attributes to customize behavior. Parameters ---------- name Attribute name on AnnData (e.g., ``"obst"``). Becomes ``adata.obst``. - axis - ``0`` for obs-aligned, ``1`` for var-aligned, ``None`` for both-axes - (layers-like). - mapping_type - ``"axis"`` for :class:`AxisArrays` (like obsm/varm), - ``"pairwise"`` for :class:`PairwiseArrays` (like obsp/varp), - ``"layers"`` for :class:`Layers`. - allow_df - Whether to allow DataFrames as values. + alignment + Tuple of axes each dimension is aligned to. Examples: + ``("obs",)`` for obs-aligned (like obsm), + ``("obs", "var")`` for both axes (like layers), + ``("obs", "obs")`` for pairwise (like obsp), + ``()`` for unaligned. io_key Key used in h5ad/zarr files. Defaults to *name*. + Class Attributes (all optional) + -------------------------------- + value_type : type + Type check on assignment (e.g., ``nx.DiGraph``). + section_after : str + Position in repr (e.g., ``"obsm"``). + section_tooltip : str + Hover text in HTML repr. + section_doc_url : str + Documentation link in HTML repr. + + Class Methods (all optional, must be static) + --------------------------------------------- + validate(key, value) + Custom validation on assignment. Raise on invalid. + subset(value, idx) + Custom subsetting for ``adata[idx]``. Default uses anndata's + ``_subset`` dispatch (works for arrays, sparse, DataFrames). + serialize(value) + Custom serialization for IO. Return a serializable object. + deserialize(data) + Custom deserialization for IO. + repr_entry(key, value, context) + Custom HTML repr formatting. Return ``FormattedOutput``. + Examples -------- + Simple axis-aligned section (arrays, no custom behavior): + .. code-block:: python - import anndata as ad - from anndata.extensions import register_aligned_section + @register_section("obst", alignment=("obs",)) + class ObstSection: + pass + + Full-featured section (TreeData-like): + + .. code-block:: python + + @register_section("obst", alignment=("obs",)) + class ObstSection: + value_type = nx.DiGraph + section_after = "obsm" + section_tooltip = "Observation trees" + + @staticmethod + def validate(key, value): + if not nx.is_tree(value): + raise ValueError(f"{key} must be a tree") + + @staticmethod + def subset(value, idx): + return subset_tree(value, idx) - # Register at import time - register_aligned_section("obst", axis=0, mapping_type="axis") + @staticmethod + def serialize(value): + return digraph_to_json(value) - adata = ad.AnnData(obs=pd.DataFrame(index=["c1", "c2", "c3"])) - adata.obst["lineage"] = np.eye(3) # validates shape against n_obs - sub = adata[:2] # sub.obst["lineage"] is subsetted - adata.write("test.h5ad") # obst is written - adata2 = ad.read_h5ad("test.h5ad") # obst is read back + @staticmethod + def deserialize(data): + return json_to_digraph(data) + + Unaligned section (SpatialData-like): + + .. code-block:: python + + @register_section("images", alignment=()) + class ImagesSection: + value_type = MultiscaleImage """ - from .aligned_mapping import ( - AlignedMappingProperty, - AxisArrays, - Layers, - PairwiseArrays, - ) - if name in AnnData._registered_sections: - msg = f"Section {name!r} is already registered" - raise ValueError(msg) - if name in _reserved_namespaces: - msg = f"Cannot register section {name!r}: conflicts with existing AnnData attribute" - raise AttributeError(msg) - - # Select the right aligned mapping class - cls_map = { - "axis": AxisArrays, - "pairwise": PairwiseArrays, - "layers": Layers, - } - if mapping_type not in cls_map: - msg = f"Unknown mapping_type: {mapping_type!r}. Must be one of {list(cls_map)}" - raise ValueError(msg) - cls = cls_map[mapping_type] - - # Create and attach the property descriptor - prop = AlignedMappingProperty(name, cls, axis) - setattr(AnnData, name, prop) - - # Register in the class-level registry - reg = SectionRegistration( - name=name, - mapping_type=mapping_type, - axis=axis, - allow_df=allow_df, - io_key=io_key or name, + def decorator(cls: type) -> type: + if name in AnnData._registered_sections: + msg = f"Section {name!r} is already registered" + raise ValueError(msg) + if name in _reserved_namespaces: + msg = f"Cannot register section {name!r}: conflicts with existing AnnData attribute" + raise AttributeError(msg) + + # Extract optional methods and attributes from the class + spec = SectionSpec( + name=name, + alignment=alignment, + io_key=io_key or name, + value_type=getattr(cls, "value_type", None), + validate_fn=getattr(cls, "validate", None), + subset_fn=getattr(cls, "subset", None), + serialize_fn=getattr(cls, "serialize", None), + deserialize_fn=getattr(cls, "deserialize", None), + repr_entry_fn=getattr(cls, "repr_entry", None), + section_after=getattr(cls, "section_after", None), + section_tooltip=getattr(cls, "section_tooltip", ""), + section_doc_url=getattr(cls, "section_doc_url", None), + ) + + # Create and attach the property descriptor + prop = SectionProperty(spec) + setattr(AnnData, name, prop) + + # Register + AnnData._registered_sections[name] = spec + _reserved_namespaces.add(name) + + # Auto-register SectionFormatter for HTML repr if repr metadata is present + if spec.section_after or spec.repr_entry_fn: + _create_section_repr_formatter(spec) + + return cls + + return decorator + + +def _create_section_repr_formatter(spec: SectionSpec) -> None: + """Auto-register a SectionFormatter for a registered section.""" + from anndata._repr.registry import ( + FormattedEntry, + FormattedOutput, + FormatterContext, + SectionFormatter, + register_formatter, ) - AnnData._registered_sections[name] = reg - _reserved_namespaces.add(name) + + class RegisteredSectionFormatter(SectionFormatter): + @property + def section_name(self) -> str: + return spec.name + + @property + def after_section(self) -> str | None: + return spec.section_after + + @property + def tooltip(self) -> str: + return spec.section_tooltip + + @property + def doc_url(self) -> str | None: + return spec.section_doc_url + + def should_show(self, obj: AnnData) -> bool: + mapping = getattr(obj, spec.name, None) + return mapping is not None and len(mapping) > 0 + + def get_entries( + self, obj: AnnData, context: FormatterContext + ) -> list[FormattedEntry]: + mapping = getattr(obj, spec.name) + entries = [] + for k in mapping: + if spec.repr_entry_fn is not None: + output = spec.repr_entry_fn(k, mapping[k], context) + else: + output = FormattedOutput( + type_name=type(mapping[k]).__name__, + ) + entries.append(FormattedEntry(key=k, output=output)) + return entries + + RegisteredSectionFormatter.__name__ = f"{spec.name}SectionFormatter" + register_formatter(RegisteredSectionFormatter()) diff --git a/src/anndata/_core/section_registry.py b/src/anndata/_core/section_registry.py new file mode 100644 index 000000000..27fa35613 --- /dev/null +++ b/src/anndata/_core/section_registry.py @@ -0,0 +1,265 @@ +"""Pluggable section registry for AnnData. + +Provides the infrastructure for :func:`~anndata.extensions.register_section`: +container classes, view handling, and property descriptors that let external +packages add new sections to AnnData without subclassing. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterator, Mapping, MutableMapping +from copy import copy +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from .views import view_update + +if TYPE_CHECKING: + from anndata import AnnData + + from .._repr.registry import FormattedOutput, FormatterContext + + +def _axis_len(value: Any, dim: int) -> int | None: + """Get length of value along a dimension, or None if not applicable.""" + if hasattr(value, "shape"): + shape = value.shape + if dim < len(shape): + return shape[dim] + return None + + +@dataclass(frozen=True) +class SectionSpec: + """Complete specification for a registered section. + + Created by :func:`register_section` from the decorated class. + """ + + name: str + """Attribute name on AnnData (e.g., ``"obst"``).""" + alignment: tuple[Literal["obs", "var"], ...] + """Axes each dimension is aligned to. Empty tuple for unaligned.""" + io_key: str + """Key used in h5ad/zarr files.""" + + # Optional callbacks extracted from the section class + value_type: type | None = None + validate_fn: Callable[[str, Any], None] | None = None + subset_fn: Callable[[Any, Any], Any] | None = None + serialize_fn: Callable[[Any], Any] | None = None + deserialize_fn: Callable[[Any], Any] | None = None + repr_entry_fn: Callable[[str, Any, FormatterContext], FormattedOutput] | None = None + + # Repr metadata + section_after: str | None = None + section_tooltip: str = "" + section_doc_url: str | None = None + + +class SectionMapping(MutableMapping): + """Container for a registered section's data. + + Validates values on assignment using the section's spec (type check, + alignment validation, custom validator). + """ + + def __init__( + self, parent: AnnData, spec: SectionSpec, data: dict | None = None + ) -> None: + self._parent = parent + self._spec = spec + self._data: dict[str, Any] = data if data is not None else {} + + def __repr__(self) -> str: + return f"{self._spec.name}: {', '.join(map(repr, self._data.keys()))}" + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __setitem__(self, key: str, value: Any) -> None: + # Type check + if self._spec.value_type is not None and not isinstance( + value, self._spec.value_type + ): + msg = ( + f"Values in {self._spec.name!r} must be {self._spec.value_type.__name__}, " + f"got {type(value).__name__}" + ) + raise TypeError(msg) + # Alignment validation + self._validate_alignment(key, value) + # Custom validation + if self._spec.validate_fn is not None: + self._spec.validate_fn(key, value) + self._data[key] = value + + def __delitem__(self, key: str) -> None: + del self._data[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __contains__(self, key: object) -> bool: + return key in self._data + + def _validate_alignment(self, key: str, value: Any) -> None: + """Check that value dimensions match the expected axes.""" + for i, axis in enumerate(self._spec.alignment): + expected = ( + self._parent.n_obs if axis == "obs" else self._parent.n_vars + ) + actual = _axis_len(value, i) + if actual is not None and actual != expected: + n_name = "n_obs" if axis == "obs" else "n_vars" + msg = ( + f"Value for {self._spec.name}[{key!r}] has shape[{i}]={actual}, " + f"expected {expected} ({n_name})" + ) + raise ValueError(msg) + + def copy(self) -> dict[str, Any]: + """Return a deep copy of the underlying data.""" + return { + k: copy(v) if not hasattr(v, "copy") else v.copy() + for k, v in self._data.items() + } + + +class SectionMappingView(Mapping): + """Read-only view of a registered section that subsets on access. + + Writing triggers copy-on-write via anndata's view_update mechanism. + """ + + def __init__( + self, + parent_mapping: SectionMapping, + parent_view: AnnData, + obs_idx: Any, + var_idx: Any, + ) -> None: + self._parent_mapping = parent_mapping + self._parent = parent_view + self._spec = parent_mapping._spec + self._obs_idx = obs_idx + self._var_idx = var_idx + + def __repr__(self) -> str: + return f"{self._spec.name} (view): {', '.join(map(repr, self._parent_mapping._data.keys()))}" + + def __getitem__(self, key: str) -> Any: + value = self._parent_mapping[key] + if not self._spec.alignment: + return value # unaligned, no subsetting + return self._subset_value(value) + + def __setitem__(self, key: str, value: Any) -> None: + from .._warnings import ImplicitModificationWarning + from ..utils import warn + + warn( + f"Setting element `.{self._spec.name}[{key!r}]` of view, " + "initializing view as actual.", + ImplicitModificationWarning, + ) + with view_update(self._parent, self._spec.name, ()) as new_mapping: + new_mapping[key] = value + + def __delitem__(self, key: str) -> None: + from .._warnings import ImplicitModificationWarning + from ..utils import warn + + if key not in self: + msg = f"{key!r} not found in view of {self._spec.name}" + raise KeyError(msg) + warn( + f"Removing element `.{self._spec.name}[{key!r}]` of view, " + "initializing view as actual.", + ImplicitModificationWarning, + ) + with view_update(self._parent, self._spec.name, ()) as new_mapping: + del new_mapping[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._parent_mapping) + + def __len__(self) -> int: + return len(self._parent_mapping) + + def __contains__(self, key: object) -> bool: + return key in self._parent_mapping + + def _subset_value(self, value: Any) -> Any: + """Subset a value according to the alignment tuple.""" + idx = self._build_index() + if self._spec.subset_fn is not None: + return self._spec.subset_fn(value, idx) + # Default: use anndata's _subset + from .index import _subset + + return _subset(value, idx) + + def _build_index(self) -> tuple: + """Build the index tuple from alignment and view indices.""" + indices = [] + for axis in self._spec.alignment: + if axis == "obs": + indices.append(self._obs_idx) + elif axis == "var": + indices.append(self._var_idx) + return tuple(indices) + + def copy(self) -> dict[str, Any]: + """Copy with subsetting applied.""" + return {k: self[k].copy() if hasattr(self[k], "copy") else self[k] for k in self} + + +class SectionProperty: + """Descriptor for registered sections on AnnData. + + Creates ephemeral SectionMapping / SectionMappingView on access, + similar to AlignedMappingProperty for built-in sections. + """ + + def __init__(self, spec: SectionSpec) -> None: + self.spec = spec + + def __get__(self, obj: AnnData | None, objtype: type | None = None) -> Any: + if obj is None: + return self + if not obj.is_view: + data = getattr(obj, f"_{self.spec.name}", None) + if data is None: + data = {} + setattr(obj, f"_{self.spec.name}", data) + return SectionMapping(obj, self.spec, data) + # View: create subsetting view + parent = obj._adata_ref + parent_mapping = getattr(parent, self.spec.name) + return SectionMappingView( + parent_mapping, obj, obj._oidx, obj._vidx + ) + + def __set__( + self, obj: AnnData, value: Mapping[str, Any] | None + ) -> None: + if value is None: + value = {} + if isinstance(value, (SectionMapping, SectionMappingView)): + value = dict(value) + elif isinstance(value, Mapping): + value = dict(value) + # Validate all values via SectionMapping + mapping = SectionMapping(obj, self.spec, {}) + for k, v in value.items(): + mapping[k] = v # validates each + if obj.is_view: + obj._init_as_actual(obj.copy()) + setattr(obj, f"_{self.spec.name}", mapping._data) + + def __delete__(self, obj: AnnData) -> None: + setattr(obj, f"_{self.spec.name}", {}) diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index 5f527546c..c7cc99db2 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -102,12 +102,14 @@ def write_h5ad( write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) # Write registered sections (e.g., obst, vart from extensions) - for sec_name, sec_info in adata._registered_sections.items(): + for sec_name, spec in adata._registered_sections.items(): mapping = getattr(adata, sec_name, None) if mapping is not None and len(mapping) > 0: - write_elem( - f, sec_info.io_key, dict(mapping), dataset_kwargs=dataset_kwargs - ) + if spec.serialize_fn is not None: + data = {k: spec.serialize_fn(v) for k, v in mapping.items()} + else: + data = dict(mapping) + write_elem(f, spec.io_key, data, dataset_kwargs=dataset_kwargs) def _write_x( @@ -269,13 +271,22 @@ def read_h5ad( def callback(read_func, elem_name: str, elem: StorageType, iospec: IOSpec): if iospec.encoding_type == "anndata" or elem_name.endswith("/"): - return AnnData(**{ + d = { # This is covering up backwards compat in the anndata initializer # In most cases we should be able to call `func(elen[k])` instead k: read_dispatched(elem[k], callback) for k in elem if not k.startswith("raw.") - }) + } + # Deserialize registered sections + for sec_name, spec in AnnData._registered_sections.items(): + if spec.io_key in d and spec.deserialize_fn is not None: + data = d[spec.io_key] + if isinstance(data, dict): + d[spec.io_key] = { + k: spec.deserialize_fn(v) for k, v in data.items() + } + return AnnData(**d) elif elem_name.startswith("/raw."): return None elif elem_name == "/X" and "X" in as_sparse: diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index dd037b00c..370202772 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -298,12 +298,14 @@ def write_anndata( _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) # Write registered sections (e.g., obst, vart from extensions) - for sec_name, sec_info in adata._registered_sections.items(): + for sec_name, spec in adata._registered_sections.items(): mapping = getattr(adata, sec_name, None) if mapping is not None and len(mapping) > 0: - _writer.write_elem( - g, sec_info.io_key, dict(mapping), dataset_kwargs=dataset_kwargs - ) + if spec.serialize_fn is not None: + data = {k: spec.serialize_fn(v) for k, v in mapping.items()} + else: + data = dict(mapping) + _writer.write_elem(g, spec.io_key, data, dataset_kwargs=dataset_kwargs) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) @@ -329,9 +331,12 @@ def read_anndata(elem: _GroupStorageType | H5File, *, _reader: Reader) -> AnnDat if k in elem: d[k] = _reader.read_elem(elem[k]) # Read registered sections (e.g., obst, vart from extensions) - for sec_name, sec_info in AnnData._registered_sections.items(): - if sec_info.io_key in elem: - d[sec_name] = _reader.read_elem(elem[sec_info.io_key]) + for sec_name, spec in AnnData._registered_sections.items(): + if spec.io_key in elem: + data = _reader.read_elem(elem[spec.io_key]) + if spec.deserialize_fn is not None and isinstance(data, dict): + data = {k: spec.deserialize_fn(v) for k, v in data.items()} + d[sec_name] = data return AnnData(**d) diff --git a/src/anndata/extensions.py b/src/anndata/extensions.py index ccad0e7fa..b0b1a6efe 100644 --- a/src/anndata/extensions.py +++ b/src/anndata/extensions.py @@ -81,8 +81,9 @@ def get_entries(self, obj, context): # Accessor registration (from PR #1870) from anndata._core.extensions import register_anndata_namespace -# Section registration (pluggable aligned mappings) -from anndata._core.extensions import register_aligned_section, SectionRegistration +# Section registration (pluggable sections with custom alignment, IO, validation) +from anndata._core.extensions import register_section +from anndata._core.section_registry import SectionSpec # HTML representation formatters from anndata._repr import ( @@ -106,8 +107,8 @@ def get_entries(self, obj, context): # Accessor registration "register_anndata_namespace", # Section registration - "register_aligned_section", - "SectionRegistration", + "register_section", + "SectionSpec", # HTML formatter registration "register_formatter", "TypeFormatter", diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py index 9aa0ce67a..4919ceac1 100644 --- a/tests/test_registered_sections.py +++ b/tests/test_registered_sections.py @@ -1,14 +1,13 @@ -"""Tests for register_aligned_section. +"""Tests for register_section decorator. -Validates that registered sections behave like built-in sections (obsm, layers, etc.) -for storage, subsetting, IO, repr, and init. Uses TreeData-like and SpatialData-like -scenarios to test real-world extension patterns. +Validates that registered sections behave correctly for all alignment +combinations, custom validation, custom subsetting, custom IO, and +HTML repr integration. Uses TreeData-like and SpatialData-like scenarios. """ from __future__ import annotations -import tempfile -from pathlib import Path +import json import numpy as np import pandas as pd @@ -16,33 +15,95 @@ from scipy.sparse import csr_matrix import anndata as ad -from anndata._core.extensions import register_aligned_section +from anndata.extensions import register_section # --------------------------------------------------------------------------- # Fixtures: register sections once per test session # --------------------------------------------------------------------------- -# Use module-scoped registration so sections persist across tests. -# This mirrors real-world usage where registration happens at import time. - @pytest.fixture(autouse=True, scope="module") def _register_test_sections(): - """Register TreeData-like and SpatialData-like sections for all tests.""" - # TreeData pattern: axis-aligned tree mappings - if "obst" not in ad.AnnData._registered_sections: - register_aligned_section("obst", axis=0, mapping_type="axis") - if "vart" not in ad.AnnData._registered_sections: - register_aligned_section("vart", axis=1, mapping_type="axis") + """Register test sections for all alignment combinations.""" + # obs-aligned (like obsm) + if "sec_obs" not in ad.AnnData._registered_sections: + + @register_section("sec_obs", alignment=("obs",)) + class SecObs: + pass + + # var-aligned (like varm) + if "sec_var" not in ad.AnnData._registered_sections: + + @register_section("sec_var", alignment=("var",)) + class SecVar: + pass + + # Both axes (like layers) + if "sec_both" not in ad.AnnData._registered_sections: + + @register_section("sec_both", alignment=("obs", "var")) + class SecBoth: + pass + + # Pairwise obs (like obsp) + if "sec_pair_obs" not in ad.AnnData._registered_sections: + + @register_section("sec_pair_obs", alignment=("obs", "obs")) + class SecPairObs: + pass + + # Pairwise var (like varp) + if "sec_pair_var" not in ad.AnnData._registered_sections: + + @register_section("sec_pair_var", alignment=("var", "var")) + class SecPairVar: + pass + + # Unaligned (like SpatialData images) + if "sec_unaligned" not in ad.AnnData._registered_sections: + + @register_section("sec_unaligned", alignment=()) + class SecUnaligned: + pass + + # Custom type validation (TreeData-like) + if "sec_typed" not in ad.AnnData._registered_sections: + + @register_section("sec_typed", alignment=("obs",)) + class SecTyped: + value_type = np.ndarray - # Pairwise pattern (like obsp but for a custom section) - if "obsd" not in ad.AnnData._registered_sections: - register_aligned_section("obsd", axis=0, mapping_type="pairwise") + @staticmethod + def validate(key, value): + if value.ndim != 2: + msg = f"{key} must be 2D" + raise ValueError(msg) - # Layers-like pattern (aligned to both obs and var) - if "extra_layers" not in ad.AnnData._registered_sections: - register_aligned_section("extra_layers", axis=None, mapping_type="layers") + # Custom serialize/deserialize + if "sec_custom_io" not in ad.AnnData._registered_sections: + + @register_section("sec_custom_io", alignment=("obs",)) + class SecCustomIO: + @staticmethod + def serialize(value): + # Convert dict to JSON string for storage + return json.dumps(value) + + @staticmethod + def deserialize(data): + return json.loads(data) + + # Custom subset + if "sec_custom_subset" not in ad.AnnData._registered_sections: + + @register_section("sec_custom_subset", alignment=("obs",)) + class SecCustomSubset: + @staticmethod + def subset(value, idx): + # Custom: return a dict describing the subset + return {"original": value, "subset_idx": idx} @pytest.fixture @@ -51,7 +112,7 @@ def adata(): return ad.AnnData( X=np.ones((5, 3)), obs=pd.DataFrame({"group": list("aabbc")}, index=[f"c{i}" for i in range(5)]), - var=pd.DataFrame({"gene": [f"g{i}" for i in range(3)]}, index=[f"v{i}" for i in range(3)]), + var=pd.DataFrame(index=[f"v{i}" for i in range(3)]), ) @@ -62,271 +123,287 @@ def adata(): class TestRegistrationAPI: def test_register_creates_property(self): - """Registered section is accessible as a property on AnnData.""" - assert hasattr(ad.AnnData, "obst") + assert hasattr(ad.AnnData, "sec_obs") adata = ad.AnnData(np.ones((3, 4))) - # Should return an empty mapping - assert len(adata.obst) == 0 + assert len(adata.sec_obs) == 0 def test_register_duplicate_raises(self): - """Registering the same section twice raises ValueError.""" with pytest.raises(ValueError, match="already registered"): - register_aligned_section("obst", axis=0) + register_section("sec_obs", alignment=("obs",))(type("Dup", (), {})) def test_register_reserved_name_raises(self): - """Registering a reserved name raises AttributeError.""" with pytest.raises(AttributeError, match="conflicts with"): - register_aligned_section("obs", axis=0) + register_section("obs", alignment=("obs",))(type("Bad", (), {})) - def test_register_invalid_mapping_type_raises(self): - """Invalid mapping_type raises ValueError.""" - with pytest.raises(ValueError, match="Unknown mapping_type"): - register_aligned_section("bad_section", axis=0, mapping_type="invalid") - - def test_section_in_registry(self): - """Registered section appears in _registered_sections.""" - assert "obst" in ad.AnnData._registered_sections - reg = ad.AnnData._registered_sections["obst"] - assert reg.name == "obst" - assert reg.axis == 0 - assert reg.mapping_type == "axis" + def test_all_sections_in_registry(self): + for name in [ + "sec_obs", + "sec_var", + "sec_both", + "sec_pair_obs", + "sec_pair_var", + "sec_unaligned", + "sec_typed", + ]: + assert name in ad.AnnData._registered_sections # --------------------------------------------------------------------------- -# Storage and Validation (TreeData-like: obst, vart) +# Obs-aligned: alignment=("obs",) # --------------------------------------------------------------------------- -class TestAxisAlignedStorage: +class TestObsAligned: def test_store_and_retrieve(self, adata): - """Can store and retrieve arrays in registered section.""" - tree = np.random.rand(5, 3) - adata.obst["lineage"] = tree - assert "lineage" in adata.obst - np.testing.assert_array_equal(adata.obst["lineage"], tree) - - def test_wrong_axis_shape_raises(self, adata): - """Storing array with wrong obs dimension raises.""" + arr = np.random.rand(5, 3) + adata.sec_obs["x"] = arr + np.testing.assert_array_equal(adata.sec_obs["x"], arr) + + def test_wrong_shape_raises(self, adata): with pytest.raises(ValueError, match="shape"): - adata.obst["bad"] = np.ones((10, 3)) # 10 != n_obs=5 + adata.sec_obs["bad"] = np.ones((10, 3)) - def test_var_aligned_section(self, adata): - """var-aligned section validates against n_vars.""" - tree = np.random.rand(3, 2) # n_vars=3 - adata.vart["gene_tree"] = tree - assert adata.vart["gene_tree"].shape == (3, 2) + def test_sparse(self, adata): + adata.sec_obs["sp"] = csr_matrix(np.eye(5)) + assert adata.sec_obs["sp"].shape == (5, 5) - def test_var_aligned_wrong_shape_raises(self, adata): - """var-aligned section rejects wrong shape.""" - with pytest.raises(ValueError, match="shape"): - adata.vart["bad"] = np.ones((10, 2)) # 10 != n_vars=3 - - def test_multiple_entries(self, adata): - """Can store multiple entries in a section.""" - adata.obst["tree1"] = np.eye(5) - adata.obst["tree2"] = np.random.rand(5, 4) - assert set(adata.obst.keys()) == {"tree1", "tree2"} - - def test_delete_entry(self, adata): - """Can delete entries from registered section.""" - adata.obst["tree"] = np.eye(5) - del adata.obst["tree"] - assert "tree" not in adata.obst - - def test_sparse_matrix(self, adata): - """Can store sparse matrices.""" - sparse = csr_matrix(np.eye(5)) - adata.obst["sparse_tree"] = sparse - assert adata.obst["sparse_tree"].shape == (5, 5) - - def test_dataframe_in_axis_section(self, adata): - """DataFrames are allowed in axis sections (like obsm).""" - df = pd.DataFrame( - {"a": [1, 2, 3, 4, 5], "b": [5, 4, 3, 2, 1]}, - index=adata.obs_names, - ) - adata.obst["df_tree"] = df - assert adata.obst["df_tree"].shape == (5, 2) + def test_subset_obs(self, adata): + adata.sec_obs["x"] = np.arange(15).reshape(5, 3) + sub = adata[:3] + assert sub.sec_obs["x"].shape == (3, 3) + + def test_subset_var_unchanged(self, adata): + adata.sec_obs["x"] = np.arange(15).reshape(5, 3) + sub = adata[:, :2] + # obs-aligned section not affected by var subsetting + assert sub.sec_obs["x"].shape == (5, 3) # --------------------------------------------------------------------------- -# Pairwise Storage (like obsp) +# Var-aligned: alignment=("var",) # --------------------------------------------------------------------------- -class TestPairwiseStorage: - def test_pairwise_section(self, adata): - """Pairwise section stores square matrices.""" - dist = np.random.rand(5, 5) - adata.obsd["distances"] = dist - np.testing.assert_array_equal(adata.obsd["distances"], dist) +class TestVarAligned: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(3, 2) + adata.sec_var["x"] = arr + np.testing.assert_array_equal(adata.sec_var["x"], arr) - def test_pairwise_wrong_shape_raises(self, adata): - """Pairwise section rejects non-square matrices.""" + def test_wrong_shape_raises(self, adata): with pytest.raises(ValueError, match="shape"): - adata.obsd["bad"] = np.ones((5, 3)) # not square + adata.sec_var["bad"] = np.ones((10, 2)) + + def test_subset_var(self, adata): + adata.sec_var["x"] = np.arange(6).reshape(3, 2) + sub = adata[:, :2] + assert sub.sec_var["x"].shape == (2, 2) + + def test_subset_obs_unchanged(self, adata): + adata.sec_var["x"] = np.arange(6).reshape(3, 2) + sub = adata[:3] + assert sub.sec_var["x"].shape == (3, 2) # --------------------------------------------------------------------------- -# Layers-like Storage (both axes) +# Both axes: alignment=("obs", "var") # --------------------------------------------------------------------------- -class TestLayersLikeStorage: - def test_layers_like_section(self, adata): - """Layers-like section stores (n_obs, n_vars) matrices.""" - data = np.random.rand(5, 3) - adata.extra_layers["normalized"] = data - np.testing.assert_array_equal(adata.extra_layers["normalized"], data) +class TestBothAxes: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(5, 3) + adata.sec_both["x"] = arr + np.testing.assert_array_equal(adata.sec_both["x"], arr) - def test_layers_like_wrong_shape_raises(self, adata): - """Layers-like section rejects wrong shape.""" + def test_wrong_obs_shape_raises(self, adata): with pytest.raises(ValueError, match="shape"): - adata.extra_layers["bad"] = np.ones((5, 10)) # 10 != n_vars=3 + adata.sec_both["bad"] = np.ones((10, 3)) + + def test_wrong_var_shape_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_both["bad"] = np.ones((5, 10)) + + def test_subset_both(self, adata): + adata.sec_both["x"] = np.arange(15).reshape(5, 3) + sub = adata[:3, :2] + assert sub.sec_both["x"].shape == (3, 2) # --------------------------------------------------------------------------- -# Subsetting +# Pairwise obs: alignment=("obs", "obs") # --------------------------------------------------------------------------- -class TestSubsetting: - def test_obs_subset(self, adata): - """Subsetting obs subsets axis-0 registered sections.""" - adata.obst["tree"] = np.arange(15).reshape(5, 3) +class TestPairwiseObs: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(5, 5) + adata.sec_pair_obs["dist"] = arr + np.testing.assert_array_equal(adata.sec_pair_obs["dist"], arr) + + def test_non_square_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_pair_obs["bad"] = np.ones((5, 3)) + + def test_subset(self, adata): + adata.sec_pair_obs["dist"] = np.eye(5) sub = adata[:3] - assert sub.obst["tree"].shape == (3, 3) - np.testing.assert_array_equal(sub.obst["tree"], np.arange(15).reshape(5, 3)[:3]) + assert sub.sec_pair_obs["dist"].shape == (3, 3) + + +# --------------------------------------------------------------------------- +# Pairwise var: alignment=("var", "var") +# --------------------------------------------------------------------------- + + +class TestPairwiseVar: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(3, 3) + adata.sec_pair_var["corr"] = arr + np.testing.assert_array_equal(adata.sec_pair_var["corr"], arr) - def test_var_subset(self, adata): - """Subsetting var subsets axis-1 registered sections.""" - adata.vart["gene_tree"] = np.arange(6).reshape(3, 2) + def test_subset(self, adata): + adata.sec_pair_var["corr"] = np.eye(3) sub = adata[:, :2] - assert sub.vart["gene_tree"].shape == (2, 2) + assert sub.sec_pair_var["corr"].shape == (2, 2) + + +# --------------------------------------------------------------------------- +# Unaligned: alignment=() +# --------------------------------------------------------------------------- + + +class TestUnaligned: + def test_store_anything(self, adata): + adata.sec_unaligned["img"] = np.random.rand(100, 100, 3) + assert adata.sec_unaligned["img"].shape == (100, 100, 3) + + def test_no_shape_validation(self, adata): + # Any shape is fine for unaligned + adata.sec_unaligned["a"] = np.ones((1,)) + adata.sec_unaligned["b"] = np.ones((999, 888)) + assert len(adata.sec_unaligned) == 2 - def test_pairwise_subset(self, adata): - """Subsetting obs subsets pairwise registered sections.""" - adata.obsd["dist"] = np.eye(5) + def test_subset_unchanged(self, adata): + adata.sec_unaligned["img"] = np.random.rand(100, 100, 3) sub = adata[:3] - assert sub.obsd["dist"].shape == (3, 3) + # Unaligned data is not subsetted + assert sub.sec_unaligned["img"].shape == (100, 100, 3) - def test_layers_like_subset(self, adata): - """Subsetting subsets both axes of layers-like sections.""" - adata.extra_layers["data"] = np.arange(15).reshape(5, 3) - sub = adata[:3, :2] - assert sub.extra_layers["data"].shape == (3, 2) + def test_non_array_values(self, adata): + adata.sec_unaligned["config"] = {"key": "value"} + assert adata.sec_unaligned["config"] == {"key": "value"} - def test_view_copy_on_write(self, adata): - """Writing to a view's registered section triggers copy-on-write.""" - adata.obst["tree"] = np.eye(5) + +# --------------------------------------------------------------------------- +# Custom type validation +# --------------------------------------------------------------------------- + + +class TestCustomValidation: + def test_type_check(self, adata): + adata.sec_typed["x"] = np.eye(5) + assert adata.sec_typed["x"].shape == (5, 5) + + def test_wrong_type_raises(self, adata): + with pytest.raises(TypeError, match="must be ndarray"): + adata.sec_typed["bad"] = [[1, 2], [3, 4]] + + def test_custom_validate(self, adata): + with pytest.raises(ValueError, match="must be 2D"): + adata.sec_typed["bad"] = np.ones(5) # 1D, not 2D + + +# --------------------------------------------------------------------------- +# Custom subset +# --------------------------------------------------------------------------- + + +class TestCustomSubset: + def test_custom_subset_fn(self, adata): + adata.sec_custom_subset["x"] = np.eye(5) sub = adata[:3] - sub.obst["new_tree"] = np.ones((3, 2)) - # sub should now be an actual (not view) with the new entry - assert not sub.is_view - assert "new_tree" in sub.obst - # original should be unchanged - assert "new_tree" not in adata.obst + result = sub.sec_custom_subset["x"] + assert isinstance(result, dict) + assert "original" in result + assert "subset_idx" in result # --------------------------------------------------------------------------- -# Init with kwargs +# Init kwargs # --------------------------------------------------------------------------- class TestInitKwargs: - def test_init_with_registered_section(self): - """Can pass registered section data as init kwarg.""" + def test_init_with_section(self): adata = ad.AnnData( np.ones((3, 4)), obs=pd.DataFrame(index=["c1", "c2", "c3"]), - obst={"tree": np.eye(3)}, + sec_obs={"x": np.eye(3)}, ) - assert "tree" in adata.obst - assert adata.obst["tree"].shape == (3, 3) + assert "x" in adata.sec_obs def test_init_with_multiple_sections(self): - """Can pass multiple registered sections at init.""" adata = ad.AnnData( np.ones((3, 4)), obs=pd.DataFrame(index=["c1", "c2", "c3"]), var=pd.DataFrame(index=["v1", "v2", "v3", "v4"]), - obst={"tree": np.eye(3)}, - vart={"gene_tree": np.eye(4)}, + sec_obs={"x": np.eye(3)}, + sec_var={"y": np.eye(4)}, ) - assert "tree" in adata.obst - assert "gene_tree" in adata.vart + assert "x" in adata.sec_obs + assert "y" in adata.sec_var - def test_init_without_registered_section(self): - """AnnData works normally without passing registered sections.""" + def test_init_without_section(self): adata = ad.AnnData(np.ones((3, 4))) - assert len(adata.obst) == 0 - assert len(adata.vart) == 0 + assert len(adata.sec_obs) == 0 # --------------------------------------------------------------------------- -# IO Roundtrip (h5ad) +# IO roundtrip (h5ad) # --------------------------------------------------------------------------- class TestH5adRoundtrip: - def test_write_read_roundtrip(self, adata, tmp_path): - """Registered sections survive h5ad write/read.""" - adata.obst["tree"] = np.eye(5) - adata.vart["gene_tree"] = np.random.rand(3, 2) - + def test_write_read_obs_aligned(self, adata, tmp_path): + adata.sec_obs["x"] = np.eye(5) path = tmp_path / "test.h5ad" adata.write(path) adata2 = ad.read_h5ad(path) + assert "x" in adata2.sec_obs + np.testing.assert_array_equal(adata2.sec_obs["x"], np.eye(5)) - assert "tree" in adata2.obst - np.testing.assert_array_equal(adata2.obst["tree"], np.eye(5)) - assert "gene_tree" in adata2.vart - assert adata2.vart["gene_tree"].shape == (3, 2) + def test_write_read_both_axes(self, adata, tmp_path): + adata.sec_both["x"] = np.arange(15).reshape(5, 3).astype(float) + path = tmp_path / "test.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + np.testing.assert_array_equal( + adata2.sec_both["x"], np.arange(15).reshape(5, 3) + ) def test_empty_section_not_written(self, adata, tmp_path): - """Empty registered sections are not written to disk.""" import h5py path = tmp_path / "test.h5ad" adata.write(path) - with h5py.File(path, "r") as f: - assert "obst" not in f - assert "vart" not in f + assert "sec_obs" not in f + + def test_custom_serialize_deserialize(self, adata, tmp_path): + adata.sec_custom_io["config"] = {"lr": 0.001, "epochs": 100} + path = tmp_path / "test.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert adata2.sec_custom_io["config"] == {"lr": 0.001, "epochs": 100} def test_subset_then_write(self, adata, tmp_path): - """Can subset, then write, and registered sections are preserved.""" - adata.obst["tree"] = np.arange(15).reshape(5, 3) + adata.sec_obs["x"] = np.arange(15).reshape(5, 3).astype(float) sub = adata[:3].copy() - path = tmp_path / "test.h5ad" sub.write(path) sub2 = ad.read_h5ad(path) - - assert sub2.obst["tree"].shape == (3, 3) - - def test_read_without_section_registered(self, tmp_path): - """Files with extra groups are read fine even without registration. - - The extra data is silently skipped (backward compat). - """ - # Write with registration - adata = ad.AnnData( - np.ones((3, 4)), - obs=pd.DataFrame(index=["c1", "c2", "c3"]), - obst={"tree": np.eye(3)}, - ) - path = tmp_path / "test.h5ad" - adata.write(path) - - # Simulate reading without registration by checking the file directly - import h5py - - with h5py.File(path, "r") as f: - assert "obst" in f # Data is in the file - # Standard read_h5ad would skip it if not registered, - # but since we registered in this session, it will be read. - # This test just verifies the file format. + assert sub2.sec_obs["x"].shape == (3, 3) # --------------------------------------------------------------------------- @@ -335,24 +412,12 @@ def test_read_without_section_registered(self, tmp_path): class TestRepr: - def test_repr_shows_registered_section(self, adata): - """Registered sections appear in repr when non-empty.""" - adata.obst["tree"] = np.eye(5) - r = repr(adata) - assert "obst" in r - assert "tree" in r + def test_repr_shows_section(self, adata): + adata.sec_obs["x"] = np.eye(5) + assert "sec_obs" in repr(adata) - def test_repr_hides_empty_section(self, adata): - """Empty registered sections don't appear in repr.""" - r = repr(adata) - assert "obst" not in r - - def test_repr_html_shows_registered_section(self, adata): - """Registered sections appear in HTML repr.""" - adata.obst["tree"] = np.eye(5) - html = adata._repr_html_() - if html is not None: # HTML repr may not be enabled - assert "obst" in html + def test_repr_hides_empty(self, adata): + assert "sec_obs" not in repr(adata) # --------------------------------------------------------------------------- @@ -361,32 +426,34 @@ def test_repr_html_shows_registered_section(self, adata): class TestCopy: - def test_copy_preserves_sections(self, adata): - """copy() preserves registered sections.""" - adata.obst["tree"] = np.eye(5) + def test_copy_preserves(self, adata): + adata.sec_obs["x"] = np.eye(5) adata2 = adata.copy() - assert "tree" in adata2.obst - np.testing.assert_array_equal(adata2.obst["tree"], np.eye(5)) + assert "x" in adata2.sec_obs + np.testing.assert_array_equal(adata2.sec_obs["x"], np.eye(5)) def test_copy_is_independent(self, adata): - """Modifications to copy don't affect original.""" - adata.obst["tree"] = np.eye(5) + adata.sec_obs["x"] = np.eye(5) adata2 = adata.copy() - adata2.obst["new"] = np.ones((5, 2)) - assert "new" not in adata.obst + adata2.sec_obs["new"] = np.ones((5, 2)) + assert "new" not in adata.sec_obs + + def test_view_copy_on_write(self, adata): + adata.sec_obs["x"] = np.eye(5) + sub = adata[:3] + sub.sec_obs["new"] = np.ones((3, 2)) + assert not sub.is_view + assert "new" in sub.sec_obs + assert "new" not in adata.sec_obs # --------------------------------------------------------------------------- -# TreeData-like Scenario +# TreeData-like end-to-end scenario # --------------------------------------------------------------------------- class TestTreeDataScenario: - """End-to-end test mimicking how TreeData would use section registration.""" - - def test_treedata_workflow(self, tmp_path): - """Full TreeData-like workflow: create, populate, subset, IO.""" - # 1. Create AnnData with tree data + def test_full_workflow(self, tmp_path): n_obs, n_vars = 10, 5 adata = ad.AnnData( X=np.random.rand(n_obs, n_vars), @@ -397,35 +464,46 @@ def test_treedata_workflow(self, tmp_path): var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_vars)]), ) - # 2. Add tree data (serialized as arrays, like TreeData does) - # In reality these would be serialized DiGraphs - lineage_tree = np.random.rand(n_obs, 4) # obs-aligned tree embedding - gene_tree = np.random.rand(n_vars, 3) # var-aligned tree embedding - - adata.obst["lineage"] = lineage_tree - adata.vart["phylogeny"] = gene_tree + # Store tree embeddings + adata.sec_obs["lineage"] = np.random.rand(n_obs, 4) + adata.sec_var["phylogeny"] = np.random.rand(n_vars, 3) - # 3. Verify storage - assert adata.obst["lineage"].shape == (n_obs, 4) - assert adata.vart["phylogeny"].shape == (n_vars, 3) - - # 4. Subset (like filtering cells) + # Subset mask = adata.obs["cell_type"] == "A" sub = adata[mask] - assert sub.obst["lineage"].shape == (5, 4) - assert sub.vart["phylogeny"].shape == (n_vars, 3) # var unchanged + assert sub.sec_obs["lineage"].shape == (5, 4) + assert sub.sec_var["phylogeny"].shape == (n_vars, 3) - # 5. Write and read back + # IO roundtrip path = tmp_path / "treedata.h5ad" adata.write(path) adata2 = ad.read_h5ad(path) + assert set(adata2.sec_obs.keys()) == {"lineage"} + assert set(adata2.sec_var.keys()) == {"phylogeny"} + + # Repr + assert "sec_obs" in repr(adata2) + assert "sec_var" in repr(adata2) - assert set(adata2.obst.keys()) == {"lineage"} - assert set(adata2.vart.keys()) == {"phylogeny"} - np.testing.assert_array_almost_equal(adata2.obst["lineage"], lineage_tree) - np.testing.assert_array_almost_equal(adata2.vart["phylogeny"], gene_tree) - # 6. Repr includes custom sections - r = repr(adata2) - assert "obst" in r - assert "vart" in r +# --------------------------------------------------------------------------- +# SpatialData-like end-to-end scenario +# --------------------------------------------------------------------------- + + +class TestSpatialDataScenario: + def test_unaligned_images(self, adata, tmp_path): + # Store images of arbitrary size + adata.sec_unaligned["hires"] = np.random.rand(200, 200, 3) + adata.sec_unaligned["lowres"] = np.random.rand(50, 50, 3) + + # Subsetting obs doesn't affect images + sub = adata[:3] + assert sub.sec_unaligned["hires"].shape == (200, 200, 3) + + # IO roundtrip + path = tmp_path / "spatial.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert set(adata2.sec_unaligned.keys()) == {"hires", "lowres"} + assert adata2.sec_unaligned["hires"].shape == (200, 200, 3) From 1baf049714cbf644283b0db27d888c58ab411503 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 10:14:52 -0700 Subject: [PATCH 09/14] feat: accept string for single-axis alignment in register_section alignment="obs" is now equivalent to alignment=("obs",). Updated docstring examples and tests to use the string form. --- src/anndata/_core/extensions.py | 15 ++++++++++----- tests/test_registered_sections.py | 14 +++++++------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 04f33e389..0b47f6d6b 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -405,7 +405,7 @@ def _repr_section_(self, context) -> list[FormattedEntry] | None: def register_section( name: str, *, - alignment: tuple[Literal["obs", "var"], ...] = (), + alignment: Literal["obs", "var"] | tuple[Literal["obs", "var"], ...] = (), io_key: str | None = None, ) -> Callable[[type], type]: """Register a new section on :class:`~anndata.AnnData`. @@ -418,8 +418,9 @@ def register_section( name Attribute name on AnnData (e.g., ``"obst"``). Becomes ``adata.obst``. alignment - Tuple of axes each dimension is aligned to. Examples: - ``("obs",)`` for obs-aligned (like obsm), + Axes each dimension is aligned to. A string for single-axis + alignment, or a tuple for multi-axis. Examples: + ``"obs"`` for obs-aligned (like obsm), ``("obs", "var")`` for both axes (like layers), ``("obs", "obs")`` for pairwise (like obsp), ``()`` for unaligned. @@ -457,7 +458,7 @@ def register_section( .. code-block:: python - @register_section("obst", alignment=("obs",)) + @register_section("obst", alignment="obs") class ObstSection: pass @@ -465,7 +466,7 @@ class ObstSection: .. code-block:: python - @register_section("obst", alignment=("obs",)) + @register_section("obst", alignment="obs") class ObstSection: value_type = nx.DiGraph section_after = "obsm" @@ -497,6 +498,10 @@ class ImagesSection: value_type = MultiscaleImage """ + # Normalize alignment: string → 1-tuple + if isinstance(alignment, str): + alignment = (alignment,) + def decorator(cls: type) -> type: if name in AnnData._registered_sections: msg = f"Section {name!r} is already registered" diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py index 4919ceac1..6712447cd 100644 --- a/tests/test_registered_sections.py +++ b/tests/test_registered_sections.py @@ -29,14 +29,14 @@ def _register_test_sections(): # obs-aligned (like obsm) if "sec_obs" not in ad.AnnData._registered_sections: - @register_section("sec_obs", alignment=("obs",)) + @register_section("sec_obs", alignment="obs") class SecObs: pass # var-aligned (like varm) if "sec_var" not in ad.AnnData._registered_sections: - @register_section("sec_var", alignment=("var",)) + @register_section("sec_var", alignment="var") class SecVar: pass @@ -71,7 +71,7 @@ class SecUnaligned: # Custom type validation (TreeData-like) if "sec_typed" not in ad.AnnData._registered_sections: - @register_section("sec_typed", alignment=("obs",)) + @register_section("sec_typed", alignment="obs") class SecTyped: value_type = np.ndarray @@ -84,7 +84,7 @@ def validate(key, value): # Custom serialize/deserialize if "sec_custom_io" not in ad.AnnData._registered_sections: - @register_section("sec_custom_io", alignment=("obs",)) + @register_section("sec_custom_io", alignment="obs") class SecCustomIO: @staticmethod def serialize(value): @@ -98,7 +98,7 @@ def deserialize(data): # Custom subset if "sec_custom_subset" not in ad.AnnData._registered_sections: - @register_section("sec_custom_subset", alignment=("obs",)) + @register_section("sec_custom_subset", alignment="obs") class SecCustomSubset: @staticmethod def subset(value, idx): @@ -129,11 +129,11 @@ def test_register_creates_property(self): def test_register_duplicate_raises(self): with pytest.raises(ValueError, match="already registered"): - register_section("sec_obs", alignment=("obs",))(type("Dup", (), {})) + register_section("sec_obs", alignment="obs")(type("Dup", (), {})) def test_register_reserved_name_raises(self): with pytest.raises(AttributeError, match="conflicts with"): - register_section("obs", alignment=("obs",))(type("Bad", (), {})) + register_section("obs", alignment="obs")(type("Bad", (), {})) def test_all_sections_in_registry(self): for name in [ From 4ed4fb5d7476dce332b19354e337de6923a954c3 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 11:56:20 -0700 Subject: [PATCH 10/14] feat: add xarray DataArray section example and tests Demonstrates using register_section with custom types: xr.DataArray as layer values with serialize/deserialize for h5ad IO roundtrip. Shows that custom types work end-to-end: storage with type enforcement, alignment validation, subsetting, copy, IO, repr. 6 new tests (51 total). --- tests/test_registered_sections.py | 96 ++++++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 1 deletion(-) diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py index 6712447cd..e5630c247 100644 --- a/tests/test_registered_sections.py +++ b/tests/test_registered_sections.py @@ -2,7 +2,8 @@ Validates that registered sections behave correctly for all alignment combinations, custom validation, custom subsetting, custom IO, and -HTML repr integration. Uses TreeData-like and SpatialData-like scenarios. +HTML repr integration. Uses TreeData-like, SpatialData-like, and +xarray scenarios. """ from __future__ import annotations @@ -12,6 +13,7 @@ import numpy as np import pandas as pd import pytest +import xarray as xr from scipy.sparse import csr_matrix import anndata as ad @@ -105,6 +107,21 @@ def subset(value, idx): # Custom: return a dict describing the subset return {"original": value, "subset_idx": idx} + # xarray layers (custom type with serialize/deserialize) + if "xr_layers" not in ad.AnnData._registered_sections: + + @register_section("xr_layers", alignment=("obs", "var")) + class XarrayLayers: + value_type = xr.DataArray + + @staticmethod + def serialize(value): + return value.values # xarray → numpy for h5ad + + @staticmethod + def deserialize(data): + return xr.DataArray(data) # numpy → xarray on read + @pytest.fixture def adata(): @@ -491,6 +508,83 @@ def test_full_workflow(self, tmp_path): # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# xarray DataArray layers (custom type + serialize/deserialize) +# --------------------------------------------------------------------------- + + +class TestXarrayScenario: + """xarray DataArrays as layer values with custom serialization.""" + + def test_store_xarray(self, adata): + da = xr.DataArray(np.random.rand(5, 3), dims=["obs", "var"]) + adata.xr_layers["normalized"] = da + assert isinstance(adata.xr_layers["normalized"], xr.DataArray) + assert adata.xr_layers["normalized"].shape == (5, 3) + + def test_type_enforcement(self, adata): + with pytest.raises(TypeError, match="must be DataArray"): + adata.xr_layers["bad"] = np.ones((5, 3)) + + def test_alignment_validation(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.xr_layers["bad"] = xr.DataArray(np.ones((10, 3))) + + def test_subset(self, adata): + da = xr.DataArray(np.arange(15.0).reshape(5, 3), dims=["obs", "var"]) + adata.xr_layers["data"] = da + sub = adata[:3, :2] + result = sub.xr_layers["data"] + assert result.shape == (3, 2) + + def test_io_roundtrip(self, adata, tmp_path): + da = xr.DataArray(np.arange(15.0).reshape(5, 3), dims=["obs", "var"]) + adata.xr_layers["data"] = da + path = tmp_path / "xr.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + # Deserialized back to xarray + assert isinstance(adata2.xr_layers["data"], xr.DataArray) + np.testing.assert_array_equal( + adata2.xr_layers["data"].values, np.arange(15.0).reshape(5, 3) + ) + + def test_full_workflow(self, tmp_path): + """End-to-end: store, subset, copy, IO with xarray layers.""" + adata = ad.AnnData( + X=np.ones((10, 5)), + obs=pd.DataFrame(index=[f"c{i}" for i in range(10)]), + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + xr_layers={ + "scaled": xr.DataArray(np.random.rand(10, 5), dims=["obs", "var"]), + }, + ) + + # Subset preserves type + sub = adata[:5] + assert isinstance(sub.xr_layers["scaled"], xr.DataArray) + assert sub.xr_layers["scaled"].shape == (5, 5) + + # Copy preserves type + copy = adata.copy() + assert isinstance(copy.xr_layers["scaled"], xr.DataArray) + + # IO roundtrip preserves type + path = tmp_path / "xr_workflow.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert isinstance(adata2.xr_layers["scaled"], xr.DataArray) + assert adata2.xr_layers["scaled"].shape == (10, 5) + + # Repr shows section + assert "xr_layers" in repr(adata) + + +# --------------------------------------------------------------------------- +# SpatialData-like end-to-end scenario +# --------------------------------------------------------------------------- + + class TestSpatialDataScenario: def test_unaligned_images(self, adata, tmp_path): # Store images of arbitrary size From 6b315c84f38b451d349825b0d240a6a8fb75f730 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 12:00:47 -0700 Subject: [PATCH 11/14] feat: N-dimensional alignment with biology examples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support >2D alignment tuples with proper subsetting. anndata's built-in _subset only handles ≤2D, so SectionMappingView implements N-D fancy indexing via np.ix_ for higher dimensions. New biology-motivated test cases: - cellcomm: alignment=("obs", "obs", "var") for ligand-receptor cell-cell communication tensors (CellChat, LIANA, CellPhoneDB) - genereg: alignment=("obs", "var", "var") for cell-specific gene regulatory networks (SCENIC, CellOracle, Dictys) 67 tests total, all passing. --- src/anndata/_core/section_registry.py | 31 ++++- tests/test_registered_sections.py | 179 ++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 1 deletion(-) diff --git a/src/anndata/_core/section_registry.py b/src/anndata/_core/section_registry.py index 27fa35613..338f877f7 100644 --- a/src/anndata/_core/section_registry.py +++ b/src/anndata/_core/section_registry.py @@ -198,7 +198,36 @@ def _subset_value(self, value: Any) -> Any: idx = self._build_index() if self._spec.subset_fn is not None: return self._spec.subset_fn(value, idx) - # Default: use anndata's _subset + # Default subsetting: handle N-dimensional alignment + # anndata's _subset is designed for ≤2D, so for higher dims + # we do the indexing directly. + import numpy as np + + from anndata.compat import IndexManager + + if isinstance(idx, tuple) and len(idx) > 2: + # Convert IndexManagers to numpy arrays + resolved = [] + for ix in idx: + if isinstance(ix, IndexManager): + resolved.append(np.asarray(ix)) + else: + resolved.append(ix) + # Use np.ix_ for fancy indexing on non-slice dims + fancy_dims = [ + i for i, ix in enumerate(resolved) if not isinstance(ix, slice) + ] + if fancy_dims: + # Build an open mesh for fancy-indexed dims + fancy_arrs = [resolved[i] for i in fancy_dims] + mesh = np.ix_(*fancy_arrs) + # Build the full index tuple + full_idx = list(resolved) + for mi, di in enumerate(fancy_dims): + full_idx[di] = mesh[mi] + return value[tuple(full_idx)] + return value[tuple(resolved)] + # ≤2D: use anndata's _subset from .index import _subset return _subset(value, idx) diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py index e5630c247..62c889f8c 100644 --- a/tests/test_registered_sections.py +++ b/tests/test_registered_sections.py @@ -97,6 +97,24 @@ def serialize(value): def deserialize(data): return json.loads(data) + # Cell-cell communication tensor: (sender, receiver, gene) + if "cellcomm" not in ad.AnnData._registered_sections: + + @register_section("cellcomm", alignment=("obs", "obs", "var")) + class CellCommSection: + """Ligand-receptor communication scores (sender × receiver × gene).""" + section_after = "obsp" + section_tooltip = "Cell-cell communication" + + # Cell-specific gene-gene interactions: (obs, var, var) + if "genereg" not in ad.AnnData._registered_sections: + + @register_section("genereg", alignment=("obs", "var", "var")) + class GeneRegSection: + """Cell-specific gene regulatory networks (cell × gene × gene).""" + section_after = "varp" + section_tooltip = "Gene regulation per cell" + # Custom subset if "sec_custom_subset" not in ad.AnnData._registered_sections: @@ -513,6 +531,167 @@ def test_full_workflow(self, tmp_path): # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Cell-cell communication tensor: alignment=("obs", "obs", "var") +# --------------------------------------------------------------------------- + + +class TestCellCommunication: + """3D tensor for ligand-receptor communication scores. + + Tools like CellChat, LIANA, and CellPhoneDB compute communication + strengths between cell pairs mediated by specific genes. The natural + shape is (sender_cell, receiver_cell, gene). With alignment=("obs", + "obs", "var"), the tensor subsets correctly when filtering cells or genes. + """ + + def test_store_tensor(self, adata): + comm = np.random.rand(5, 5, 3) + adata.cellcomm["lr_scores"] = comm + assert adata.cellcomm["lr_scores"].shape == (5, 5, 3) + + def test_validates_obs_dim(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.cellcomm["bad"] = np.ones((10, 10, 3)) + + def test_validates_var_dim(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.cellcomm["bad"] = np.ones((5, 5, 10)) + + def test_validates_square_obs(self, adata): + """Sender and receiver must both be n_obs.""" + with pytest.raises(ValueError, match="shape"): + adata.cellcomm["bad"] = np.ones((5, 3, 3)) + + def test_subset_cells(self, adata): + """Filtering cells subsets both sender and receiver dims.""" + adata.cellcomm["lr"] = np.random.rand(5, 5, 3) + sub = adata[:3] + assert sub.cellcomm["lr"].shape == (3, 3, 3) + + def test_subset_genes(self, adata): + """Filtering genes subsets the third dim.""" + adata.cellcomm["lr"] = np.random.rand(5, 5, 3) + sub = adata[:, :2] + assert sub.cellcomm["lr"].shape == (5, 5, 2) + + def test_subset_both(self, adata): + adata.cellcomm["lr"] = np.random.rand(5, 5, 3) + sub = adata[:3, :2] + assert sub.cellcomm["lr"].shape == (3, 3, 2) + + def test_io_roundtrip(self, adata, tmp_path): + comm = np.random.rand(5, 5, 3) + adata.cellcomm["lr_scores"] = comm + path = tmp_path / "comm.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + np.testing.assert_array_almost_equal( + adata2.cellcomm["lr_scores"], comm + ) + + def test_workflow(self, tmp_path): + """End-to-end: simulate CellChat-like analysis.""" + n_obs, n_vars = 20, 50 + adata = ad.AnnData( + X=np.random.rand(n_obs, n_vars), + obs=pd.DataFrame( + {"cell_type": pd.Categorical(["T"] * 10 + ["B"] * 10)}, + index=[f"cell_{i}" for i in range(n_obs)], + ), + var=pd.DataFrame( + {"is_ligand": [True] * 25 + [False] * 25}, + index=[f"gene_{i}" for i in range(n_vars)], + ), + ) + + # Compute communication scores (simulated) + adata.cellcomm["cellchat"] = np.random.rand(n_obs, n_obs, n_vars) + + # Filter to T cells only + t_cells = adata.obs["cell_type"] == "T" + sub = adata[t_cells] + assert sub.cellcomm["cellchat"].shape == (10, 10, n_vars) + + # Filter to ligand genes only + ligands = adata.var["is_ligand"] + sub2 = adata[:, ligands] + assert sub2.cellcomm["cellchat"].shape == (n_obs, n_obs, 25) + + # IO roundtrip + path = tmp_path / "cellchat.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert adata2.cellcomm["cellchat"].shape == (n_obs, n_obs, n_vars) + + +# --------------------------------------------------------------------------- +# xarray DataArray layers (custom type + serialize/deserialize) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Cell-specific gene regulation: alignment=("obs", "var", "var") +# --------------------------------------------------------------------------- + + +class TestGeneRegulation: + """3D tensor for cell-specific gene regulatory networks. + + Each cell has its own gene-gene interaction matrix (e.g., inferred + from single-cell GRN methods like SCENIC, CellOracle, or Dictys). + Shape is (cell, source_gene, target_gene). Subsetting cells reduces + the first dim, subsetting genes reduces both gene dims. + """ + + def test_store_tensor(self, adata): + grn = np.random.rand(5, 3, 3) + adata.genereg["scenic"] = grn + assert adata.genereg["scenic"].shape == (5, 3, 3) + + def test_validates_obs_dim(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.genereg["bad"] = np.ones((10, 3, 3)) + + def test_validates_var_dims(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.genereg["bad"] = np.ones((5, 10, 3)) # source wrong + with pytest.raises(ValueError, match="shape"): + adata.genereg["bad"] = np.ones((5, 3, 10)) # target wrong + + def test_subset_cells(self, adata): + """Filtering cells subsets the first dim only.""" + adata.genereg["grn"] = np.random.rand(5, 3, 3) + sub = adata[:3] + assert sub.genereg["grn"].shape == (3, 3, 3) + + def test_subset_genes(self, adata): + """Filtering genes subsets both gene dims (source and target).""" + adata.genereg["grn"] = np.random.rand(5, 3, 3) + sub = adata[:, :2] + assert sub.genereg["grn"].shape == (5, 2, 2) + + def test_subset_both(self, adata): + adata.genereg["grn"] = np.random.rand(5, 3, 3) + sub = adata[:3, :2] + assert sub.genereg["grn"].shape == (3, 2, 2) + + def test_io_roundtrip(self, adata, tmp_path): + grn = np.random.rand(5, 3, 3) + adata.genereg["scenic"] = grn + path = tmp_path / "grn.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + np.testing.assert_array_almost_equal( + adata2.genereg["scenic"], grn + ) + + +# --------------------------------------------------------------------------- +# xarray DataArray layers (custom type + serialize/deserialize) +# --------------------------------------------------------------------------- + + class TestXarrayScenario: """xarray DataArrays as layer values with custom serialization.""" From 387015df9804671f8ede77b9bd698e16f6c4d02e Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 13:28:30 -0700 Subject: [PATCH 12/14] feat: factored tensor accessor + ruff formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add TestFactoredTensor: sections store compact rank-R factors (n_obs × rank) and (n_vars × rank), accessor reconstructs the full (obs × obs × var) tensor on demand via einsum. Includes point queries without materializing the tensor. Demonstrates combining register_section (for factor storage with axis-aligned subsetting and IO) with register_anndata_namespace (for the tensor reconstruction API and HTML repr). 73 tests total, all passing. Ruff formatting applied. --- src/anndata/_core/anndata.py | 11 +- src/anndata/_core/extensions.py | 3 +- src/anndata/_core/section_registry.py | 24 ++--- src/anndata/extensions.py | 4 +- tests/test_registered_sections.py | 148 ++++++++++++++++++++++++-- 5 files changed, 161 insertions(+), 29 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 032667dd8..da0cf0962 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1430,7 +1430,16 @@ def _mutated_copy(self, **kwargs) -> AnnData: raise NotImplementedError(msg) new = {} - for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", *self._registered_sections]: + for key in [ + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + *self._registered_sections, + ]: if key in kwargs: new[key] = kwargs[key] else: diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 0b47f6d6b..6ac53637f 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -16,7 +16,7 @@ # Based off of the extension framework in Polars # https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py -__all__ = ["register_anndata_namespace", "register_section", "SectionSpec"] +__all__ = ["SectionSpec", "register_anndata_namespace", "register_section"] # Protocol for accessors that provide section visualization REPR_SECTION_METHOD = "_repr_section_" @@ -548,7 +548,6 @@ def _create_section_repr_formatter(spec: SectionSpec) -> None: from anndata._repr.registry import ( FormattedEntry, FormattedOutput, - FormatterContext, SectionFormatter, register_formatter, ) diff --git a/src/anndata/_core/section_registry.py b/src/anndata/_core/section_registry.py index 338f877f7..5ede0fc32 100644 --- a/src/anndata/_core/section_registry.py +++ b/src/anndata/_core/section_registry.py @@ -9,7 +9,7 @@ from collections.abc import Callable, Iterator, Mapping, MutableMapping from copy import copy -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal from .views import view_update @@ -109,9 +109,7 @@ def __contains__(self, key: object) -> bool: def _validate_alignment(self, key: str, value: Any) -> None: """Check that value dimensions match the expected axes.""" for i, axis in enumerate(self._spec.alignment): - expected = ( - self._parent.n_obs if axis == "obs" else self._parent.n_vars - ) + expected = self._parent.n_obs if axis == "obs" else self._parent.n_vars actual = _axis_len(value, i) if actual is not None and actual != expected: n_name = "n_obs" if axis == "obs" else "n_vars" @@ -244,7 +242,9 @@ def _build_index(self) -> tuple: def copy(self) -> dict[str, Any]: """Copy with subsetting applied.""" - return {k: self[k].copy() if hasattr(self[k], "copy") else self[k] for k in self} + return { + k: self[k].copy() if hasattr(self[k], "copy") else self[k] for k in self + } class SectionProperty: @@ -269,18 +269,14 @@ def __get__(self, obj: AnnData | None, objtype: type | None = None) -> Any: # View: create subsetting view parent = obj._adata_ref parent_mapping = getattr(parent, self.spec.name) - return SectionMappingView( - parent_mapping, obj, obj._oidx, obj._vidx - ) + return SectionMappingView(parent_mapping, obj, obj._oidx, obj._vidx) - def __set__( - self, obj: AnnData, value: Mapping[str, Any] | None - ) -> None: + def __set__(self, obj: AnnData, value: Mapping[str, Any] | None) -> None: if value is None: value = {} - if isinstance(value, (SectionMapping, SectionMappingView)): - value = dict(value) - elif isinstance(value, Mapping): + if isinstance(value, (SectionMapping, SectionMappingView)) or isinstance( + value, Mapping + ): value = dict(value) # Validate all values via SectionMapping mapping = SectionMapping(obj, self.spec, {}) diff --git a/src/anndata/extensions.py b/src/anndata/extensions.py index b0b1a6efe..2bb623d2a 100644 --- a/src/anndata/extensions.py +++ b/src/anndata/extensions.py @@ -79,10 +79,8 @@ def get_entries(self, obj, context): from __future__ import annotations # Accessor registration (from PR #1870) -from anndata._core.extensions import register_anndata_namespace - # Section registration (pluggable sections with custom alignment, IO, validation) -from anndata._core.extensions import register_section +from anndata._core.extensions import register_anndata_namespace, register_section from anndata._core.section_registry import SectionSpec # HTML representation formatters diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py index 62c889f8c..c85a27ed0 100644 --- a/tests/test_registered_sections.py +++ b/tests/test_registered_sections.py @@ -19,7 +19,6 @@ import anndata as ad from anndata.extensions import register_section - # --------------------------------------------------------------------------- # Fixtures: register sections once per test session # --------------------------------------------------------------------------- @@ -103,6 +102,7 @@ def deserialize(data): @register_section("cellcomm", alignment=("obs", "obs", "var")) class CellCommSection: """Ligand-receptor communication scores (sender × receiver × gene).""" + section_after = "obsp" section_tooltip = "Cell-cell communication" @@ -112,6 +112,7 @@ class CellCommSection: @register_section("genereg", alignment=("obs", "var", "var")) class GeneRegSection: """Cell-specific gene regulatory networks (cell × gene × gene).""" + section_after = "varp" section_tooltip = "Gene regulation per cell" @@ -125,6 +126,19 @@ def subset(value, idx): # Custom: return a dict describing the subset return {"original": value, "subset_idx": idx} + # Factored tensor: store rank-R factors, reconstruct on demand + if "comm_obs" not in ad.AnnData._registered_sections: + + @register_section("comm_obs", alignment="obs") + class CommObs: + """Cell factor matrix (n_obs × rank) for communication tensor.""" + + if "comm_var" not in ad.AnnData._registered_sections: + + @register_section("comm_var", alignment="var") + class CommVar: + """Gene factor matrix (n_vars × rank) for communication tensor.""" + # xarray layers (custom type with serialize/deserialize) if "xr_layers" not in ad.AnnData._registered_sections: @@ -413,9 +427,7 @@ def test_write_read_both_axes(self, adata, tmp_path): path = tmp_path / "test.h5ad" adata.write(path) adata2 = ad.read_h5ad(path) - np.testing.assert_array_equal( - adata2.sec_both["x"], np.arange(15).reshape(5, 3) - ) + np.testing.assert_array_equal(adata2.sec_both["x"], np.arange(15).reshape(5, 3)) def test_empty_section_not_written(self, adata, tmp_path): import h5py @@ -586,9 +598,7 @@ def test_io_roundtrip(self, adata, tmp_path): path = tmp_path / "comm.h5ad" adata.write(path) adata2 = ad.read_h5ad(path) - np.testing.assert_array_almost_equal( - adata2.cellcomm["lr_scores"], comm - ) + np.testing.assert_array_almost_equal(adata2.cellcomm["lr_scores"], comm) def test_workflow(self, tmp_path): """End-to-end: simulate CellChat-like analysis.""" @@ -682,10 +692,130 @@ def test_io_roundtrip(self, adata, tmp_path): path = tmp_path / "grn.h5ad" adata.write(path) adata2 = ad.read_h5ad(path) - np.testing.assert_array_almost_equal( - adata2.genereg["scenic"], grn + np.testing.assert_array_almost_equal(adata2.genereg["scenic"], grn) + + +# --------------------------------------------------------------------------- +# Factored tensor: sections + accessor (scalable communication analysis) +# --------------------------------------------------------------------------- + + +class TestFactoredTensor: + """Store rank-R factors in sections, reconstruct tensor via accessor. + + For million-cell datasets, a dense (n_obs × n_obs × n_vars) tensor + is infeasible. Instead, store compact factors (n_obs × rank) and + (n_vars × rank), and reconstruct on demand. The factors subset + correctly, serialize to h5ad, and the accessor provides the tensor API. + """ + + @pytest.fixture(autouse=True) + def _ensure_accessor(self): + """Create and register the accessor (idempotent).""" + if hasattr(ad.AnnData, "comm"): + return + from anndata.extensions import ( + FormattedEntry, + FormattedOutput, + register_anndata_namespace, ) + @register_anndata_namespace("comm") + class CellCommAccessor: + section_after = "obsp" + section_tooltip = "Cell-cell communication (factored)" + + def __init__(self, adata: ad.AnnData): + self._adata = adata + + def tensor(self, key="default"): + """Reconstruct (obs × obs × var) tensor from factors.""" + U = self._adata.comm_obs[key] + V = self._adata.comm_var[key] + return np.einsum("ir,jr,kr->ijk", U, U, V) + + def query(self, sender, receiver, gene, key="default"): + """O(rank) point query without materializing tensor.""" + U = self._adata.comm_obs[key] + V = self._adata.comm_var[key] + i = self._adata.obs_names.get_loc(sender) + j = self._adata.obs_names.get_loc(receiver) + k = self._adata.var_names.get_loc(gene) + return float(U[i] @ (U[j] * V[k])) + + def _repr_section_(self, context): + keys = list(self._adata.comm_obs.keys()) + if not keys: + return None + return [ + FormattedEntry( + key=k, + output=FormattedOutput( + type_name=f"rank-{self._adata.comm_obs[k].shape[1]} factors", + preview=( + f"({self._adata.comm_obs[k].shape[0]} cells " + f"× {self._adata.comm_var[k].shape[0]} genes)" + ), + ), + ) + for k in keys + ] + + def test_store_factors(self, adata): + n_obs, n_vars, rank = 5, 3, 2 + adata.comm_obs["lr"] = np.random.rand(n_obs, rank) + adata.comm_var["lr"] = np.random.rand(n_vars, rank) + assert adata.comm_obs["lr"].shape == (n_obs, rank) + assert adata.comm_var["lr"].shape == (n_vars, rank) + + def test_reconstruct_tensor(self, adata): + rank = 3 + adata.comm_obs["lr"] = np.random.rand(5, rank) + adata.comm_var["lr"] = np.random.rand(3, rank) + tensor = adata.comm.tensor("lr") + assert tensor.shape == (5, 5, 3) + + def test_point_query(self, adata): + rank = 3 + U = np.random.rand(5, rank) + V = np.random.rand(3, rank) + adata.comm_obs["lr"] = U + adata.comm_var["lr"] = V + score = adata.comm.query("c0", "c1", "v0", "lr") + expected = float(U[0] @ (U[1] * V[0])) + assert abs(score - expected) < 1e-10 + + def test_subset_preserves_reconstruction(self, adata): + rank = 3 + adata.comm_obs["lr"] = np.random.rand(5, rank) + adata.comm_var["lr"] = np.random.rand(3, rank) + full_tensor = adata.comm.tensor("lr") + + sub = adata[:3, :2] + sub_tensor = sub.comm.tensor("lr") + assert sub_tensor.shape == (3, 3, 2) + np.testing.assert_array_almost_equal(sub_tensor, full_tensor[:3, :3, :2]) + + def test_io_roundtrip(self, adata, tmp_path): + rank = 3 + adata.comm_obs["lr"] = np.random.rand(5, rank) + adata.comm_var["lr"] = np.random.rand(3, rank) + tensor_before = adata.comm.tensor("lr") + + path = tmp_path / "factored.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + tensor_after = adata2.comm.tensor("lr") + np.testing.assert_array_almost_equal(tensor_before, tensor_after) + + def test_compression_ratio(self): + """Factors are orders of magnitude smaller than dense tensor.""" + n_obs, n_vars, rank = 1000, 500, 10 + factor_bytes = (n_obs * rank + n_vars * rank) * 8 # float64 + tensor_bytes = n_obs * n_obs * n_vars * 8 + ratio = tensor_bytes / factor_bytes + assert ratio > 100 # ~33,000× for these sizes + # --------------------------------------------------------------------------- # xarray DataArray layers (custom type + serialize/deserialize) From cd2b37d1da5a7dc261ba4727e2a58c0c4e933eba Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 18:40:10 -0700 Subject: [PATCH 13/14] feat: centralize section list via built-in section registry Register all built-in sections (X, obs, var, uns, obsm, varm, obsp, varp, layers, raw) in _registered_sections with SectionSpec metadata. Add iter_sections() utility for filtered iteration with options for kind filtering, empty-section skipping. Replace hardcoded section lists in: - _gen_repr: uses iter_sections(exclude_kinds={"X", "raw"}) - _mutated_copy: uses iter_sections(kinds={"dataframe", "mapping"}) - write_h5ad: uses iter_sections(exclude_kinds={"X", "raw"}) - write_anndata: same - read_anndata: iterates _registered_sections.values() The five aligned mapping sections (obsm, varm, obsp, varp, layers), both DataFrames (obs, var), uns, and all extension sections are now discovered from a single registry. Only X and raw retain special handling due to their unique structure. --- src/anndata/_core/anndata.py | 54 +++++----- src/anndata/_core/extensions.py | 6 +- src/anndata/_core/section_registry.py | 145 ++++++++++++++++++++++++-- src/anndata/_io/h5ad.py | 34 +++--- src/anndata/_io/specs/methods.py | 51 +++------ tests/test_registered_sections.py | 5 +- 6 files changed, 203 insertions(+), 92 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index da0cf0962..ab65b8cd3 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -395,8 +395,11 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 if any((obs, var, uns, obsm, varm, obsp, varp)): msg = "If `X` is a dict no further arguments must be provided." raise ValueError(msg) - # Copy registered sections from source AnnData - for sec_name in self._registered_sections: + # Copy extension sections from source AnnData + # (built-in sections are handled by the explicit unpacking below) + for sec_name, spec in self._registered_sections.items(): + if spec.builtin: + continue if sec_name not in extra_sections: src_mapping = getattr(X, sec_name, None) if src_mapping is not None and len(src_mapping) > 0: @@ -561,22 +564,17 @@ def cs_to_bytes(X) -> int: return sum(sizes.values()) def _gen_repr(self, n_obs, n_vars) -> str: + from .section_registry import iter_sections + backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}" - for attr in [ - "obs", - "var", - "uns", - "obsm", - "varm", - "layers", - "obsp", - "varp", - *self._registered_sections, - ]: - keys = getattr(self, attr).keys() + for spec, value in iter_sections(self, exclude_kinds={"X", "raw"}): + try: + keys = value.keys() + except Exception: # noqa: BLE001 + continue if len(keys) > 0: - descr += f"\n {attr}: {str(list(keys))[1:-1]}" + descr += f"\n {spec.name}: {str(list(keys))[1:-1]}" return descr def __repr__(self) -> str: @@ -1430,20 +1428,13 @@ def _mutated_copy(self, **kwargs) -> AnnData: raise NotImplementedError(msg) new = {} - for key in [ - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - *self._registered_sections, - ]: - if key in kwargs: - new[key] = kwargs[key] + from .section_registry import iter_sections + + for spec, value in iter_sections(self, kinds={"dataframe", "mapping"}): + if spec.name in kwargs: + new[spec.name] = kwargs[spec.name] else: - new[key] = getattr(self, key).copy() + new[spec.name] = value.copy() if "X" in kwargs: new["X"] = kwargs["X"] elif self._has_X(): @@ -2180,6 +2171,13 @@ def _remove_unused_categories_xr( pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. +# Populate _registered_sections with built-in section specs. +# Must happen after AnnData class definition is complete. +from .section_registry import _init_builtin_sections # noqa: E402 + +_init_builtin_sections(AnnData) + + def _check_2d_shape(X): """\ Check shape of array or sparse matrix. diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 6ac53637f..4fabf0f49 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from typing import Literal from anndata._repr.registry import FormattedEntry, FormatterContext @@ -396,10 +397,7 @@ def _repr_section_(self, context) -> list[FormattedEntry] | None: # Section registration # --------------------------------------------------------------------------- -from collections.abc import Callable -from typing import Literal - -from .section_registry import SectionProperty, SectionSpec +from .section_registry import SectionProperty, SectionSpec # noqa: E402 def register_section( diff --git a/src/anndata/_core/section_registry.py b/src/anndata/_core/section_registry.py index 5ede0fc32..173a28d83 100644 --- a/src/anndata/_core/section_registry.py +++ b/src/anndata/_core/section_registry.py @@ -7,13 +7,17 @@ from __future__ import annotations -from collections.abc import Callable, Iterator, Mapping, MutableMapping +from collections.abc import Mapping, MutableMapping from copy import copy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING from .views import view_update +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Any, Literal + if TYPE_CHECKING: from anndata import AnnData @@ -33,7 +37,8 @@ def _axis_len(value: Any, dim: int) -> int | None: class SectionSpec: """Complete specification for a registered section. - Created by :func:`register_section` from the decorated class. + Created by :func:`register_section` from the decorated class, + or internally for built-in sections. """ name: str @@ -42,6 +47,10 @@ class SectionSpec: """Axes each dimension is aligned to. Empty tuple for unaligned.""" io_key: str """Key used in h5ad/zarr files.""" + kind: Literal["X", "dataframe", "mapping", "unstructured", "raw"] = "mapping" + """Section category. Used by :func:`iter_sections` for filtering.""" + builtin: bool = False + """Whether this is a built-in section (vs. registered by an extension).""" # Optional callbacks extracted from the section class value_type: type | None = None @@ -274,9 +283,7 @@ def __get__(self, obj: AnnData | None, objtype: type | None = None) -> Any: def __set__(self, obj: AnnData, value: Mapping[str, Any] | None) -> None: if value is None: value = {} - if isinstance(value, (SectionMapping, SectionMappingView)) or isinstance( - value, Mapping - ): + if isinstance(value, (SectionMapping, SectionMappingView, Mapping)): value = dict(value) # Validate all values via SectionMapping mapping = SectionMapping(obj, self.spec, {}) @@ -288,3 +295,129 @@ def __set__(self, obj: AnnData, value: Mapping[str, Any] | None) -> None: def __delete__(self, obj: AnnData) -> None: setattr(obj, f"_{self.spec.name}", {}) + + +# --------------------------------------------------------------------------- +# Built-in section specs (metadata only — the actual descriptors are +# AlignedMappingProperty instances already on the AnnData class) +# --------------------------------------------------------------------------- + +#: Ordered list of all built-in sections, used to seed ``_registered_sections``. +BUILTIN_SECTIONS: list[SectionSpec] = [ + SectionSpec(name="X", alignment=("obs", "var"), io_key="X", kind="X", builtin=True), + SectionSpec( + name="obs", alignment=("obs",), io_key="obs", kind="dataframe", builtin=True + ), + SectionSpec( + name="var", alignment=("var",), io_key="var", kind="dataframe", builtin=True + ), + SectionSpec( + name="uns", alignment=(), io_key="uns", kind="unstructured", builtin=True + ), + SectionSpec( + name="obsm", alignment=("obs",), io_key="obsm", kind="mapping", builtin=True + ), + SectionSpec( + name="varm", alignment=("var",), io_key="varm", kind="mapping", builtin=True + ), + SectionSpec( + name="layers", + alignment=("obs", "var"), + io_key="layers", + kind="mapping", + builtin=True, + ), + SectionSpec( + name="obsp", + alignment=("obs", "obs"), + io_key="obsp", + kind="mapping", + builtin=True, + ), + SectionSpec( + name="varp", + alignment=("var", "var"), + io_key="varp", + kind="mapping", + builtin=True, + ), + SectionSpec(name="raw", alignment=("obs",), io_key="raw", kind="raw", builtin=True), +] + + +def _init_builtin_sections(cls: type[AnnData]) -> None: + """Populate ``_registered_sections`` with built-in section specs. + + Called once during AnnData class setup. Does NOT create descriptors — + the built-in ``AlignedMappingProperty`` instances are already on the class. + """ + for spec in BUILTIN_SECTIONS: + cls._registered_sections[spec.name] = spec + + +# --------------------------------------------------------------------------- +# Section iteration utility +# --------------------------------------------------------------------------- + + +def iter_sections( + adata: AnnData, + *, + kinds: set[str] | None = None, + exclude_kinds: set[str] | None = None, + only_nonempty: bool = False, +) -> Iterator[tuple[SectionSpec, Any]]: + """Iterate over AnnData sections with optional filtering. + + Yields ``(spec, value)`` pairs for each section, where *value* is + the result of ``getattr(adata, spec.name)``. + + Parameters + ---------- + adata + AnnData to iterate over. + kinds + If given, only yield sections whose ``kind`` is in this set. + E.g., ``{"mapping"}`` for dict-like sections (obsm, layers, …). + exclude_kinds + If given, skip sections whose ``kind`` is in this set. + E.g., ``{"unstructured", "raw"}`` to skip uns and raw. + only_nonempty + If ``True``, skip sections that are empty or ``None``. + + Examples + -------- + All mapping sections (built-in + registered): + + >>> for spec, mapping in iter_sections(adata, kinds={"mapping"}): + ... print(spec.name, list(mapping.keys())) + + Everything except uns and raw: + + >>> for spec, value in iter_sections(adata, exclude_kinds={"unstructured", "raw"}): + ... ... + + Non-empty sections for repr: + + >>> for spec, value in iter_sections(adata, only_nonempty=True): + ... print(spec.name) + """ + for spec in adata._registered_sections.values(): + if kinds is not None and spec.kind not in kinds: + continue + if exclude_kinds is not None and spec.kind in exclude_kinds: + continue + try: + value = getattr(adata, spec.name, None) + except Exception: # noqa: BLE001 + # Crashing objects in aligned mappings (adversarial data) + continue + if only_nonempty: + if value is None: + continue + try: + if len(value) == 0: + continue + except TypeError: + pass # no len, treat as non-empty + yield spec, value diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index c7cc99db2..9ec0b93a0 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -93,23 +93,21 @@ def write_h5ad( dataset_kwargs=dataset_kwargs, ) _write_raw(f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs) - write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs) - write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) - # Write registered sections (e.g., obst, vart from extensions) - for sec_name, spec in adata._registered_sections.items(): - mapping = getattr(adata, sec_name, None) - if mapping is not None and len(mapping) > 0: - if spec.serialize_fn is not None: - data = {k: spec.serialize_fn(v) for k, v in mapping.items()} - else: - data = dict(mapping) - write_elem(f, spec.io_key, data, dataset_kwargs=dataset_kwargs) + + # Write all non-X/raw sections via the unified registry + from anndata._core.section_registry import iter_sections + + for spec, value in iter_sections(adata, exclude_kinds={"X", "raw"}): + # Skip empty mappings (but always write DataFrames — they carry the index) + if spec.kind != "dataframe" and len(value) == 0: + continue + if spec.serialize_fn is not None: + data = {k: spec.serialize_fn(v) for k, v in value.items()} + elif spec.kind == "dataframe": + data = value # write DataFrame directly + else: + data = dict(value) # mappings and uns → dict + write_elem(f, spec.io_key, data, dataset_kwargs=dataset_kwargs) def _write_x( @@ -279,7 +277,7 @@ def callback(read_func, elem_name: str, elem: StorageType, iospec: IOSpec): if not k.startswith("raw.") } # Deserialize registered sections - for sec_name, spec in AnnData._registered_sections.items(): + for spec in AnnData._registered_sections.values(): if spec.io_key in d and spec.deserialize_fn is not None: data = d[spec.io_key] if isinstance(data, dict): diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 370202772..408ed192f 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -285,27 +285,25 @@ def write_anndata( _writer: Writer, dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): + from anndata._core.section_registry import iter_sections + g = f.require_group(k) + # X and raw need special handling if adata.X is not None: _writer.write_elem(g, "X", adata.X, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "var", adata.var, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) - # Write registered sections (e.g., obst, vart from extensions) - for sec_name, spec in adata._registered_sections.items(): - mapping = getattr(adata, sec_name, None) - if mapping is not None and len(mapping) > 0: - if spec.serialize_fn is not None: - data = {k: spec.serialize_fn(v) for k, v in mapping.items()} - else: - data = dict(mapping) - _writer.write_elem(g, spec.io_key, data, dataset_kwargs=dataset_kwargs) + # All other sections via the unified registry + for spec, value in iter_sections(adata, exclude_kinds={"X", "raw"}): + # Skip empty mappings (but always write DataFrames — they carry the index) + if spec.kind != "dataframe" and len(value) == 0: + continue + if spec.serialize_fn is not None: + data = {k: spec.serialize_fn(v) for k, v in value.items()} + elif spec.kind == "dataframe": + data = value # write DataFrame directly + else: + data = dict(value) # mappings and uns → dict + _writer.write_elem(g, spec.io_key, data, dataset_kwargs=dataset_kwargs) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) @@ -316,27 +314,12 @@ def write_anndata( @_REGISTRY.register_read(ZarrGroup, IOSpec("raw", "0.1.0")) def read_anndata(elem: _GroupStorageType | H5File, *, _reader: Reader) -> AnnData: d = {} - for k in [ - "X", - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - "uns", - "raw", - ]: - if k in elem: - d[k] = _reader.read_elem(elem[k]) - # Read registered sections (e.g., obst, vart from extensions) - for sec_name, spec in AnnData._registered_sections.items(): + for spec in AnnData._registered_sections.values(): if spec.io_key in elem: data = _reader.read_elem(elem[spec.io_key]) if spec.deserialize_fn is not None and isinstance(data, dict): data = {k: spec.deserialize_fn(v) for k, v in data.items()} - d[sec_name] = data + d[spec.name] = data return AnnData(**d) diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py index c85a27ed0..40fc32263 100644 --- a/tests/test_registered_sections.py +++ b/tests/test_registered_sections.py @@ -25,7 +25,7 @@ @pytest.fixture(autouse=True, scope="module") -def _register_test_sections(): +def _register_test_sections(): # noqa: PLR0912 """Register test sections for all alignment combinations.""" # obs-aligned (like obsm) if "sec_obs" not in ad.AnnData._registered_sections: @@ -181,7 +181,8 @@ def test_register_duplicate_raises(self): register_section("sec_obs", alignment="obs")(type("Dup", (), {})) def test_register_reserved_name_raises(self): - with pytest.raises(AttributeError, match="conflicts with"): + # "obs" is a built-in registered section, so it's already registered + with pytest.raises(ValueError, match="already registered"): register_section("obs", alignment="obs")(type("Bad", (), {})) def test_all_sections_in_registry(self): From 76594d353dbb1bf0f9dc1b20869cc12914cd4bd8 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 30 Mar 2026 18:49:53 -0700 Subject: [PATCH 14/14] fix: raw alignment should be empty, not ('obs',) Raw manages its own subsetting internally (X along obs, var/varm unchanged). The alignment tuple shouldn't imply it behaves like an obs-aligned mapping. --- src/anndata/_core/section_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_core/section_registry.py b/src/anndata/_core/section_registry.py index 173a28d83..833e2ab1e 100644 --- a/src/anndata/_core/section_registry.py +++ b/src/anndata/_core/section_registry.py @@ -341,7 +341,7 @@ def __delete__(self, obj: AnnData) -> None: kind="mapping", builtin=True, ), - SectionSpec(name="raw", alignment=("obs",), io_key="raw", kind="raw", builtin=True), + SectionSpec(name="raw", alignment=(), io_key="raw", kind="raw", builtin=True), ]