@@ -247,21 +247,37 @@ namespace matx
247247 * @param indices indices
248248 * @return Value after broadcasting
249249 */
250- template <ElementsPerThread EPT, typename T, typename ... Is, std::enable_if_t <std::conjunction_v<std::is_integral<Is>...>, bool > = true >
250+ // Const-qualified RHS fetch
251+ template <ElementsPerThread EPT, typename T, typename ... Is, std::enable_if_t <std::conjunction_v<std::is_integral<Is>...> && std::is_const_v<std::remove_reference_t <T>>, bool > = true >
252+ __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype (auto ) get_matx_value(T &&i, Is... indices)
253+ {
254+ using OpT = remove_cvref_t <T>;
255+ constexpr int RANK = OpT::Rank ();
256+ const OpT &ci = i;
257+ if constexpr (RANK == int (sizeof ...(Is)) || RANK == matxNoRank) {
258+ return ci.template operator ()<EPT>(indices...);
259+ }
260+ else
261+ {
262+ using seq = offset_sequence_t <sizeof ...(Is) - RANK, std::make_index_sequence<RANK>>;
263+ auto tup = cuda::std::make_tuple (indices...);
264+ auto sliced_tup = select_tuple (std::forward<decltype (tup)>(tup), seq{});
265+ return cuda::std::apply ([&](auto ... args) {
266+ return ci.template operator ()<EPT>(args...);
267+ }, sliced_tup);
268+ }
269+ }
270+
271+ // Non-const fetch preserves original behavior (may return refs)
272+ template <ElementsPerThread EPT, typename T, typename ... Is, std::enable_if_t <std::conjunction_v<std::is_integral<Is>...> && !std::is_const_v<std::remove_reference_t <T>>, bool > = true >
251273 __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype (auto ) get_matx_value(T &&i, Is... indices)
252274 {
253275 constexpr int RANK = remove_cvref_t <T>::Rank ();
254276 if constexpr (RANK == int (sizeof ...(Is)) || RANK == matxNoRank) {
255- // If we're only indexing with the same number of arguments as the rank of the operator, just return operator()
256277 return cuda::std::forward<T>(i).template operator ()<EPT>(indices...);
257278 }
258279 else
259280 {
260- // Otherwise we need to broadcast by constructing a large set of indices
261- // Construct an integer sequence of the length of the tuple, but only using the last indices. We construct an offset sequence
262- // to index into the broadcasted dimensions. For example, if T is a 3D tensor and we want to index as a 5D, we take the indices
263- // {0, 1, 2} we'd normally index with, and add the difference in rank (2), to get {2, 3, 4}. Another way to think of this is it
264- // simply chops off the first sizeof...(Is) - RANK indices since they're not used for operator().
265281 using seq = offset_sequence_t <sizeof ...(Is) - RANK, std::make_index_sequence<RANK>>;
266282 auto tup = cuda::std::make_tuple (indices...);
267283 auto sliced_tup = select_tuple (std::forward<decltype (tup)>(tup), seq{});
@@ -271,25 +287,41 @@ namespace matx
271287 }
272288 }
273289
274- template <ElementsPerThread EPT, typename T, typename IdxType, size_t N>
290+ template <ElementsPerThread EPT, typename T, typename IdxType, size_t N, std::enable_if_t <std::is_const_v<std::remove_reference_t <T>>, bool > = true >
291+ __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype (auto ) get_matx_value(T &&i, const cuda::std::array<IdxType, N> idx)
292+ {
293+ using OpT = remove_cvref_t <T>;
294+ constexpr int RANK = OpT::Rank ();
295+ const OpT &ci = i;
296+ if constexpr (RANK == N || RANK == matxNoRank) {
297+ return cuda::std::apply ([&ci](auto ... args) -> decltype (auto ) {
298+ return ci.template operator ()<EPT>(args...);
299+ }, idx);
300+ } else {
301+ cuda::std::array<index_t , RANK> nbc_idx;
302+ cuda::std::copy (idx.begin () + (N - RANK), idx.end (), nbc_idx.begin ());
303+ return cuda::std::apply ([&ci](auto ... args) -> decltype (auto ) {
304+ return ci.template operator ()<EPT>(args...);
305+ }, nbc_idx);
306+ }
307+ }
308+
309+ template <ElementsPerThread EPT, typename T, typename IdxType, size_t N, std::enable_if_t <!std::is_const_v<std::remove_reference_t <T>>, bool > = true >
275310 __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype (auto ) get_matx_value(T &&i, const cuda::std::array<IdxType, N> idx)
276311 {
277312 constexpr int RANK = remove_cvref_t <T>::Rank ();
278313 if constexpr (RANK == N || RANK == matxNoRank) {
279- // If we're only indexing with the same number of arguments as the rank of the operator, just return operator()
280314 return cuda::std::apply ([&i](auto ... args) -> decltype (auto ) {
281315 return cuda::std::forward<T>(i).template operator ()<EPT>(args...);
282- }, idx);
283- }
284- else
285- {
286- cuda::std::array<index_t , RANK> nbc_idx; // non-broadcast indices
316+ }, idx);
317+ } else {
318+ cuda::std::array<index_t , RANK> nbc_idx;
287319 cuda::std::copy (idx.begin () + (N - RANK), idx.end (), nbc_idx.begin ());
288320 return cuda::std::apply ([&i](auto ... args) -> decltype (auto ) {
289321 return cuda::std::forward<T>(i).template operator ()<EPT>(args...);
290322 }, nbc_idx);
291323 }
292- }
324+ }
293325
294326
295327 template <ElementsPerThread EPT, typename T, typename ... Is, std::enable_if_t <std::conjunction_v<std::is_integral<Is>...>, bool > = true >
@@ -317,7 +349,8 @@ namespace matx
317349 {
318350 return i;
319351 }
320- }
352+ }
353+
321354
322355 template <typename T> __MATX_INLINE__ std::string to_short_str () {
323356 if constexpr (!is_complex_v<T>) {
0 commit comments