Skip to content
121 changes: 121 additions & 0 deletions numba_cuda/numba/cuda/cudadrv/nvvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def is_available():
class NVVM(object):
"""Process-wide singleton."""

_libnvvm_cuda_version = None
_libnvvm_cuda_version_attempted = False

_PROTOTYPES = {
# nvvmResult nvvmVersion(int *major, int *minor)
"nvvmVersion": (nvvm_result, POINTER(c_int), POINTER(c_int)),
Expand Down Expand Up @@ -195,6 +198,115 @@ def get_ir_version(self):
self.check_error(err, "Failed to get IR version.")
return majorIR.value, minorIR.value, majorDbg.value, minorDbg.value

def get_cuda_version(self):
"""
Detect the libNVVM CUDA version by compiling dummy IR and analyzing the PTX output.

Workaround for the lack of direct CUDA version API (nvbugs 5312315).
The approach:
- Compile a small dummy NVVM IR to PTX
- Use PTX version analysis APIs if available to infer CUDA version
- Cache the result for future use
"""

if self._libnvvm_cuda_version_attempted:
return self._libnvvm_cuda_version
self._libnvvm_cuda_version_attempted = True

try:
from cuda.bindings.utils import (
get_minimal_required_cuda_ver_from_ptx_ver,
get_ptx_ver,
)
except ImportError:
return None

precheck_nvvm_ir = """target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define void @dummy_kernel() {{
entry:
ret void
}}

!nvvm.annotations = !{{!0}}
!0 = !{{void ()* @dummy_kernel, !"kernel", i32 1}}

!nvvmir.version = !{{!1}}
!1 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
"""

# Create a test program to compile in order to determine
# the CUDA Toolkit version based on the PTX version that
# is generated by libnvvm.
program = c_void_p()
try:
# Create the NVVM program
err = self.nvvmCreateProgram(byref(program))
self.check_error(err, "Failed to create test program.")

# Add the test program to the compilation unit
precheck_nvvm_ir = precheck_nvvm_ir.format(
major=self._majorIR,
minor=self._minorIR,
debug_major=self._majorDbg,
debug_minor=self._minorDbg,
)
precheck_ir_bytes = precheck_nvvm_ir.encode("utf-8")
err = self.nvvmAddModuleToProgram(
program,
precheck_ir_bytes,
len(precheck_ir_bytes),
"precheck.ll".encode("utf-8"),
)
self.check_error(err, "Failed to add test module.")

# Compile the test program
options = ["-arch=compute_90"]
option_ptrs = (c_char_p * len(options))(
*[c_char_p(x.encode("utf-8")) for x in options]
)
err = self.nvvmVerifyProgram(program, len(options), option_ptrs)
self.check_error(err, "Failed to verify test program.")
err = self.nvvmCompileProgram(program, len(options), option_ptrs)
self.check_error(err, "Failed to compile test program.")

# Retrieve the PTX from the compiled program
ptx_size = c_size_t()
err = self.nvvmGetCompiledResultSize(program, byref(ptx_size))
self.check_error(
err, "Failed to get test program compiled result size."
)
ptx_data = (c_char * ptx_size.value)()
err = self.nvvmGetCompiledResult(program, ptx_data)
self.check_error(err, "Failed to get test program compiled result.")
except Exception as exception:
print(f"Exception compiling test program: {exception}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a warning rather than a print:

Suggested change
print(f"Exception compiling test program: {exception}")
warnings.warn(
f"Exception compiling test program: {exception}",
category=NvvmWarning
)

raise exception
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should re-raise the exception, just let it pass - otherwise I'd expect it to propagate all the way back to the user, which we may not want.

Suggested change
raise exception

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: catching the exception, printing it, and re-raising defeats the purpose of graceful fallback. the code at lines 299-305 expects exceptions to be silently caught, allowing _libnvvm_cuda_version to remain None. this re-raise will prevent the function from returning None on error.

Suggested change
except Exception as exception:
print(f"Exception compiling test program: {exception}")
raise exception
except Exception:
pass

finally:
if program.value:
# Destroy the NVVM program, not fatal if it fails
err = self.nvvmDestroyProgram(byref(program))
try:
self.check_error(err, "Failed to destroy test program.")
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's no point checking the error if we're going to swallow the exception the check will raise anyway.

Suggested change
try:
self.check_error(err, "Failed to destroy test program.")
except Exception:
pass


# Extract the PTX version and lookup the corresponding
# CUDA Toolkit version. If this fails, the CUDA Toolkit version
# cannot be determined and self._libnvvm_cuda_version will remain None
# as expected.
try:
ptx_version = get_ptx_ver(ptx_data[:].decode("utf-8"))
self._libnvvm_cuda_version = (
get_minimal_required_cuda_ver_from_ptx_ver(ptx_version)
)
except Exception:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the only likely exception we'd expect from the PTX version functions is a ValueError - anything else is a bit more surprising so we should let it manifest to expose the underlying bug instead:

Suggested change
except Exception:
except ValueError:

pass

# Return the CUDA Toolkit version or None if it could not be determined
return self._libnvvm_cuda_version

def check_error(self, error, msg, exit=False):
if error:
exc = NvvmError(msg, RESULT_CODE_NAMES[error])
Expand Down Expand Up @@ -243,6 +355,15 @@ def stringify_option(k, v):

return f"-{k}={v}".encode("utf-8")

# Starting in r13.1, we must pass in the -numba-debug flag to the
# compiler when compiling with a debug build. If the CUDA version
# cannot be determined, assume that a newer version is being used and
# pass in the -numba-debug flag.
if "g" in options:
ctk_version = self.driver.get_cuda_version()
if ctk_version is None or ctk_version >= (13, 1):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we couldn't determine the version of the CTK, that could be because the version is 12.x because the necessary PTX version functions weren't present in the CUDA bindings for 12.x. So I think it'd be safer to assume we don't pass the -numba-debug flag if we can't determine the version:

Suggested change
if ctk_version is None or ctk_version >= (13, 1):
if ctk_version is not None and ctk_version >= (13, 1):

options["numba-debug"] = None

options = [stringify_option(k, v) for k, v in options.items()]
option_ptrs = (c_char_p * len(options))(*[c_char_p(x) for x in options])

Expand Down