Skip to content

Commit 8691627

Browse files
committed
repr: use has_xp() for array-api detection in ArrayAPIFormatter
Two-tier detection: tier 1 uses the canonical has_xp() protocol check from anndata.compat (catches JAX, numpy >=2.0); tier 2 falls back to duck-typing (shape/dtype/ndim) for arrays that don't yet implement the full protocol (PyTorch, TensorFlow). Also uses __array_namespace__() for backend label resolution and updates stale PR scverse#2063scverse#2071.
1 parent 640692e commit 8691627

2 files changed

Lines changed: 113 additions & 114 deletions

File tree

src/anndata/_repr/formatters.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
CSS_NESTED_ANNDATA,
4646
CSS_TEXT_MUTED,
4747
)
48+
from ..compat import has_xp
4849
from .components import render_category_list
4950
from .lazy import get_lazy_categorical_info, is_lazy_column
5051
from .registry import (
@@ -72,6 +73,8 @@
7273
)
7374

7475
if TYPE_CHECKING:
76+
from typing import ClassVar
77+
7578
from .registry import FormatterContext
7679

7780

@@ -765,75 +768,77 @@ class ArrayAPIFormatter(TypeFormatter):
765768
"""
766769
Formatter for Array-API compatible arrays (JAX, PyTorch, TensorFlow, etc.).
767770
768-
Future-proofing notes:
769-
- PR #2063 (https://github.com/scverse/anndata/pull/2063) adds Array-API compatibility
770-
- Handles JAX arrays, PyTorch tensors, TensorFlow tensors, and other array-like objects
771-
- Uses duck typing to detect array-like objects without isinstance checks
772-
- Low priority (50) ensures specific formatters (numpy, cupy, etc.) are tried first
771+
Detection strategy (two tiers):
773772
774-
References:
775-
- Array API Standard: https://data-apis.org/array-api/latest/
773+
1. :func:`~anndata.compat.has_xp` — canonical check for arrays implementing the
774+
`Array API standard <https://data-apis.org/array-api/latest/>`_ (e.g. JAX, CuPy).
775+
2. Duck-typing fallback — catches arrays with ``shape``/``dtype``/``ndim`` that do
776+
not (yet) implement the full protocol (e.g. PyTorch tensors, TensorFlow tensors).
777+
778+
Low priority (50) ensures specific formatters (numpy, cupy, dask, etc.)
779+
are tried first.
776780
"""
777781

778782
priority = 50 # Lower than specific formatters (numpy=110, cupy=120) but higher than builtins
779783

784+
_FRIENDLY_NAMES: ClassVar[dict[str, str]] = {
785+
"jax": "JAX",
786+
"jaxlib": "JAX",
787+
"torch": "PyTorch",
788+
"tensorflow": "TensorFlow",
789+
"tf": "TensorFlow",
790+
"mxnet": "MXNet",
791+
"cupy": "CuPy",
792+
}
793+
794+
# Modules already handled by dedicated formatters (defensive guard;
795+
# the priority system normally prevents reaching this formatter).
796+
_HANDLED_MODULES: ClassVar[tuple[str, ...]] = (
797+
"numpy",
798+
"pandas",
799+
"scipy.sparse",
800+
"cupy",
801+
"cupyx",
802+
"awkward",
803+
"dask",
804+
)
805+
780806
def can_format(self, obj: object, context: FormatterContext) -> bool:
781-
# Duck typing: Check for array-like attributes
782-
# Must have shape, dtype, and ndim (array-api standard)
783-
has_array_attrs = (
784-
hasattr(obj, "shape") and hasattr(obj, "dtype") and hasattr(obj, "ndim")
785-
)
807+
# Tier 1: full Array API protocol (JAX, CuPy ≥12, numpy ≥2.0, …)
808+
if has_xp(obj):
809+
# numpy has its own formatter
810+
return not isinstance(obj, np.ndarray)
786811

787-
if not has_array_attrs:
812+
# Tier 2: duck-typing for arrays that expose shape/dtype/ndim
813+
# but don't implement the full protocol (PyTorch, TensorFlow, …)
814+
if not (
815+
hasattr(obj, "shape") and hasattr(obj, "dtype") and hasattr(obj, "ndim")
816+
):
788817
return False
789818

790-
# Exclude types that already have specific formatters
791-
# This prevents conflicts and ensures more specific formatters take precedence
819+
# Exclude types that have dedicated formatters
792820
module = type(obj).__module__
793-
already_handled = isinstance(
794-
obj, (np.ndarray, pd.DataFrame, pd.Series)
795-
) or module.startswith((
796-
"numpy",
797-
"pandas",
798-
"scipy.sparse",
799-
"cupy", # Has CuPyArrayFormatter
800-
"cupyx",
801-
"awkward", # Has AwkwardArrayFormatter
802-
"dask", # Has DaskArrayFormatter
803-
))
804-
805-
return not already_handled
821+
return not isinstance(obj, np.ndarray) and not module.startswith(
822+
self._HANDLED_MODULES
823+
)
806824

