4040]
4141
4242
43+ def _typestr_has_fp64 (arr_typestr ):
44+ return arr_typestr in ["f8" , "c16" ]
45+
46+
47+ def _typestr_has_fp16 (arr_typestr ):
48+ return arr_typestr in ["f2" ]
49+
50+
4351@pytest .fixture (params = _usm_types_list )
4452def usm_type (request ):
4553 return request .param
@@ -95,6 +103,14 @@ def test_copy1d_c_contig(src_typestr, dst_typestr):
95103 q = dpctl .SyclQueue ()
96104 except dpctl .SyclQueueCreationError :
97105 pytest .skip ("Queue could not be created" )
106+ if not q .sycl_device .has_aspect_fp64 and (
107+ _typestr_has_fp64 (src_typestr ) or _typestr_has_fp64 (dst_typestr )
108+ ):
109+ pytest .skip ("Device does not support double precision" )
110+ if not q .sycl_device .has_aspect_fp16 and (
111+ _typestr_has_fp16 (src_typestr ) or _typestr_has_fp16 (dst_typestr )
112+ ):
113+ pytest .skip ("Device does not support half precision" )
98114 src_dt = np .dtype (src_typestr )
99115 dst_dt = np .dtype (dst_typestr )
100116 Xnp = _random_vector (4096 , src_dt )
@@ -113,6 +129,14 @@ def test_copy1d_strided(src_typestr, dst_typestr):
113129 q = dpctl .SyclQueue ()
114130 except dpctl .SyclQueueCreationError :
115131 pytest .skip ("Queue could not be created" )
132+ if not q .sycl_device .has_aspect_fp64 and (
133+ _typestr_has_fp64 (src_typestr ) or _typestr_has_fp64 (dst_typestr )
134+ ):
135+ pytest .skip ("Device does not support double precision" )
136+ if not q .sycl_device .has_aspect_fp16 and (
137+ _typestr_has_fp16 (src_typestr ) or _typestr_has_fp16 (dst_typestr )
138+ ):
139+ pytest .skip ("Device does not support half precision" )
116140 src_dt = np .dtype (src_typestr )
117141 dst_dt = np .dtype (dst_typestr )
118142 Xnp = _random_vector (4096 , src_dt )
@@ -131,7 +155,12 @@ def test_copy1d_strided(src_typestr, dst_typestr):
131155 assert are_close (Ynp , dpt .asnumpy (Y ))
132156
133157 # now 0-strided source
134- X = dpt .usm_ndarray ((4096 ,), dtype = src_typestr , strides = (0 ,))
158+ X = dpt .usm_ndarray (
159+ (4096 ,),
160+ dtype = src_typestr ,
161+ strides = (0 ,),
162+ buffer_ctor_kwargs = {"queue" : q },
163+ )
135164 X [0 ] = Xnp [0 ]
136165 Y = dpt .empty (X .shape , dtype = dst_typestr , sycl_queue = q )
137166 hev , ev = ti ._copy_usm_ndarray_into_usm_ndarray (src = X , dst = Y , sycl_queue = q )
@@ -145,6 +174,14 @@ def test_copy1d_strided2(src_typestr, dst_typestr):
145174 q = dpctl .SyclQueue ()
146175 except dpctl .SyclQueueCreationError :
147176 pytest .skip ("Queue could not be created" )
177+ if not q .sycl_device .has_aspect_fp64 and (
178+ _typestr_has_fp64 (src_typestr ) or _typestr_has_fp64 (dst_typestr )
179+ ):
180+ pytest .skip ("Device does not support double precision" )
181+ if not q .sycl_device .has_aspect_fp16 and (
182+ _typestr_has_fp16 (src_typestr ) or _typestr_has_fp16 (dst_typestr )
183+ ):
184+ pytest .skip ("Device does not support half precision" )
148185 src_dt = np .dtype (src_typestr )
149186 dst_dt = np .dtype (dst_typestr )
150187 Xnp = _random_vector (4096 , src_dt )
@@ -172,6 +209,14 @@ def test_copy2d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2):
172209 q = dpctl .SyclQueue ()
173210 except dpctl .SyclQueueCreationError :
174211 pytest .skip ("Queue could not be created" )
212+ if not q .sycl_device .has_aspect_fp64 and (
213+ _typestr_has_fp64 (src_typestr ) or _typestr_has_fp64 (dst_typestr )
214+ ):
215+ pytest .skip ("Device does not support double precision" )
216+ if not q .sycl_device .has_aspect_fp16 and (
217+ _typestr_has_fp16 (src_typestr ) or _typestr_has_fp16 (dst_typestr )
218+ ):
219+ pytest .skip ("Device does not support half precision" )
175220
176221 src_dt = np .dtype (src_typestr )
177222 dst_dt = np .dtype (dst_typestr )
@@ -188,16 +233,16 @@ def test_copy2d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2):
188233 slice (None , None , st1 * sgn1 ),
189234 slice (None , None , st2 * sgn2 ),
190235 ]
191- Y = dpt .empty ((n1 , n2 ), dtype = dst_dt )
236+ Y = dpt .empty ((n1 , n2 ), dtype = dst_dt , device = X . device )
192237 hev , ev = ti ._copy_usm_ndarray_into_usm_ndarray (src = X , dst = Y , sycl_queue = q )
193238 Ynp = _force_cast (Xnp , dst_dt )
194239 hev .wait ()
195240 assert are_close (Ynp , dpt .asnumpy (Y ))
196- Yst = dpt .empty ((2 * n1 , n2 ), dtype = dst_dt )[::2 , ::- 1 ]
241+ Yst = dpt .empty ((2 * n1 , n2 ), dtype = dst_dt , device = X . device )[::2 , ::- 1 ]
197242 hev , ev = ti ._copy_usm_ndarray_into_usm_ndarray (
198243 src = X , dst = Yst , sycl_queue = q
199244 )
200- Y = dpt .empty ((n1 , n2 ), dtype = dst_dt )
245+ Y = dpt .empty ((n1 , n2 ), dtype = dst_dt , device = X . device )
201246 hev2 , ev2 = ti ._copy_usm_ndarray_into_usm_ndarray (
202247 src = Yst , dst = Y , sycl_queue = q , depends = [ev ]
203248 )
@@ -220,6 +265,14 @@ def test_copy3d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2, st3, sgn3):
220265 except dpctl .SyclQueueCreationError :
221266 pytest .skip ("Queue could not be created" )
222267
268+ if not q .sycl_device .has_aspect_fp64 and (
269+ _typestr_has_fp64 (src_typestr ) or _typestr_has_fp64 (dst_typestr )
270+ ):
271+ pytest .skip ("Device does not support double precision" )
272+ if not q .sycl_device .has_aspect_fp16 and (
273+ _typestr_has_fp16 (src_typestr ) or _typestr_has_fp16 (dst_typestr )
274+ ):
275+ pytest .skip ("Device does not support half precision" )
223276 src_dt = np .dtype (src_typestr )
224277 dst_dt = np .dtype (dst_typestr )
225278 n1 , n2 , n3 = 5 , 4 , 6
@@ -237,16 +290,16 @@ def test_copy3d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2, st3, sgn3):
237290 slice (None , None , st2 * sgn2 ),
238291 slice (None , None , st3 * sgn3 ),
239292 ]
240- Y = dpt .empty ((n1 , n2 , n3 ), dtype = dst_dt )
293+ Y = dpt .empty ((n1 , n2 , n3 ), dtype = dst_dt , device = X . device )
241294 hev , ev = ti ._copy_usm_ndarray_into_usm_ndarray (src = X , dst = Y , sycl_queue = q )
242295 Ynp = _force_cast (Xnp , dst_dt )
243296 hev .wait ()
244297 assert are_close (Ynp , dpt .asnumpy (Y )), "1"
245- Yst = dpt .empty ((2 * n1 , n2 , n3 ), dtype = dst_dt )[::2 , ::- 1 ]
298+ Yst = dpt .empty ((2 * n1 , n2 , n3 ), dtype = dst_dt , device = X . device )[::2 , ::- 1 ]
246299 hev2 , ev2 = ti ._copy_usm_ndarray_into_usm_ndarray (
247300 src = X , dst = Yst , sycl_queue = q
248301 )
249- Y2 = dpt .empty ((n1 , n2 , n3 ), dtype = dst_dt )
302+ Y2 = dpt .empty ((n1 , n2 , n3 ), dtype = dst_dt , device = X . device )
250303 hev3 , ev3 = ti ._copy_usm_ndarray_into_usm_ndarray (
251304 src = Yst , dst = Y2 , sycl_queue = q , depends = [ev2 ]
252305 )
0 commit comments