Skip to content

Commit d1a8ff8

Browse files
committed
restrict Gemma4 manual CP route
1 parent 25ae433 commit d1a8ff8

2 files changed

Lines changed: 38 additions & 4 deletions

File tree

nemo_automodel/components/distributed/cp_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,10 +503,10 @@ def _get_mesh_size(mesh):
503503
# Gemma4 needs a local-query/global-key attention mask that PyTorch's
504504
# ring-template CP path cannot represent. Its pre-embed step marks the
505505
# batch so we use explicit contiguous sequence sharding and let
506-
# attach_cp_sdpa_hooks all-gather K/V and token types inside attention.
507-
manual_allgather = (
508-
bool(batch.pop("_cp_manual_allgather", False)) or "mm_token_type_ids" in batch or "_packed_seq_ids" in batch
509-
)
506+
# attach_cp_sdpa_hooks all-gather K/V and token metadata inside attention.
507+
# Metadata such as mm_token_type_ids or _packed_seq_ids does not select this
508+
# path by itself because other VLMs can carry those fields.
509+
manual_allgather = bool(batch.pop("_cp_manual_allgather", False))
510510

511511
# Remove attention_mask from the batch so the model does not attempt to
512512
# build a local 4D mask with the wrong key length. Preserve padding

tests/unit_tests/distributed/test_cp_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_make_cp_batch_and_ctx_pads_to_cp_load_balance_multiple(monkeypatch):
146146
"input_ids": torch.tensor([[1, 2, 3]]),
147147
"labels": torch.tensor([[1, 2, 3]]),
148148
"mm_token_type_ids": torch.tensor([[0, 1, 0]]),
149+
"_cp_manual_allgather": True,
149150
}
150151

151152
_cu.make_cp_batch_and_ctx(device_mesh, batch, padding_token_id=99)
@@ -156,6 +157,37 @@ def test_make_cp_batch_and_ctx_pads_to_cp_load_balance_multiple(monkeypatch):
156157
assert batch["mm_token_type_ids"][0, -1].item() == 0
157158

158159

160+
def test_make_cp_batch_and_ctx_mm_token_type_ids_do_not_select_manual_allgather(monkeypatch):
161+
"""VLM metadata alone should not opt non-Gemma4 models into manual all-gather CP."""
162+
device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1)
163+
calls = {}
164+
165+
def fake_create_context_parallel_ctx(**kwargs):
166+
calls["cp_buffers"] = kwargs["cp_buffers"]
167+
return "cp_ctx"
168+
169+
def fake_get_train_context(enable_loss_parallel, enable_compiled_autograd, cp_context=None):
170+
calls["cp_context"] = cp_context
171+
return contextlib.nullcontext
172+
173+
monkeypatch.setattr(_cu, "create_context_parallel_ctx", fake_create_context_parallel_ctx)
174+
monkeypatch.setattr(_cu, "get_train_context", fake_get_train_context)
175+
176+
batch = {
177+
"input_ids": torch.tensor([[1, 2, 3, 4]]),
178+
"labels": torch.tensor([[1, 2, 3, 4]]),
179+
"mm_token_type_ids": torch.tensor([[0, 1, 1, 0]]),
180+
}
181+
182+
ctx_obj, new_batch = _cu.make_cp_batch_and_ctx(device_mesh, batch, padding_token_id=99)
183+
184+
assert ctx_obj is contextlib.nullcontext
185+
assert calls["cp_context"] == "cp_ctx"
186+
assert len(calls["cp_buffers"]) == 3
187+
assert torch.equal(new_batch["input_ids"], torch.tensor([[1, 2, 3, 4]]))
188+
assert torch.equal(new_batch["mm_token_type_ids"], torch.tensor([[0, 1, 1, 0]]))
189+
190+
159191
def test_make_cp_batch_and_ctx_supports_inputs_embeds_and_per_layer_inputs(monkeypatch):
160192
"""Gemma4 CP pre-embedding path should shard inputs_embeds side inputs."""
161193
device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1)
@@ -167,6 +199,7 @@ def test_make_cp_batch_and_ctx_supports_inputs_embeds_and_per_layer_inputs(monke
167199
"labels": labels,
168200
"per_layer_inputs": per_layer_inputs,
169201
"mm_token_type_ids": torch.zeros(1, 4, dtype=torch.long),
202+
"_cp_manual_allgather": True,
170203
}
171204

172205
_cu.make_cp_batch_and_ctx(device_mesh, batch)
@@ -184,6 +217,7 @@ def test_make_cp_batch_and_ctx_pads_and_slices_packed_seq_ids(monkeypatch):
184217
"input_ids": torch.tensor([[1, 2, 3]]),
185218
"labels": torch.tensor([[1, 2, 3]]),
186219
"_packed_seq_ids": torch.tensor([[1, 1, 2]]),
220+
"_cp_manual_allgather": True,
187221
}
188222

189223
_cu.make_cp_batch_and_ctx(device_mesh, batch, padding_token_id=99)

0 commit comments

Comments
 (0)