File tree Expand file tree Collapse file tree 1 file changed +17
-8
lines changed
Expand file tree Collapse file tree 1 file changed +17
-8
lines changed Original file line number Diff line number Diff line change 1515# limitations under the License.
1616
1717import contextlib
18+ import itertools
1819import operator
1920
2021import numpy as np
@@ -236,20 +237,28 @@ def _nd_corners(arr_in, edge_items):
236237 sycl_queue = arr_in .sycl_queue ,
237238 )
238239
240+ split_dim = sum (
241+ arr_in .shape [dim ] > 2 * edge_items for dim in range (arr_ndim )
242+ )
243+ blocks = [
244+ list (ele ) for ele in list (itertools .product (* [[0 , 1 ]] * split_dim ))
245+ ]
246+ for dim in range (arr_ndim ):
247+ if arr_in .shape [dim ] <= 2 * edge_items :
248+ for blk in blocks :
249+ blk .insert (dim , 0 )
250+
239251 hev_list = []
240- for corner in range ( arr_ndim ** 2 ) :
252+ for blk in blocks :
241253 slices = ()
242- tmp = bin (corner ).replace ("0b" , "" ).zfill (arr_ndim )
243-
244254 for dim in reversed (range (arr_ndim )):
245- if arr_in .shape [dim ] < 2 * edge_items :
246- slices = (np .s_ [:],) + slices
247- else :
248- ind = (- 1 ) ** int (tmp [dim ]) * edge_items
249- if ind < 0 :
255+ if arr_in .shape [dim ] > 2 * edge_items :
256+ if blk [dim ] == 1 :
250257 slices = (np .s_ [- edge_items ::],) + slices
251258 else :
252259 slices = (np .s_ [:edge_items :],) + slices
260+ else :
261+ slices = (np .s_ [:],) + slices
253262 hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
254263 src = arr_in [slices ],
255264 dst = arr_out [slices ],
You can’t perform that action at this time.
0 commit comments