2020from .conftest import HAS_AWKWARD , HAS_DASK
2121
2222
23+ def _make_array_api_mock (module : str , * , shape , dtype , device = "cpu" ):
24+ """Create a mock array satisfying the SupportsArrayApi protocol.
25+
26+ The mock has all attributes required by ``anndata.types.SupportsArrayApi``
27+ (``shape``, ``device``, ``__array_namespace__``, ``to_device``,
28+ ``__dlpack__``, ``__dlpack_device__``) so that ``has_xp()`` returns True.
29+ """
30+ ns_module = type ("Namespace" , (), {"__name__" : module .split ("." )[0 ]})()
31+
32+ cls = type (
33+ "MockArrayAPI" ,
34+ (),
35+ {
36+ "shape" : shape ,
37+ "dtype" : dtype ,
38+ "ndim" : len (shape ),
39+ "device" : device ,
40+ "__array_namespace__" : lambda self , ** kw : ns_module ,
41+ "to_device" : lambda self , dev , / , ** kw : self ,
42+ "__dlpack__" : lambda self , ** kw : None ,
43+ "__dlpack_device__" : lambda self : (1 , 0 ),
44+ },
45+ )
46+ cls .__module__ = module
47+ return cls ()
48+
49+
2350class TestNumpyFormatters :
2451 """Tests for NumPy array formatters."""
2552
@@ -506,39 +533,35 @@ def test_array_api_formatter_jax_like(self):
506533 from anndata ._repr .formatters import ArrayAPIFormatter
507534 from anndata ._repr .registry import FormatterContext
508535
509- class MockJAXArray :
510- def __init__ (self ):
511- self .shape = (100 , 50 )
512- self .dtype = np .float32
513- self .ndim = 2
514- self .device = "gpu:0"
515-
516- MockJAXArray .__module__ = "jax.numpy"
536+ mock_arr = _make_array_api_mock (
537+ "jax.numpy" , shape = (100 , 50 ), dtype = np .float32 , device = "gpu:0"
538+ )
517539
518540 formatter = ArrayAPIFormatter ()
519- mock_arr = MockJAXArray ()
520541
521542 assert formatter .can_format (mock_arr , FormatterContext ())
522543 result = formatter .format (mock_arr , FormatterContext ())
523544
524- assert "MockJAXArray" in result .type_name
525545 assert "100" in result .type_name
526546 assert "50" in result .type_name
527547 assert result .css_class == "anndata-dtype--array-api"
528- assert "JAX" in result .tooltip
548+ assert "JAX" in result .tooltip or "jax" in result . tooltip . lower ()
529549 assert "gpu:0" in result .tooltip
530550
531- def test_array_api_formatter_pytorch_like (self ):
532- """Test Array-API formatter with PyTorch-like tensor."""
551+ def test_array_api_formatter_pytorch_like_duck_typing (self ):
552+ """Test Array-API formatter with PyTorch-like tensor (duck-typing tier).
553+
554+ PyTorch doesn't implement __array_namespace__, so it falls back
555+ to the duck-typing check (shape/dtype/ndim).
556+ """
533557 from anndata ._repr .formatters import ArrayAPIFormatter
534558 from anndata ._repr .registry import FormatterContext
535559
536560 class MockTorchTensor :
537- def __init__ (self ):
538- self .shape = (64 , 32 )
539- self .dtype = "torch.float32"
540- self .ndim = 2
541- self .device = "cuda:0"
561+ shape = (64 , 32 )
562+ dtype = "torch.float32"
563+ ndim = 2
564+ device = "cuda:0"
542565
543566 MockTorchTensor .__module__ = "torch"
544567
@@ -548,34 +571,24 @@ def __init__(self):
548571 assert formatter .can_format (mock_tensor , FormatterContext ())
549572 result = formatter .format (mock_tensor , FormatterContext ())
550573
551- assert "MockTorchTensor " in result .type_name
552- assert "PyTorch " in result .tooltip
574+ assert "PyTorch " in result .tooltip or "torch" in result . tooltip . lower ()
575+ assert "cuda:0 " in result .tooltip
553576
554- def test_array_api_formatter_device_buffer (self ):
555- """Test Array-API formatter with device_buffer attribute ."""
577+ def test_array_api_formatter_protocol_array (self ):
578+ """Test Array-API formatter with full protocol array (tier 1) ."""
556579 from anndata ._repr .formatters import ArrayAPIFormatter
557580 from anndata ._repr .registry import FormatterContext
558581
559- class MockDeviceBuffer :
560- def device (self ):
561- return "tpu:0"
562-
563- class MockJAXArrayWithBuffer :
564- def __init__ (self ):
565- self .shape = (10 , 5 )
566- self .dtype = np .float32
567- self .ndim = 2
568- self .device_buffer = MockDeviceBuffer ()
569-
570- MockJAXArrayWithBuffer .__module__ = "jaxlib.xla_extension"
582+ mock_arr = _make_array_api_mock (
583+ "jax.numpy" , shape = (10 , 5 ), dtype = np .float32 , device = "tpu:0"
584+ )
571585
572586 formatter = ArrayAPIFormatter ()
573- mock_arr = MockJAXArrayWithBuffer ()
574587
575588 assert formatter .can_format (mock_arr , FormatterContext ())
576589 result = formatter .format (mock_arr , FormatterContext ())
577590
578- assert "JAX" in result .tooltip
591+ assert "JAX" in result .tooltip or "jax" in result . tooltip . lower ()
579592 assert "tpu:0" in result .tooltip
580593
581594 def test_array_api_formatter_excludes_numpy (self ):
@@ -677,22 +690,9 @@ def test_array_api_formatter_with_mock_jax_array(self):
677690 from anndata ._repr .formatters import ArrayAPIFormatter
678691 from anndata ._repr .registry import FormatterContext
679692
680- class MockJAXArray :
681- def __init__ (self ):
682- self .shape = (100 , 50 )
683- self .dtype = np .dtype ("float32" )
684- self .ndim = 2
685-
686- @property
687- def __module__ (self ):
688- return "jax.numpy"
689-
690- def __class__ (self ):
691- return type ("DeviceArray" , (), {})
692-
693- mock_array = MockJAXArray ()
694- type(mock_array ).__module__ = "jax.numpy"
695- type(mock_array ).__name__ = "DeviceArray"
693+ mock_array = _make_array_api_mock (
694+ "jax.numpy" , shape = (100 , 50 ), dtype = np .dtype ("float32" ), device = "cpu"
695+ )
696696
697697 formatter = ArrayAPIFormatter ()
698698 can_format = formatter .can_format (mock_array , FormatterContext ())
@@ -987,17 +987,11 @@ def test_array_api_formatter_obsm_preview(self):
987987 from anndata ._repr .formatters import ArrayAPIFormatter
988988 from anndata ._repr .registry import FormatterContext
989989
990- # Mock JAX-like array
991- class MockJAXArray :
992- def __init__ (self ):
993- self .shape = (100 , 15 )
994- self .dtype = np .float32
995- self .ndim = 2
996-
997- MockJAXArray .__module__ = "jax.numpy"
990+ arr = _make_array_api_mock (
991+ "jax.numpy" , shape = (100 , 15 ), dtype = np .float32 , device = "cpu"
992+ )
998993
999994 formatter = ArrayAPIFormatter ()
1000- arr = MockJAXArray ()
1001995
1002996 # No preview outside obsm/varm
1003997 result = formatter .format (arr , FormatterContext (section = "uns" ))
0 commit comments