@@ -146,6 +146,7 @@ def test_make_cp_batch_and_ctx_pads_to_cp_load_balance_multiple(monkeypatch):
146146 "input_ids" : torch .tensor ([[1 , 2 , 3 ]]),
147147 "labels" : torch .tensor ([[1 , 2 , 3 ]]),
148148 "mm_token_type_ids" : torch .tensor ([[0 , 1 , 0 ]]),
149+ "_cp_manual_allgather" : True ,
149150 }
150151
151152 _cu .make_cp_batch_and_ctx (device_mesh , batch , padding_token_id = 99 )
@@ -156,6 +157,37 @@ def test_make_cp_batch_and_ctx_pads_to_cp_load_balance_multiple(monkeypatch):
156157 assert batch ["mm_token_type_ids" ][0 , - 1 ].item () == 0
157158
158159
160+ def test_make_cp_batch_and_ctx_mm_token_type_ids_do_not_select_manual_allgather (monkeypatch ):
161+ """VLM metadata alone should not opt non-Gemma4 models into manual all-gather CP."""
162+ device_mesh = _DummyDeviceMesh (cp_size = 2 , tp_size = 1 )
163+ calls = {}
164+
165+ def fake_create_context_parallel_ctx (** kwargs ):
166+ calls ["cp_buffers" ] = kwargs ["cp_buffers" ]
167+ return "cp_ctx"
168+
169+ def fake_get_train_context (enable_loss_parallel , enable_compiled_autograd , cp_context = None ):
170+ calls ["cp_context" ] = cp_context
171+ return contextlib .nullcontext
172+
173+ monkeypatch .setattr (_cu , "create_context_parallel_ctx" , fake_create_context_parallel_ctx )
174+ monkeypatch .setattr (_cu , "get_train_context" , fake_get_train_context )
175+
176+ batch = {
177+ "input_ids" : torch .tensor ([[1 , 2 , 3 , 4 ]]),
178+ "labels" : torch .tensor ([[1 , 2 , 3 , 4 ]]),
179+ "mm_token_type_ids" : torch .tensor ([[0 , 1 , 1 , 0 ]]),
180+ }
181+
182+ ctx_obj , new_batch = _cu .make_cp_batch_and_ctx (device_mesh , batch , padding_token_id = 99 )
183+
184+ assert ctx_obj is contextlib .nullcontext
185+ assert calls ["cp_context" ] == "cp_ctx"
186+ assert len (calls ["cp_buffers" ]) == 3
187+ assert torch .equal (new_batch ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 ]]))
188+ assert torch .equal (new_batch ["mm_token_type_ids" ], torch .tensor ([[0 , 1 , 1 , 0 ]]))
189+
190+
159191def test_make_cp_batch_and_ctx_supports_inputs_embeds_and_per_layer_inputs (monkeypatch ):
160192 """Gemma4 CP pre-embedding path should shard inputs_embeds side inputs."""
161193 device_mesh = _DummyDeviceMesh (cp_size = 2 , tp_size = 1 )
@@ -167,6 +199,7 @@ def test_make_cp_batch_and_ctx_supports_inputs_embeds_and_per_layer_inputs(monke
167199 "labels" : labels ,
168200 "per_layer_inputs" : per_layer_inputs ,
169201 "mm_token_type_ids" : torch .zeros (1 , 4 , dtype = torch .long ),
202+ "_cp_manual_allgather" : True ,
170203 }
171204
172205 _cu .make_cp_batch_and_ctx (device_mesh , batch )
@@ -184,6 +217,7 @@ def test_make_cp_batch_and_ctx_pads_and_slices_packed_seq_ids(monkeypatch):
184217 "input_ids" : torch .tensor ([[1 , 2 , 3 ]]),
185218 "labels" : torch .tensor ([[1 , 2 , 3 ]]),
186219 "_packed_seq_ids" : torch .tensor ([[1 , 1 , 2 ]]),
220+ "_cp_manual_allgather" : True ,
187221 }
188222
189223 _cu .make_cp_batch_and_ctx (device_mesh , batch , padding_token_id = 99 )
0 commit comments