|
24 | 24 | #ifndef TVM_RUNTIME_THREADING_BACKEND_H_ |
25 | 25 | #define TVM_RUNTIME_THREADING_BACKEND_H_ |
26 | 26 |
|
| 27 | +#include <tvm/runtime/c_backend_api.h> |
| 28 | + |
| 29 | +#include <algorithm> |
27 | 30 | #include <functional> |
28 | 31 | #include <memory> |
29 | 32 | #include <vector> |
@@ -147,6 +150,73 @@ TVM_DLL void Configure(tvm::runtime::threading::ThreadGroup::AffinityMode mode, |
147 | 150 | int32_t NumThreads(); |
148 | 151 |
|
149 | 152 | } // namespace threading |
| 153 | + |
| 154 | +/*! |
| 155 | + * \brief Execute the given lambda function in parallel with |
| 156 | + * threading backend in TVM. |
| 157 | + * \tparam T The type of the lambda: "void (int i)". |
| 158 | + * \param flambda The lambda to be executed in parallel. |
| 159 | + * It should have the signature "void (int i)". |
| 160 | + * \param begin The start index of this parallel loop (inclusive). |
| 161 | + * \param end The end index of this parallel loop (exclusive). |
| 162 | + * \example |
| 163 | + * |
| 164 | + * The for loop |
| 165 | + * for (int i = 0; i < 10; i++) { |
| 166 | + * a[i] = i; |
| 167 | + * } |
| 168 | + * should work the same as: |
| 169 | + * parallel_for_with_threading_backend([&a](int i) { |
| 170 | + * a[i] = i; |
| 171 | + * }, 0, 10); |
| 172 | + */ |
| 173 | +template <typename T> |
| 174 | +inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end); |
| 175 | + |
| 176 | +namespace detail { |
| 177 | + |
| 178 | +// The detailed implementation of `parallel_for_with_threading_backend`. |
| 179 | +// To avoid template expansion, the implementation cannot be placed |
| 180 | +// in .cc files. |
| 181 | + |
| 182 | +template <typename T> |
| 183 | +struct ParallelForWithThreadingBackendLambdaInvoker { |
| 184 | + static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, void* cdata) { |
| 185 | + int num_task = penv->num_task; |
| 186 | + // Convert void* back to lambda type. |
| 187 | + T* lambda_ptr = static_cast<T*>(cdata); |
| 188 | + // Invoke the lambda with the task id (thread id). |
| 189 | + (*lambda_ptr)(task_id, num_task); |
| 190 | + return 0; |
| 191 | + } |
| 192 | +}; |
| 193 | + |
| 194 | +template <typename T> |
| 195 | +inline void parallel_launch_with_threading_backend(T flambda) { |
| 196 | + // Launch the lambda by passing its address. |
| 197 | + void* cdata = &flambda; |
| 198 | + TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker<T>::TVMParallelLambdaInvoke, |
| 199 | + cdata, /*num_task=*/0); |
| 200 | +} |
| 201 | + |
| 202 | +} // namespace detail |
| 203 | + |
| 204 | +template <typename T> |
| 205 | +inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) { |
| 206 | + auto flaunch = [begin, end, flambda](int task_id, int num_task) { |
| 207 | + // For each thread, do static division and call into flambda. |
| 208 | + int64_t total_len = end - begin; |
| 209 | + int64_t step = (total_len + num_task - 1) / num_task; |
| 210 | + int64_t local_begin = std::min(begin + step * task_id, end); |
| 211 | + int64_t local_end = std::min(local_begin + step, end); |
| 212 | + for (int64_t i = local_begin; i < local_end; ++i) { |
| 213 | + flambda(i); |
| 214 | + } |
| 215 | + }; |
| 216 | + // Launch with all threads. |
| 217 | + detail::parallel_launch_with_threading_backend(flaunch); |
| 218 | +} |
| 219 | + |
150 | 220 | } // namespace runtime |
151 | 221 | } // namespace tvm |
152 | 222 |
|
|
0 commit comments