Skip to content

Commit 0e9d602

Browse files
Fix repeat() to avoid non-0D scalar conversion
1 parent 0797836 commit 0e9d602

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,12 @@ def repeat(x, repeats, /, *, axis=None):
829829
if repeats.size == 1:
830830
scalar = True
831831
# bring the single element to the host
832-
repeats = int(repeats)
832+
if repeats.ndim == 0:
833+
repeats = int(repeats)
834+
else:
835+
# Get the single element explicitly
836+
# since non-0D arrays can not be converted to scalars
837+
repeats = int(repeats[0])
833838
if repeats < 0:
834839
raise ValueError("`repeats` elements must be positive")
835840
else:

0 commit comments

Comments
 (0)