Skip to content

Commit 8c3615e

Browse files
authored
fix(vlm): align n_images_per_sample with batch_size in kimi_k25 collate (#2175)
fix(vlm): keep n_images_per_sample length aligned with batch_size in kimi_k25 In kimi_k25_vl_collate_fn, n_images_per_sample was derived from all_grid_thws only, which is conditionally appended: * text-only samples have no grid_thws → not appended * samples whose image region got orphaned by truncation (when drop_overlong=False, the default) also skip the append, but the sample itself stays in all_expanded with image tokens replaced by pad Result: len(n_images_per_sample) < batch_size on mixed batches, while input_ids has shape [batch_size, max_len]. Downstream PP _chunk_vlm_media indexes cumsum_images by sample index up to batch_size and raises IndexError when the cumsum is shorter. Track per-sample image count in lockstep with all_expanded (zeros for text-only and orphaned samples) so the resulting tensor has length batch_size in all cases. No behavior change for batches where every sample has an intact image, since the per-sample count then equals what the old derivation produced. Adds two regression tests covering (1) text-only + image mixed batch and (2) intact-image + truncation-orphaned mixed batch. Signed-off-by: khazic <khazzz1c@gmail.com>
1 parent 41786e2 commit 8c3615e

2 files changed

Lines changed: 153 additions & 3 deletions

File tree

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,10 @@ def kimi_k25_vl_collate_fn(
871871
all_expanded = []
872872
all_pixel_values = []
873873
all_grid_thws = []
874+
# Per-sample image counts, kept in lockstep with all_expanded so that
875+
# n_images_per_sample length matches batch_size downstream. Samples that
876+
# are text-only or whose image region was orphaned by truncation get 0.
877+
per_sample_image_count: List[int] = []
874878

875879
for i, conversation in enumerate(conversations):
876880
# Collect medias for this conversation
@@ -923,12 +927,14 @@ def kimi_k25_vl_collate_fn(
923927

924928
# Only include image data if all expanded image tokens survived truncation.
925929
# Partial truncation into image regions would cause a mismatch in the model forward.
930+
sample_image_count = 0
926931
if grid_thws is not None:
927932
merge_h, merge_w = _DEFAULT_MERGE_KERNEL
928933
expected_image_tokens = sum(int((h // merge_h) * (w // merge_w)) for _, h, w in grid_thws.tolist())
929934
actual_image_tokens = (input_ids == media_token_id).sum().item()
930935
if actual_image_tokens == expected_image_tokens:
931936
all_grid_thws.append(grid_thws)
937+
sample_image_count = int(grid_thws.shape[0])
932938
if "pixel_values" in sample_batch:
933939
all_pixel_values.append(sample_batch["pixel_values"])
934940
else:
@@ -943,6 +949,7 @@ def kimi_k25_vl_collate_fn(
943949
"attention_mask": attention_mask,
944950
}
945951
)
952+
per_sample_image_count.append(sample_image_count)
946953

947954
if not all_expanded:
948955
raise ValueError(
@@ -990,9 +997,10 @@ def kimi_k25_vl_collate_fn(
990997
result["grid_thws"] = torch.cat(all_grid_thws, dim=0)
991998
# Also add as image_grid_hws for PP chunking in finetune.py
992999
result["image_grid_hws"] = result["grid_thws"][:, 1:] # [N, 3] -> [N, 2] (drop temporal dim, keep H,W)
993-
# Per-sample image counts for PP chunking
994-
image_counts = [g.shape[0] for g in all_grid_thws]
995-
result["n_images_per_sample"] = torch.tensor(image_counts, dtype=torch.long)
1000+
# Per-sample image counts for PP chunking. Length must equal batch_size,
1001+
# so include zeros for text-only samples and for samples whose image
1002+
# region was orphaned by truncation.
1003+
result["n_images_per_sample"] = torch.tensor(per_sample_image_count, dtype=torch.long)
9961004

9971005
# Build labels
9981006
labels = build_labels_from_template(

tests/unit_tests/datasets/vlm/test_collate_fns.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,148 @@ def fake_build_labels(input_ids, conversations, processor_arg):
14101410
assert (batch["input_ids"] == MEDIA_TOKEN_ID).sum().item() == 0
14111411

14121412

1413+
def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_text_only_mix(
1414+
collate_mod, monkeypatch
1415+
):
1416+
"""Mixed batch (text-only + image): n_images_per_sample length must equal batch_size.
1417+
1418+
Regression: previously image_counts was derived from all_grid_thws only, so
1419+
text-only samples were skipped and the resulting tensor was shorter than
1420+
batch_size. Downstream PP _chunk_vlm_media indexes cumsum_images by
1421+
sample index and would IndexError out of bounds.
1422+
"""
1423+
MEDIA_TOKEN_ID = 163605
1424+
1425+
class MixedProcessor:
1426+
def __init__(self):
1427+
self.tokenizer = DummyTokenizer(pad_token_id=0)
1428+
self.media_placeholder_token_id = MEDIA_TOKEN_ID
1429+
1430+
def apply_chat_template(self, conversation, **kwargs):
1431+
return "chat:processed"
1432+
1433+
def __call__(self, *, text, return_tensors, medias=None, **kwargs):
1434+
if medias:
1435+
input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4]])
1436+
attention_mask = torch.ones_like(input_ids)
1437+
return {
1438+
"input_ids": input_ids,
1439+
"attention_mask": attention_mask,
1440+
"grid_thws": torch.tensor([[1, 4, 4]]),
1441+
"pixel_values": torch.randn(1, 3, 14, 14),
1442+
}
1443+
input_ids = torch.tensor([[10, 11, 12, 13, 14]])
1444+
attention_mask = torch.ones_like(input_ids)
1445+
return {"input_ids": input_ids, "attention_mask": attention_mask}
1446+
1447+
processor = MixedProcessor()
1448+
1449+
def fake_build_labels(input_ids, conversations, processor_arg):
1450+
batch_size, seq_len = input_ids.shape
1451+
return torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
1452+
1453+
monkeypatch.setattr(collate_mod, "build_labels_from_template", fake_build_labels, raising=True)
1454+
1455+
text_only = [
1456+
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
1457+
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]},
1458+
]
1459+
with_image = [
1460+
{"role": "user", "content": [{"type": "image", "image": "x.jpg"}, {"type": "text", "text": "What?"}]},
1461+
{"role": "assistant", "content": [{"type": "text", "text": "Cat."}]},
1462+
]
1463+
examples = [{"conversation": text_only}, {"conversation": with_image}]
1464+
1465+
batch = collate_mod.kimi_k25_vl_collate_fn(examples, processor)
1466+
1467+
assert "n_images_per_sample" in batch
1468+
assert batch["n_images_per_sample"].shape == (2,), (
1469+
f"n_images_per_sample length must equal batch_size=2, "
1470+
f"got shape {batch['n_images_per_sample'].shape}"
1471+
)
1472+
# text-only sample → 0; image sample → 1
1473+
assert batch["n_images_per_sample"].tolist() == [0, 1]
1474+
1475+
1476+
def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_truncation_orphan(
1477+
collate_mod, monkeypatch
1478+
):
1479+
"""Mixed batch (truncated image + intact image): n_images_per_sample length must equal batch_size.
1480+
1481+
Regression: a sample whose image region got orphaned by truncation was
1482+
correctly excluded from all_grid_thws but still kept in all_expanded.
1483+
Without the fix, n_images_per_sample length would be smaller than the
1484+
final batch and downstream PP indexing would crash.
1485+
"""
1486+
MEDIA_TOKEN_ID = 163605
1487+
1488+
class MaybeOrphanProcessor:
1489+
"""Returns the same large grid for both calls; the second call's tokens
1490+
will be truncated past the image region by max_length below."""
1491+
1492+
def __init__(self):
1493+
self.tokenizer = DummyTokenizer(pad_token_id=0)
1494+
self.media_placeholder_token_id = MEDIA_TOKEN_ID
1495+
self._call_idx = 0
1496+
1497+
def apply_chat_template(self, conversation, **kwargs):
1498+
return "chat:processed"
1499+
1500+
def __call__(self, *, text, return_tensors, medias=None, **kwargs):
1501+
self._call_idx += 1
1502+
if self._call_idx == 1:
1503+
# Small grid that fits within max_length after expansion
1504+
input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4]])
1505+
attention_mask = torch.ones_like(input_ids)
1506+
grid_thws = torch.tensor([[1, 4, 4]]) # 4 image tokens
1507+
return {
1508+
"input_ids": input_ids,
1509+
"attention_mask": attention_mask,
1510+
"grid_thws": grid_thws,
1511+
"pixel_values": torch.randn(1, 3, 14, 14),
1512+
}
1513+
# Second sample: 5 text + 16 image tokens = 21 post-expansion;
1514+
# max_length=15 truncates into the image region → orphan path.
1515+
input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4, 5]])
1516+
attention_mask = torch.ones_like(input_ids)
1517+
grid_thws = torch.tensor([[1, 8, 8]]) # 16 image tokens after expansion
1518+
return {
1519+
"input_ids": input_ids,
1520+
"attention_mask": attention_mask,
1521+
"grid_thws": grid_thws,
1522+
"pixel_values": torch.randn(1, 3, 64, 64),
1523+
}
1524+
1525+
processor = MaybeOrphanProcessor()
1526+
1527+
def fake_build_labels(input_ids, conversations, processor_arg):
1528+
batch_size, seq_len = input_ids.shape
1529+
return torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
1530+
1531+
monkeypatch.setattr(collate_mod, "build_labels_from_template", fake_build_labels, raising=True)
1532+
1533+
conv_intact = [
1534+
{"role": "user", "content": [{"type": "image", "image": "a.jpg"}, {"type": "text", "text": "?"}]},
1535+
{"role": "assistant", "content": [{"type": "text", "text": "."}]},
1536+
]
1537+
conv_orphan = [
1538+
{"role": "user", "content": [{"type": "image", "image": "b.jpg"}, {"type": "text", "text": "?"}]},
1539+
{"role": "assistant", "content": [{"type": "text", "text": "."}]},
1540+
]
1541+
examples = [{"conversation": conv_intact}, {"conversation": conv_orphan}]
1542+
1543+
batch = collate_mod.kimi_k25_vl_collate_fn(examples, processor, max_length=15)
1544+
1545+
assert batch["input_ids"].shape[0] == 2
1546+
assert "n_images_per_sample" in batch
1547+
assert batch["n_images_per_sample"].shape == (2,), (
1548+
f"n_images_per_sample length must equal batch_size=2, "
1549+
f"got shape {batch['n_images_per_sample'].shape}"
1550+
)
1551+
# First sample's image survives → 1; second sample is orphaned → 0
1552+
assert batch["n_images_per_sample"].tolist() == [1, 0]
1553+
1554+
14131555
def test_kimi_k25_vl_collate_fn_multiple_examples(collate_mod, monkeypatch):
14141556
"""Test kimi_k25_vl_collate_fn handles multiple examples with padding."""
14151557
# Processor that produces variable length sequences

0 commit comments

Comments
 (0)