@@ -541,5 +541,66 @@ def test_with_partition_filter_rejects_non_partition_field(self):
541541 self .assertIn ("non-partition" , str (ctx .exception ))
542542
543543
544+ class VectorSearchManySplitsTest (unittest .TestCase ):
545+
546+ def test_vector_search_with_many_splits (self ):
547+ from pypaimon .globalindex .vector_search_result import (
548+ DictBasedScoredIndexResult ,
549+ )
550+ from pypaimon .table .source .vector_search_read import VectorSearchReadImpl
551+ from pypaimon .table .source .vector_search_split import VectorSearchSplit
552+
553+ num_splits = 1200
554+ embedding_field = _field (1 , "embedding" , "FLOAT" )
555+ entries = [
556+ _entry (None , field_id = 1 , index_type = "lumina-vector-ann" ,
557+ file_name = "vec-%d.index" % i ,
558+ row_range_start = i , row_range_end = i )
559+ for i in range (num_splits )
560+ ]
561+ table = _StubTable (fields = [embedding_field ], entries = entries )
562+ _patch_snapshot (self , entries )
563+
564+ def _fake_create (index_type , file_io , index_path ,
565+ index_io_meta_list , options = None ):
566+ row_id = index_io_meta_list [0 ].file_name
567+ row_id = int (row_id .split ("-" )[1 ].split ("." )[0 ])
568+
569+ class _FakeReader :
570+ def visit_vector_search (self_inner , vs ):
571+ return DictBasedScoredIndexResult ({row_id : float (row_id )})
572+
573+ def close (self_inner ):
574+ pass
575+
576+ def __enter__ (self_inner ):
577+ return self_inner
578+
579+ def __exit__ (self_inner , * a ):
580+ return False
581+ return _FakeReader ()
582+
583+ splits = [
584+ VectorSearchSplit (
585+ row_range_start = i , row_range_end = i ,
586+ vector_index_files = [entries [i ].index_file ])
587+ for i in range (num_splits )
588+ ]
589+
590+ with mock .patch (
591+ "pypaimon.table.source.vector_search_read._create_vector_reader" ,
592+ side_effect = _fake_create ):
593+ reader = VectorSearchReadImpl (
594+ table , limit = 10 , vector_column = embedding_field ,
595+ query_vector = [1.0 ], filter_ = None )
596+ result = reader .read (splits )
597+
598+ self .assertGreater (result .results ().cardinality (), 0 )
599+ self .assertIsNotNone (result .score_getter ()(0 ))
600+
601+ def tearDown (self ):
602+ mock .patch .stopall ()
603+
604+
544605if __name__ == "__main__" :
545606 unittest .main ()
0 commit comments