-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
Summary
GraphMode.WARP_STAGED crashes with wp_cuda_graph_update_memcpy_batch: CUDA error 1: invalid argument when used with JAX multi-device sharding. Single-device WARP_STAGED, and multi-device WARP and WARP_STAGED_EX all work correctly.
| Mode | Single GPU | Multi GPU (sharded) |
|---|---|---|
WARP (default) |
✅ | ✅ |
WARP_STAGED |
✅ | ❌ |
WARP_STAGED_EX |
✅ | ✅ |
Error
Warp CUDA error 1: invalid argument (in function wp_cuda_graph_update_memcpy_batch,
/builds/omniverse/warp/warp/native/warp.cu:3018)
RuntimeError: Failed to update graph memcpy batch: Warp CUDA error 1: invalid argument
(in function wp_cuda_graph_update_memcpy_batch, /builds/omniverse/warp/warp/native/warp.cu:3018)
The error originates from mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py:789 during graph replay on the second device partition.
Minimal reproduction
import os; os.environ["MUJOCO_GL"] = "egl"
import jax, jax.numpy as jnp, numpy as np, mujoco
from mujoco import mjx
import mujoco.mjx.warp as mjxw
devices = [d for d in jax.devices() if d.platform == "gpu"]
assert len(devices) >= 2
xml = """
<mujoco>
<worldbody>
<body>
<joint type="free"/>
<geom type="sphere" size="0.1" mass="1"/>
</body>
<body pos="0 0 -1">
<geom type="plane" size="5 5 0.1"/>
</body>
</worldbody>
</mujoco>
"""
mj_model = mujoco.MjModel.from_xml_string(xml)
model = mjx.put_model(mj_model, impl="warp", graph_mode=mjxw.types.GraphMode.WARP_STAGED)
mesh = jax.sharding.Mesh(np.array(devices), ("fsdp",))
sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("fsdp"))
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
num_envs = len(devices) * 2
batch_data = jax.vmap(lambda _: mjx.make_data(mj_model, impl="warp"))(jnp.arange(num_envs))
def get_sharding(x):
if hasattr(x, "shape") and len(x.shape) > 0 and x.shape[0] % len(devices) == 0:
return sharded
return replicated
data_sharding = jax.tree.map(get_sharding, batch_data)
batch_data = jax.device_put(batch_data, data_sharding)
@jax.jit(out_shardings=data_sharding)
def batched_step(data):
return jax.vmap(mjx.step, in_axes=(None, 0))(model, data)
result = batched_step(batch_data) # Crashes here
result.qpos.block_until_ready()Environment
- mujoco: 3.6.0
- warp-lang: 1.11.1
- jax/jaxlib: 0.8.2
- GPU: 2x NVIDIA H100 80GB HBM3
- Driver: 550.163.01
- OS: Linux 6.8.0 (Ubuntu)
Notes
- Changing
WARP_STAGED→WARP_STAGED_EXis a working alternative. - The error occurs on graph replay (second partition), not during initial graph capture.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels