Skip to content

feat: support array-api#2071

Merged
ilan-gold merged 204 commits intoscverse:mainfrom
amalia-k510:ig/array_api_continue
Feb 6, 2026
Merged

feat: support array-api#2071
ilan-gold merged 204 commits intoscverse:mainfrom
amalia-k510:ig/array_api_continue

Conversation

@amalia-k510
Copy link
Contributor

@amalia-k510 amalia-k510 commented Aug 11, 2025

fixes #1731, fixes #1195, fixes #697

First step in getting anndata concat and test generation to work properly with JAX, (and Cubed potentially), without just converting everything into NumPy.

Random data creation and shape handling use xp.asarray so arrays stay in their original backend where possible. I also updated concat paths to actually check types before converting, added helpers for sparse detection and array API checks, and made sure backend arrays only get turned into NumPy when absolutely necessary. This fixes a bunch of concat-related test failures.

It’s still not perfect. Some pandas calls in concat still force conversion to NumPy, so the data gets copied instead of being used directly. Cubed support is only a placeholder right now. Type detection might still be a bit too broad, which can lead to extra conversions. Works for NumPy and JAX in tests, but I haven’t tried other backends.

Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I just went over general code style, nothing JAX-related

Comment on lines +662 to +663
# Force to NumPy (materializes JAX/Cubed); fine for small tests,
# but may be slow or fail on large/lazy arrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code doesn’t just run for tests though. Also are you sure that this is a good idea for arrays with pandas dtypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was initially forcing everything to NumPy, but that’s no longer the case. I’ve updated it so the it should preserve arrays with pandas dtypes.

return False


def _to_numpy_if_array_api(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be no second copy of this that’s slightly different, only one!

Comment on lines +670 to +675
dest = self._adata_ref._X
# Handles read-only NumPy views from backend arrays like JAX by
# making a writable copy so in-place assignment on views can succeed.
if isinstance(dest, np.ndarray) and not dest.flags.writeable:
dest = np.array(dest, copy=True) # make a fresh, writable buffer
self._adata_ref._X = dest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle

hasattr(x, "dtype") and is_extension_array_dtype(x.dtype)
):
return x
return np.asarray(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok nice this is the right direction no doubt! So what we want here probably is not to rely on asarray but dlpack to do the conversion. In short:

  1. We should have a check in _apply_to_array to see if something is array-api compatible but not a numpy ndarray.
  2. If this case is true, dlpack into numpy, recursively call _apply_to_array
  3. Then use dlpack to take the output of the recursive call to the original type before we went to numpy.

Does that make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a nice paradigm to follow for situations where we have an existing numpy or cupy implementation and it isn't clear how to use the array-api to achieve our aims. We should still try to use it as much as possible so that we can eventually remove numpy codepaths where possible, but this is a nice first step.

@ilan-gold ilan-gold force-pushed the ig/array_api_continue branch from bf5194f to cddb604 Compare February 3, 2026 12:05
Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, I just have some clarifying questions!

# As of 2023 dlpack, it must be possible for a library to export to this, see: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html#array_api.array.__dlpack__
# However, https://github.com/numpy/numpy/issues/20742 means we can't roundtrip jax arrays using dlpack so better to just let numpy do its thing in asarray.
self.add_array(np.asarray(self.get_default()))
res = np.from_dlpack(self._manager[(1, 0)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res = np.from_dlpack(self._manager[(1, 0)])
res = np.from_dlpack(self._manager[1, 0])

@ilan-gold ilan-gold merged commit ed34f6b into scverse:main Feb 6, 2026
24 checks passed
katosh added a commit to settylab/anndata that referenced this pull request Feb 10, 2026
Two-tier detection: tier 1 uses the canonical has_xp() protocol check
from anndata.compat (catches JAX, numpy >=2.0); tier 2 falls back to
duck-typing (shape/dtype/ndim) for arrays that don't yet implement the
full protocol (PyTorch, TensorFlow). Also uses __array_namespace__()
for backend label resolution and updates stale PR scverse#2063scverse#2071.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for sparse arrays (pydata) to .varp and .obsp fields General array-api support Support JAX as a backend

3 participants