1515# limitations under the License.
1616
1717import contextlib
18+ import itertools
1819import operator
1920
2021import numpy as np
@@ -223,10 +224,11 @@ def print_options(*args, **kwargs):
223224
224225
225226def _nd_corners (arr_in , edge_items ):
226- arr_ndim = arr_in .ndim
227+ _shape = arr_in .shape
228+ max_shape = 2 * edge_items + 1
227229 res_shape = tuple (
228- 2 * edge_items if arr_in . shape [i ] > 2 * edge_items else arr_in . shape [i ]
229- for i in range (arr_ndim )
230+ max_shape if _shape [i ] > max_shape else _shape [i ]
231+ for i in range (arr_in . ndim )
230232 )
231233
232234 arr_out = dpt .empty (
@@ -236,29 +238,27 @@ def _nd_corners(arr_in, edge_items):
236238 sycl_queue = arr_in .sycl_queue ,
237239 )
238240
241+ blocks = []
242+ for i in range (len (_shape )):
243+ if _shape [i ] > max_shape :
244+ blocks .append (
245+ (
246+ np .s_ [:edge_items ],
247+ np .s_ [- edge_items :],
248+ )
249+ )
250+ else :
251+ blocks .append ((np .s_ [:],))
252+
239253 hev_list = []
240- for corner in range (arr_ndim ** 2 ):
241- slices = ()
242- tmp = bin (corner ).replace ("0b" , "" ).zfill (arr_ndim )
243-
244- for dim in reversed (range (arr_ndim )):
245- if arr_in .shape [dim ] < 2 * edge_items :
246- slices = (np .s_ [:],) + slices
247- else :
248- ind = (- 1 ) ** int (tmp [dim ]) * edge_items
249- if ind < 0 :
250- slices = (np .s_ [- edge_items ::],) + slices
251- else :
252- slices = (np .s_ [:edge_items :],) + slices
254+ for slc in itertools .product (* blocks ):
253255 hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
254- src = arr_in [slices ],
255- dst = arr_out [slices ],
256- sycl_queue = arr_in .sycl_queue ,
256+ src = arr_in [slc ], dst = arr_out [slc ], sycl_queue = arr_in .sycl_queue
257257 )
258258 hev_list .append (hev )
259259
260260 dpctl .SyclEvent .wait_for (hev_list )
261- return arr_out
261+ return dpt . asnumpy ( arr_out )
262262
263263
264264def usm_ndarray_str (
@@ -365,8 +365,7 @@ def usm_ndarray_str(
365365 edge_items = options ["edgeitems" ]
366366
367367 if x .size > threshold :
368- # need edge_items + 1 elements for np.array2string to abbreviate
369- data = dpt .asnumpy (_nd_corners (x , edge_items + 1 ))
368+ data = _nd_corners (x , edge_items )
370369 options ["threshold" ] = 0
371370 else :
372371 data = dpt .asnumpy (x )
0 commit comments