@@ -87,6 +87,39 @@ def PTX(arch, ptx_version):
8787 return PTX_TEMPLATE .format (PTX_VERSION = ptx_version , ARCH = arch )
8888
8989
90+ @pytest .fixture
91+ def nvcc_smoke (tmpdir ) -> str :
92+ # TODO: Use cuda-pathfinder to locate nvcc on system.
93+ nvcc = shutil .which ("nvcc" )
94+ if nvcc is None :
95+ pytest .skip ("nvcc not found on PATH" )
96+
97+ # Smoke test: make sure nvcc is actually usable (toolkit + host compiler are set up),
98+ # not merely present on PATH.
99+ src = tmpdir / "nvcc_smoke.cu"
100+ out = tmpdir / "nvcc_smoke.o"
101+ with open (src , "w" ) as f :
102+ f .write ("" )
103+ try :
104+ subprocess .run ( # noqa: S603
105+ [nvcc , "-c" , str (src ), "-o" , str (out )],
106+ check = True ,
107+ capture_output = True ,
108+ )
109+ except subprocess .CalledProcessError as e :
110+ stdout = (e .stdout or b"" ).decode (errors = "replace" )
111+ stderr = (e .stderr or b"" ).decode (errors = "replace" )
112+ pytest .skip (
113+ "nvcc found on PATH but failed to compile a trivial input.\n "
114+ f"command: { [nvcc , '-c' , str (src ), '-o' , str (out )]!r} \n "
115+ f"exit_code: { e .returncode } \n "
116+ f"stdout:\n { stdout } \n "
117+ f"stderr:\n { stderr } \n "
118+ )
119+
120+ return nvcc
121+
122+
90123@pytest .fixture
91124def CUBIN (arch ):
92125 def CHECK_NVRTC (err ):
@@ -132,23 +165,32 @@ def CHECK_NVRTC(err):
132165
133166
134167@pytest .fixture
135- def OBJECT (arch , tmpdir ):
136- empty_cplusplus_kernel = "__global__ void A() {} int main() { return 0; } "
168+ def OBJECT (arch , tmpdir , nvcc_smoke ):
169+ empty_cplusplus_kernel = "__global__ void A() {}"
137170 with open (tmpdir / "object.cu" , "w" ) as f :
138171 f .write (empty_cplusplus_kernel )
139172
140- # TODO: Use cuda-pathfinder to locate nvcc on system.
141- nvcc = shutil .which ("nvcc" )
142- if nvcc is None :
143- pytest .skip ("nvcc not found on PATH" )
173+ nvcc = nvcc_smoke
144174
145175 # This is a test fixture that intentionally invokes a trusted tool (`nvcc`) to
146176 # compile a temporary CUDA translation unit.
147- subprocess .run ( # noqa: S603
148- [nvcc , "-arch" , arch , "-o" , str (tmpdir / "object.o" ), str (tmpdir / "object.cu" )],
149- check = True ,
150- capture_output = True ,
151- )
177+ cmd = [nvcc , "-c" , "-arch" , arch , "-o" , str (tmpdir / "object.o" ), str (tmpdir / "object.cu" )]
178+ try :
179+ subprocess .run ( # noqa: S603
180+ cmd ,
181+ check = True ,
182+ capture_output = True ,
183+ )
184+ except subprocess .CalledProcessError as e :
185+ stdout = (e .stdout or b"" ).decode (errors = "replace" )
186+ stderr = (e .stderr or b"" ).decode (errors = "replace" )
187+ raise RuntimeError (
188+ "nvcc smoke test passed, but nvcc failed while compiling the test object.\n "
189+ f"command: { cmd !r} \n "
190+ f"exit_code: { e .returncode } \n "
191+ f"stdout:\n { stdout } \n "
192+ f"stderr:\n { stderr } \n "
193+ ) from e
152194 with open (tmpdir / "object.o" , "rb" ) as f :
153195 object = f .read ()
154196
0 commit comments