Allow ROCm arch override#1896
Conversation
|
Thanks @ailuntz. |
|
Thanks for the feedback. I reviewed the code and agree the override should be a fallback only when rocminfo fails, and keeping both env vars is unnecessary. I've updated the PR to use only BNB_ROCM_GPU_ARCH as a fallback and pushed the changes. |
|
Thanks. Can you update the existing error message to mention BNB_ROCM_GPU_ARCH as a workaround so users know what to do when rocminfo fails (ask users to install rocminfo first, if it's not possible then ask them to set the env variable) |
|
Updated this PR based on the feedback. Changes in the latest push:
|
|
Thanks @ailuntz. But I think we can keep everything in the def get_rocm_gpu_arch() -> str:
"""Get ROCm GPU architecture."""
logger = logging.getLogger(__name__)
try:
if not torch.version.hip:
return "unknown"
if platform.system() == "Windows":
cmd = ["hipinfo.exe"]
arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)"
else:
cmd = ["rocminfo"]
arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"
result = subprocess.run(cmd, capture_output=True, text=True)
match = re.search(arch_pattern, result.stdout)
if match:
return "gfx" + match.group(1)
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
# Fallback: check BNB_ROCM_GPU_ARCH env variable
env_arch = os.environ.get("BNB_ROCM_GPU_ARCH", "").strip()
if env_arch:
if re.fullmatch(r"gfx[a-zA-Z0-9]+", env_arch):
logger.info(f"Using ROCm GPU architecture from BNB_ROCM_GPU_ARCH: {env_arch}")
return env_arch
else:
logger.warning(
f"BNB_ROCM_GPU_ARCH='{env_arch}' is not a valid architecture name. "
"Expected format: gfx followed by alphanumeric characters (e.g. gfx942, gfx90a)."
)
if torch.cuda.is_available():
logger.warning(
"ROCm GPU architecture detection failed despite ROCm being available. "
"Please ensure 'rocminfo' is installed and available in your PATH. "
"If that is not possible, you can set the BNB_ROCM_GPU_ARCH environment variable "
"manually (e.g. export BNB_ROCM_GPU_ARCH=gfx942)."
)
return "unknown"And no need to add tests for this. |
Fixes #1444.\n\nSupport BNB_ROCM_GPU_ARCH override when rocminfo is unavailable.