807825
def format(self, obj: object, context: FormatterContext) -> FormattedOutput:
808-
# Extract module and type information
809-
module_name = type(obj).__module__.split(".")[0] # e.g., "jax" from "jax.numpy"
810826
type_name = type(obj).__name__
811-
812827
shape_str = " × ".join(format_number(s) for s in obj.shape)
813828
dtype_str = str(obj.dtype)
814829

815-
# Detect common array-api backends
816-
# Known backends: JAX, PyTorch, TensorFlow, MXNet, etc.
817-
known_backends = {
818-
"jax": "JAX",
819-
"jaxlib": "JAX",
820-
"torch": "PyTorch",
821-
"tensorflow": "TensorFlow",
822-
"tf": "TensorFlow",
823-
"mxnet": "MXNet",
824-
}
825-
backend_label = known_backends.get(module_name, module_name)
826-
827-
# Try to get device information (GPU vs CPU)
830+
# Derive backend label: prefer __array_namespace__, fall back to module name
831+
backend_label = type(obj).__module__.split(".")[0]
832+
with contextlib.suppress(Exception):
833+
xp = obj.__array_namespace__()
834+
ns_name = getattr(xp, "__name__", "") or type(xp).__module__
835+
backend_label = ns_name.split(".")[0]
836+
backend_label = self._FRIENDLY_NAMES.get(backend_label, backend_label)
837+
838+
# Device info (present on array-api arrays; also on PyTorch/CuPy)
828839
device_info = ""
829-
try:
830-
if hasattr(obj, "device"):
831-
device_info = f" on {obj.device}"
832-
elif hasattr(obj, "device_buffer"): # JAX
833-
device_info = f" on {obj.device_buffer.device()}"
834-
except Exception: # noqa: BLE001
835-
# Intentional broad catch: device access varies by backend and can fail
836-
pass
840+
with contextlib.suppress(Exception):
841+
device_info = f" on {obj.device}"
837842

838843
# For obsm/varm sections, show number of columns in preview
839844
preview = None
@@ -1126,7 +1131,7 @@ def _register_builtin_formatters() -> None:
11261131
LazyColumnFormatter(),
11271132
# Low-medium priority (Array-API compatible arrays)
11281133
# Must come after specific array formatters (numpy, cupy, etc.) but before builtins
1129-
# Handles JAX, PyTorch, TensorFlow arrays added in PR #2063
1134+
# Handles JAX, PyTorch, TensorFlow arrays (PR #2071 added array-api support)
11301135
ArrayAPIFormatter(),
11311136
# Low priority (builtins)
11321137
NoneFormatter(),

tests/repr/test_repr_formatters.py

Lines changed: 55 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,33 @@
2020
from .conftest import HAS_AWKWARD, HAS_DASK
2121

2222

23+
def _make_array_api_mock(module: str, *, shape, dtype, device="cpu"):
24+
"""Create a mock array satisfying the SupportsArrayApi protocol.
25+
26+
The mock has all attributes required by ``anndata.types.SupportsArrayApi``
27+
(``shape``, ``device``, ``__array_namespace__``, ``to_device``,
28+
``__dlpack__``, ``__dlpack_device__``) so that ``has_xp()`` returns True.
29+
"""
30+
ns_module = type("Namespace", (), {"__name__": module.split(".")[0]})()
31+
32+
cls = type(
33+
"MockArrayAPI",
34+
(),
35+
{
36+
"shape": shape,
37+
"dtype": dtype,
38+
"ndim": len(shape),
39+
"device": device,
40+
"__array_namespace__": lambda self, **kw: ns_module,
41+
"to_device": lambda self, dev, /, **kw: self,
42+
"__dlpack__": lambda self, **kw: None,
43+
"__dlpack_device__": lambda self: (1, 0),
44+
},
45+
)
46+
cls.__module__ = module
47+
return cls()
48+
49+
2350
class TestNumpyFormatters:
2451
"""Tests for NumPy array formatters."""
2552

@@ -506,39 +533,35 @@ def test_array_api_formatter_jax_like(self):
506533
from anndata._repr.formatters import ArrayAPIFormatter
507534
from anndata._repr.registry import FormatterContext
508535

509-
class MockJAXArray:
510-
def __init__(self):
511-
self.shape = (100, 50)
512-
self.dtype = np.float32
513-
self.ndim = 2
514-
self.device = "gpu:0"
515-
516-
MockJAXArray.__module__ = "jax.numpy"
536+
mock_arr = _make_array_api_mock(
537+
"jax.numpy", shape=(100, 50), dtype=np.float32, device="gpu:0"
538+
)
517539

518540
formatter = ArrayAPIFormatter()
519-
mock_arr = MockJAXArray()
520541

521542
assert formatter.can_format(mock_arr, FormatterContext())
522543
result = formatter.format(mock_arr, FormatterContext())
523544

524-
assert "MockJAXArray" in result.type_name
525545
assert "100" in result.type_name
526546
assert "50" in result.type_name
527547
assert result.css_class == "anndata-dtype--array-api"
528-
assert "JAX" in result.tooltip
548+
assert "JAX" in result.tooltip or "jax" in result.tooltip.lower()
529549
assert "gpu:0" in result.tooltip
530550

