File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1009,9 +1009,16 @@ def _to_strided_memory_view(
10091009) -> tuple [StridedMemoryView , bool ]:
10101010 if _driver .is_device_memory (obj ):
10111011 return obj ._strided_memory_view_shim , False
1012- elif not isinstance (
1013- obj , (np .ndarray , _UNSUPPORTED_DLPACK_TYPES )
1014- ) and hasattr (obj , "__dlpack__" ):
1012+ elif (
1013+ not isinstance (obj , (np .ndarray , _UNSUPPORTED_DLPACK_TYPES ))
1014+ and hasattr (obj , "__dlpack__" )
1015+ and (
1016+ (dtype := getattr (obj , "dtype" , None )) is None
1017+ or not issubclass (
1018+ getattr (dtype , "type" , None ), _UNSUPPORTED_DLPACK_TYPES
1019+ )
1020+ )
1021+ ):
10151022 # numpy arrays need to be copied to the device
10161023 # so we can't view them as SMVs until then
10171024 #
Original file line number Diff line number Diff line change 88from numba .cuda .testing import CUDATestCase , skip_on_cudasim
99import unittest
1010
11+ import pytest
12+
1113
1214class TestCudaDateTime (CUDATestCase ):
1315 def test_basic_datetime_kernel (self ):
@@ -58,6 +60,26 @@ def timediff(start, end):
5860
5961 self .assertPreciseEqual (delta , arr2 - arr1 )
6062
63+ @skip_on_cudasim ("API unsupported in the simulator" )
64+ def test_datetime_cupy_inputs (self ):
65+ cp = pytest .importorskip ("cupy" )
66+ datetime_t = from_dtype (cp .dtype ("datetime64[D]" ))
67+
68+ @cuda .jit
69+ def assign (out , arr ):
70+ for i in range (arr .shape [0 ]):
71+ out [i ] = arr [i ]
72+
73+ # TODO: cupy doesn't allow passing the datetime64[D] array directly
74+ arr = cp .array (
75+ np .arange ("2005-02" , "2006-02" , dtype = "datetime64[D]" ).view ("int64" )
76+ ).view ("datetime64[D]" )
77+
78+ out = cp .empty_like (arr )
79+ assign [1 , 1 ](out , arr )
80+
81+ self .assertPreciseEqual (arr .get (), out .get ())
82+
6183 @skip_on_cudasim ("ufunc API unsupported in the simulator" )
6284 def test_gufunc (self ):
6385 datetime_t = from_dtype (np .dtype ("datetime64[D]" ))
You can’t perform that action at this time.
0 commit comments