Skip to content

Commit 1df7137

Browse files
authored
Fixed const issues seen in user's code (#1044)
1 parent 30f172d commit 1df7137

4 files changed

Lines changed: 83 additions & 39 deletions

File tree

include/matx/core/tensor_utils.h

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -247,21 +247,37 @@ namespace matx
247247
* @param indices indices
248248
* @return Value after broadcasting
249249
*/
250-
template <ElementsPerThread EPT, typename T, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
250+
// Const-qualified RHS fetch
251+
template <ElementsPerThread EPT, typename T, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...> && std::is_const_v<std::remove_reference_t<T>>, bool> = true>
252+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, Is... indices)
253+
{
254+
using OpT = remove_cvref_t<T>;
255+
constexpr int RANK = OpT::Rank();
256+
const OpT &ci = i;
257+
if constexpr (RANK == int(sizeof...(Is)) || RANK == matxNoRank) {
258+
return ci.template operator()<EPT>(indices...);
259+
}
260+
else
261+
{
262+
using seq = offset_sequence_t<sizeof...(Is) - RANK, std::make_index_sequence<RANK>>;
263+
auto tup = cuda::std::make_tuple(indices...);
264+
auto sliced_tup = select_tuple(std::forward<decltype(tup)>(tup), seq{});
265+
return cuda::std::apply([&](auto... args) {
266+
return ci.template operator()<EPT>(args...);
267+
}, sliced_tup);
268+
}
269+
}
270+
271+
// Non-const fetch preserves original behavior (may return refs)
272+
template <ElementsPerThread EPT, typename T, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...> && !std::is_const_v<std::remove_reference_t<T>>, bool> = true>
251273
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, Is... indices)
252274
{
253275
constexpr int RANK = remove_cvref_t<T>::Rank();
254276
if constexpr (RANK == int(sizeof...(Is)) || RANK == matxNoRank) {
255-
// If we're only indexing with the same number of arguments as the rank of the operator, just return operator()
256277
return cuda::std::forward<T>(i).template operator()<EPT>(indices...);
257278
}
258279
else
259280
{
260-
// Otherwise we need to broadcast by constructing a large set of indices
261-
// Construct an integer sequence of the length of the tuple, but only using the last indices. We construct an offset sequence
262-
// to index into the broadcasted dimensions. For example, if T is a 3D tensor and we want to index as a 5D, we take the indices
263-
// {0, 1, 2} we'd normally index with, and add the difference in rank (2), to get {2, 3, 4}. Another way to think of this is it
264-
// simply chops off the first sizeof...(Is) - RANK indices since they're not used for operator().
265281
using seq = offset_sequence_t<sizeof...(Is) - RANK, std::make_index_sequence<RANK>>;
266282
auto tup = cuda::std::make_tuple(indices...);
267283
auto sliced_tup = select_tuple(std::forward<decltype(tup)>(tup), seq{});
@@ -271,25 +287,41 @@ namespace matx
271287
}
272288
}
273289

