Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

# Mirrors WinBase.h (unfortunately not defined already elsewhere)
WINBASE_LOAD_WITH_ALTERED_SEARCH_PATH = 0x00000008
WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000

Expand Down Expand Up @@ -45,6 +46,17 @@
kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR]
kernel32.AddDllDirectory.restype = ctypes.c_void_p # DLL_DIRECTORY_COOKIE

# SearchPathW - find a file in the system search path
kernel32.SearchPathW.argtypes = [
ctypes.wintypes.LPCWSTR, # lpPath (NULL to use standard search)
ctypes.wintypes.LPCWSTR, # lpFileName
ctypes.wintypes.LPCWSTR, # lpExtension
ctypes.wintypes.DWORD, # nBufferLength
ctypes.wintypes.LPWSTR, # lpBuffer
ctypes.POINTER(ctypes.wintypes.LPWSTR), # lpFilePart
]
kernel32.SearchPathW.restype = ctypes.wintypes.DWORD


def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int:
"""Convert ctypes HMODULE to unsigned int."""
Expand Down Expand Up @@ -113,6 +125,31 @@ def check_if_already_loaded_from_elsewhere(libname: str, have_abs_path: bool) ->
return None


def _search_path_for_dll(dll_name: str) -> str | None:
"""Search for a DLL using Windows SearchPathW.

Args:
dll_name: The name of the DLL to find

Returns:
The absolute path to the DLL if found, None otherwise
"""
buffer = ctypes.create_unicode_buffer(260) # MAX_PATH
length = kernel32.SearchPathW(None, dll_name, None, len(buffer), buffer, None)

if length == 0:
return None

# If buffer was too small, try with larger buffer
if length > len(buffer):
buffer = ctypes.create_unicode_buffer(length)
length = kernel32.SearchPathW(None, dll_name, None, len(buffer), buffer, None)
if length == 0:
return None

return buffer.value


def load_with_system_search(libname: str) -> LoadedDL | None:
"""Try to load a DLL using system search paths.

Expand All @@ -124,10 +161,15 @@ def load_with_system_search(libname: str) -> LoadedDL | None:
"""
# Reverse tabulated names to achieve new → old search order.
for dll_name in reversed(SUPPORTED_WINDOWS_DLLS.get(libname, ())):
handle = kernel32.LoadLibraryExW(dll_name, None, 0)
if handle:
abs_path = abs_path_for_dynamic_library(libname, handle)
return LoadedDL(abs_path, False, ctypes_handle_to_unsigned_int(handle), "system-search")
# First, find the DLL's full path using SearchPathW
found_path = _search_path_for_dll(dll_name)
if found_path:
# Load with LOAD_WITH_ALTERED_SEARCH_PATH so Windows searches for
# dependencies from the DLL's directory (required for CUDA DLLs
# whose dependencies are co-located)
handle = kernel32.LoadLibraryExW(found_path, None, WINBASE_LOAD_WITH_ALTERED_SEARCH_PATH)
if handle:
return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), "system-search")

return None

Expand Down