diff --git a/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_windows.py b/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_windows.py index b9f15ea50b..247cc4d467 100644 --- a/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_windows.py +++ b/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_windows.py @@ -13,6 +13,7 @@ ) # Mirrors WinBase.h (unfortunately not defined already elsewhere) +WINBASE_LOAD_WITH_ALTERED_SEARCH_PATH = 0x00000008 WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 @@ -45,6 +46,17 @@ kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR] kernel32.AddDllDirectory.restype = ctypes.c_void_p # DLL_DIRECTORY_COOKIE +# SearchPathW - find a file in the system search path +kernel32.SearchPathW.argtypes = [ + ctypes.wintypes.LPCWSTR, # lpPath (NULL to use standard search) + ctypes.wintypes.LPCWSTR, # lpFileName + ctypes.wintypes.LPCWSTR, # lpExtension + ctypes.wintypes.DWORD, # nBufferLength + ctypes.wintypes.LPWSTR, # lpBuffer + ctypes.POINTER(ctypes.wintypes.LPWSTR), # lpFilePart +] +kernel32.SearchPathW.restype = ctypes.wintypes.DWORD + def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int: """Convert ctypes HMODULE to unsigned int.""" @@ -113,6 +125,31 @@ def check_if_already_loaded_from_elsewhere(libname: str, have_abs_path: bool) -> return None +def _search_path_for_dll(dll_name: str) -> str | None: + """Search for a DLL using Windows SearchPathW. + + Args: + dll_name: The name of the DLL to find + + Returns: + The absolute path to the DLL if found, None otherwise + """ + buffer = ctypes.create_unicode_buffer(260) # MAX_PATH + length = kernel32.SearchPathW(None, dll_name, None, len(buffer), buffer, None) + + if length == 0: + return None + + # If buffer was too small, try with larger buffer + if length > len(buffer): + buffer = ctypes.create_unicode_buffer(length) + length = kernel32.SearchPathW(None, dll_name, None, len(buffer), buffer, None) + if length == 0: + return None + + return buffer.value + + def load_with_system_search(libname: str) -> LoadedDL | None: """Try to load a DLL using system search paths. @@ -124,10 +161,15 @@ def load_with_system_search(libname: str) -> LoadedDL | None: """ # Reverse tabulated names to achieve new → old search order. for dll_name in reversed(SUPPORTED_WINDOWS_DLLS.get(libname, ())): - handle = kernel32.LoadLibraryExW(dll_name, None, 0) - if handle: - abs_path = abs_path_for_dynamic_library(libname, handle) - return LoadedDL(abs_path, False, ctypes_handle_to_unsigned_int(handle), "system-search") + # First, find the DLL's full path using SearchPathW + found_path = _search_path_for_dll(dll_name) + if found_path: + # Load with LOAD_WITH_ALTERED_SEARCH_PATH so Windows searches for + # dependencies from the DLL's directory (required for CUDA DLLs + # whose dependencies are co-located) + handle = kernel32.LoadLibraryExW(found_path, None, WINBASE_LOAD_WITH_ALTERED_SEARCH_PATH) + if handle: + return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), "system-search") return None