274-
template <ElementsPerThread EPT, typename T, typename IdxType, size_t N>
290+
template <ElementsPerThread EPT, typename T, typename IdxType, size_t N, std::enable_if_t<std::is_const_v<std::remove_reference_t<T>>, bool> = true>
291+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, const cuda::std::array<IdxType, N> idx)
292+
{
293+
using OpT = remove_cvref_t<T>;
294+
constexpr int RANK = OpT::Rank();
295+
const OpT &ci = i;
296+
if constexpr (RANK == N || RANK == matxNoRank) {
297+
return cuda::std::apply([&ci](auto... args) -> decltype(auto) {
298+
return ci.template operator()<EPT>(args...);
299+
}, idx);
300+
} else {
301+
cuda::std::array<index_t, RANK> nbc_idx;
302+
cuda::std::copy(idx.begin() + (N - RANK), idx.end(), nbc_idx.begin());
303+
return cuda::std::apply([&ci](auto... args) -> decltype(auto) {
304+
return ci.template operator()<EPT>(args...);
305+
}, nbc_idx);
306+
}
307+
}
308+
309+
template <ElementsPerThread EPT, typename T, typename IdxType, size_t N, std::enable_if_t<!std::is_const_v<std::remove_reference_t<T>>, bool> = true>
275310
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, const cuda::std::array<IdxType, N> idx)
276311
{
277312
constexpr int RANK = remove_cvref_t<T>::Rank();
278313
if constexpr (RANK == N || RANK == matxNoRank) {
279-
// If we're only indexing with the same number of arguments as the rank of the operator, just return operator()
280314
return cuda::std::apply([&i](auto... args) -> decltype(auto) {
281315
return cuda::std::forward<T>(i).template operator()<EPT>(args...);
282-
}, idx);
283-
}
284-
else
285-
{
286-
cuda::std::array<index_t, RANK> nbc_idx; // non-broadcast indices
316+
}, idx);
317+
} else {
318+
cuda::std::array<index_t, RANK> nbc_idx;
287319
cuda::std::copy(idx.begin() + (N - RANK), idx.end(), nbc_idx.begin());
288320
return cuda::std::apply([&i](auto... args) -> decltype(auto) {
289321
return cuda::std::forward<T>(i).template operator()<EPT>(args...);
290322
}, nbc_idx);
291323
}
292-
}
324+
}
293325

294326

295327
template <ElementsPerThread EPT, typename T, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
@@ -317,7 +349,8 @@ namespace matx
317349
{
318350
return i;
319351
}
320-
}
352+
}
353+
321354

322355
template <typename T> __MATX_INLINE__ std::string to_short_str() {
323356
if constexpr (!is_complex_v<T>) {

include/matx/operators/binary_operators.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,24 @@ namespace matx
118118
}
119119

120120
template <detail::ElementsPerThread EPT, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
121-
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ decltype(auto) operator()(Is... indices) const
121+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(Is... indices) const
122122
{
123-
auto i1 = get_value<EPT>(in1_, indices...);
124-
auto i2 = get_value<EPT>(in2_, indices...);
123+
// Bind operands as const to ensure RHS value-return semantics for composite ops
124+
const auto &lhs = in1_;
125+
const auto &rhs = in2_;
126+
const auto i1 = get_value<EPT>(lhs, indices...);
127+
const auto i2 = get_value<EPT>(rhs, indices...);
125128
return op_.template operator()<EPT>(i1, i2);
126129
}
127130

128131
template <typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
129-
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ decltype(auto) operator()(Is... indices) const
132+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(Is... indices) const
130133
{
131134
return this->template operator()<detail::ElementsPerThread::ONE>(indices...);
132135
}
133136

134137
template <ElementsPerThread EPT, typename ArrayType, std::enable_if_t<is_std_array_v<ArrayType>, bool> = true>
135-
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const ArrayType &idx) const noexcept
138+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto operator()(const ArrayType &idx) const noexcept
136139
{
137140
return cuda::std::apply([&](auto &&...args) {
138141
return this->operator()<EPT>(args...);

include/matx/operators/concat.h

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,16 @@ namespace matx
9090
}
9191
}
9292

93+
94+
// Non-const path returns references where available (used for LHS writes)
9395
template <ElementsPerThread EPT, int I = 0, int N>
94-
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array<index_t,RANK> &indices) const {
96+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(cuda::std::array<index_t,RANK> &indices) {
9597
if constexpr ( I == N ) {
9698
// This should never happen, but we return a fake value from the first tuple element anyways
97-
const auto &op = cuda::std::get<0>(ops_);
99+
auto &op = cuda::std::get<0>(ops_);
98100
return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { return op.template operator()<EPT>(call_args...); }, indices);
99101
} else {
100-
const auto &op = cuda::std::get<I>(ops_);
102+
auto &op = cuda::std::get<I>(ops_);
101103
auto idx = indices[axis_];
102104
auto size = op.Size(axis_);
103105
// If in range of this operator
@@ -107,30 +109,34 @@ namespace matx
107109
} else {
108110
// otherwise remove this operator and recurse
109111
indices[axis_] -= size;
110-
return GetVal<EPT, I+1, N>(indices);
112+
return get_impl<EPT, I+1, N>(indices);
111113
}
112114
}
113115
}
114116

