diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 2bf9d25a1e6..6739f2118ee 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -253,8 +253,7 @@ def astype(data, dtype, *, xp=None, **kwargs): if xp is None: xp = get_array_namespace(data) - if xp == np: - # numpy currently doesn't have a astype: + if xp is np or not hasattr(xp, "astype"): return data.astype(dtype, **kwargs) return xp.astype(data, dtype, **kwargs) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index ca03c7e67a1..8fa3641c796 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -43,6 +43,7 @@ raise_if_dask_computes, requires_bottleneck, requires_cftime, + requires_cupy, requires_dask, requires_pyarrow, ) @@ -184,6 +185,20 @@ def test_where_extension_duck_array(self, categorical1, categorical2): where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) ).all() + @requires_cupy + def test_where_cupy_duck_array(self): + import cupy as cp + + arr = cp.array([[cp.nan, cp.nan], [2, 3], [4, 5]]) + mask = ~cp.isnan(arr) + da = DataArray(arr, dims=("x", "y"), name="example") + output = da.where(mask, 0) + + expected = np.array([[0, 0], [2, 3], [4, 5]]) + + assert isinstance(output.data, cp.ndarray) + assert_array_equal(output.to_numpy(), expected) + def test_concatenate_extension_duck_array(self, categorical1, categorical2): concate_res = concatenate( [PandasExtensionArray(categorical1), PandasExtensionArray(categorical2)]