Skip to content
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ba2cbe9
resolve conflicts with main
selmanozleyen Jan 18, 2026
d951c56
load_obs thing was removed by auto formatting
selmanozleyen Jan 18, 2026
aa348fe
update tests to resolve conflict
selmanozleyen Jan 18, 2026
6220a8a
readthedocs merge
selmanozleyen Jan 18, 2026
309d33e
chore: clarify compatibility of `h5ad` + forward compat of old shuffl…
ilan-gold Jan 19, 2026
437f184
breaking: clarify obs handling + change output keys (#115)
ilan-gold Jan 19, 2026
679ee50
fix: header level (#116)
ilan-gold Jan 19, 2026
ad1ec55
merge changes
selmanozleyen Jan 19, 2026
c8b4395
apply suggestions
selmanozleyen Jan 19, 2026
62d7d48
checkout readme from main
selmanozleyen Jan 19, 2026
d7539bf
breaking: clarify obs handling + change output keys (#115)
ilan-gold Jan 19, 2026
0d7764d
parent 627eb08d699b9cb07ab24fa67775e1e794c07245
selmanozleyen Jan 18, 2026
f7742a1
restore from main
selmanozleyen Jan 19, 2026
3d73e3a
fix: checking out: confused origin and upstream again...
selmanozleyen Jan 19, 2026
dff2a01
continuation of the upstream origin confusion fix
selmanozleyen Jan 19, 2026
0c13efa
breaking: clarify obs handling + change output keys (#115)
ilan-gold Jan 19, 2026
ffe23be
fix: header level (#116)
ilan-gold Jan 19, 2026
e10d300
Merge branch 'main' into feat/sampler
selmanozleyen Jan 19, 2026
5d522fe
refactor _prepare_dataset_and_obs
selmanozleyen Jan 19, 2026
d8168c1
update docstring for loadrequest
selmanozleyen Jan 19, 2026
c501646
separate files for samplers
selmanozleyen Jan 19, 2026
f9862b0
prepare_output is no longer needed
selmanozleyen Jan 19, 2026
6a1153e
clarify docs
selmanozleyen Jan 19, 2026
418f79a
fix overlook: already sorted batch_indices no need to resort them
selmanozleyen Jan 19, 2026
9b786f3
fix prepare_output refactor
selmanozleyen Jan 19, 2026
fc1661e
add todo
selmanozleyen Jan 19, 2026
d764adc
rename from leftover to remainder for clarity. since there is no left…
selmanozleyen Jan 19, 2026
5bc2751
simplify validate_sampler
selmanozleyen Jan 19, 2026
66d5d3c
remove old generic params
selmanozleyen Jan 19, 2026
0e8a472
add broad typing
selmanozleyen Jan 19, 2026
ae3e1bc
clarify todos and add username
selmanozleyen Jan 20, 2026
742605a
type and modify decorator
selmanozleyen Jan 20, 2026
cf30686
no need for lambdas in decorators
selmanozleyen Jan 20, 2026
261c5e8
make decorator compatible in multiple cases
selmanozleyen Jan 20, 2026
4402d4e
put ABC in abc folder
selmanozleyen Jan 20, 2026
0929849
update test with the fix
selmanozleyen Jan 20, 2026
899cc18
qualname for fix. no sampler in public API
selmanozleyen Jan 20, 2026
89e7ccd
check coverage when shuffled otherwise also check order
selmanozleyen Jan 20, 2026
0356374
fix to prev commit
selmanozleyen Jan 20, 2026
87f1ccb
clarify doc
selmanozleyen Jan 20, 2026
8a7f8c2
update worker tests
selmanozleyen Jan 20, 2026
a85634a
new * location for ChunkSampler
selmanozleyen Jan 20, 2026
61efe81
add typing but can revert if too verbose
selmanozleyen Jan 20, 2026
faaf525
remove unused fields. (maybe linter check can be added)
selmanozleyen Jan 20, 2026
eecb0b1
remove old SO link
selmanozleyen Jan 20, 2026
2cd09ca
don't put generators into np.all !!
selmanozleyen Jan 20, 2026
4fcb553
apply typing and docstring suggestion
selmanozleyen Jan 21, 2026
b53d685
change in folder structure
selmanozleyen Jan 21, 2026
d464caa
make batch sampler getter
selmanozleyen Jan 21, 2026
0b4883b
remove empty line
selmanozleyen Jan 21, 2026
c450586
apply docstring suggestions for Loader args
selmanozleyen Jan 21, 2026
15a11b9
remove empty line
selmanozleyen Jan 21, 2026
84c9124
conf.py is same as main
selmanozleyen Jan 21, 2026
c38160d
change shuffle
selmanozleyen Jan 21, 2026
b8472e3
remove todo
selmanozleyen Jan 21, 2026
e620194
update to match old behaviour
selmanozleyen Jan 21, 2026
1b38afd
Merge branch 'main' into feat/sampler
selmanozleyen Jan 21, 2026
7237023
put vstack inside accumulate chunks
selmanozleyen Jan 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@
# If building the documentation fails because of a missing link that is outside your control,
# you can add an exception to this list.
# ("py:class", "igraph.Graph"),
("py:class", "annbatch.types.TypeAliasType")
("py:class", "annbatch.types.TypeAliasType"),
# this is not exposed in the public API yet
("py:class", "annbatch.sampler.abc._sampler.Sampler"),
Comment thread
selmanozleyen marked this conversation as resolved.
Outdated
]

qualname_overrides = {
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ omit = [
"**/test_*.py",
]

[[tool.mypy.overrides]]
module = [ "anndata.*", "cupyx.*", "cupy.*" ]
ignore_missing_imports = true

[tool.cruft]
skip = [
"tests",
Expand Down
10 changes: 8 additions & 2 deletions src/annbatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

from importlib.metadata import version

from . import types
from . import sampler, types
from .io import DatasetCollection, write_sharded
from .loader import Loader

__version__ = version("annbatch")

__all__ = ["Loader", "write_sharded", "DatasetCollection", "types"]
__all__ = [
"Loader",
"DatasetCollection",
"types",
"sampler",
Comment thread
selmanozleyen marked this conversation as resolved.
Outdated
"write_sharded",
]
305 changes: 146 additions & 159 deletions src/annbatch/loader.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions src/annbatch/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Sampler classes for efficient chunk-based data access.

This module provides samplers optimized for chunk-based data access patterns.
"""

from annbatch.sampler import abc
from annbatch.sampler._chunk_sampler import ChunkSampler

__all__ = [
"ChunkSampler",
"abc",
]
182 changes: 182 additions & 0 deletions src/annbatch/sampler/_chunk_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Sampler classes for efficient chunk-based data access."""

from __future__ import annotations

import math
from importlib.util import find_spec
from typing import TYPE_CHECKING

import numpy as np

from annbatch.sampler.abc import Sampler
from annbatch.utils import check_lt_1, split_given_size

if TYPE_CHECKING:
from collections.abc import Iterator

from annbatch.types import LoadRequest
from annbatch.utils import WorkerHandle


class ChunkSampler(Sampler):
"""Chunk-based sampler for batched data access.

Parameters
----------
batch_size
Number of observations per batch.
chunk_size
Size of each chunk i.e. the range of each chunk yielded.
mask
A slice defining the observation range to sample from (start:stop).
shuffle
Whether to shuffle chunk and index order.
preload_nchunks
Number of chunks to load per iteration.
drop_last
Whether to drop the last incomplete batch.
rng
Random number generator for shuffling.
"""

_batch_size: int
_chunk_size: int
_shuffle: bool
_preload_nchunks: int
_mask: slice
_drop_last: bool
_rng: np.random.Generator

def __init__(
self,
chunk_size: int,
preload_nchunks: int,
batch_size: int,
*,
mask: slice | None = None,
shuffle: bool = False,
drop_last: bool = False,
rng: np.random.Generator | None = None,
):
if mask is None:
mask = slice(0, None)
if mask.step is not None and mask.step != 1:
raise ValueError(f"mask.step must be 1, but got {mask.step}")
start, stop = mask.start or 0, mask.stop
if start < 0:
raise ValueError("mask.start must be >= 0")
if stop is not None and start >= stop:
raise ValueError("mask.start must be < mask.stop when mask.stop is specified")

check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"])
preload_size = chunk_size * preload_nchunks

if batch_size > preload_size:
raise ValueError(
"batch_size cannot exceed chunk_size * preload_nchunks. "
f"Got batch_size={batch_size}, but max is {preload_size}."
)
if preload_size % batch_size != 0:
raise ValueError(
"chunk_size * preload_nchunks must be divisible by batch_size. "
f"Got {preload_size} % {batch_size} = {preload_size % batch_size}."
)
self._rng = rng or np.random.default_rng()
self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle
self._preload_nchunks, self._mask, self._drop_last = (
preload_nchunks,
slice(start, stop),
drop_last,
)

def validate(self, n_obs: int) -> None:
"""Validate the sampler configuration against the loader's n_obs.

Parameters
----------
n_obs
The total number of observations in the loader.

Raises
------
ValueError
If the sampler configuration is invalid for the given n_obs.
"""
start, stop = self._mask.start or 0, self._mask.stop or n_obs
if stop > n_obs:
raise ValueError(
f"Sampler mask.stop ({stop}) exceeds loader n_obs ({n_obs}). "
"The sampler range must be within the loader's observations."
)
if start >= stop:
raise ValueError(f"Sampler mask.start ({start}) must be < mask.stop ({stop}).")

def _get_worker_handle(self) -> WorkerHandle | None:
worker_handle = None
if find_spec("torch"):
from torch.utils.data import get_worker_info

from annbatch.utils import WorkerHandle

if get_worker_info() is not None:
worker_handle = WorkerHandle()
# Worker mode validation - only check when there are multiple workers
# With batch_size=1, every batch is exactly 1 item, so no partial batches exist
if (
worker_handle is not None
and worker_handle.num_workers > 1
and not self._drop_last
and self._batch_size != 1
):
raise ValueError("When using DataLoader with multiple workers drop_last=False is not supported.")
return worker_handle

def _sample(self, n_obs: int) -> Iterator[LoadRequest]:
worker_handle = self._get_worker_handle()
start, stop = self._mask.start or 0, self._mask.stop or n_obs
# Compute chunks directly from resolved mask range
# Create chunk indices for possible shuffling and worker sharding
chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size))
if self._shuffle:
# TODO(selmanozleyen): maybe this should be done worker-aware way?
self._rng.shuffle(chunk_indices)
chunks = self._compute_chunks(chunk_indices, start, stop)
# Worker sharding: each worker gets a disjoint subset of chunks
if worker_handle is not None:
chunks = worker_handle.get_part_for_worker(chunks)
# Set up the iterator for chunks and the batch indices for splits
in_memory_size = self._chunk_size * self._preload_nchunks
chunks_per_batch = split_given_size(chunks, self._preload_nchunks)
batch_indices = np.arange(in_memory_size)
split_batch_indices = split_given_size(batch_indices, self._batch_size)
for batch_chunks in chunks_per_batch[:-1]:
if self._shuffle:
# Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training
self._rng.shuffle(batch_indices)
split_batch_indices = split_given_size(batch_indices, self._batch_size)
yield {"chunks": batch_chunks, "splits": split_batch_indices}
# On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size
final_chunks = chunks_per_batch[-1]
total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks))
if self._drop_last:
total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size
batch_indices = split_given_size(
(self._rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch),
self._batch_size,
)
yield {"chunks": final_chunks, "splits": batch_indices}

def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]:
"""Compute chunks from start and stop indices.

This function is used to compute the chunks for the data to load.
The chunks are computed such that the last chunk is the incomplete chunk if the total number of observations is not divisible by the chunk size.
Supposed to also work with shuffled chunk indices so that the last chunk computed isn't always the incomplete chunk.
"""
n_chunks, pivot_index = len(chunk_indices), chunk_indices[-1]
offsets = np.ones(n_chunks + 1, dtype=int) * self._chunk_size
offsets[0] = start
offsets[pivot_index + 1] = incomplete if (incomplete := (stop - start) % self._chunk_size) else self._chunk_size
offsets = np.cumsum(offsets)
starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices]
return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)]
5 changes: 5 additions & 0 deletions src/annbatch/sampler/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from annbatch.sampler.abc._sampler import Sampler

__all__ = [
"Sampler",
]
70 changes: 70 additions & 0 deletions src/annbatch/sampler/abc/_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Sampler classes for efficient chunk-based data access."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Iterator

from annbatch.types import LoadRequest


class Sampler(ABC):
"""Base sampler class.

Samplers control how data is batched and loaded from the underlying datasets.
"""

def sample(self, n_obs: int) -> Iterator[LoadRequest]:
"""Sample load requests given the total number of observations.

Parameters
----------
n_obs
The total number of observations available.

Yields
------
LoadRequest
Load requests for batching data.
"""
self.validate(n_obs)
yield from self._sample(n_obs)

@abstractmethod
def validate(self, n_obs: int) -> None:
"""Validate the sampler configuration against the given n_obs.

This method is called at the start of each `sample()` call.
Override this method to add custom validation for sampler parameters.

Parameters
----------
n_obs
The total number of observations in the loader.

Raises
------
ValueError
If the sampler configuration is invalid for the given n_obs.
"""

@abstractmethod
def _sample(self, n_obs: int) -> Iterator[LoadRequest]:
"""Implementation of the sample method.

This method is called by the sample method to perform the actual sampling after
validation has passed.

Parameters
----------
n_obs
The total number of observations available.

Yields
------
LoadRequest
Load requests for batching data.
"""
19 changes: 19 additions & 0 deletions src/annbatch/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@
type OutputInMemoryArray_T = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray | Tensor


class LoadRequest(TypedDict):
"""Load request from sampler.

This is the request format Loader will expect from the sampler.
Not satisfying the constrains documented here may result in unexpected behavior.

Attributes
----------
chunks
Chunks to load - a list of slices with a range of chunk_size except the last one which may be smaller but not empty.
splits
How the concatenation of chunks should be split into batches.
Comment thread
selmanozleyen marked this conversation as resolved.
Outdated
A list of splits, last one may be partial but not empty i.e. 1 <= len(last_split) <= batch_size.
"""

chunks: list[slice]
splits: list[np.ndarray]


class LoaderOutput[OutputInMemoryArray: OutputInMemoryArray_T](TypedDict):
"""The output of the loader, the "data matrix" with its obs, optional, and index, also optional."""

Expand Down
Loading
Loading