115-
117+
// Const path: unify scalar return type to value_type to avoid ref/value conflicts
116118
template <ElementsPerThread EPT, int I = 0, int N>
117-
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array<index_t,RANK> &indices) {
119+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto get_impl(cuda::std::array<index_t,RANK> &indices) const {
120+
using return_t = cuda::std::conditional_t<
121+
(EPT == ElementsPerThread::ONE),
122+
value_type,
123+
Vector<value_type, static_cast<index_t>(EPT)>>;
118124
if constexpr ( I == N ) {
119-
// This should never happen, but we return a fake value from the first tuple element anyways
120-
auto &op = cuda::std::get<0>(ops_);
121-
return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { return op.template operator()<EPT>(call_args...); }, indices);
125+
const auto &op = cuda::std::get<0>(ops_);
126+
return cuda::std::apply([&](auto &&...call_args) -> return_t {
127+
return op.template operator()<EPT>(call_args...);
128+
}, indices);
122129
} else {
123-
auto &op = cuda::std::get<I>(ops_);
130+
const auto &op = cuda::std::get<I>(ops_);
124131
auto idx = indices[axis_];
125132
auto size = op.Size(axis_);
126-
// If in range of this operator
127133
if(idx < size) {
128-
// evaluate operator
129-
return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { return op.template operator()<EPT>(call_args...); }, indices);
134+
return cuda::std::apply([&](auto &&...call_args) -> return_t {
135+
return op.template operator()<EPT>(call_args...);
136+
}, indices);
130137
} else {
131-
// otherwise remove this operator and recurse
132138
indices[axis_] -= size;
133-
return GetVal<EPT, I+1, N>(indices);
139+
return get_impl<EPT, I+1, N>(indices);
134140
}
135141
}
136142
}
@@ -140,13 +146,15 @@ namespace matx
140146
{
141147
if constexpr (EPT == ElementsPerThread::ONE) {
142148
cuda::std::array<index_t, sizeof...(Is)> indices = {{is...}};
143-
return GetVal<EPT, 0, sizeof...(Ts)>(indices);
149+
return get_impl<EPT, 0, sizeof...(Ts)>(indices);
144150
}
145151
else {
146152
return Vector<value_type, static_cast<index_t>(EPT)>{};
147153
}
148154
}
149155

156+
157+
150158
template <typename... Is>
151159
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... is) const
152160
{
@@ -158,7 +166,7 @@ namespace matx
158166
{
159167
if constexpr (EPT == ElementsPerThread::ONE) {
160168
cuda::std::array<index_t, sizeof...(Is)> indices = {{is...}};
161-
return GetVal<EPT, 0, sizeof...(Ts)>(indices);
169+
return get_impl<EPT, 0, sizeof...(Ts)>(indices);
162170
}
163171
else {
164172
return Vector<value_type, static_cast<index_t>(EPT)>{};

include/matx/operators/set.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class set : public BaseOp<set<T, Op>> {
109109
// functions, so we have to make a separate one.
110110
template <ElementsPerThread EPT, typename... Ts>
111111
__MATX_DEVICE__ __MATX_HOST__ inline auto _internal_mapply(Ts&&... args) const noexcept {
112-
auto r = detail::get_value<EPT>(op_, args...);
112+
const auto r = detail::get_value<EPT>(op_, args...);
113113
out_(args...) = r;
114114
return r;
115115
}

0 commit comments

Comments
 (0)