@@ -1410,6 +1410,148 @@ def fake_build_labels(input_ids, conversations, processor_arg):
14101410 assert (batch ["input_ids" ] == MEDIA_TOKEN_ID ).sum ().item () == 0
14111411
14121412
1413+ def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_text_only_mix (
1414+ collate_mod , monkeypatch
1415+ ):
1416+ """Mixed batch (text-only + image): n_images_per_sample length must equal batch_size.
1417+
1418+ Regression: previously image_counts was derived from all_grid_thws only, so
1419+ text-only samples were skipped and the resulting tensor was shorter than
1420+ batch_size. Downstream PP _chunk_vlm_media indexes cumsum_images by
1421+ sample index and would IndexError out of bounds.
1422+ """
1423+ MEDIA_TOKEN_ID = 163605
1424+
1425+ class MixedProcessor :
1426+ def __init__ (self ):
1427+ self .tokenizer = DummyTokenizer (pad_token_id = 0 )
1428+ self .media_placeholder_token_id = MEDIA_TOKEN_ID
1429+
1430+ def apply_chat_template (self , conversation , ** kwargs ):
1431+ return "chat:processed"
1432+
1433+ def __call__ (self , * , text , return_tensors , medias = None , ** kwargs ):
1434+ if medias :
1435+ input_ids = torch .tensor ([[1 , 2 , MEDIA_TOKEN_ID , 3 , 4 ]])
1436+ attention_mask = torch .ones_like (input_ids )
1437+ return {
1438+ "input_ids" : input_ids ,
1439+ "attention_mask" : attention_mask ,
1440+ "grid_thws" : torch .tensor ([[1 , 4 , 4 ]]),
1441+ "pixel_values" : torch .randn (1 , 3 , 14 , 14 ),
1442+ }
1443+ input_ids = torch .tensor ([[10 , 11 , 12 , 13 , 14 ]])
1444+ attention_mask = torch .ones_like (input_ids )
1445+ return {"input_ids" : input_ids , "attention_mask" : attention_mask }
1446+
1447+ processor = MixedProcessor ()
1448+
1449+ def fake_build_labels (input_ids , conversations , processor_arg ):
1450+ batch_size , seq_len = input_ids .shape
1451+ return torch .arange (seq_len ).unsqueeze (0 ).repeat (batch_size , 1 )
1452+
1453+ monkeypatch .setattr (collate_mod , "build_labels_from_template" , fake_build_labels , raising = True )
1454+
1455+ text_only = [
1456+ {"role" : "user" , "content" : [{"type" : "text" , "text" : "Hi" }]},
1457+ {"role" : "assistant" , "content" : [{"type" : "text" , "text" : "Hello" }]},
1458+ ]
1459+ with_image = [
1460+ {"role" : "user" , "content" : [{"type" : "image" , "image" : "x.jpg" }, {"type" : "text" , "text" : "What?" }]},
1461+ {"role" : "assistant" , "content" : [{"type" : "text" , "text" : "Cat." }]},
1462+ ]
1463+ examples = [{"conversation" : text_only }, {"conversation" : with_image }]
1464+
1465+ batch = collate_mod .kimi_k25_vl_collate_fn (examples , processor )
1466+
1467+ assert "n_images_per_sample" in batch
1468+ assert batch ["n_images_per_sample" ].shape == (2 ,), (
1469+ f"n_images_per_sample length must equal batch_size=2, "
1470+ f"got shape { batch ['n_images_per_sample' ].shape } "
1471+ )
1472+ # text-only sample → 0; image sample → 1
1473+ assert batch ["n_images_per_sample" ].tolist () == [0 , 1 ]
1474+
1475+
1476+ def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_truncation_orphan (
1477+ collate_mod , monkeypatch
1478+ ):
1479+ """Mixed batch (truncated image + intact image): n_images_per_sample length must equal batch_size.
1480+
1481+ Regression: a sample whose image region got orphaned by truncation was
1482+ correctly excluded from all_grid_thws but still kept in all_expanded.
1483+ Without the fix, n_images_per_sample length would be smaller than the
1484+ final batch and downstream PP indexing would crash.
1485+ """
1486+ MEDIA_TOKEN_ID = 163605
1487+
1488+ class MaybeOrphanProcessor :
1489+ """Returns the same large grid for both calls; the second call's tokens
1490+ will be truncated past the image region by max_length below."""
1491+
1492+ def __init__ (self ):
1493+ self .tokenizer = DummyTokenizer (pad_token_id = 0 )
1494+ self .media_placeholder_token_id = MEDIA_TOKEN_ID
1495+ self ._call_idx = 0
1496+
1497+ def apply_chat_template (self , conversation , ** kwargs ):
1498+ return "chat:processed"
1499+
1500+ def __call__ (self , * , text , return_tensors , medias = None , ** kwargs ):
1501+ self ._call_idx += 1
1502+ if self ._call_idx == 1 :
1503+ # Small grid that fits within max_length after expansion
1504+ input_ids = torch .tensor ([[1 , 2 , MEDIA_TOKEN_ID , 3 , 4 ]])
1505+ attention_mask = torch .ones_like (input_ids )
1506+ grid_thws = torch .tensor ([[1 , 4 , 4 ]]) # 4 image tokens
1507+ return {
1508+ "input_ids" : input_ids ,
1509+ "attention_mask" : attention_mask ,
1510+ "grid_thws" : grid_thws ,
1511+ "pixel_values" : torch .randn (1 , 3 , 14 , 14 ),
1512+ }
1513+ # Second sample: 5 text + 16 image tokens = 21 post-expansion;
1514+ # max_length=15 truncates into the image region → orphan path.
1515+ input_ids = torch .tensor ([[1 , 2 , MEDIA_TOKEN_ID , 3 , 4 , 5 ]])
1516+ attention_mask = torch .ones_like (input_ids )
1517+ grid_thws = torch .tensor ([[1 , 8 , 8 ]]) # 16 image tokens after expansion
1518+ return {
1519+ "input_ids" : input_ids ,
1520+ "attention_mask" : attention_mask ,
1521+ "grid_thws" : grid_thws ,
1522+ "pixel_values" : torch .randn (1 , 3 , 64 , 64 ),
1523+ }
1524+
1525+ processor = MaybeOrphanProcessor ()
1526+
1527+ def fake_build_labels (input_ids , conversations , processor_arg ):
1528+ batch_size , seq_len = input_ids .shape
1529+ return torch .arange (seq_len ).unsqueeze (0 ).repeat (batch_size , 1 )
1530+
1531+ monkeypatch .setattr (collate_mod , "build_labels_from_template" , fake_build_labels , raising = True )
1532+
1533+ conv_intact = [
1534+ {"role" : "user" , "content" : [{"type" : "image" , "image" : "a.jpg" }, {"type" : "text" , "text" : "?" }]},
1535+ {"role" : "assistant" , "content" : [{"type" : "text" , "text" : "." }]},
1536+ ]
1537+ conv_orphan = [
1538+ {"role" : "user" , "content" : [{"type" : "image" , "image" : "b.jpg" }, {"type" : "text" , "text" : "?" }]},
1539+ {"role" : "assistant" , "content" : [{"type" : "text" , "text" : "." }]},
1540+ ]
1541+ examples = [{"conversation" : conv_intact }, {"conversation" : conv_orphan }]
1542+
1543+ batch = collate_mod .kimi_k25_vl_collate_fn (examples , processor , max_length = 15 )
1544+
1545+ assert batch ["input_ids" ].shape [0 ] == 2
1546+ assert "n_images_per_sample" in batch
1547+ assert batch ["n_images_per_sample" ].shape == (2 ,), (
1548+ f"n_images_per_sample length must equal batch_size=2, "
1549+ f"got shape { batch ['n_images_per_sample' ].shape } "
1550+ )
1551+ # First sample's image survives → 1; second sample is orphaned → 0
1552+ assert batch ["n_images_per_sample" ].tolist () == [1 , 0 ]
1553+
1554+
14131555def test_kimi_k25_vl_collate_fn_multiple_examples (collate_mod , monkeypatch ):
14141556 """Test kimi_k25_vl_collate_fn handles multiple examples with padding."""
14151557 # Processor that produces variable length sequences
0 commit comments