Conversation
…a into ig/array_api_continue import merge
flying-sheep
left a comment
There was a problem hiding this comment.
OK, I just went over general code style, nothing JAX-related
src/anndata/_core/merge.py
Outdated
| # Force to NumPy (materializes JAX/Cubed); fine for small tests, | ||
| # but may be slow or fail on large/lazy arrays |
There was a problem hiding this comment.
This code doesn’t just run for tests though. Also are you sure that this is a good idea for arrays with pandas dtypes?
There was a problem hiding this comment.
Yeah, I was initially forcing everything to NumPy, but that’s no longer the case. I’ve updated it so the it should preserve arrays with pandas dtypes.
src/anndata/_core/merge.py
Outdated
| return False | ||
|
|
||
|
|
||
| def _to_numpy_if_array_api(x): |
There was a problem hiding this comment.
there should be no second copy of this that’s slightly different, only one!
src/anndata/_core/anndata.py
Outdated
| dest = self._adata_ref._X | ||
| # Handles read-only NumPy views from backend arrays like JAX by | ||
| # making a writable copy so in-place assignment on views can succeed. | ||
| if isinstance(dest, np.ndarray) and not dest.flags.writeable: | ||
| dest = np.array(dest, copy=True) # make a fresh, writable buffer | ||
| self._adata_ref._X = dest |
There was a problem hiding this comment.
I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle
src/anndata/_core/merge.py
Outdated
| hasattr(x, "dtype") and is_extension_array_dtype(x.dtype) | ||
| ): | ||
| return x | ||
| return np.asarray(x) |
There was a problem hiding this comment.
Ok nice this is the right direction no doubt! So what we want here probably is not to rely on asarray but dlpack to do the conversion. In short:
- We should have a check in
_apply_to_arrayto see if something is array-api compatible but not a numpy ndarray. - If this case is true, dlpack into numpy, recursively call
_apply_to_array - Then use dlpack to take the output of the recursive call to the original type before we went to numpy.
Does that make sense?
There was a problem hiding this comment.
I think this is a nice paradigm to follow for situations where we have an existing numpy or cupy implementation and it isn't clear how to use the array-api to achieve our aims. We should still try to use it as much as possible so that we can eventually remove numpy codepaths where possible, but this is a nice first step.
… with copying introduced as an extra precaution
bf5194f to
cddb604
Compare
flying-sheep
left a comment
There was a problem hiding this comment.
looks great, I just have some clarifying questions!
Co-authored-by: Philipp A. <flying-sheep@web.de>
Co-authored-by: Philipp A. <flying-sheep@web.de>
for more information, see https://pre-commit.ci
| # As of 2023 dlpack, it must be possible for a library to export to this, see: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html#array_api.array.__dlpack__ | ||
| # However, https://github.com/numpy/numpy/issues/20742 means we can't roundtrip jax arrays using dlpack so better to just let numpy do its thing in asarray. | ||
| self.add_array(np.asarray(self.get_default())) | ||
| res = np.from_dlpack(self._manager[(1, 0)]) |
There was a problem hiding this comment.
| res = np.from_dlpack(self._manager[(1, 0)]) | |
| res = np.from_dlpack(self._manager[1, 0]) |
Two-tier detection: tier 1 uses the canonical has_xp() protocol check from anndata.compat (catches JAX, numpy >=2.0); tier 2 falls back to duck-typing (shape/dtype/ndim) for arrays that don't yet implement the full protocol (PyTorch, TensorFlow). Also uses __array_namespace__() for backend label resolution and updates stale PR scverse#2063 → scverse#2071.
fixes #1731, fixes #1195, fixes #697
First step in getting anndata concat and test generation to work properly with JAX, (and Cubed potentially), without just converting everything into NumPy.
Random data creation and shape handling use xp.asarray so arrays stay in their original backend where possible. I also updated concat paths to actually check types before converting, added helpers for sparse detection and array API checks, and made sure backend arrays only get turned into NumPy when absolutely necessary. This fixes a bunch of concat-related test failures.
It’s still not perfect. Some pandas calls in concat still force conversion to NumPy, so the data gets copied instead of being used directly. Cubed support is only a placeholder right now. Type detection might still be a bit too broad, which can lead to extra conversions. Works for NumPy and JAX in tests, but I haven’t tried other backends.