@@ -292,11 +292,16 @@ def unique_inverse(x):
292292 array_api_dev = x .device
293293 exec_q = array_api_dev .sycl_queue
294294 x_usm_type = x .usm_type
295- if x .ndim == 1 :
295+ ind_dt = default_device_index_type (exec_q )
296+ if x .ndim == 0 :
297+ return UniqueInverseResult (
298+ dpt .reshape (x , (1 ,), order = "C" , copy = True ),
299+ dpt .zeros_like (x , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
300+ )
301+ elif x .ndim == 1 :
296302 fx = x
297303 else :
298304 fx = dpt .reshape (x , (x .size ,), order = "C" )
299- ind_dt = default_device_index_type (exec_q )
300305 sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
301306 unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
302307 if fx .size == 0 :
@@ -456,13 +461,24 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
456461 array_api_dev = x .device
457462 exec_q = array_api_dev .sycl_queue
458463 x_usm_type = x .usm_type
459- if x .ndim == 1 :
464+ ind_dt = default_device_index_type (exec_q )
465+ if x .ndim == 0 :
466+ uv = dpt .reshape (x , (1 ,), order = "C" , copy = True )
467+ return UniqueAllResult (
468+ uv ,
469+ dpt .zeros_like (uv , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
470+ dpt .zeros_like (x , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
471+ dpt .ones_like (
472+ uv , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
473+ ),
474+ )
475+ elif x .ndim == 1 :
460476 fx = x
461477 else :
462478 fx = dpt .reshape (x , (x .size ,), order = "C" )
463- ind_dt = default_device_index_type (exec_q )
464479 sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
465480 unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
481+ print (unsorting_ids )
466482 if fx .size == 0 :
467483 # original array contains no data
468484 # so it can be safely returned as values
0 commit comments