Skip to content

Skip ann_test test_pmap on ROCm due to IndivisibleError#773

Open
AratiGanesh wants to merge 1 commit into
rocm-jaxlib-v0.9.1from
araganes/skip-pmap-rocm-v0.9.1
Open

Skip ann_test test_pmap on ROCm due to IndivisibleError#773
AratiGanesh wants to merge 1 commit into
rocm-jaxlib-v0.9.1from
araganes/skip-pmap-rocm-v0.9.1

Conversation

@AratiGanesh
Copy link
Copy Markdown

Motivation

AnnTest::test_pmap fails on ROCm with jax._src.sharding.IndivisibleError: shape=[1, 2, 4] is incompatible with mesh_shape=... after the JAX 0.8.2 → 0.9.1 bump.

Inside jax.pmap, lax.approx_min_k(..., aggregate_to_topk=False) produces an XLA SPMD tile {devices=[1,2,4]<=[8] last_tile_dim_replicate}. JAX 0.9.x's new _gspmd_to_named_sharding_via_meshparse_flatten_op_sharding path then tries to translate this 3-D tile back into the 1-D pmap mesh of size 8 and fails, since [1, 2, 4] cannot factor into a 1-D mesh.

Technical Details

Cherry-pick of upstream commit 596683446be555baf6dbf7803d45e51eb33bdcd1 onto rocm-jaxlib-v0.9.1. Adds a ROCm-only skipTest guard in AnnTest.test_pmap (tests/ann_test.py):

if jtu.is_device_rocm():
self.skipTest("IndivisibleError: SPMD tiling incompatible with 1D pmap mesh on ROCm")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant