|
22 | 22 | * \brief Native threading backend |
23 | 23 | */ |
24 | 24 | #include <tvm/runtime/logging.h> |
| 25 | +#include <tvm/runtime/registry.h> |
25 | 26 | #include <tvm/runtime/threading_backend.h> |
26 | 27 |
|
27 | 28 | #if defined(__linux__) || defined(__ANDROID__) |
@@ -106,6 +107,39 @@ class QuRTThread { |
106 | 107 | void* stack_ = nullptr; |
107 | 108 | }; |
108 | 109 | #endif // __hexagon__ |
| 110 | + |
| 111 | +// This is a common function used to set thread affinity. |
| 112 | +void SetThreadAffinity(std::thread::native_handle_type thread, |
| 113 | + const std::vector<unsigned int>& ids) { |
| 114 | +#if defined(__linux__) || defined(__ANDROID__) |
| 115 | + if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) { |
| 116 | + thread = pthread_self(); |
| 117 | + } |
| 118 | + cpu_set_t cpuset; |
| 119 | + CPU_ZERO(&cpuset); |
| 120 | + for (auto id : ids) { |
| 121 | + CPU_SET(id, &cpuset); |
| 122 | + } |
| 123 | +#if defined(__ANDROID__) |
| 124 | +#if __ANDROID_API__ >= 21 |
| 125 | + pid_t tid = pthread_gettid_np(thread); |
| 126 | +#else |
| 127 | + typedef struct { |
| 128 | + void* next; |
| 129 | + void* pred; |
| 130 | + pid_t tid; |
| 131 | + } pthread_internal; |
| 132 | + pid_t tid = reinterpret_cast<pthread_internal*>(thread)->tid; |
| 133 | +#endif |
| 134 | + if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) { |
| 135 | + LOG(WARNING) << "sched_setaffinity failed"; |
| 136 | + } |
| 137 | +#else |
| 138 | + pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); |
| 139 | +#endif |
| 140 | +#endif |
| 141 | +} |
| 142 | + |
109 | 143 | thread_local int max_concurrency = 0; |
110 | 144 | class ThreadGroup::Impl { |
111 | 145 | public: |
@@ -158,37 +192,6 @@ class ThreadGroup::Impl { |
158 | 192 | } |
159 | 193 |
|
160 | 194 | private: |
161 | | - void SetThreadAffinity(std::thread::native_handle_type thread, |
162 | | - const std::vector<unsigned int>& ids) { |
163 | | -#if defined(__linux__) || defined(__ANDROID__) |
164 | | - if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) { |
165 | | - thread = pthread_self(); |
166 | | - } |
167 | | - cpu_set_t cpuset; |
168 | | - CPU_ZERO(&cpuset); |
169 | | - for (auto id : ids) { |
170 | | - CPU_SET(id, &cpuset); |
171 | | - } |
172 | | -#if defined(__ANDROID__) |
173 | | -#if __ANDROID_API__ >= 21 |
174 | | - pid_t tid = pthread_gettid_np(thread); |
175 | | -#else |
176 | | - typedef struct { |
177 | | - void* next; |
178 | | - void* pred; |
179 | | - pid_t tid; |
180 | | - } pthread_internal; |
181 | | - pid_t tid = reinterpret_cast<pthread_internal*>(thread)->tid; |
182 | | -#endif |
183 | | - if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) { |
184 | | - LOG(WARNING) << "sched_setaffinity failed"; |
185 | | - } |
186 | | -#else |
187 | | - pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); |
188 | | -#endif |
189 | | -#endif |
190 | | - } |
191 | | - |
192 | 195 | // bind worker threads to disjoint cores |
193 | 196 | // if worker 0 is offloaded to main, i.e. exclude_worker0 is true, |
194 | 197 | // the main thread is bound to core 0. |
@@ -326,7 +329,11 @@ class ThreadGroup::Impl { |
326 | 329 | const std::pair<unsigned int, int64_t>& b) { |
327 | 330 | return a.second == b.second ? a.first < b.first : a.second > b.second; |
328 | 331 | }; |
| 332 | +#if defined(__hexagon__) |
329 | 333 | std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); |
| 334 | +#else |
| 335 | + std::stable_sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); |
| 336 | +#endif |
330 | 337 | int64_t big_freq = max_freqs.begin()->second; |
331 | 338 | int64_t little_freq = max_freqs.rbegin()->second; |
332 | 339 | for (auto it = max_freqs.begin(); it != max_freqs.end(); it++) { |
@@ -431,6 +438,14 @@ int MaxConcurrency() { |
431 | 438 | return std::max(max_concurrency, 1); |
432 | 439 | } |
433 | 440 |
|
| 441 | +// This global function can be used by disco runtime to bind processes |
| 442 | +// to CPUs. |
| 443 | +TVM_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") |
| 444 | + .set_body_typed([](IntTuple cpu_ids) { |
| 445 | + SetThreadAffinity(CURRENT_THREAD_HANDLE, |
| 446 | + std::vector<unsigned int>{cpu_ids.begin(), cpu_ids.end()}); |
| 447 | + }); |
| 448 | + |
434 | 449 | } // namespace threading |
435 | 450 | } // namespace runtime |
436 | 451 | } // namespace tvm |
0 commit comments