diff --git a/dpctl/tensor/_print.py b/dpctl/tensor/_print.py index f1e20a12c4..914f555a77 100644 --- a/dpctl/tensor/_print.py +++ b/dpctl/tensor/_print.py @@ -316,7 +316,11 @@ def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None): dtype_str = "dtype={}".format(x.dtype.name) bottom_len = len(s) - (s.rfind("\n") + 1) next_line = bottom_len + len(dtype_str) + 1 > line_width - dtype_str = ",\n" + dtype_str if next_line else ", " + dtype_str + dtype_str = ( + ",\n" + " " * len(prefix) + dtype_str + if next_line + else ", " + dtype_str + ) else: dtype_str = "" diff --git a/dpctl/tests/test_usm_ndarray_print.py b/dpctl/tests/test_usm_ndarray_print.py index 47e4910921..05a4a2b8a9 100644 --- a/dpctl/tests/test_usm_ndarray_print.py +++ b/dpctl/tests/test_usm_ndarray_print.py @@ -211,6 +211,16 @@ def test_print_repr(self): x = dpt.arange(4, dtype="i4", sycl_queue=q) assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)" + dpt.set_print_options(linewidth=1) + np.testing.assert_equal( + repr(x), + "usm_ndarray([0," + "\n 1," + "\n 2," + "\n 3]," + "\n dtype=int32)", + ) + def test_print_repr_abbreviated(self): q = get_queue_or_skip() @@ -237,6 +247,19 @@ def test_print_repr_abbreviated(self): "\n [6, ..., 8]], dtype=int32)", ) + dpt.set_print_options(linewidth=1) + np.testing.assert_equal( + repr(y), + "usm_ndarray([[0," + "\n ...," + "\n 2]," + "\n ...," + "\n [6," + "\n ...," + "\n 8]]," + "\n dtype=int32)", + ) + @pytest.mark.parametrize( "dtype", [