Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions xarray/compat/array_api_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import datetime as dt

import numpy as np

from xarray.namedarray.pycompat import array_type

builtin_types = (
bool,
int,
float,
complex,
str,
bytes,
dt.datetime,
dt.timedelta,
)


def is_weak_scalar_type(t):
return isinstance(t, bool | int | float | complex | str | bytes)
Expand Down Expand Up @@ -38,12 +51,15 @@ def _future_array_api_result_type(*arrays_and_dtypes, xp):


def result_type(*arrays_and_dtypes, xp) -> np.dtype:
if xp is np or any(
isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
):
return xp.result_type(*arrays_and_dtypes)
else:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
try:
if xp is np or any(
isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
):
return xp.result_type(*arrays_and_dtypes)
else:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
except TypeError:
return np.dtype(object)


def get_array_namespace(*values):
Expand Down
17 changes: 6 additions & 11 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,11 @@ def should_promote_to_object(
"""
np_result_types = set()
for arr_or_dtype in arrays_and_dtypes:
try:
result_type = array_api_compat.result_type(
maybe_promote_to_variable_width(arr_or_dtype), xp=xp
)
if isinstance(result_type, np.dtype):
np_result_types.add(result_type)
except TypeError:
# passing individual objects to xp.result_type (i.e., what `array_api_compat.result_type` calls) means NEP-18 implementations won't have
# a chance to intercept special values (such as NA) that numpy core cannot handle.
# Thus they are considered as types that don't need promotion i.e., the `arr_or_dtype` that rose the `TypeError` will not contribute to `np_result_types`.
pass
result_type = array_api_compat.result_type(
maybe_promote_to_variable_width(arr_or_dtype), xp=xp
)
if isinstance(result_type, np.dtype):
np_result_types.add(result_type)

if np_result_types:
for left, right in PROMOTE_TO_OBJECT:
Expand Down Expand Up @@ -322,6 +316,7 @@ def result_type(

if should_promote_to_object(arrays_and_dtypes, xp):
return np.dtype(object)

maybe_promote = functools.partial(
maybe_promote_to_variable_width,
# let extension arrays handle their own str/bytes
Expand Down
10 changes: 3 additions & 7 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,9 @@ def maybe_coerce_to_str(index, original_coords):
"""
from xarray.core import dtypes

try:
result_type = dtypes.result_type(*original_coords)
except TypeError:
pass
else:
if result_type.kind in "SU":
index = np.asarray(index, dtype=result_type.type)
result_type = dtypes.result_type(*original_coords)
if result_type.kind in "SU":
index = np.asarray(index, dtype=result_type.type)

return index

Expand Down
4 changes: 4 additions & 0 deletions xarray/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class DummyArrayAPINamespace:
([np.dtype("<U2"), str], np.dtype("U")),
([np.dtype("S3"), np.bytes_], np.dtype("S")),
([np.dtype("S10"), bytes], np.dtype("S")),
([type("Foo", (object,), {"foo": "bar"})()], np.object_),
([np.float32, type("Foo", (object,), {"foo": "bar"})()], np.object_),
([np.str_, type("Foo", (object,), {"foo": "bar"})()], np.object_),
([np.bytes_, type("Foo", (object,), {"foo": "bar"})()], np.object_),
],
)
def test_result_type(args, expected) -> None:
Expand Down
Loading