diff --git a/dpctl/tensor/_types.pxi b/dpctl/tensor/_types.pxi index f6682f7b66..8ffae77572 100644 --- a/dpctl/tensor/_types.pxi +++ b/dpctl/tensor/_types.pxi @@ -102,7 +102,7 @@ cdef str _make_typestr(int typenum): return type_to_str[typenum] + str(type_bytesize(typenum)) -cdef int typenum_from_format(str s) except *: +cdef int typenum_from_format(str s): """ Internal utility to convert string describing type format @@ -110,20 +110,21 @@ cdef int typenum_from_format(str s) except *: Shortcuts for formats are i, u, d, D """ if not s: - raise TypeError("Format string '" + s + "' cannot be empty.") + return -1 try: dt = np.dtype(s) - except Exception as e: - raise TypeError("Format '" + s + "' is not understood.") from e + except Exception: + return -1 if (dt.byteorder == ">"): - raise TypeError("Format '" + s + "' can only have native byteorder.") + return -2 return dt.num + cdef int descr_to_typenum(object dtype): "Returns typenum for argumentd dtype that has attribute descr, assumed numpy.dtype" obj = getattr(dtype, 'descr') if (not isinstance(obj, list) or len(obj) != 1): - return -1 + return -1 # token for ValueError obj = obj[0] if (not isinstance(obj, tuple) or len(obj) != 2 or obj[0]): return -1 @@ -133,7 +134,7 @@ cdef int descr_to_typenum(object dtype): return typenum_from_format(obj) -cdef int dtype_to_typenum(dtype) except *: +cdef int dtype_to_typenum(dtype): if isinstance(dtype, str): return typenum_from_format(dtype) elif isinstance(dtype, bytes): @@ -143,9 +144,11 @@ cdef int dtype_to_typenum(dtype) except *: else: try: dt = np.dtype(dtype) - if hasattr(dt, 'descr'): - return descr_to_typenum(dt) - else: - return -1 + except TypeError: + return -3 except Exception: return -1 + if hasattr(dt, 'descr'): + return descr_to_typenum(dt) + else: + return -3 # token for TypeError diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 7ddf191296..62ebc83a86 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -188,9 +188,15 @@ cdef class usm_ndarray: raise TypeError("Argument shape must be a list or a tuple.") nd = len(shape) typenum = dtype_to_typenum(dtype) + if (typenum < 0): + if typenum == -2: + raise ValueError("Data type '" + str(dtype) + "' can only have native byteorder.") + elif typenum == -1: + raise ValueError("Data type '" + str(dtype) + "' is not understood.") + raise TypeError(f"Expected string or a dtype object, got {type(dtype)}") itemsize = type_bytesize(typenum) if (itemsize < 1): - raise TypeError("dtype=" + dtype + " is not supported.") + raise TypeError("dtype=" + np.dtype(dtype).name + " is not supported.") # allocate host C-arrays for shape, strides err = _from_input_shape_strides( nd, shape, strides, itemsize, ord(order), diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 9e24847f28..485451623f 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -95,13 +95,23 @@ def test_dtypes(dtype): assert expected_fmt == actual_fmt -@pytest.mark.parametrize("dtype", ["", ">f4", "invalid", 123]) +@pytest.mark.parametrize( + "dtype", + [ + "", + ">f4", + "invalid", + 123, + np.dtype(">f4"), + np.dtype([("a", ">f4"), ("b", "i4")]), + ], +) def test_dtypes_invalid(dtype): with pytest.raises((TypeError, ValueError)): dpt.usm_ndarray((1,), dtype=dtype) -@pytest.mark.parametrize("dt", ["d", "c16"]) +@pytest.mark.parametrize("dt", ["f", "c8"]) def test_properties(dt): """ Test that properties execute @@ -1181,6 +1191,11 @@ def test_empty_like(dt, usm_kind): assert X.sycl_queue == Y.sycl_queue +def test_empty_unexpected_data_type(): + with pytest.raises(TypeError): + dpt.empty(1, dtype=np.object_) + + @pytest.mark.parametrize( "dt", _all_dtypes,