-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsetup.py
More file actions
93 lines (77 loc) · 2.51 KB
/
setup.py
File metadata and controls
93 lines (77 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging
import os
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
root_dir = os.path.dirname(os.path.abspath(__file__))
mpi_home = os.environ.get("MPI_HOME")
if mpi_home is None:
mpi_home = "/usr/lib/openmpi/"
if not os.path.exists(mpi_home):
mpi_home = "/usr/lib/x86_64-linux-gnu/openmpi/"
if not os.path.exists(mpi_home):
print("Couldn't find MPI install dir, please set MPI_HOME env variable")
sys.exit(1)
nccl_home = os.environ.get("NCCL_HOME")
if nccl_home is None or not os.path.exists(nccl_home):
nccl_home = None
logging.warn("Couldn't find NCCL install dir, please set NCCL_HOME to enable NCCL build")
torch_version = torch.__version__.split('.')
torch_version_defines = ["-DTORCH_MAJOR="+torch_version[0], "-DTORCH_MINOR="+torch_version[1]]
extensions = []
cmdclass = {}
extensions = [
CppExtension(
name="torch_pg._C",
sources=[
"src/Bindings.cpp",
"src/ProcessGroupMPI.cpp",
],
include_dirs=[
os.path.join(root_dir, "include"),
os.path.join(mpi_home, "include"),
],
library_dirs=[
os.path.join(mpi_home, "lib"),
],
libraries=["mpi",],
extra_compile_args=["-DOMPI_SKIP_MPICXX=1"] + torch_version_defines,
),
]
if nccl_home is not None:
extensions += [
CUDAExtension(
name="torch_pg._CUDA",
sources=[
"src/CUDABindings.cpp",
"src/NCCLUtils.cpp",
"src/ProcessGroupNCCL.cpp"
],
include_dirs=[
os.path.join(root_dir, "include"),
os.path.join(nccl_home, "include"),
],
library_dirs=[
os.path.join(nccl_home, "lib"),
],
libraries=["nccl",],
extra_compile_args=["-DENABLE_NCCL_P2P_SUPPORT"] + torch_version_defines,
),
]
cmdclass["build_ext"] = BuildExtension
setup(
name="torch-pg",
version="0.0.0",
packages=find_packages(),
ext_modules=extensions,
cmdclass=cmdclass,
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: BSD License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Operating System :: OS Independent",
],
)