Skip to content

GraphMode.WARP_STAGED fails on multi-device (sharded) execution #3191

@shuhangchen

Description

@shuhangchen

Summary

GraphMode.WARP_STAGED crashes with wp_cuda_graph_update_memcpy_batch: CUDA error 1: invalid argument when used with JAX multi-device sharding. Single-device WARP_STAGED, and multi-device WARP and WARP_STAGED_EX all work correctly.

Mode Single GPU Multi GPU (sharded)
WARP (default)
WARP_STAGED
WARP_STAGED_EX

Error

Warp CUDA error 1: invalid argument (in function wp_cuda_graph_update_memcpy_batch,
  /builds/omniverse/warp/warp/native/warp.cu:3018)

RuntimeError: Failed to update graph memcpy batch: Warp CUDA error 1: invalid argument
  (in function wp_cuda_graph_update_memcpy_batch, /builds/omniverse/warp/warp/native/warp.cu:3018)

The error originates from mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py:789 during graph replay on the second device partition.

Minimal reproduction

import os; os.environ["MUJOCO_GL"] = "egl"
import jax, jax.numpy as jnp, numpy as np, mujoco
from mujoco import mjx
import mujoco.mjx.warp as mjxw

devices = [d for d in jax.devices() if d.platform == "gpu"]
assert len(devices) >= 2

xml = """
<mujoco>
  <worldbody>
    <body>
      <joint type="free"/>
      <geom type="sphere" size="0.1" mass="1"/>
    </body>
    <body pos="0 0 -1">
      <geom type="plane" size="5 5 0.1"/>
    </body>
  </worldbody>
</mujoco>
"""
mj_model = mujoco.MjModel.from_xml_string(xml)
model = mjx.put_model(mj_model, impl="warp", graph_mode=mjxw.types.GraphMode.WARP_STAGED)

mesh = jax.sharding.Mesh(np.array(devices), ("fsdp",))
sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("fsdp"))
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

num_envs = len(devices) * 2
batch_data = jax.vmap(lambda _: mjx.make_data(mj_model, impl="warp"))(jnp.arange(num_envs))

def get_sharding(x):
    if hasattr(x, "shape") and len(x.shape) > 0 and x.shape[0] % len(devices) == 0:
        return sharded
    return replicated

data_sharding = jax.tree.map(get_sharding, batch_data)
batch_data = jax.device_put(batch_data, data_sharding)

@jax.jit(out_shardings=data_sharding)
def batched_step(data):
    return jax.vmap(mjx.step, in_axes=(None, 0))(model, data)

result = batched_step(batch_data)  # Crashes here
result.qpos.block_until_ready()

Environment

  • mujoco: 3.6.0
  • warp-lang: 1.11.1
  • jax/jaxlib: 0.8.2
  • GPU: 2x NVIDIA H100 80GB HBM3
  • Driver: 550.163.01
  • OS: Linux 6.8.0 (Ubuntu)

Notes

  • Changing WARP_STAGEDWARP_STAGED_EX is a working alternative.
  • The error occurs on graph replay (second partition), not during initial graph capture.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions