Description
I've been trying to get AlphaFold3 working on an AMD Mi300x box we have on loan from HPE and I've run into some issues caused by the test here:
|
_hip_triton = import_from_plugin("rocm", "_triton") |
failing on the "community" docker images for JAX, which results in triton complaining that we don't have the GPU version of JAX installed.
I've tried:
docker.io/rocm/jax-community latest 193ba487b999
docker.io/rocm/jax-community rocm6.2.4-jax0.4.35-py3.11.10 ef50d5181ba5
docker.io/rocm/jax-community rocm6.2.3-jax0.4.34-py3.11.10 b229479e4af8
As an example, you can test this in a container built from these images:
Python 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxlib import gpu_triton
>>> gpu_triton._hip_triton
>>>
For comparison with the "non community" but older JAX:
docker.io/rocm/jax latest d949265c6ac2
Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxlib import gpu_triton
>>> gpu_triton._hip_triton
<module 'jax_rocm60_plugin._triton' from '/opt/venv/lib/python3.10/site-packages/jax_rocm60_plugin/_triton.so'>
>>>
I'm also finding it difficult to install from wheels as per https://github.com/rocm/jax/tree/main/build/rocm because pip can't find the wheels for the ROCm features.
e.g.
root@ea47cc7f8ef9:/# pip install jax[rocm]==0.4.38
Collecting jax[rocm]==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 52.2 MB/s eta 0:00:00
WARNING: jax 0.4.38 does not provide the extra 'rocm'
I remember distantly from installing things on Nvidia GPUs, that there is a "secret" repository you have to use sometimes to get JAX Cuda plugins to work (https://storage.googleapis.com/jax-releases/jax_cuda_releases.html) - is there a similar thing for ROCm plugins?
System info (python version, jaxlib version, accelerator, etc.)
docker.io/rocm/jax-community containers, tag rocm6.2.3-jax0.4.34-py3.11.10, rocm6.2.4-jax0.4.35-py3.11.10, latest
Python 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax: 0.5.0
jaxlib: 0.5.0
numpy: 1.26.4
python: 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0]
device info: AMD Instinct MI300X-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='5ca490f32332', release='5.14.0-503.34.1.el9_5.x86_64', version='#1 SMP PREEMPT_DYNAMIC Thu Mar 27 06:00:50 EDT 2025', machine='x86_64')
Description
I've been trying to get AlphaFold3 working on an AMD Mi300x box we have on loan from HPE and I've run into some issues caused by the test here:
jax/jaxlib/gpu_triton.py
Line 20 in 1f93b4b
I've tried:
As an example, you can test this in a container built from these images:
For comparison with the "non community" but older JAX:
I'm also finding it difficult to install from wheels as per https://github.com/rocm/jax/tree/main/build/rocm because
pipcan't find the wheels for the ROCm features.e.g.
I remember distantly from installing things on Nvidia GPUs, that there is a "secret" repository you have to use sometimes to get JAX Cuda plugins to work (https://storage.googleapis.com/jax-releases/jax_cuda_releases.html) - is there a similar thing for ROCm plugins?
System info (python version, jaxlib version, accelerator, etc.)
docker.io/rocm/jax-communitycontainers, tagrocm6.2.3-jax0.4.34-py3.11.10,rocm6.2.4-jax0.4.35-py3.11.10,latest