-
Notifications
You must be signed in to change notification settings - Fork 3
feat: Sampler interface and ability to add custom samplers #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 46 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
ba2cbe9
resolve conflicts with main
selmanozleyen d951c56
load_obs thing was removed by auto formatting
selmanozleyen aa348fe
update tests to resolve conflict
selmanozleyen 6220a8a
readthedocs merge
selmanozleyen 309d33e
chore: clarify compatibility of `h5ad` + forward compat of old shuffl…
ilan-gold 437f184
breaking: clarify obs handling + change output keys (#115)
ilan-gold 679ee50
fix: header level (#116)
ilan-gold ad1ec55
merge changes
selmanozleyen c8b4395
apply suggestions
selmanozleyen 62d7d48
checkout readme from main
selmanozleyen d7539bf
breaking: clarify obs handling + change output keys (#115)
ilan-gold 0d7764d
parent 627eb08d699b9cb07ab24fa67775e1e794c07245
selmanozleyen f7742a1
restore from main
selmanozleyen 3d73e3a
fix: checking out: confused origin and upstream again...
selmanozleyen dff2a01
continuation of the upstream origin confusion fix
selmanozleyen 0c13efa
breaking: clarify obs handling + change output keys (#115)
ilan-gold ffe23be
fix: header level (#116)
ilan-gold e10d300
Merge branch 'main' into feat/sampler
selmanozleyen 5d522fe
refactor _prepare_dataset_and_obs
selmanozleyen d8168c1
update docstring for loadrequest
selmanozleyen c501646
separate files for samplers
selmanozleyen f9862b0
prepare_output is no longer needed
selmanozleyen 6a1153e
clarify docs
selmanozleyen 418f79a
fix overlook: already sorted batch_indices no need to resort them
selmanozleyen 9b786f3
fix prepare_output refactor
selmanozleyen fc1661e
add todo
selmanozleyen d764adc
rename from leftover to remainder for clarity. since there is no left…
selmanozleyen 5bc2751
simplify validate_sampler
selmanozleyen 66d5d3c
remove old generic params
selmanozleyen 0e8a472
add broad typing
selmanozleyen ae3e1bc
clarify todos and add username
selmanozleyen 742605a
type and modify decorator
selmanozleyen cf30686
no need for lambdas in decorators
selmanozleyen 261c5e8
make decorator compatible in multiple cases
selmanozleyen 4402d4e
put ABC in abc folder
selmanozleyen 0929849
update test with the fix
selmanozleyen 899cc18
qualname for fix. no sampler in public API
selmanozleyen 89e7ccd
check coverage when shuffled otherwise also check order
selmanozleyen 0356374
fix to prev commit
selmanozleyen 87f1ccb
clarify doc
selmanozleyen 8a7f8c2
update worker tests
selmanozleyen a85634a
new * location for ChunkSampler
selmanozleyen 61efe81
add typing but can revert if too verbose
selmanozleyen faaf525
remove unused fields. (maybe linter check can be added)
selmanozleyen eecb0b1
remove old SO link
selmanozleyen 2cd09ca
don't put generators into np.all !!
selmanozleyen 4fcb553
apply typing and docstring suggestion
selmanozleyen b53d685
change in folder structure
selmanozleyen d464caa
make batch sampler getter
selmanozleyen 0b4883b
remove empty line
selmanozleyen c450586
apply docstring suggestions for Loader args
selmanozleyen 15a11b9
remove empty line
selmanozleyen 84c9124
conf.py is same as main
selmanozleyen c38160d
change shuffle
selmanozleyen b8472e3
remove todo
selmanozleyen e620194
update to match old behaviour
selmanozleyen 1b38afd
Merge branch 'main' into feat/sampler
selmanozleyen 7237023
put vstack inside accumulate chunks
selmanozleyen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| """Sampler classes for efficient chunk-based data access. | ||
|
|
||
| This module provides samplers optimized for chunk-based data access patterns. | ||
| """ | ||
|
|
||
| from annbatch.sampler import abc | ||
| from annbatch.sampler._chunk_sampler import ChunkSampler | ||
|
|
||
| __all__ = [ | ||
| "ChunkSampler", | ||
| "abc", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| """Sampler classes for efficient chunk-based data access.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from importlib.util import find_spec | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
|
|
||
| from annbatch.sampler.abc import Sampler | ||
| from annbatch.utils import check_lt_1, split_given_size | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Iterator | ||
|
|
||
| from annbatch.types import LoadRequest | ||
| from annbatch.utils import WorkerHandle | ||
|
|
||
|
|
||
| class ChunkSampler(Sampler): | ||
| """Chunk-based sampler for batched data access. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch_size | ||
| Number of observations per batch. | ||
| chunk_size | ||
| Size of each chunk i.e. the range of each chunk yielded. | ||
| mask | ||
| A slice defining the observation range to sample from (start:stop). | ||
| shuffle | ||
| Whether to shuffle chunk and index order. | ||
| preload_nchunks | ||
| Number of chunks to load per iteration. | ||
| drop_last | ||
| Whether to drop the last incomplete batch. | ||
| rng | ||
| Random number generator for shuffling. | ||
| """ | ||
|
|
||
| _batch_size: int | ||
| _chunk_size: int | ||
| _shuffle: bool | ||
| _preload_nchunks: int | ||
| _mask: slice | ||
| _drop_last: bool | ||
| _rng: np.random.Generator | ||
|
|
||
| def __init__( | ||
| self, | ||
| chunk_size: int, | ||
| preload_nchunks: int, | ||
| batch_size: int, | ||
| *, | ||
| mask: slice | None = None, | ||
| shuffle: bool = False, | ||
| drop_last: bool = False, | ||
| rng: np.random.Generator | None = None, | ||
| ): | ||
| if mask is None: | ||
| mask = slice(0, None) | ||
| if mask.step is not None and mask.step != 1: | ||
| raise ValueError(f"mask.step must be 1, but got {mask.step}") | ||
| start, stop = mask.start or 0, mask.stop | ||
| if start < 0: | ||
| raise ValueError("mask.start must be >= 0") | ||
| if stop is not None and start >= stop: | ||
| raise ValueError("mask.start must be < mask.stop when mask.stop is specified") | ||
|
|
||
| check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) | ||
| preload_size = chunk_size * preload_nchunks | ||
|
|
||
| if batch_size > preload_size: | ||
| raise ValueError( | ||
| "batch_size cannot exceed chunk_size * preload_nchunks. " | ||
| f"Got batch_size={batch_size}, but max is {preload_size}." | ||
| ) | ||
| if preload_size % batch_size != 0: | ||
| raise ValueError( | ||
| "chunk_size * preload_nchunks must be divisible by batch_size. " | ||
| f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." | ||
| ) | ||
| self._rng = rng or np.random.default_rng() | ||
| self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle | ||
| self._preload_nchunks, self._mask, self._drop_last = ( | ||
| preload_nchunks, | ||
| slice(start, stop), | ||
| drop_last, | ||
| ) | ||
|
|
||
| def validate(self, n_obs: int) -> None: | ||
| """Validate the sampler configuration against the loader's n_obs. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_obs | ||
| The total number of observations in the loader. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If the sampler configuration is invalid for the given n_obs. | ||
| """ | ||
| start, stop = self._mask.start or 0, self._mask.stop or n_obs | ||
| if stop > n_obs: | ||
| raise ValueError( | ||
| f"Sampler mask.stop ({stop}) exceeds loader n_obs ({n_obs}). " | ||
| "The sampler range must be within the loader's observations." | ||
| ) | ||
| if start >= stop: | ||
| raise ValueError(f"Sampler mask.start ({start}) must be < mask.stop ({stop}).") | ||
|
|
||
| def _get_worker_handle(self) -> WorkerHandle | None: | ||
| worker_handle = None | ||
| if find_spec("torch"): | ||
| from torch.utils.data import get_worker_info | ||
|
|
||
| from annbatch.utils import WorkerHandle | ||
|
|
||
| if get_worker_info() is not None: | ||
| worker_handle = WorkerHandle() | ||
| # Worker mode validation - only check when there are multiple workers | ||
| # With batch_size=1, every batch is exactly 1 item, so no partial batches exist | ||
| if ( | ||
| worker_handle is not None | ||
| and worker_handle.num_workers > 1 | ||
| and not self._drop_last | ||
| and self._batch_size != 1 | ||
| ): | ||
| raise ValueError("When using DataLoader with multiple workers drop_last=False is not supported.") | ||
| return worker_handle | ||
|
|
||
| def _sample(self, n_obs: int) -> Iterator[LoadRequest]: | ||
| worker_handle = self._get_worker_handle() | ||
| start, stop = self._mask.start or 0, self._mask.stop or n_obs | ||
| # Compute chunks directly from resolved mask range | ||
| # Create chunk indices for possible shuffling and worker sharding | ||
| chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) | ||
| if self._shuffle: | ||
| # TODO(selmanozleyen): maybe this should be done worker-aware way? | ||
| self._rng.shuffle(chunk_indices) | ||
| chunks = self._compute_chunks(chunk_indices, start, stop) | ||
| # Worker sharding: each worker gets a disjoint subset of chunks | ||
| if worker_handle is not None: | ||
| chunks = worker_handle.get_part_for_worker(chunks) | ||
| # Set up the iterator for chunks and the batch indices for splits | ||
| in_memory_size = self._chunk_size * self._preload_nchunks | ||
| chunks_per_batch = split_given_size(chunks, self._preload_nchunks) | ||
| batch_indices = np.arange(in_memory_size) | ||
| split_batch_indices = split_given_size(batch_indices, self._batch_size) | ||
| for batch_chunks in chunks_per_batch[:-1]: | ||
| if self._shuffle: | ||
| # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training | ||
| self._rng.shuffle(batch_indices) | ||
| split_batch_indices = split_given_size(batch_indices, self._batch_size) | ||
| yield {"chunks": batch_chunks, "splits": split_batch_indices} | ||
| # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size | ||
| final_chunks = chunks_per_batch[-1] | ||
| total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) | ||
| if self._drop_last: | ||
| total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size | ||
| batch_indices = split_given_size( | ||
| (self._rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), | ||
| self._batch_size, | ||
| ) | ||
| yield {"chunks": final_chunks, "splits": batch_indices} | ||
|
|
||
| def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]: | ||
| """Compute chunks from start and stop indices. | ||
|
|
||
| This function is used to compute the chunks for the data to load. | ||
| The chunks are computed such that the last chunk is the incomplete chunk if the total number of observations is not divisible by the chunk size. | ||
| Supposed to also work with shuffled chunk indices so that the last chunk computed isn't always the incomplete chunk. | ||
| """ | ||
| n_chunks, pivot_index = len(chunk_indices), chunk_indices[-1] | ||
| offsets = np.ones(n_chunks + 1, dtype=int) * self._chunk_size | ||
| offsets[0] = start | ||
| offsets[pivot_index + 1] = incomplete if (incomplete := (stop - start) % self._chunk_size) else self._chunk_size | ||
| offsets = np.cumsum(offsets) | ||
| starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices] | ||
| return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from annbatch.sampler.abc._sampler import Sampler | ||
|
|
||
| __all__ = [ | ||
| "Sampler", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| """Sampler classes for efficient chunk-based data access.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Iterator | ||
|
|
||
| from annbatch.types import LoadRequest | ||
|
|
||
|
|
||
| class Sampler(ABC): | ||
| """Base sampler class. | ||
|
|
||
| Samplers control how data is batched and loaded from the underlying datasets. | ||
| """ | ||
|
|
||
| def sample(self, n_obs: int) -> Iterator[LoadRequest]: | ||
| """Sample load requests given the total number of observations. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_obs | ||
| The total number of observations available. | ||
|
|
||
| Yields | ||
| ------ | ||
| LoadRequest | ||
| Load requests for batching data. | ||
| """ | ||
| self.validate(n_obs) | ||
| yield from self._sample(n_obs) | ||
|
|
||
| @abstractmethod | ||
| def validate(self, n_obs: int) -> None: | ||
| """Validate the sampler configuration against the given n_obs. | ||
|
|
||
| This method is called at the start of each `sample()` call. | ||
| Override this method to add custom validation for sampler parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_obs | ||
| The total number of observations in the loader. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If the sampler configuration is invalid for the given n_obs. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def _sample(self, n_obs: int) -> Iterator[LoadRequest]: | ||
| """Implementation of the sample method. | ||
|
|
||
| This method is called by the sample method to perform the actual sampling after | ||
| validation has passed. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_obs | ||
| The total number of observations available. | ||
|
|
||
| Yields | ||
| ------ | ||
| LoadRequest | ||
| Load requests for batching data. | ||
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.