531-
def test_array_api_formatter_pytorch_like(self):
532-
"""Test Array-API formatter with PyTorch-like tensor."""
551+
def test_array_api_formatter_pytorch_like_duck_typing(self):
552+
"""Test Array-API formatter with PyTorch-like tensor (duck-typing tier).
553+
554+
PyTorch doesn't implement __array_namespace__, so it falls back
555+
to the duck-typing check (shape/dtype/ndim).
556+
"""
533557
from anndata._repr.formatters import ArrayAPIFormatter
534558
from anndata._repr.registry import FormatterContext
535559

536560
class MockTorchTensor:
537-
def __init__(self):
538-
self.shape = (64, 32)
539-
self.dtype = "torch.float32"
540-
self.ndim = 2
541-
self.device = "cuda:0"
561+
shape = (64, 32)
562+
dtype = "torch.float32"
563+
ndim = 2
564+
device = "cuda:0"
542565

543566
MockTorchTensor.__module__ = "torch"
544567

@@ -548,34 +571,24 @@ def __init__(self):
548571
assert formatter.can_format(mock_tensor, FormatterContext())
549572
result = formatter.format(mock_tensor, FormatterContext())
550573

551-
assert "MockTorchTensor" in result.type_name
552-
assert "PyTorch" in result.tooltip
574+
assert "PyTorch" in result.tooltip or "torch" in result.tooltip.lower()
575+
assert "cuda:0" in result.tooltip
553576

554-
def test_array_api_formatter_device_buffer(self):
555-
"""Test Array-API formatter with device_buffer attribute."""
577+
def test_array_api_formatter_protocol_array(self):
578+
"""Test Array-API formatter with full protocol array (tier 1)."""
556579
from anndata._repr.formatters import ArrayAPIFormatter
557580
from anndata._repr.registry import FormatterContext
558581

559-
class MockDeviceBuffer:
560-
def device(self):
561-
return "tpu:0"
562-
563-
class MockJAXArrayWithBuffer:
564-
def __init__(self):
565-
self.shape = (10, 5)
566-
self.dtype = np.float32
567-
self.ndim = 2
568-
self.device_buffer = MockDeviceBuffer()
569-
570-
MockJAXArrayWithBuffer.__module__ = "jaxlib.xla_extension"
582+
mock_arr = _make_array_api_mock(
583+
"jax.numpy", shape=(10, 5), dtype=np.float32, device="tpu:0"
584+
)
571585

572586
formatter = ArrayAPIFormatter()
573-
mock_arr = MockJAXArrayWithBuffer()
574587

575588
assert formatter.can_format(mock_arr, FormatterContext())
576589
result = formatter.format(mock_arr, FormatterContext())
577590

578-
assert "JAX" in result.tooltip
591+
assert "JAX" in result.tooltip or "jax" in result.tooltip.lower()
579592
assert "tpu:0" in result.tooltip
580593

581594
def test_array_api_formatter_excludes_numpy(self):
@@ -677,22 +690,9 @@ def test_array_api_formatter_with_mock_jax_array(self):
677690
from anndata._repr.formatters import ArrayAPIFormatter
678691
from anndata._repr.registry import FormatterContext
679692

680-
class MockJAXArray:
681-
def __init__(self):
682-
self.shape = (100, 50)
683-
self.dtype = np.dtype("float32")
684-
self.ndim = 2
685-
686-
@property
687-
def __module__(self):
688-
return "jax.numpy"
689-
690-
def __class__(self):
691-
return type("DeviceArray", (), {})
692-
693-
mock_array = MockJAXArray()
694-
type(mock_array).__module__ = "jax.numpy"
695-
type(mock_array).__name__ = "DeviceArray"
693+
mock_array = _make_array_api_mock(
694+
"jax.numpy", shape=(100, 50), dtype=np.dtype("float32"), device="cpu"
695+
)
696696

697697
formatter = ArrayAPIFormatter()
698698
can_format = formatter.can_format(mock_array, FormatterContext())
@@ -987,17 +987,11 @@ def test_array_api_formatter_obsm_preview(self):
987987
from anndata._repr.formatters import ArrayAPIFormatter
988988
from anndata._repr.registry import FormatterContext
989989

990-
# Mock JAX-like array
991-
class MockJAXArray:
992-
def __init__(self):
993-
self.shape = (100, 15)
994-
self.dtype = np.float32
995-
self.ndim = 2
996-
997-
MockJAXArray.__module__ = "jax.numpy"
990+
arr = _make_array_api_mock(
991+
"jax.numpy", shape=(100, 15), dtype=np.float32, device="cpu"
992+
)
998993

999994
formatter = ArrayAPIFormatter()
1000-
arr = MockJAXArray()
1001995

1002996
# No preview outside obsm/varm
1003997
result = formatter.format(arr, FormatterContext(section="uns"))

0 commit comments

Comments
 (0)