@@ -56,9 +56,9 @@ struct MaskedExtractStridedFunctor
5656 char *dst_data_p,
5757 size_t orthog_iter_size,
5858 size_t masked_iter_size,
59- OrthogIndexerT orthog_src_dst_indexer_,
60- MaskedSrcIndexerT masked_src_indexer_,
61- MaskedDstIndexerT masked_dst_indexer_)
59+ const OrthogIndexerT & orthog_src_dst_indexer_,
60+ const MaskedSrcIndexerT & masked_src_indexer_,
61+ const MaskedDstIndexerT & masked_dst_indexer_)
6262 : src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
6363 orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
6464 orthog_src_dst_indexer(orthog_src_dst_indexer_),
@@ -106,13 +106,14 @@ struct MaskedExtractStridedFunctor
106106 char *dst_cp = nullptr ;
107107 size_t orthog_nelems = 0 ;
108108 size_t masked_nelems = 0 ;
109- OrthogIndexerT
110- orthog_src_dst_indexer; // has nd, shape, src_strides, dst_strides for
111- // dimensions that ARE NOT masked
112- MaskedSrcIndexerT masked_src_indexer; // has nd, shape, src_strides for
113- // dimensions that ARE masked
114- MaskedDstIndexerT
115- masked_dst_indexer; // has 1, dst_strides for dimensions that ARE masked
109+ // has nd, shape, src_strides, dst_strides for
110+ // dimensions that ARE NOT masked
111+ const OrthogIndexerT orthog_src_dst_indexer;
112+ // has nd, shape, src_strides for
113+ // dimensions that ARE masked
114+ const MaskedSrcIndexerT masked_src_indexer;
115+ // has 1, dst_strides for dimensions that ARE masked
116+ const MaskedDstIndexerT masked_dst_indexer;
116117};
117118
118119template <typename OrthogIndexerT,
@@ -127,9 +128,9 @@ struct MaskedPlaceStridedFunctor
127128 const char *rhs_data_p,
128129 size_t orthog_iter_size,
129130 size_t masked_iter_size,
130- OrthogIndexerT orthog_dst_rhs_indexer_,
131- MaskedDstIndexerT masked_dst_indexer_,
132- MaskedRhsIndexerT masked_rhs_indexer_)
131+ const OrthogIndexerT & orthog_dst_rhs_indexer_,
132+ const MaskedDstIndexerT & masked_dst_indexer_,
133+ const MaskedRhsIndexerT & masked_rhs_indexer_)
133134 : dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
134135 orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
135136 orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
@@ -177,13 +178,14 @@ struct MaskedPlaceStridedFunctor
177178 const char *rhs_cp = nullptr ;
178179 size_t orthog_nelems = 0 ;
179180 size_t masked_nelems = 0 ;
180- OrthogIndexerT
181- orthog_dst_rhs_indexer; // has nd, shape, dst_strides, rhs_strides for
182- // dimensions that ARE NOT masked
183- MaskedDstIndexerT masked_dst_indexer; // has nd, shape, dst_strides for
184- // dimensions that ARE masked
185- MaskedRhsIndexerT
186- masked_rhs_indexer; // has 1, rhs_strides for dimensions that ARE masked
181+ // has nd, shape, dst_strides, rhs_strides for
182+ // dimensions that ARE NOT masked
183+ const OrthogIndexerT orthog_dst_rhs_indexer;
184+ // has nd, shape, dst_strides for
185+ // dimensions that ARE masked
186+ const MaskedDstIndexerT masked_dst_indexer;
187+ // has 1, rhs_strides for dimensions that ARE masked
188+ const MaskedRhsIndexerT masked_rhs_indexer;
187189};
188190
189191// ======= Masked extraction ================================
@@ -226,12 +228,12 @@ sycl::event masked_extract_all_slices_strided_impl(
226228 // using StridedIndexer;
227229 // using TwoZeroOffsets_Indexer;
228230
229- TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
231+ constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
230232
231233 /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
232234 * *_packed_shape_strides) */
233- StridedIndexer masked_src_indexer (nd, 0 , packed_src_shape_strides);
234- Strided1DIndexer masked_dst_indexer (0 , dst_size, dst_stride);
235+ const StridedIndexer masked_src_indexer (nd, 0 , packed_src_shape_strides);
236+ const Strided1DIndexer masked_dst_indexer (0 , dst_size, dst_stride);
235237
236238 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
237239 cgh.depends_on (depends);
@@ -283,17 +285,16 @@ sycl::event masked_extract_some_slices_strided_impl(
283285 const char *cumsum_p,
284286 char *dst_p,
285287 int orthog_nd,
286- const ssize_t
287- *packed_ortho_src_dst_shape_strides, // [ortho_shape, ortho_src_strides,
288- // ortho_dst_strides], length
289- // 3*ortho_nd
288+ // [ortho_shape, ortho_src_strides, // ortho_dst_strides],
289+ // length 3*ortho_nd
290+ const ssize_t *packed_ortho_src_dst_shape_strides,
290291 ssize_t ortho_src_offset,
291292 ssize_t ortho_dst_offset,
292293 int masked_nd,
293- const ssize_t *packed_masked_src_shape_strides, // [masked_src_shape,
294- // masked_src_strides],
295- // length 2*masked_nd
296- ssize_t masked_dst_size, // mask_dst is 1D
294+ // [masked_src_shape, masked_src_strides] ,
295+ // length 2*masked_nd, mask_dst is 1D
296+ const ssize_t *packed_masked_src_shape_strides,
297+ ssize_t masked_dst_size,
297298 ssize_t masked_dst_stride,
298299 const std::vector<sycl::event> &depends = {})
299300{
@@ -302,13 +303,14 @@ sycl::event masked_extract_some_slices_strided_impl(
302303 // using StridedIndexer;
303304 // using TwoOffsets_StridedIndexer;
304305
305- TwoOffsets_StridedIndexer orthog_src_dst_indexer{
306+ const TwoOffsets_StridedIndexer orthog_src_dst_indexer{
306307 orthog_nd, ortho_src_offset, ortho_dst_offset,
307308 packed_ortho_src_dst_shape_strides};
308309
309- StridedIndexer masked_src_indexer{masked_nd, 0 ,
310- packed_masked_src_shape_strides};
311- Strided1DIndexer masked_dst_indexer{0 , masked_dst_size, masked_dst_stride};
310+ const StridedIndexer masked_src_indexer{masked_nd, 0 ,
311+ packed_masked_src_shape_strides};
312+ const Strided1DIndexer masked_dst_indexer{0 , masked_dst_size,
313+ masked_dst_stride};
312314
313315 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
314316 cgh.depends_on (depends);
@@ -403,12 +405,12 @@ sycl::event masked_place_all_slices_strided_impl(
403405 ssize_t rhs_stride,
404406 const std::vector<sycl::event> &depends = {})
405407{
406- TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
408+ constexpr TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
407409
408410 /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
409411 * *_packed_shape_strides) */
410- StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
411- Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
412+ const StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
413+ const Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
412414
413415 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
414416 cgh.depends_on (depends);
@@ -460,30 +462,29 @@ sycl::event masked_place_some_slices_strided_impl(
460462 const char *cumsum_p,
461463 const char *rhs_p,
462464 int orthog_nd,
463- const ssize_t
464- *packed_ortho_dst_rhs_shape_strides, // [ortho_shape, ortho_dst_strides,
465- // ortho_rhs_strides], length
466- // 3*ortho_nd
465+ // [ortho_shape, ortho_dst_strides, ortho_rhs_strides],
466+ // length 3*ortho_nd
467+ const ssize_t *packed_ortho_dst_rhs_shape_strides,
467468 ssize_t ortho_dst_offset,
468469 ssize_t ortho_rhs_offset,
469470 int masked_nd,
470- const ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape,
471- // masked_dst_strides],
472- // length 2*masked_nd
473- ssize_t masked_rhs_size, // mask_dst is 1D
471+ // [masked_dst_shape, masked_dst_strides] ,
472+ // length 2*masked_nd, mask_dst is 1D
473+ const ssize_t *packed_masked_dst_shape_strides,
474+ ssize_t masked_rhs_size,
474475 ssize_t masked_rhs_stride,
475476 const std::vector<sycl::event> &depends = {})
476477{
477- TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
478+ const TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
478479 orthog_nd, ortho_dst_offset, ortho_rhs_offset,
479480 packed_ortho_dst_rhs_shape_strides};
480481
481482 /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
482483 * *_packed_shape_strides) */
483- StridedIndexer masked_dst_indexer{masked_nd, 0 ,
484- packed_masked_dst_shape_strides};
485- Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
486- masked_rhs_stride};
484+ const StridedIndexer masked_dst_indexer{masked_nd, 0 ,
485+ packed_masked_dst_shape_strides};
486+ const Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
487+ masked_rhs_stride};
487488
488489 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
489490 cgh.depends_on (depends);
0 commit comments