|
11 | 11 | from cuda.core.experimental._utils import ( |
12 | 12 | _handle_boolean_option, |
13 | 13 | check_or_create_options, |
| 14 | + driver, |
14 | 15 | handle_return, |
15 | 16 | is_nested_sequence, |
16 | 17 | is_sequence, |
@@ -413,6 +414,21 @@ def __init__(self, code, code_type, options: ProgramOptions = None): |
413 | 414 | raise TypeError |
414 | 415 | # TODO: support pre-loaded headers & include names |
415 | 416 | # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved |
| 417 | + |
| 418 | + supported_archs = handle_return(nvrtc.nvrtcGetSupportedArchs()) |
| 419 | + |
| 420 | + if options is not None: |
| 421 | + arch_not_supported = options.arch is not None and options.arch not in supported_archs |
| 422 | + default_arch_not_supported = ( |
| 423 | + options.arch is None |
| 424 | + and 10 * Device().compute_capability[0] + Device().compute_capability[1] not in supported_archs |
| 425 | + ) |
| 426 | + |
| 427 | + if arch_not_supported or default_arch_not_supported: |
| 428 | + raise ValueError( |
| 429 | + f"The provided arch, or default arch (that of the current device) " |
| 430 | + f"is not supported by the current backend. Supported architectures: {supported_archs}" |
| 431 | + ) |
416 | 432 | self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], [])) |
417 | 433 | self._backend = "nvrtc" |
418 | 434 | else: |
@@ -448,6 +464,12 @@ def compile(self, target_type, name_expressions=(), logs=None): |
448 | 464 | raise NotImplementedError |
449 | 465 |
|
450 | 466 | if self._backend == "nvrtc": |
| 467 | + version = handle_return(nvrtc.nvrtcVersion()) |
| 468 | + if handle_return(driver.cuDriverGetVersion()) > version[0] * 1000 + version[1] * 10: |
| 469 | + raise RuntimeError( |
| 470 | + "The CUDA driver version is newer than the NVRTC version. " |
| 471 | + "Please update your NVRTC library to match the CUDA driver version." |
| 472 | + ) |
451 | 473 | if name_expressions: |
452 | 474 | for n in name_expressions: |
453 | 475 | handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle) |
|
0 commit comments