feat(search): add multi-query batch search support#1685
feat(search): add multi-query batch search support#1685LHT129 wants to merge 1 commit intoantgroup:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the search capabilities of the system by introducing multi-query batch processing. This allows users to submit multiple query vectors in a single request, streamlining operations and potentially improving efficiency for applications requiring concurrent searches. The changes involve refactoring core search methods in HGraph, SparseIndex, and SINDI to iterate over multiple queries and concatenate their results, while ensuring backward compatibility for single-query operations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Pull request overview
Adds multi-query (batch) support to search operations by allowing query datasets with multiple elements and concatenating results across queries.
Changes:
- Batch-enable sparse (SparseIndex, SINDI) KNN and range search by looping over query elements
- Batch-enable HGraph
RangeSearchandSearchWithRequestwith per-query iteration and concatenated outputs - Update argument checks / messages to reflect single-query vs multi-query constraints
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| src/algorithm/sparse_index.cpp | Removes single-query restriction and concatenates results across multiple sparse queries |
| src/algorithm/sindi/sindi.cpp | Adds per-query loop for sparse SINDI KNN/range search and concatenates results |
| src/algorithm/hgraph.cpp | Enables batch range search and batch SearchWithRequest; keeps iterator-based KNN single-query |
| for (int64_t q_idx = 0; q_idx < query_count; ++q_idx) { | ||
| auto results = std::make_shared<StandardHeap<true, false>>(allocator_, -1); | ||
| auto [sorted_ids, sorted_vals] = sort_sparse_vector(sparse_vectors[q_idx]); | ||
| for (int j = 0; j < cur_element_count_; ++j) { | ||
| auto distance = CalDistanceByIdUnsafe(sorted_ids, sorted_vals, j); | ||
| auto label = label_table_->GetLabelById(j); | ||
| if (not filter || filter->CheckValid(label)) { | ||
| results->Push(distance, label); | ||
| if (results->Size() > k) { | ||
| results->Pop(); | ||
| } | ||
| } | ||
| } | ||
| for (auto j = static_cast<int64_t>(results->Size() - 1); j >= 0; --j) { | ||
| all_dists.push_back(results->Top().first); | ||
| all_ids.push_back(results->Top().second); | ||
| results->Pop(); | ||
| } | ||
| } |
There was a problem hiding this comment.
static_cast<int64_t>(results->Size() - 1) can underflow when results->Size() == 0 because the subtraction happens before the cast (likely on an unsigned type). This can produce an implementation-defined large value and lead to out-of-bounds behavior (or an unexpectedly long loop). Fix by avoiding Size() - 1 on an unsigned value (e.g., cast before subtract, use for (int64_t j = static_cast<int64_t>(results->Size()); j-- > 0;), or just while (!results->Empty()) { ... }). Also, the loop variable j is unused—if the intent was to reverse heap order, write into a pre-sized buffer using j as the index.
| int64_t total_count = static_cast<int64_t>(all_dists.size()); | ||
| auto [result, dists, ids] = create_fast_dataset(total_count, allocator_); | ||
| if (total_count == 0) { | ||
| result->Dim(0)->NumElements(query_count); | ||
| return result; | ||
| } | ||
|
|
||
| for (int64_t j = 0; j < total_count; ++j) { | ||
| dists[j] = all_dists[j]; | ||
| ids[j] = all_ids[j]; | ||
| } | ||
| // return result | ||
| return collect_results(results); | ||
| return result; |
There was a problem hiding this comment.
The PR description states results will be concatenated as N queries × k results = N*k elements, but this code sizes the output to all_dists.size() (which can be < query_count * k when filters exclude items or when fewer than k are collected). If the contract is fixed-size N*k, consider allocating query_count * k and padding missing entries (similar to HGraph::SearchWithRequest). If the contract is variable-size, the PR description (and any downstream consumer expectations) should be updated accordingly.
| for (auto j = static_cast<int64_t>(results->Size() - 1); j >= 0; --j) { | ||
| all_dists.push_back(results->Top().first); | ||
| all_ids.push_back(results->Top().second); | ||
| results->Pop(); | ||
| } |
There was a problem hiding this comment.
Same underflow risk as in KnnSearch: results->Size() - 1 can underflow when Size() == 0 (subtraction occurs before the cast). Use a safe reverse loop pattern or while (!results->Empty()) to drain the heap.
| Vector<float> all_dists(allocator_); | ||
| Vector<int64_t> all_ids(allocator_); |
There was a problem hiding this comment.
This function takes vsag::Allocator* allocator but allocates intermediate vectors and the returned dataset using allocator_ (member) instead of the provided allocator. That can break caller expectations (e.g., arena lifetimes, tracking, thread-local allocators) and can also mismatch the allocator used inside search_impl(...) (which still receives allocator). Allocate the output dataset (and ideally intermediate buffers) with the allocator parameter for consistency.
| int64_t total_count = static_cast<int64_t>(all_dists.size()); | ||
| auto [dataset_results, dists, ids] = create_fast_dataset(total_count, allocator_); | ||
| for (int64_t j = 0; j < total_count; ++j) { | ||
| dists[j] = all_dists[j]; | ||
| ids[j] = all_ids[j]; | ||
| } | ||
| return dataset_results; |
There was a problem hiding this comment.
This function takes vsag::Allocator* allocator but allocates intermediate vectors and the returned dataset using allocator_ (member) instead of the provided allocator. That can break caller expectations (e.g., arena lifetimes, tracking, thread-local allocators) and can also mismatch the allocator used inside search_impl(...) (which still receives allocator). Allocate the output dataset (and ideally intermediate buffers) with the allocator parameter for consistency.
| Vector<int64_t> all_counts(allocator_); | ||
| Vector<float> all_dists(allocator_); | ||
| Vector<int64_t> all_ids(allocator_); | ||
| Vector<char> all_extra_infos(allocator_); |
There was a problem hiding this comment.
RangeSearch now allocates extra_infos but never populates it, so callers may observe uninitialized extra info data (regression vs the previous per-result fill). Also all_counts is collected but never attached to the returned dataset, making batch range results ambiguous (without per-query counts/offsets, consumers can’t reliably segment concatenated results). Fix by (1) collecting per-result extra infos alongside all_dists/all_ids and copying into the allocated extra_infos buffer, and (2) encoding per-query counts/offsets into the returned dataset in the same way single-query range search previously did (or by extending the dataset metadata to include all_counts).
| Vector<char> all_extra_infos(allocator_); |
| int64_t total_count = static_cast<int64_t>(all_dists.size()); | ||
| auto [dataset_results, dists, ids] = create_fast_dataset(total_count, allocator_); | ||
| char* extra_infos = nullptr; | ||
| if (extra_info_size_ > 0) { | ||
| extra_infos = | ||
| static_cast<char*>(allocator_->Allocate(extra_info_size_ * search_result->Size())); | ||
| extra_infos = static_cast<char*>(allocator_->Allocate(extra_info_size_ * total_count)); | ||
| dataset_results->ExtraInfos(extra_infos); | ||
| } | ||
| for (int64_t j = count - 1; j >= 0; --j) { | ||
| dists[j] = search_result->Top().first; | ||
| ids[j] = this->label_table_->GetLabelById(search_result->Top().second); | ||
| if (extra_infos != nullptr) { | ||
| this->extra_infos_->GetExtraInfoById(search_result->Top().second, | ||
| extra_infos + extra_info_size_ * j); | ||
| } | ||
| search_result->Pop(); | ||
|
|
||
| for (int64_t j = 0; j < total_count; ++j) { | ||
| dists[j] = all_dists[j]; | ||
| ids[j] = all_ids[j]; | ||
| } |
There was a problem hiding this comment.
RangeSearch now allocates extra_infos but never populates it, so callers may observe uninitialized extra info data (regression vs the previous per-result fill). Also all_counts is collected but never attached to the returned dataset, making batch range results ambiguous (without per-query counts/offsets, consumers can’t reliably segment concatenated results). Fix by (1) collecting per-result extra infos alongside all_dists/all_ids and copying into the allocated extra_infos buffer, and (2) encoding per-query counts/offsets into the returned dataset in the same way single-query range search previously did (or by extending the dataset metadata to include all_counts).
| int64_t total_result_count = query_count * k; | ||
| auto [dataset_results, dists, ids] = create_fast_dataset(total_result_count, ctx.alloc); |
There was a problem hiding this comment.
SearchWithRequest previously returned an empty dataset when the search produced no results; it now always returns query_count * k elements and pads missing entries with -1. That is a behavioral change for single-query calls and conflicts with the PR’s “Backward compatible (single query unchanged)” claim. If fixed-size output is the new intended contract, it should be reflected consistently across all KnnSearch APIs (including sparse/SINDI) and documented; otherwise, consider preserving the old behavior at least for query_count == 1 (or when count == 0) to avoid breaking existing consumers.
| int64_t offset = q_idx * k; | ||
| auto count = static_cast<int64_t>(search_result->Size()); | ||
| for (int64_t j = count - 1; j >= 0; --j) { | ||
| dists[offset + j] = search_result->Top().first; | ||
| ids[offset + j] = this->label_table_->GetLabelById(search_result->Top().second); |
There was a problem hiding this comment.
SearchWithRequest previously returned an empty dataset when the search produced no results; it now always returns query_count * k elements and pads missing entries with -1. That is a behavioral change for single-query calls and conflicts with the PR’s “Backward compatible (single query unchanged)” claim. If fixed-size output is the new intended contract, it should be reflected consistently across all KnnSearch APIs (including sparse/SINDI) and documented; otherwise, consider preserving the old behavior at least for query_count == 1 (or when count == 0) to avoid breaking existing consumers.
| for (int64_t j = count; j < k; ++j) { | ||
| dists[offset + j] = -1.0F; | ||
| ids[offset + j] = -1; | ||
| if (extra_infos != nullptr) { | ||
| memset(extra_infos + extra_info_size_ * (offset + j), 0, extra_info_size_); | ||
| } | ||
| } |
There was a problem hiding this comment.
SearchWithRequest previously returned an empty dataset when the search produced no results; it now always returns query_count * k elements and pads missing entries with -1. That is a behavioral change for single-query calls and conflicts with the PR’s “Backward compatible (single query unchanged)” claim. If fixed-size output is the new intended contract, it should be reflected consistently across all KnnSearch APIs (including sparse/SINDI) and documented; otherwise, consider preserving the old behavior at least for query_count == 1 (or when count == 0) to avoid breaking existing consumers.
There was a problem hiding this comment.
Code Review
This pull request introduces multi-query batch search support across several components. The changes in HGraph::SearchWithRequest and sindi.cpp are well-implemented. However, I've identified some issues in HGraph::RangeSearch and SparseIndex's search methods related to result ordering and data handling. Specifically, HGraph::RangeSearch fails to populate extra information for results, and both it and the SparseIndex search methods produce results in an incorrect order. I've provided detailed comments and suggestions to address these points.
| Vector<int64_t> all_counts(allocator_); | ||
| Vector<float> all_dists(allocator_); | ||
| Vector<int64_t> all_ids(allocator_); | ||
| Vector<char> all_extra_infos(allocator_); | ||
|
|
||
| if (use_reorder_) { | ||
| this->reorder( | ||
| raw_query, this->high_precise_codes_, search_result, limited_size, nullptr, ctx); | ||
| } | ||
| for (int64_t q_idx = 0; q_idx < query_count; ++q_idx) { | ||
| const auto* raw_query = get_data(query, q_idx); | ||
|
|
||
| InnerSearchParam search_param; | ||
| search_param.ep = this->entry_point_id_; | ||
| search_param.topk = 1; | ||
| search_param.ef = 1; | ||
| for (auto i = static_cast<int64_t>(this->route_graphs_.size() - 1); i >= 0; --i) { | ||
| auto result = this->search_one_graph(raw_query, | ||
| this->route_graphs_[i], | ||
| this->basic_flatten_codes_, | ||
| search_param, | ||
| (VisitedListPtr) nullptr, | ||
| &ctx); | ||
| search_param.ep = result->Top().second; | ||
| } | ||
|
|
||
| search_param.ef = std::max(params.ef_search, limited_size); | ||
| search_param.is_inner_id_allowed = ft; | ||
| search_param.radius = radius; | ||
| search_param.search_mode = RANGE_SEARCH; | ||
| search_param.consider_duplicate = true; | ||
| search_param.range_search_limit_size = static_cast<int>(limited_size); | ||
| search_param.parallel_search_thread_count = params.parallel_search_thread_count; | ||
|
|
||
| auto search_result = this->search_one_graph(raw_query, | ||
| this->bottom_graph_, | ||
| this->basic_flatten_codes_, | ||
| search_param, | ||
| (VisitedListPtr) nullptr, | ||
| &ctx); | ||
|
|
||
| if (use_reorder_) { | ||
| this->reorder( | ||
| raw_query, this->high_precise_codes_, search_result, limited_size, nullptr, ctx); | ||
| } | ||
|
|
||
| if (limited_size > 0) { | ||
| while (search_result->Size() > limited_size) { | ||
| search_result->Pop(); | ||
| } | ||
| } | ||
|
|
||
| if (limited_size > 0) { | ||
| while (search_result->Size() > limited_size) { | ||
| auto count = static_cast<const int64_t>(search_result->Size()); | ||
| all_counts.push_back(count); | ||
|
|
||
| for (int64_t j = count - 1; j >= 0; --j) { | ||
| all_dists.push_back(search_result->Top().first); | ||
| all_ids.push_back(this->label_table_->GetLabelById(search_result->Top().second)); | ||
| search_result->Pop(); | ||
| } | ||
| } | ||
|
|
||
| auto count = static_cast<const int64_t>(search_result->Size()); | ||
| auto [dataset_results, dists, ids] = create_fast_dataset(count, allocator_); | ||
| int64_t total_count = static_cast<int64_t>(all_dists.size()); | ||
| auto [dataset_results, dists, ids] = create_fast_dataset(total_count, allocator_); | ||
| char* extra_infos = nullptr; | ||
| if (extra_info_size_ > 0) { | ||
| extra_infos = | ||
| static_cast<char*>(allocator_->Allocate(extra_info_size_ * search_result->Size())); | ||
| extra_infos = static_cast<char*>(allocator_->Allocate(extra_info_size_ * total_count)); | ||
| dataset_results->ExtraInfos(extra_infos); | ||
| } | ||
| for (int64_t j = count - 1; j >= 0; --j) { | ||
| dists[j] = search_result->Top().first; | ||
| ids[j] = this->label_table_->GetLabelById(search_result->Top().second); | ||
| if (extra_infos != nullptr) { | ||
| this->extra_infos_->GetExtraInfoById(search_result->Top().second, | ||
| extra_infos + extra_info_size_ * j); | ||
| } | ||
| search_result->Pop(); | ||
|
|
||
| for (int64_t j = 0; j < total_count; ++j) { | ||
| dists[j] = all_dists[j]; | ||
| ids[j] = all_ids[j]; | ||
| } |
There was a problem hiding this comment.
The implementation for batch range search has a few issues:
- Missing Extra Info: The
extra_infosbuffer is allocated but never populated with data. This is a critical bug as it will lead to incorrect or empty extra information in the search results. - Incorrect Result Order: The results for each query are collected in descending order of distance (worst first) because of how
push_backis used with the max-heap. This is inconsistent with the single-query behavior and the implementation inSearchWithRequest, which provide results sorted by ascending distance. - Unused Variables: The
all_countsandall_extra_infosvectors are declared but not used for their intended purpose, leading to dead code.
I suggest refactoring this section to fix these issues. A more efficient approach would be to resize the result vectors and fill them in the correct order, similar to how it's handled in SearchWithRequest but adapted for the variable number of results in a range search.
Vector<float> all_dists(allocator_);
Vector<int64_t> all_ids(allocator_);
Vector<char> all_extra_infos(allocator_);
for (int64_t q_idx = 0; q_idx < query_count; ++q_idx) {
const auto* raw_query = get_data(query, q_idx);
InnerSearchParam search_param;
search_param.ep = this->entry_point_id_;
search_param.topk = 1;
search_param.ef = 1;
for (auto i = static_cast<int64_t>(this->route_graphs_.size() - 1); i >= 0; --i) {
auto result = this->search_one_graph(raw_query,
this->route_graphs_[i],
this->basic_flatten_codes_,
search_param,
(VisitedListPtr) nullptr,
&ctx);
search_param.ep = result->Top().second;
}
search_param.ef = std::max(params.ef_search, limited_size);
search_param.is_inner_id_allowed = ft;
search_param.radius = radius;
search_param.search_mode = RANGE_SEARCH;
search_param.consider_duplicate = true;
search_param.range_search_limit_size = static_cast<int>(limited_size);
search_param.parallel_search_thread_count = params.parallel_search_thread_count;
auto search_result = this->search_one_graph(raw_query,
this->bottom_graph_,
this->basic_flatten_codes_,
search_param,
(VisitedListPtr) nullptr,
&ctx);
if (use_reorder_) {
this->reorder(
raw_query, this->high_precise_codes_, search_result, limited_size, nullptr, ctx);
}
if (limited_size > 0) {
while (search_result->Size() > limited_size) {
search_result->Pop();
}
}
auto count = static_cast<const int64_t>(search_result->Size());
if (count > 0) {
size_t current_offset = all_dists.size();
all_dists.resize(current_offset + count);
all_ids.resize(current_offset + count);
size_t current_extra_offset = all_extra_infos.size();
if (extra_info_size_ > 0) {
all_extra_infos.resize(current_extra_offset + count * extra_info_size_);
}
for (int64_t j = count - 1; j >= 0; --j) {
auto inner_id = search_result->Top().second;
all_dists[current_offset + j] = search_result->Top().first;
all_ids[current_offset + j] = this->label_table_->GetLabelById(inner_id);
if (extra_info_size_ > 0) {
this->extra_infos_->GetExtraInfoById(inner_id, all_extra_infos.data() + current_extra_offset + extra_info_size_ * j);
}
search_result->Pop();
}
}
}
int64_t total_count = static_cast<int64_t>(all_dists.size());
auto [dataset_results, dists, ids] = create_fast_dataset(total_count, allocator_);
char* extra_infos = nullptr;
if (extra_info_size_ > 0 && total_count > 0) {
extra_infos = static_cast<char*>(allocator_->Allocate(extra_info_size_ * total_count));
dataset_results->ExtraInfos(extra_infos);
memcpy(extra_infos, all_extra_infos.data(), total_count * extra_info_size_);
}
for (int64_t j = 0; j < total_count; ++j) {
dists[j] = all_dists[j];
ids[j] = all_ids[j];
}| for (auto j = static_cast<int64_t>(results->Size() - 1); j >= 0; --j) { | ||
| all_dists.push_back(results->Top().first); | ||
| all_ids.push_back(results->Top().second); | ||
| results->Pop(); | ||
| } |
There was a problem hiding this comment.
The current implementation for collecting results from the heap will store them in descending order of distance (worst to best) for each query. This is because you are using push_back while popping from a max-heap. This is inconsistent with the typical expectation of results sorted by ascending distance. I suggest resizing the vectors and filling them backwards to maintain the correct order.
| for (auto j = static_cast<int64_t>(results->Size() - 1); j >= 0; --j) { | |
| all_dists.push_back(results->Top().first); | |
| all_ids.push_back(results->Top().second); | |
| results->Pop(); | |
| } | |
| auto count = results->Size(); | |
| if (count > 0) { | |
| size_t current_offset = all_dists.size(); | |
| all_dists.resize(current_offset + count); | |
| all_ids.resize(current_offset + count); | |
| for (auto j = static_cast<int64_t>(count - 1); j >= 0; --j) { | |
| all_dists[current_offset + j] = results->Top().first; | |
| all_ids[current_offset + j] = results->Top().second; | |
| results->Pop(); | |
| } | |
| } |
| for (auto j = static_cast<int64_t>(results->Size() - 1); j >= 0; --j) { | ||
| all_dists.push_back(results->Top().first); | ||
| all_ids.push_back(results->Top().second); | ||
| results->Pop(); | ||
| } |
There was a problem hiding this comment.
The current implementation for collecting results from the heap will store them in descending order of distance (worst to best) for each query. This is because you are using push_back while popping from a max-heap. This is inconsistent with the typical expectation of results sorted by ascending distance. I suggest resizing the vectors and filling them backwards to maintain the correct order.
| for (auto j = static_cast<int64_t>(results->Size() - 1); j >= 0; --j) { | |
| all_dists.push_back(results->Top().first); | |
| all_ids.push_back(results->Top().second); | |
| results->Pop(); | |
| } | |
| auto count = results->Size(); | |
| if (count > 0) { | |
| size_t current_offset = all_dists.size(); | |
| all_dists.resize(current_offset + count); | |
| all_ids.resize(current_offset + count); | |
| for (auto j = static_cast<int64_t>(count - 1); j >= 0; --j) { | |
| all_dists[current_offset + j] = results->Top().first; | |
| all_ids[current_offset + j] = results->Top().second; | |
| results->Pop(); | |
| } | |
| } |
- Modify HGraph::SearchWithRequest to support multiple queries in a single call - Update HGraph::RangeSearch for batch range search - Update SparseIndex::KnnSearch and RangeSearch for multi-query - Update SINDI::KnnSearch and RangeSearch for multi-query - Iterator-based search remains single-query only due to state tracking - Results are concatenated: query0 results, query1 results, etc. This leverages existing DatasetPtr multi-element capability. Signed-off-by: LHT129 <tianlan.lht@antgroup.com> Co-authored-by: Kimi-K2.5 <assistant@example.com>
Summary
Add multi-query batch search support to VSAG search operations, enabling multiple query vectors in a single call by leveraging the existing DatasetPtr multi-element capability.
Changes
Technical Details
GetNumElements() == 1restrictions from search methodsget_data(query, q_idx)Files Changed
Testing
Related Issues
Checklist