-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpu_parallel.py
More file actions
22 lines (19 loc) · 759 Bytes
/
Copy pathgpu_parallel.py
File metadata and controls
22 lines (19 loc) · 759 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import multiprocessing as mp
from train_functions import train_on_gpu, del_active_training_csv, rename_results_csv, get_timestamp, move_to_logs
from train_config import work_list
timestamp = get_timestamp()
gpu_log = f"GPU_{timestamp}.txt"
manager = mp.Manager()
queue_lock = manager.Lock()
work_queue = manager.Queue()
for work in work_list:
work_queue.put(work)
num_gpus = 2 #torch.cuda.device_count()
processes_per_gpu = 12
pool = mp.Pool(processes=num_gpus * processes_per_gpu)
pool.starmap(train_on_gpu, [(work_queue, gpu_id, queue_lock, gpu_log) for gpu_id in range(num_gpus) for _ in range(processes_per_gpu)])
pool.close()
pool.join()
del_active_training_csv()
saved_results = rename_results_csv(timestamp)
move_to_logs(gpu_log, saved_results)