@@ -422,19 +422,23 @@ def stack(arrays, axis=0):
422422 return res
423423
424424
425- def can_cast (array_and_dtype_from , dtype_to , casting = "safe" ):
425+ def can_cast (from_ , to , casting = "safe" ):
426426 """
427427 can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
428428
429429 Determines if one data type can be cast to another data type according \
430430 Type Promotion Rules rules.
431431 """
432- if not isinstance (dtype_to , dpt .dtype ):
432+ if isinstance (to , dpt .usm_ndarray ):
433433 raise TypeError ("Expected dtype type." )
434434
435- dtype_from = dpt .dtype (array_and_dtype_from )
435+ dtype_to = dpt .dtype (to )
436436
437- _supported_dtype ([dtype_to , dtype_from ])
437+ dtype_from = (
438+ from_ .dtype if isinstance (from_ , dpt .usm_ndarray ) else dpt .dtype (from_ )
439+ )
440+
441+ _supported_dtype ([dtype_from , dtype_to ])
438442
439443 return np .can_cast (dtype_from , dtype_to , casting )
440444
@@ -447,7 +451,10 @@ def result_type(*arrays_and_dtypes):
447451 Returns the dtype that results from applying the Type Promotion Rules to \
448452 the arguments.
449453 """
450- dtypes = [dpt .dtype (X ) for X in arrays_and_dtypes ]
454+ dtypes = [
455+ X .dtype if isinstance (X , dpt .usm_ndarray ) else dpt .dtype (X )
456+ for X in arrays_and_dtypes
457+ ]
451458
452459 _supported_dtype (dtypes )
453460
@@ -460,6 +467,8 @@ def iinfo(type):
460467
461468 Returns machine limits for integer data types.
462469 """
470+ if isinstance (type , dpt .usm_ndarray ):
471+ raise TypeError ("Expected dtype type, get {to}." )
463472 _supported_dtype ([dpt .dtype (type )])
464473 return np .iinfo (type )
465474
@@ -470,11 +479,14 @@ def finfo(type):
470479
471480 Returns machine limits for float data types.
472481 """
482+ if isinstance (type , dpt .usm_ndarray ):
483+ raise TypeError ("Expected dtype type, get {to}." )
473484 _supported_dtype ([dpt .dtype (type )])
474485 return np .finfo (type )
475486
476487
477488def _supported_dtype (dtypes ):
478- if not all (dtype .char in "?bBhHiIlLqQefdFD" for dtype in dtypes ):
479- raise ValueError ("Unsupported dtype encountered." )
489+ for dtype in dtypes :
490+ if dtype .char not in "?bBhHiIlLqQefdFD" :
491+ raise ValueError (f"Dpctl doesn't support dtype { dtype } ." )
480492 return True
0 commit comments