Description
Background
While running MuJoCo MJX robotics simulation workloads on AMD RDNA 3 GPUs, the Newton solver (which internally uses jax.vmap over Cholesky decomposition) consistently crashes with:
INTERNAL: solver_kernels_ffi.cc:452: operation hipGetLastError() failed: out of memory
PR: #717
Minimal Reproducer
import jax, jax.numpy as jnp
A = jnp.eye(4) * 5 + jnp.ones((4, 4))
Ab = jnp.broadcast_to(A, (2, 4, 4)).copy()
r = jax.jit(lambda A: jax.vmap(jnp.linalg.cholesky)(A))(Ab)
r.block_until_ready()
Root Cause
When batch > 1, PotrfDispatch routes to PotrfBatchedImpl, which calls hipSOLVER's batched API. This API internally calls hipMalloc outside of XLA's BFC allocator. Since XLA preallocates ~75% of GPU VRAM by default, the external allocation fails — even for a tiny 2x4x4 matrix.
The non-batched path (PotrfImpl) correctly uses XLA's scratch allocator for all memory and works fine.
Workaround
export XLA_PYTHON_CLIENT_PREALLOCATE=false
System info (python version, jaxlib version, accelerator, etc.)
- OS: Ubuntu 24.04.3 LTS, Kernel 6.14.0-37-generic x86_64
- CPU: AMD Ryzen 7 5800X 8-Core
- GPU: AMD Radeon RX 7900 XTX (gfx1100 / RDNA 3, 24GB VRAM)
- ROCm: 7.2.0
- Python: 3.12.12
- JAX / JAXlib: 0.8.0
- jax-rocm7-plugin: 0.8.0+rocm7.2.0
- jax-rocm7-pjrt: 0.8.0+rocm7.2.0
- rocSOLVER: 3.32.0.70200
- hipSOLVER: 3.2.0.70200
Description
Background
While running MuJoCo MJX robotics simulation workloads on AMD RDNA 3 GPUs, the Newton solver (which internally uses
jax.vmapover Cholesky decomposition) consistently crashes with:PR: #717
Minimal Reproducer
Root Cause
When
batch > 1,PotrfDispatchroutes toPotrfBatchedImpl, which calls hipSOLVER's batched API. This API internally callshipMallocoutside of XLA's BFC allocator. Since XLA preallocates ~75% of GPU VRAM by default, the external allocation fails — even for a tiny 2x4x4 matrix.The non-batched path (
PotrfImpl) correctly uses XLA's scratch allocator for all memory and works fine.Workaround
export XLA_PYTHON_CLIENT_PREALLOCATE=falseSystem info (python version, jaxlib version, accelerator, etc.)