From 69985a0dc07b3016ab8b6f3a2ef3f76c11f88990 Mon Sep 17 00:00:00 2001 From: gpalla Date: Wed, 17 Dec 2025 13:33:10 -0800 Subject: [PATCH 1/2] new AnnDataField API for multiple columns (and potentially multiple attributes) from anndata --- .vscode/settings.json | 6 +- README.md | 7 ++- docs/index.md | 6 +- docs/notebooks/example.ipynb | 6 +- src/annbatch/__init__.py | 3 +- src/annbatch/fields.py | 32 +++++++++++ src/annbatch/loader.py | 105 +++++++++++++++++++++++------------ tests/conftest.py | 5 +- tests/test_dataset.py | 101 ++++++++++++++++++++++++--------- 9 files changed, 201 insertions(+), 70 deletions(-) create mode 100644 src/annbatch/fields.py diff --git a/.vscode/settings.json b/.vscode/settings.json index e034b91f..a6bcf494 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,5 +14,9 @@ }, "python.analysis.typeCheckingMode": "basic", "python.testing.pytestEnabled": true, - "python.testing.pytestArgs": ["-vv", "--color=yes"], + "python.testing.pytestArgs": [ + "-vv", + "--color=yes" + ], + "cursorpyright.analysis.typeCheckingMode": "basic", } diff --git a/README.md b/README.md index e48c3347..50be856f 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Data loading: ```python from pathlib import Path -from annbatch import Loader +from annbatch import AnnDataField, Loader import anndata as ad import zarr @@ -122,7 +122,10 @@ with ad.settings.override(remove_unused_categories=False): ) for p in Path("path/to/output/collection").glob("*.zarr") ], - obs_keys=["label_column", "batch_column"], + adata_fields={ + "label": AnnDataField(attr="obs", key="label_column"), + "batch": AnnDataField(attr="obs", key="batch_column"), + }, ) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) diff --git a/docs/index.md b/docs/index.md index 9043ce97..42da7c07 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,6 +33,8 @@ See the [zarr docs on sharding][] for more information. #### Chunked access ```python +from annbatch import AnnDataField, Loader + ds = Loader( batch_size=4096, chunk_size=32, @@ -46,7 +48,9 @@ ds = Loader( ) for p in PATH_TO_STORE.glob("*.zarr") ], - obs_keys="label_column", + adata_fields={ + "label": AnnDataField(attr="obs", key="label_column"), + }, ) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 3f8aef26..414738f7 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -219,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -240,7 +240,7 @@ "source": [ "import anndata as ad\n", "\n", - "from annbatch import Loader\n", + "from annbatch import AnnDataField, Loader\n", "\n", "ds = Loader(\n", " batch_size=4096, # Total number of obs per yielded batch\n", @@ -260,7 +260,7 @@ " )\n", " for p in COLLECTION_PATH.glob(\"*.zarr\")\n", " ],\n", - " obs_keys=\"cell_type\",\n", + " adata_fields={\"cell_type\": AnnDataField(attr=\"obs\", key=\"cell_type\")},\n", ")" ] }, diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 544d9070..abcf67fa 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -1,9 +1,10 @@ from importlib.metadata import version from . import types +from .fields import AnnDataField from .io import add_to_collection, create_anndata_collection, write_sharded from .loader import Loader __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "add_to_collection", "create_anndata_collection", "types"] +__all__ = ["AnnDataField", "Loader", "write_sharded", "add_to_collection", "create_anndata_collection", "types"] diff --git a/src/annbatch/fields.py b/src/annbatch/fields.py new file mode 100644 index 00000000..cf378b6e --- /dev/null +++ b/src/annbatch/fields.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal + +import numpy as np +from anndata import AnnData + + +@dataclass(frozen=True) +class AnnDataField: + """ + Minimal, extensible field accessor for AnnData-like objects. + + This is intentionally small: for now only `attr="obs"` is supported. + The design mirrors Cellarium's `AnnDataField` and can be extended to `X`, + `layers`, `obsm`, etc. + """ + + attr: Literal["obs"] + key: str + convert_fn: Callable[[Any], Any] | None = None + + def __call__(self, adata: AnnData) -> np.ndarray: + if self.attr != "obs": + raise NotImplementedError(f"AnnDataField(attr={self.attr!r}) is not supported yet.") + + value = adata.obs[self.key] + if self.convert_fn is not None: + value = self.convert_fn(value) + return np.asarray(value) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 0e572336..9563e1d1 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -16,6 +16,7 @@ from scipy import sparse as sp from zarr import Array as ZarrArray +from annbatch.fields import AnnDataField from annbatch.types import BackingArray_T, InputInMemoryArray_T, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, @@ -35,8 +36,18 @@ CupyCSRMatrix = NoneType CupyArray = NoneType try: - from torch.utils.data import IterableDataset as _IterableDataset -except ImportError: + import warnings + + with warnings.catch_warnings(): + # Some environments emit a FutureWarning from an optional NVML shim during torch import. + # We don't want `annbatch` imports to fail under test runners that treat warnings as errors. + warnings.filterwarnings( + "ignore", + message=r"The pynvml package is deprecated\..*", + category=FutureWarning, + ) + from torch.utils.data import IterableDataset as _IterableDataset +except (ImportError, Warning): class _IterableDataset: pass @@ -116,7 +127,7 @@ class Loader[BackingArray: BackingArray_T, InputInMemoryArray: InputInMemoryArra """ _train_datasets: list[BackingArray] - _labels: list[np.ndarray] | None = None + _labels: list[dict[str, np.ndarray]] | None = None _return_index: bool = False _batch_size: int = 1 _shapes: list[tuple[int, int]] @@ -234,7 +245,7 @@ def add_anndatas( self, adatas: list[ad.AnnData], layer_keys: list[str | None] | str | None = None, - obs_keys: list[str] | str | None = None, + adata_fields: dict[str, AnnDataField] | None = None, ) -> Self: """Append anndatas to this dataset. @@ -242,17 +253,16 @@ def add_anndatas( ---------- adatas List of :class:`anndata.AnnData` objects, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix. - obs_keys - List of :attr:`anndata.AnnData.obs` column labels. layer_keys List of :attr:`anndata.AnnData.layers` keys, and if None, :attr:`anndata.AnnData.X` will be used. + adata_fields + Mapping from output key to an :class:`~annbatch.AnnDataField` describing how to extract labels. """ self._used_anndata_adder = True if isinstance(layer_keys, str | None): layer_keys = [layer_keys] * len(adatas) - if isinstance(obs_keys, str | None): - obs_keys = [obs_keys] * len(adatas) - elem_to_keys = dict(zip(["layer", "obs"], [layer_keys, obs_keys], strict=True)) + + elem_to_keys = dict(zip(["layer"], [layer_keys], strict=True)) check_lt_1( [len(adatas)] + sum((([len(k)] if k is not None else []) for k in elem_to_keys.values()), []), ["Number of anndatas"] @@ -261,16 +271,15 @@ def add_anndatas( [], ), ) - for adata, obs_key, layer_key in zip(adatas, obs_keys, layer_keys, strict=True): - kwargs = {"obs_key": obs_key, "layer_key": layer_key} - self.add_anndata(adata, **kwargs) + for adata, layer_key in zip(adatas, layer_keys, strict=True): + self.add_anndata(adata, layer_key=layer_key, adata_fields=adata_fields) return self def add_anndata( self, adata: ad.AnnData, layer_key: str | None = None, - obs_key: str | None = None, + adata_fields: dict[str, AnnDataField] | None = None, ) -> Self: """Append an anndata to this dataset. @@ -278,20 +287,25 @@ def add_anndata( ---------- adata A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix. - obs_key - :attr:`anndata.AnnData.obs` column labels. layer_key :attr:`anndata.AnnData.layers` keys, and if None, :attr:`anndata.AnnData.X` will be used. + adata_fields + Mapping from output key to an :class:`~annbatch.AnnDataField` describing how to extract labels. """ self._used_anndata_adder = True dataset = adata.X if layer_key is None else adata.layers[layer_key] if not isinstance(dataset, BackingArray_T.__value__): raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") - obs = adata.obs[obs_key].to_numpy() if obs_key is not None else None + obs: dict[str, np.ndarray] | None + if adata_fields is not None: + obs = {out_key: field(adata) for out_key, field in adata_fields.items()} + else: + obs = None + # `dataset` is runtime-validated above; keep mypy/pyright happy without `cast()`. self.add_dataset(cast("BackingArray", dataset), obs) return self - def add_datasets(self, datasets: list[BackingArray], obs: list[np.ndarray] | None = None) -> Self: + def add_datasets(self, datasets: list[BackingArray], obs: list[dict[str, np.ndarray] | None] | None = None) -> Self: """Append datasets to this dataset. Parameters @@ -300,15 +314,18 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[np.ndarray] | Non List of :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` objects, generally from :attr:`anndata.AnnData.X`. They must all be of the same type and match that of any already added datasets. obs - List of :class:`numpy.ndarray` labels, generally from :attr:`anndata.AnnData.obs`. + List of label dictionaries, generally derived from :attr:`anndata.AnnData.obs`. """ + obs_list: list[dict[str, np.ndarray] | None] if obs is None: - obs = [None] * len(datasets) - for ds, o in zip(datasets, obs, strict=True): + obs_list = [None] * len(datasets) + else: + obs_list = obs + for ds, o in zip(datasets, obs_list, strict=True): self.add_dataset(ds, o) return self - def add_dataset(self, dataset: BackingArray, obs: np.ndarray | None = None) -> Self: + def add_dataset(self, dataset: BackingArray, obs: dict[str, np.ndarray] | None = None) -> Self: """Append a dataset to this dataset. Parameters @@ -316,7 +333,7 @@ def add_dataset(self, dataset: BackingArray, obs: np.ndarray | None = None) -> S dataset A :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` object, generally from :attr:`anndata.AnnData.X`. obs - :class:`numpy.ndarray` labels, generally from :attr:`anndata.AnnData.obs`. + Label dictionary, generally derived from :attr:`anndata.AnnData.obs`. """ if len(self._train_datasets) > 0: if self._labels is None and obs is not None: @@ -331,6 +348,22 @@ def add_dataset(self, dataset: BackingArray, obs: np.ndarray | None = None) -> S raise ValueError( f"All datasets on a given loader must be of the same type {self.dataset_type} but got {type(dataset)}" ) + if obs is not None: + if len(obs) == 0: + raise ValueError("If `obs` is provided it must be a non-empty label dictionary.") + # Ensure all keys are present and lengths match. + for k, v in obs.items(): + if len(v) != dataset.shape[0]: + raise ValueError( + f"Label field {k!r} must have length equal to n_obs ({dataset.shape[0]}), got {len(v)}." + ) + if self._labels is not None: + expected_keys = set(self._labels[0].keys()) + found_keys = set(obs.keys()) + if found_keys != expected_keys: + raise ValueError( + f"All datasets must have the same label keys. Expected {sorted(expected_keys)}, got {sorted(found_keys)}." + ) if not isinstance(dataset, BackingArray_T.__value__): raise TypeError(f"Cannot add dataset of type {type(dataset)}") if isinstance(dataset, ad.abc.CSRDataset) and not dataset.backend == "zarr": @@ -342,7 +375,8 @@ def add_dataset(self, dataset: BackingArray, obs: np.ndarray | None = None) -> S self._shapes = self._shapes + [dataset.shape] self._train_datasets = datasets if self._labels is not None: # labels exist - self._labels += [obs] + # `obs` cannot be None here due to earlier invariants. + self._labels += [cast("dict[str, np.ndarray]", obs)] elif obs is not None: # labels dont exist yet, but are being added for the first time self._labels = [obs] return self @@ -587,7 +621,8 @@ async def _index_datasets( def __iter__( self, ) -> Iterator[ - tuple[OutputInMemoryArray_T, None | np.ndarray] | tuple[OutputInMemoryArray_T, None | np.ndarray, np.ndarray] + tuple[OutputInMemoryArray_T, None | dict[str, np.ndarray]] + | tuple[OutputInMemoryArray_T, None | dict[str, np.ndarray], np.ndarray] ]: """Iterate over the on-disk csr datasets. @@ -628,15 +663,12 @@ def __iter__( else: chunks_converted = [self._np_module.asarray(c) for c in chunks] # Accumulate labels - labels: None | list[np.ndarray] = None + labels: None | list[dict[str, np.ndarray]] = None if self._labels is not None: labels = [] for dataset_idx in dataset_index_to_slices.keys(): - labels += [ - self._labels[dataset_idx][ - np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]]) - ] - ] + idxs = np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]]) + labels.append({k: self._labels[dataset_idx][k][idxs] for k in self._labels[dataset_idx].keys()}) # Accumulate indices if necessary indices: None | list[np.ndarray] = None if self._return_index: @@ -661,9 +693,12 @@ def __iter__( else mod.vstack([in_memory_data, *chunks_converted]) ) if self._labels is not None: - in_memory_labels = ( - np.concatenate(labels) if in_memory_labels is None else np.concatenate([in_memory_labels, *labels]) - ) + if in_memory_labels is None: + in_memory_labels = {k: np.concatenate([d[k] for d in labels]) for k in labels[0].keys()} + else: + in_memory_labels = { + k: np.concatenate([in_memory_labels[k], *[d[k] for d in labels]]) for k in in_memory_labels + } if self._return_index: in_memory_indices = ( np.concatenate(indices) @@ -681,7 +716,7 @@ def __iter__( if s.shape[0] == self._batch_size: res = [ in_memory_data[s], - in_memory_labels[s] if self._labels is not None else None, + ({k: v[s] for k, v in in_memory_labels.items()} if self._labels is not None else None), ] if self._return_index: res += [in_memory_indices[s]] @@ -692,7 +727,7 @@ def __iter__( if (s.shape[0] % self._batch_size) != 0: in_memory_data = in_memory_data[s] if in_memory_labels is not None: - in_memory_labels = in_memory_labels[s] + in_memory_labels = {k: v[s] for k, v in in_memory_labels.items()} if in_memory_indices is not None: in_memory_indices = in_memory_indices[s] else: diff --git a/tests/conftest.py b/tests/conftest.py index aa012e34..90a05062 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,10 @@ def adata_with_zarr_path_same_var_space(tmpdir_factory, n_shards: int = 3) -> Ge adata = ad.AnnData( X=np.random.random((n_cells_per_shard, feature_dim)).astype("f4"), obs=pd.DataFrame( - {"label": np.random.default_rng().integers(0, 5, size=n_cells_per_shard)}, + { + "label": np.random.default_rng().integers(0, 5, size=n_cells_per_shard), + "batch": np.full((n_cells_per_shard,), shard, dtype=np.int32), + }, index=np.arange(n_cells_per_shard).astype(str), ), layers={ diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 180baf86..2761d34d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,7 +2,7 @@ from importlib.util import find_spec from types import NoneType -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, TypedDict, cast import anndata as ad import h5py @@ -12,7 +12,7 @@ import scipy.sparse as sp import zarr -from annbatch import Loader +from annbatch import AnnDataField, Loader try: from cupy import ndarray as CupyArray @@ -28,7 +28,7 @@ class Data(TypedDict): dataset: ad.abc.CSRDataset | zarr.Array - obs: np.ndarray + obs: dict[str, np.ndarray] class ListData: @@ -40,36 +40,39 @@ def open_sparse(path: Path, *, use_zarrs: bool = False, use_anndata: bool = Fals old_pipeline = zarr.config.get("codec_pipeline.path") with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + obs_df = cast("pd.DataFrame", ad.io.read_elem(zarr.open(path)["obs"])) data = { "dataset": ad.io.sparse_dataset(zarr.open(path)["layers"]["sparse"]), - "obs": ad.io.read_elem(zarr.open(path)["obs"])["label"].to_numpy(), + "obs": {"label": obs_df["label"].to_numpy()}, } if use_anndata: - return ad.AnnData(X=data["dataset"], obs=pd.DataFrame({"label": data["obs"]})) - return data + return ad.AnnData(X=data["dataset"], obs=obs_df) + return cast("Data", data) def open_dense(path: Path, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: old_pipeline = zarr.config.get("codec_pipeline.path") with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + obs_df = cast("pd.DataFrame", ad.io.read_elem(zarr.open(path)["obs"])) data = { "dataset": zarr.open(path)["X"], - "obs": ad.io.read_elem(zarr.open(path)["obs"])["label"].to_numpy(), + "obs": {"label": obs_df["label"].to_numpy()}, } if use_anndata: - return ad.AnnData(X=data["dataset"], obs=pd.DataFrame({"label": data["obs"]})) - return data + return ad.AnnData(X=data["dataset"], obs=obs_df) + return cast("Data", data) def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: - return ( + return cast( + "ListData | list[ad.AnnData]", { "datasets": [d["dataset"] for d in datas], "obs": [d["obs"] for d in datas], } if all(isinstance(d, dict) for d in datas) - else datas + else datas, ) @@ -86,7 +89,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: open_func=open_func, batch_size=batch_size, preload_to_gpu=preload_to_gpu, - obs_keys=obs_keys: Loader( + adata_fields=adata_fields: Loader( shuffle=shuffle, chunk_size=chunk_size, preload_nchunks=preload_nchunks, @@ -96,64 +99,68 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: to_torch=False, ).add_anndatas( [open_func(p, use_zarrs=use_zarrs, use_anndata=True) for p in path.glob("*.zarr")], - obs_keys=obs_keys, + adata_fields=adata_fields, ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_type={open_func.__name__[5:]}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-adata_fields={list(adata_fields.keys()) if adata_fields else None}-dataset_type={open_func.__name__[5:]}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=pytest.mark.skipif( find_spec("cupy") is None and preload_to_gpu, reason="need cupy installed", ), ) - for chunk_size, preload_nchunks, obs_keys, open_func, layer_keys, batch_size, preload_to_gpu in [ + for chunk_size, preload_nchunks, open_func, layer_keys, batch_size, preload_to_gpu, adata_fields in [ elem for preload_to_gpu in [True, False] - for obs_keys in [None, "label"] for open_func in [open_sparse, open_dense] + for adata_fields in [ + None, + {"label": AnnDataField(attr="obs", key="label")}, + {"label": AnnDataField(attr="obs", key="label"), "batch": AnnDataField(attr="obs", key="batch")}, + ] for elem in [ [ 1, 5, - obs_keys, open_func, None, 1, preload_to_gpu, + adata_fields, ], # singleton chunk size [ 5, 1, - obs_keys, open_func, None, 1, preload_to_gpu, + adata_fields, ], # singleton preload [ 10, 5, - obs_keys, open_func, None, 5, preload_to_gpu, + adata_fields, ], # batch size divides total in memory size evenly [ 10, 5, - obs_keys, open_func, None, 50, preload_to_gpu, + adata_fields, ], # batch size equal to in-memory size loading [ 10, 5, - obs_keys, open_func, None, 14, preload_to_gpu, + adata_fields, ], # batch size does not divide in memory size evenly ] ] @@ -178,6 +185,7 @@ def test_store_load_dataset( indices = [] expected_data = adata.X if is_dense else adata.layers["sparse"].toarray() for batch in loader: + assert len(batch) == 3 x, label, index = batch n_elems += x.shape[0] # Check feature dimension @@ -195,18 +203,59 @@ def test_store_load_dataset( if not shuffle: np.testing.assert_allclose(stacked, expected_data) if len(labels) > 0: - expected_labels = adata.obs["label"] - np.testing.assert_allclose( - np.concatenate(labels).ravel(), - expected_labels, - ) + # labels are dict[str, np.ndarray] + labels_by_key = {k: np.concatenate([d[k] for d in labels]).ravel() for k in labels[0].keys()} + for k, got in labels_by_key.items(): + np.testing.assert_allclose(got, adata.obs[k].to_numpy()) else: if len(indices) > 0: indices = np.concatenate(indices).ravel() np.testing.assert_allclose(stacked, expected_data[indices]) + if len(labels) > 0: + for k in labels[0].keys(): + got = np.concatenate([d[k] for d in labels]).ravel() + np.testing.assert_allclose(got, adata.obs[k].to_numpy()[indices]) assert n_elems == adata.shape[0] +def test_adata_fields_convert_fn(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], use_zarrs: bool): + # Smoke test that adata_fields controls output key name and applies convert_fn. + path = adata_with_zarr_path_same_var_space[1] + old_pipeline = zarr.config.get("codec_pipeline.path") + with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + adatas = [ + ad.AnnData( + X=zarr.open(p)["X"], + obs=ad.io.read_elem(zarr.open(p)["obs"]), + ) + for p in path.glob("*.zarr") + ] + + ds = Loader( + shuffle=False, + chunk_size=10, + preload_nchunks=2, + return_index=True, + batch_size=7, + preload_to_gpu=False, + to_torch=False, + ).add_anndatas( + adatas, + adata_fields={ + "y": AnnDataField(attr="obs", key="label", convert_fn=lambda s: s.to_numpy(dtype=np.int64) + 1), + "b": AnnDataField(attr="obs", key="batch"), + }, + ) + + batch = next(iter(ds)) + assert len(batch) == 3 + x, labels, idx = batch + assert isinstance(labels, dict) + assert set(labels.keys()) == {"y", "b"} + # spot-check transform applied + np.testing.assert_allclose(labels["y"], adata_with_zarr_path_same_var_space[0].obs["label"].to_numpy()[idx] + 1) + + @pytest.mark.parametrize( "gen_loader", [ From eff85d6c137fb6dd38d156514b373f5898a82699 Mon Sep 17 00:00:00 2001 From: gpalla Date: Wed, 17 Dec 2025 16:13:25 -0800 Subject: [PATCH 2/2] AnnDataField for all fields. --- pyproject.toml | 2 +- src/annbatch/fields.py | 43 +++++++++++++++------ src/annbatch/loader.py | 3 +- tests/test_dataset.py | 3 +- tests/test_fields.py | 75 ++++++++++++++++++++++++++++++++++++ tests/test_store_creation.py | 5 ++- 6 files changed, 116 insertions(+), 15 deletions(-) create mode 100644 tests/test_fields.py diff --git a/pyproject.toml b/pyproject.toml index 808bc6a9..f064922a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dependencies = [ "anndata[lazy]>=0.12.6", - "dask>=2025.9.0", + "dask>=2025.9", "pandas>=2.2.2,<3", "scipy>1.15,<1.17", # for debug logging (referenced from the issue template) diff --git a/src/annbatch/fields.py b/src/annbatch/fields.py index cf378b6e..a72419ac 100644 --- a/src/annbatch/fields.py +++ b/src/annbatch/fields.py @@ -1,11 +1,15 @@ from __future__ import annotations -from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Literal +from operator import attrgetter +from typing import TYPE_CHECKING, Any import numpy as np -from anndata import AnnData + +if TYPE_CHECKING: + from collections.abc import Callable + + from anndata import AnnData @dataclass(frozen=True) @@ -13,20 +17,37 @@ class AnnDataField: """ Minimal, extensible field accessor for AnnData-like objects. - This is intentionally small: for now only `attr="obs"` is supported. - The design mirrors Cellarium's `AnnDataField` and can be extended to `X`, - `layers`, `obsm`, etc. + Mirrors Cellarium's `AnnDataField`: select an AnnData attribute via `attr`, + optionally index into it via `key`, and optionally apply `convert_fn`. """ - attr: Literal["obs"] - key: str + attr: str + key: list[str] | str | None = None convert_fn: Callable[[Any], Any] | None = None def __call__(self, adata: AnnData) -> np.ndarray: - if self.attr != "obs": - raise NotImplementedError(f"AnnDataField(attr={self.attr!r}) is not supported yet.") + """Extract this field from an AnnData-like object. + + `attr` is looked up on `adata` (e.g. ``"X"``, ``"obs"``, ``"layers"``, ``"obsm"``). + If `key` is provided, the attribute is indexed with `key` (e.g. a column in + ``obs`` or an entry in ``layers``). If ``convert_fn`` is provided, it is applied + to the selected value; otherwise the selected value is converted via + :func:`numpy.asarray`. + + Parameters + ---------- + adata + AnnData-like object to read from. + + Returns + ------- + numpy.ndarray + Array representation of the selected field. + """ + value = attrgetter(self.attr)(adata) + if self.key is not None: + value = value[self.key] - value = adata.obs[self.key] if self.convert_fn is not None: value = self.convert_fn(value) return np.asarray(value) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 9563e1d1..d85f734a 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -16,7 +16,6 @@ from scipy import sparse as sp from zarr import Array as ZarrArray -from annbatch.fields import AnnDataField from annbatch.types import BackingArray_T, InputInMemoryArray_T, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, @@ -57,6 +56,8 @@ class _IterableDataset: from collections.abc import Iterator from types import ModuleType + from annbatch.fields import AnnDataField + # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2761d34d..e35dbd33 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -7,7 +7,6 @@ import anndata as ad import h5py import numpy as np -import pandas as pd import pytest import scipy.sparse as sp import zarr @@ -25,6 +24,8 @@ from collections.abc import Callable from pathlib import Path + import pandas as pd + class Data(TypedDict): dataset: ad.abc.CSRDataset | zarr.Array diff --git a/tests/test_fields.py b/tests/test_fields.py new file mode 100644 index 00000000..3dfbb1c5 --- /dev/null +++ b/tests/test_fields.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import zarr + +from annbatch import AnnDataField, Loader + + +def test_anndatafield_works_via_loader(tmp_path): + n_obs, n_vars = 4, 3 + X = np.arange(n_obs * n_vars, dtype=np.float32).reshape(n_obs, n_vars) + counts = X + 1 + X_pca = np.arange(n_obs * 2, dtype=np.float32).reshape(n_obs, 2) + + store = zarr.open_group(tmp_path / "dummy.zarr", mode="w", zarr_format=3) + store.create_array("X", data=X, chunks=(2, n_vars)) + layers_g = store.create_group("layers") + layers_g.create_array("counts", data=counts, chunks=(2, n_vars)) + + obs = pd.DataFrame( + { + "label_int": np.array([0, 1, 2, 3], dtype=np.int64), + "label_str": ["a", "b", "c", "d"], + }, + index=[str(i) for i in range(n_obs)], + ) + adata = ad.AnnData( + X=store["X"], + obs=obs, + layers={"counts": layers_g["counts"]}, + obsm={"X_pca": X_pca}, + ) + + mapping = {"a": 1, "b": 2, "c": 3, "d": 4} + ds = Loader( + shuffle=False, + chunk_size=2, + preload_nchunks=1, + batch_size=2, + return_index=True, + preload_to_gpu=False, + to_torch=False, + ).add_anndatas( + [adata], + adata_fields={ + "label_int": AnnDataField(attr="obs", key="label_int"), + "label_int_str": AnnDataField(attr="obs", key="label_int", convert_fn=lambda s: s.astype(str)), + "label_str_int": AnnDataField(attr="obs", key="label_str", convert_fn=lambda s: s.map(mapping).to_numpy()), + "counts": AnnDataField(attr="layers", key="counts"), + "X_pca": AnnDataField(attr="obsm", key="X_pca"), + }, + ) + + xs, idxs = [], [] + labels = {k: [] for k in ["label_int", "label_int_str", "label_str_int", "counts", "X_pca"]} + for x, y, idx in ds: + xs.append(x) + idxs.append(idx) + for k in labels: + labels[k].append(y[k]) + + idxs = np.concatenate(idxs).ravel() + np.testing.assert_array_equal(np.vstack(xs), X) + np.testing.assert_array_equal(np.concatenate(labels["label_int"]).ravel(), obs["label_int"].to_numpy()) + np.testing.assert_array_equal( + np.concatenate(labels["label_int_str"]).ravel(), obs["label_int"].astype(str).to_numpy() + ) + np.testing.assert_array_equal( + np.concatenate(labels["label_str_int"]).ravel(), obs["label_str"].map(mapping).to_numpy() + ) + np.testing.assert_array_equal(np.vstack(labels["counts"]), counts) + np.testing.assert_array_equal(np.vstack(labels["X_pca"]), X_pca) + np.testing.assert_array_equal(idxs, np.arange(n_obs)) diff --git a/tests/test_store_creation.py b/tests/test_store_creation.py index 5e9f099c..3b63ce38 100644 --- a/tests/test_store_creation.py +++ b/tests/test_store_creation.py @@ -40,7 +40,10 @@ def test_write_sharded_shard_size_too_big(tmp_path: Path, chunk_size: int, expec def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "layers", "raw"], tmp_path: Path): adata_1 = ad.AnnData(X=np.random.randn(10, 20)) extra_args = { - elem_name: {"arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else ["a"] * 10} + # For `obs`, avoid a string column which outer-joins into mixed str/NaN and can trip Zarr's vlen-utf8 encoder. + elem_name: { + "arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else np.arange(10) + } } adata_2 = ad.AnnData(X=np.random.randn(10, 20), **extra_args) path_1 = tmp_path / "just_x.h5ad"