@@ -75,12 +75,14 @@ def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
7575 if x .ndim == 1 :
7676 fx = x
7777 else :
78- fx = dpt .reshape (x , (x .size ,), order = "C" , copy = False )
78+ fx = dpt .reshape (x , (x .size ,), order = "C" )
79+ if fx .size == 0 :
80+ return fx
7981 s = dpt .sort (fx )
8082 unique_mask = dpt .empty (fx .shape , dtype = "?" , sycl_queue = exec_q )
8183 dpt .not_equal (s [:- 1 ], s [1 :], out = unique_mask [1 :])
8284 unique_mask [0 ] = True
83- cumsum = dpt .empty (s .shape , dtype = dpt .int64 )
85+ cumsum = dpt .empty (s .shape , dtype = dpt .int64 , sycl_queue = exec_q )
8486 n_uniques = mask_positions (unique_mask , cumsum , sycl_queue = exec_q )
8587 if n_uniques == fx .size :
8688 return s
@@ -127,13 +129,15 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
127129 if x .ndim == 1 :
128130 fx = x
129131 else :
130- fx = dpt .reshape (x , (x .size ,), order = "C" , copy = False )
131- s = dpt .sort (x )
132+ fx = dpt .reshape (x , (x .size ,), order = "C" )
133+ ind_dt = default_device_index_type (exec_q )
134+ if fx .size == 0 :
135+ return UniqueCountsResult (fx , dpt .empty_like (fx , dtype = ind_dt ))
136+ s = dpt .sort (fx )
132137 unique_mask = dpt .empty (s .shape , dtype = "?" , sycl_queue = exec_q )
133138 dpt .not_equal (s [:- 1 ], s [1 :], out = unique_mask [1 :])
134139 unique_mask [0 ] = True
135- ind_dt = default_device_index_type (exec_q )
136- cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 )
140+ cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 , sycl_queue = exec_q )
137141 # synchronizing call
138142 n_uniques = mask_positions (unique_mask , cumsum , sycl_queue = exec_q )
139143 if n_uniques == fx .size :
@@ -195,18 +199,20 @@ def unique_inverse(x):
195199 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
196200 array_api_dev = x .device
197201 exec_q = array_api_dev .sycl_queue
202+ ind_dt = default_device_index_type (exec_q )
198203 if x .ndim == 1 :
199204 fx = x
200205 else :
201- fx = dpt .reshape (x , (x .size ,), order = "C" , copy = False )
202- ind_dt = default_device_index_type (exec_q )
206+ fx = dpt .reshape (x , (x .size ,), order = "C" )
203207 sorting_ids = dpt .argsort (fx )
204208 unsorting_ids = dpt .argsort (sorting_ids )
209+ if fx .size == 0 :
210+ return UniqueInverseResult (fx , dpt .reshape (unsorting_ids , x .shape ))
205211 s = fx [sorting_ids ]
206212 unique_mask = dpt .empty (fx .shape , dtype = "?" , sycl_queue = exec_q )
207213 unique_mask [0 ] = True
208214 dpt .not_equal (s [:- 1 ], s [1 :], out = unique_mask [1 :])
209- cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 )
215+ cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 , sycl_queue = exec_q )
210216 # synchronizing call
211217 n_uniques = mask_positions (unique_mask , cumsum , sycl_queue = exec_q )
212218 if n_uniques == fx .size :
@@ -251,7 +257,9 @@ def unique_inverse(x):
251257 ht_ev , _ = _full_usm_ndarray (fill_value = i , dst = _dst , sycl_queue = exec_q )
252258 ht_ev .wait ()
253259 pos = pos_next
254- return UniqueInverseResult (unique_vals , inv [unsorting_ids ])
260+ return UniqueInverseResult (
261+ unique_vals , dpt .reshape (inv [unsorting_ids ], x .shape )
262+ )
255263
256264
257265def unique_all (x : dpt .usm_ndarray ) -> UniqueAllResult :
@@ -289,22 +297,39 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
289297 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
290298 array_api_dev = x .device
291299 exec_q = array_api_dev .sycl_queue
300+ ind_dt = default_device_index_type (exec_q )
292301 if x .ndim == 1 :
293302 fx = x
294303 else :
295- fx = dpt .reshape (x , (x .size ,), order = "C" , copy = False )
296- ind_dt = default_device_index_type (exec_q )
304+ fx = dpt .reshape (x , (x .size ,), order = "C" )
297305 sorting_ids = dpt .argsort (fx )
298306 unsorting_ids = dpt .argsort (sorting_ids )
307+ if fx .size == 0 :
308+ # original array contains no data
309+ # so it can be safely returned as values
310+ return UniqueAllResult (
311+ fx ,
312+ sorting_ids ,
313+ dpt .reshape (unsorting_ids , x .shape ),
314+ dpt .empty_like (fx , dtype = ind_dt ),
315+ )
299316 s = fx [sorting_ids ]
300317 unique_mask = dpt .empty (fx .shape , dtype = "?" , sycl_queue = exec_q )
301318 dpt .not_equal (s [:- 1 ], s [1 :], out = unique_mask [1 :])
302319 unique_mask [0 ] = True
303- cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 )
320+ cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 , sycl_queue = exec_q )
304321 # synchronizing call
305322 n_uniques = mask_positions (unique_mask , cumsum , sycl_queue = exec_q )
306323 if n_uniques == fx .size :
307- return UniqueInverseResult (s , unsorting_ids )
324+ _counts = dpt .ones (
325+ n_uniques , dtype = ind_dt , usm_type = x .usm_type , sycl_queue = exec_q
326+ )
327+ return UniqueAllResult (
328+ s ,
329+ sorting_ids ,
330+ dpt .reshape (unsorting_ids , x .shape ),
331+ _counts ,
332+ )
308333 unique_vals = dpt .empty (
309334 n_uniques , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
310335 )
@@ -346,6 +371,6 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
346371 return UniqueAllResult (
347372 unique_vals ,
348373 sorting_ids [cum_unique_counts [:- 1 ]],
349- inv [unsorting_ids ],
374+ dpt . reshape ( inv [unsorting_ids ], x . shape ) ,
350375 _counts ,
351376 )
0 commit comments