Skip to content

Commit c7743b5

Browse files
committed
fix : hack around missing datetime64 in dlpack
1 parent ad89b66 commit c7743b5

2 files changed

Lines changed: 32 additions & 3 deletions

File tree

numba_cuda/numba/cuda/cudadrv/devicearray.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,9 +1009,16 @@ def _to_strided_memory_view(
10091009
) -> tuple[StridedMemoryView, bool]:
10101010
if _driver.is_device_memory(obj):
10111011
return obj._strided_memory_view_shim, False
1012-
elif not isinstance(
1013-
obj, (np.ndarray, _UNSUPPORTED_DLPACK_TYPES)
1014-
) and hasattr(obj, "__dlpack__"):
1012+
elif (
1013+
not isinstance(obj, (np.ndarray, _UNSUPPORTED_DLPACK_TYPES))
1014+
and hasattr(obj, "__dlpack__")
1015+
and (
1016+
(dtype := getattr(obj, "dtype", None)) is None
1017+
or not issubclass(
1018+
getattr(dtype, "type", None), _UNSUPPORTED_DLPACK_TYPES
1019+
)
1020+
)
1021+
):
10151022
# numpy arrays need to be copied to the device
10161023
# so we can't view them as SMVs until then
10171024
#

numba_cuda/numba/cuda/tests/cudapy/test_datetime.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from numba.cuda.testing import CUDATestCase, skip_on_cudasim
99
import unittest
1010

11+
import pytest
12+
1113

1214
class TestCudaDateTime(CUDATestCase):
1315
def test_basic_datetime_kernel(self):
@@ -58,6 +60,26 @@ def timediff(start, end):
5860

5961
self.assertPreciseEqual(delta, arr2 - arr1)
6062

63+
@skip_on_cudasim("API unsupported in the simulator")
64+
def test_datetime_cupy_inputs(self):
65+
cp = pytest.importorskip("cupy")
66+
datetime_t = from_dtype(cp.dtype("datetime64[D]"))
67+
68+
@cuda.jit
69+
def assign(out, arr):
70+
for i in range(arr.shape[0]):
71+
out[i] = arr[i]
72+
73+
# TODO: cupy doesn't allow passing the datetime64[D] array directly
74+
arr = cp.array(
75+
np.arange("2005-02", "2006-02", dtype="datetime64[D]").view("int64")
76+
).view("datetime64[D]")
77+
78+
out = cp.empty_like(arr)
79+
assign[1, 1](out, arr)
80+
81+
self.assertPreciseEqual(arr.get(), out.get())
82+
6183
@skip_on_cudasim("ufunc API unsupported in the simulator")
6284
def test_gufunc(self):
6385
datetime_t = from_dtype(np.dtype("datetime64[D]"))

0 commit comments

Comments
 (0)