Skip to content

Commit 470b4a1

Browse files
committed
[Disco] Support setting workers' CPU affinity
This PR supports setting the CPU affinity for disco workers. Specifically, a global function `runtime.disco.bind_worker_to_cpu_core` is added to allow accepting a list of CPU ids, and then set the CPU affinity for workers. This can potentially reduce the OS scheduling overhead that increases the disco worker pthread conditional waiting time before being waken up.
1 parent f8b9a5f commit 470b4a1

2 files changed

Lines changed: 54 additions & 31 deletions

File tree

src/runtime/disco/builtin.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t
129129
TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device {
130130
return DiscoWorker::ThreadLocal()->default_device;
131131
});
132+
TVM_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core").set_body_typed([](IntTuple cpu_ids) {
133+
int worker_id = WorkerId();
134+
ICHECK_LT(worker_id, static_cast<int>(cpu_ids.size()));
135+
const PackedFunc* f_set_thread_affinity =
136+
Registry::Get("tvm.runtime.threading.set_current_thread_affinity");
137+
ICHECK_NOTNULL(f_set_thread_affinity);
138+
(*f_set_thread_affinity)(IntTuple{cpu_ids[worker_id]});
139+
});
132140

133141
} // namespace runtime
134142
} // namespace tvm

src/runtime/threading_backend.cc

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
* \brief Native threading backend
2323
*/
2424
#include <tvm/runtime/logging.h>
25+
#include <tvm/runtime/registry.h>
2526
#include <tvm/runtime/threading_backend.h>
2627

2728
#if defined(__linux__) || defined(__ANDROID__)
@@ -106,6 +107,39 @@ class QuRTThread {
106107
void* stack_ = nullptr;
107108
};
108109
#endif // __hexagon__
110+
111+
// This is a common function used to set thread affinity.
112+
void SetThreadAffinity(std::thread::native_handle_type thread,
113+
const std::vector<unsigned int>& ids) {
114+
#if defined(__linux__) || defined(__ANDROID__)
115+
if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) {
116+
thread = pthread_self();
117+
}
118+
cpu_set_t cpuset;
119+
CPU_ZERO(&cpuset);
120+
for (auto id : ids) {
121+
CPU_SET(id, &cpuset);
122+
}
123+
#if defined(__ANDROID__)
124+
#if __ANDROID_API__ >= 21
125+
pid_t tid = pthread_gettid_np(thread);
126+
#else
127+
typedef struct {
128+
void* next;
129+
void* pred;
130+
pid_t tid;
131+
} pthread_internal;
132+
pid_t tid = reinterpret_cast<pthread_internal*>(thread)->tid;
133+
#endif
134+
if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) {
135+
LOG(WARNING) << "sched_setaffinity failed";
136+
}
137+
#else
138+
pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
139+
#endif
140+
#endif
141+
}
142+
109143
thread_local int max_concurrency = 0;
110144
class ThreadGroup::Impl {
111145
public:
@@ -158,37 +192,6 @@ class ThreadGroup::Impl {
158192
}
159193

160194
private:
161-
void SetThreadAffinity(std::thread::native_handle_type thread,
162-
const std::vector<unsigned int>& ids) {
163-
#if defined(__linux__) || defined(__ANDROID__)
164-
if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) {
165-
thread = pthread_self();
166-
}
167-
cpu_set_t cpuset;
168-
CPU_ZERO(&cpuset);
169-
for (auto id : ids) {
170-
CPU_SET(id, &cpuset);
171-
}
172-
#if defined(__ANDROID__)
173-
#if __ANDROID_API__ >= 21
174-
pid_t tid = pthread_gettid_np(thread);
175-
#else
176-
typedef struct {
177-
void* next;
178-
void* pred;
179-
pid_t tid;
180-
} pthread_internal;
181-
pid_t tid = reinterpret_cast<pthread_internal*>(thread)->tid;
182-
#endif
183-
if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) {
184-
LOG(WARNING) << "sched_setaffinity failed";
185-
}
186-
#else
187-
pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
188-
#endif
189-
#endif
190-
}
191-
192195
// bind worker threads to disjoint cores
193196
// if worker 0 is offloaded to main, i.e. exclude_worker0 is true,
194197
// the main thread is bound to core 0.
@@ -326,7 +329,11 @@ class ThreadGroup::Impl {
326329
const std::pair<unsigned int, int64_t>& b) {
327330
return a.second == b.second ? a.first < b.first : a.second > b.second;
328331
};
332+
#if defined(__hexagon__)
329333
std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);
334+
#else
335+
std::stable_sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);
336+
#endif
330337
int64_t big_freq = max_freqs.begin()->second;
331338
int64_t little_freq = max_freqs.rbegin()->second;
332339
for (auto it = max_freqs.begin(); it != max_freqs.end(); it++) {
@@ -431,6 +438,14 @@ int MaxConcurrency() {
431438
return std::max(max_concurrency, 1);
432439
}
433440

441+
// This global function can be used by disco runtime to bind processes
442+
// to CPUs.
443+
TVM_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity")
444+
.set_body_typed([](IntTuple cpu_ids) {
445+
SetThreadAffinity(CURRENT_THREAD_HANDLE,
446+
std::vector<unsigned int>{cpu_ids.begin(), cpu_ids.end()});
447+
});
448+
434449
} // namespace threading
435450
} // namespace runtime
436451
} // namespace tvm

0 commit comments

Comments
 (0)