Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@
},
"python.analysis.typeCheckingMode": "basic",
"python.testing.pytestEnabled": true,
"python.testing.pytestArgs": ["-vv", "--color=yes"],
"python.testing.pytestArgs": [
"-vv",
"--color=yes"
],
"cursorpyright.analysis.typeCheckingMode": "basic",
}
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Data loading:
```python
from pathlib import Path

from annbatch import Loader
from annbatch import AnnDataField, Loader
import anndata as ad
import zarr

Expand All @@ -122,7 +122,10 @@ with ad.settings.override(remove_unused_categories=False):
)
for p in Path("path/to/output/collection").glob("*.zarr")
],
obs_keys=["label_column", "batch_column"],
adata_fields={
"label": AnnDataField(attr="obs", key="label_column"),
"batch": AnnDataField(attr="obs", key="batch_column"),
},
)

# Iterate over dataloader (plugin replacement for torch.utils.DataLoader)
Expand Down
6 changes: 5 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ See the [zarr docs on sharding][] for more information.
#### Chunked access

```python
from annbatch import AnnDataField, Loader

ds = Loader(
batch_size=4096,
chunk_size=32,
Expand All @@ -46,7 +48,9 @@ ds = Loader(
)
for p in PATH_TO_STORE.glob("*.zarr")
],
obs_keys="label_column",
adata_fields={
"label": AnnDataField(attr="obs", key="label_column"),
},
)

# Iterate over dataloader (plugin replacement for torch.utils.DataLoader)
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"tags": [
"hide-output"
Expand All @@ -240,7 +240,7 @@
"source": [
"import anndata as ad\n",
"\n",
"from annbatch import Loader\n",
"from annbatch import AnnDataField, Loader\n",
"\n",
"ds = Loader(\n",
" batch_size=4096, # Total number of obs per yielded batch\n",
Expand All @@ -260,7 +260,7 @@
" )\n",
" for p in COLLECTION_PATH.glob(\"*.zarr\")\n",
" ],\n",
" obs_keys=\"cell_type\",\n",
" adata_fields={\"cell_type\": AnnDataField(attr=\"obs\", key=\"cell_type\")},\n",
")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"anndata[lazy]>=0.12.6",
"dask>=2025.9.0",
"dask>=2025.9",
"pandas>=2.2.2,<3",
"scipy>1.15,<1.17",
# for debug logging (referenced from the issue template)
Expand Down
3 changes: 2 additions & 1 deletion src/annbatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from importlib.metadata import version

from . import types
from .fields import AnnDataField
from .io import add_to_collection, create_anndata_collection, write_sharded
from .loader import Loader

__version__ = version("annbatch")

__all__ = ["Loader", "write_sharded", "add_to_collection", "create_anndata_collection", "types"]
__all__ = ["AnnDataField", "Loader", "write_sharded", "add_to_collection", "create_anndata_collection", "types"]
53 changes: 53 additions & 0 deletions src/annbatch/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from dataclasses import dataclass
from operator import attrgetter
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
from collections.abc import Callable

from anndata import AnnData


@dataclass(frozen=True)
class AnnDataField:
"""
Minimal, extensible field accessor for AnnData-like objects.

Mirrors Cellarium's `AnnDataField`: select an AnnData attribute via `attr`,
optionally index into it via `key`, and optionally apply `convert_fn`.
"""

attr: str
key: list[str] | str | None = None
convert_fn: Callable[[Any], Any] | None = None

def __call__(self, adata: AnnData) -> np.ndarray:
"""Extract this field from an AnnData-like object.

`attr` is looked up on `adata` (e.g. ``"X"``, ``"obs"``, ``"layers"``, ``"obsm"``).
If `key` is provided, the attribute is indexed with `key` (e.g. a column in
``obs`` or an entry in ``layers``). If ``convert_fn`` is provided, it is applied
to the selected value; otherwise the selected value is converted via
:func:`numpy.asarray`.

Parameters
----------
adata
AnnData-like object to read from.

Returns
-------
numpy.ndarray
Array representation of the selected field.
"""
value = attrgetter(self.attr)(adata)
if self.key is not None:
value = value[self.key]

if self.convert_fn is not None:
value = self.convert_fn(value)
return np.asarray(value)
Loading
Loading