Skip to content

Commit 748882a

Browse files
authored
[Runtime] Parallel-for with threading backend (#16133)
This PR introduces the runtime parallel-for helper function in C++ with the threading backend in TVM. Right now the existing [parallel-for](https://github.com/apache/tvm/blob/bd67d2e5ebde1aec18bcfa74c087516579bda1ae/include/tvm/support/parallel_for.h#L48-L68) in TVM is not thread persistent, in which case we cannot get persistent TLS for each thread. The introduced parallel-for-with-threading-backend function leverages the threading backend in TVM and persists threads.
1 parent bd67d2e commit 748882a

2 files changed

Lines changed: 79 additions & 0 deletions

File tree

include/tvm/runtime/threading_backend.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
2525
#define TVM_RUNTIME_THREADING_BACKEND_H_
2626

27+
#include <tvm/runtime/c_backend_api.h>
28+
29+
#include <algorithm>
2730
#include <functional>
2831
#include <memory>
2932
#include <vector>
@@ -147,6 +150,73 @@ TVM_DLL void Configure(tvm::runtime::threading::ThreadGroup::AffinityMode mode,
147150
int32_t NumThreads();
148151

149152
} // namespace threading
153+
154+
/*!
155+
* \brief Execute the given lambda function in parallel with
156+
* threading backend in TVM.
157+
* \tparam T The type of the lambda: "void (int i)".
158+
* \param flambda The lambda to be executed in parallel.
159+
* It should have the signature "void (int i)".
160+
* \param begin The start index of this parallel loop (inclusive).
161+
* \param end The end index of this parallel loop (exclusive).
162+
* \example
163+
*
164+
* The for loop
165+
* for (int i = 0; i < 10; i++) {
166+
* a[i] = i;
167+
* }
168+
* should work the same as:
169+
* parallel_for_with_threading_backend([&a](int i) {
170+
* a[i] = i;
171+
* }, 0, 10);
172+
*/
173+
template <typename T>
174+
inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end);
175+
176+
namespace detail {
177+
178+
// The detailed implementation of `parallel_for_with_threading_backend`.
179+
// To avoid template expansion, the implementation cannot be placed
180+
// in .cc files.
181+
182+
template <typename T>
183+
struct ParallelForWithThreadingBackendLambdaInvoker {
184+
static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, void* cdata) {
185+
int num_task = penv->num_task;
186+
// Convert void* back to lambda type.
187+
T* lambda_ptr = static_cast<T*>(cdata);
188+
// Invoke the lambda with the task id (thread id).
189+
(*lambda_ptr)(task_id, num_task);
190+
return 0;
191+
}
192+
};
193+
194+
template <typename T>
195+
inline void parallel_launch_with_threading_backend(T flambda) {
196+
// Launch the lambda by passing its address.
197+
void* cdata = &flambda;
198+
TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker<T>::TVMParallelLambdaInvoke,
199+
cdata, /*num_task=*/0);
200+
}
201+
202+
} // namespace detail
203+
204+
template <typename T>
205+
inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) {
206+
auto flaunch = [begin, end, flambda](int task_id, int num_task) {
207+
// For each thread, do static division and call into flambda.
208+
int64_t total_len = end - begin;
209+
int64_t step = (total_len + num_task - 1) / num_task;
210+
int64_t local_begin = std::min(begin + step * task_id, end);
211+
int64_t local_end = std::min(local_begin + step, end);
212+
for (int64_t i = local_begin; i < local_end; ++i) {
213+
flambda(i);
214+
}
215+
};
216+
// Launch with all threads.
217+
detail::parallel_launch_with_threading_backend(flaunch);
218+
}
219+
150220
} // namespace runtime
151221
} // namespace tvm
152222

tests/cpp/threading_backend_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,12 @@ TEST(ThreadingBackend, TVMBackendAffinityConfigure) {
185185
t->join();
186186
}
187187
}
188+
189+
TEST(ThreadingBackend, TVMBackendParallelForWithThreadingBackend) {
190+
int n = 100;
191+
std::vector<int> vec(/*size=*/n, /*value=*/0);
192+
tvm::runtime::parallel_for_with_threading_backend([&vec](int i) { vec[i] = i; }, 0, n);
193+
for (int i = 0; i < n; ++i) {
194+
EXPECT_EQ(vec[i], i);
195+
}
196+
}

0 commit comments

Comments
 (0)