diff --git a/contrib/tici b/contrib/tici index ee7809180a1..db0a4054f26 160000 --- a/contrib/tici +++ b/contrib/tici @@ -1 +1 @@ -Subproject commit ee7809180a1193d37f892960f7d4ae1580cea13c +Subproject commit db0a4054f26d115d21a3f215ea6fe71961041e9d diff --git a/contrib/tici-search-lib/CMakeLists.txt b/contrib/tici-search-lib/CMakeLists.txt index 413a1886d2c..427d85827b4 100644 --- a/contrib/tici-search-lib/CMakeLists.txt +++ b/contrib/tici-search-lib/CMakeLists.txt @@ -7,7 +7,7 @@ file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/cxxbridge) add_custom_command( OUTPUT ${TICI_LIB} - COMMAND cargo build -p tici-search-lib --release --target-dir ${CMAKE_CURRENT_BINARY_DIR} --manifest-path ${TICI_PROJECT_DIR}/Cargo.toml + COMMAND ${CMAKE_COMMAND} -E env "CC=gcc" cargo build -p tici-search-lib --release --target-dir ${CMAKE_CURRENT_BINARY_DIR} --manifest-path ${TICI_PROJECT_DIR}/Cargo.toml WORKING_DIRECTORY ${TICI_PROJECT_DIR} DEPENDS ${LIB_SOURCE_FILES} COMMENT "Build Rust lib"${CMAKE_CURRENT_BINARY_DIR} @@ -26,13 +26,16 @@ target_include_directories(tici_search_lib_static INTERFACE add_library(tici_search_lib SHARED "${TiFlash_SOURCE_DIR}/libs/libclara-cmake/dummy.cpp") target_compile_options(tici_search_lib PRIVATE -pthread) target_link_options(tici_search_lib PRIVATE -pthread) -target_link_libraries(tici_search_lib PRIVATE "$") if(APPLE) - target_link_libraries(tici_search_lib PRIVATE + target_link_options(tici_search_lib PRIVATE "LINKER:-force_load,${TICI_LIB}") + target_link_libraries(tici_search_lib PRIVATE tici_search_lib_static "-framework Security" "-framework CoreFoundation" "-framework IOKit" ) +else() + target_link_libraries(tici_search_lib PRIVATE + -Wl,--whole-archive tici_search_lib_static -Wl,--no-whole-archive) endif() target_include_directories(tici_search_lib INTERFACE diff --git a/contrib/tipb b/contrib/tipb index 1852f9829ce..3b13f136b03 160000 --- a/contrib/tipb +++ b/contrib/tipb @@ -1 +1 @@ -Subproject commit 1852f9829ce3d3962895fec43f908b31fbdc58fb +Subproject commit 3b13f136b03e82412a4d8426422f0beb1cf57d69 diff --git a/dbms/src/Flash/Coprocessor/TiCIScan.cpp b/dbms/src/Flash/Coprocessor/TiCIScan.cpp index 1e65076e67a..2b2e22b19e1 100644 --- a/dbms/src/Flash/Coprocessor/TiCIScan.cpp +++ b/dbms/src/Flash/Coprocessor/TiCIScan.cpp @@ -22,6 +22,15 @@ #include namespace DB { + +static TiCIQueryMode resolveQueryMode(const tipb::IndexScan & idx_scan) +{ + if (idx_scan.has_tici_vector_query_info()) + return TiCIQueryMode::Vector; + RUNTIME_CHECK_MSG(idx_scan.has_fts_query_info(), "IndexScan must have either fts_query_info or tici_vector_query_info"); + return TiCIQueryMode::FTS; +} + TiCIScan::TiCIScan(const tipb::Executor * tici_scan_, const String & executor_id_, const DAGContext & dag_context) : tici_scan(tici_scan_) , executor_id(executor_id_) @@ -29,15 +38,24 @@ TiCIScan::TiCIScan(const tipb::Executor * tici_scan_, const String & executor_id , table_id(tici_scan->idx_scan().table_id()) , index_id(tici_scan->idx_scan().index_id()) , return_columns(TiDB::toTiDBColumnInfos(tici_scan->idx_scan().columns())) - , query_type(tici_scan->idx_scan().fts_query_info().query_type()) + , query_mode(resolveQueryMode(tici_scan->idx_scan())) , shard_infos(dag_context.query_shard_infos.getTableShardInfosByExecutorID(tici_scan_->executor_id())) - , limit(tici_scan->idx_scan().fts_query_info().top_k()) + , limit( + query_mode == TiCIQueryMode::Vector + ? tici_scan->idx_scan().tici_vector_query_info().top_k() + : tici_scan->idx_scan().fts_query_info().top_k()) , sort_column_ids( - tici_scan->idx_scan().fts_query_info().sort_column_ids().begin(), - tici_scan->idx_scan().fts_query_info().sort_column_ids().end()) + query_mode == TiCIQueryMode::FTS + ? std::vector( + tici_scan->idx_scan().fts_query_info().sort_column_ids().begin(), + tici_scan->idx_scan().fts_query_info().sort_column_ids().end()) + : std::vector()) , sort_column_asc( - tici_scan->idx_scan().fts_query_info().sort_column_asc().begin(), - tici_scan->idx_scan().fts_query_info().sort_column_asc().end()) + query_mode == TiCIQueryMode::FTS + ? std::vector( + tici_scan->idx_scan().fts_query_info().sort_column_asc().begin(), + tici_scan->idx_scan().fts_query_info().sort_column_asc().end()) + : std::vector()) {} void TiCIScan::constructTiCIScanForRemoteRead(tipb::IndexScan * tipb_index_scan) const diff --git a/dbms/src/Flash/Coprocessor/TiCIScan.h b/dbms/src/Flash/Coprocessor/TiCIScan.h index a9229dcfb98..378086aee0d 100644 --- a/dbms/src/Flash/Coprocessor/TiCIScan.h +++ b/dbms/src/Flash/Coprocessor/TiCIScan.h @@ -21,6 +21,12 @@ namespace DB { class DAGContext; +enum class TiCIQueryMode +{ + FTS, + Vector, +}; + class TiCIScan { public: @@ -38,13 +44,22 @@ class TiCIScan const int & getLimit() const { return limit; } const tipb::Executor * getTiCIScan() const { return tici_scan; } + TiCIQueryMode getQueryMode() const { return query_mode; } + void constructTiCIScanForRemoteRead(tipb::IndexScan * tipb_index_scan) const; const ::google::protobuf::RepeatedPtrField<::tipb::Expr> & getMatchExpr() const { + RUNTIME_CHECK(query_mode == TiCIQueryMode::FTS); return tici_scan->idx_scan().fts_query_info().match_expr(); } + const tipb::TiCIVectorQueryInfo & getVectorQueryInfo() const + { + RUNTIME_CHECK(query_mode == TiCIQueryMode::Vector); + return tici_scan->idx_scan().tici_vector_query_info(); + } + bool isCount() const { return is_count_agg; } void setIsCountAgg(bool v) { is_count_agg = v; } @@ -65,7 +80,7 @@ class TiCIScan const int index_id; TiDB::ColumnInfos return_columns; NamesAndTypes names_and_types; - [[maybe_unused]] tipb::FTSQueryType query_type; + TiCIQueryMode query_mode; const TableShardInfos shard_infos; const int limit; std::vector sort_column_ids; diff --git a/dbms/src/Flash/Planner/PhysicalPlan.cpp b/dbms/src/Flash/Planner/PhysicalPlan.cpp index bb4f2158a38..57b63cabe03 100644 --- a/dbms/src/Flash/Planner/PhysicalPlan.cpp +++ b/dbms/src/Flash/Planner/PhysicalPlan.cpp @@ -98,20 +98,43 @@ void PhysicalPlan::buildTableScan(const String & executor_id, const tipb::Execut void PhysicalPlan::buildTiCIScan(const String & executor_id, const tipb::Executor * executor) { - RUNTIME_ASSERT(executor->idx_scan().has_fts_query_info()); + RUNTIME_ASSERT( + executor->idx_scan().has_fts_query_info() || executor->idx_scan().has_tici_vector_query_info(), + "IndexScan must have either fts_query_info or tici_vector_query_info"); TiCIScan tici_scan(executor, executor_id, dagContext()); - LOG_INFO( - log, - "tici scan: keyspace_id={} table_id={} index_id={} limit={} shard_count={} match_expr_size={} query_type={} " - "start_ts={}", - tici_scan.getKeyspaceID(), - tici_scan.getTableId(), - tici_scan.getIndexId(), - tici_scan.getLimit(), - tici_scan.getShardInfos().shard_info_list.size(), - tici_scan.getMatchExpr().size(), - tipb::FTSQueryType_Name(executor->idx_scan().fts_query_info().query_type()), - context.getSettingsRef().read_tso); + if (tici_scan.getQueryMode() == TiCIQueryMode::Vector) + { + const auto & vqi = executor->idx_scan().tici_vector_query_info(); + LOG_INFO( + log, + "tici vector scan: keyspace_id={} table_id={} index_id={} col_id={} distance_metric={} top_k={} " + "dimension={} filter_expr_size={} shard_count={} start_ts={}", + tici_scan.getKeyspaceID(), + tici_scan.getTableId(), + tici_scan.getIndexId(), + vqi.column_id(), + tipb::VectorDistanceMetric_Name(vqi.distance_metric()), + vqi.top_k(), + vqi.dimension(), + vqi.filter_expr_size(), + tici_scan.getShardInfos().shard_info_list.size(), + context.getSettingsRef().read_tso); + } + else + { + LOG_INFO( + log, + "tici scan: keyspace_id={} table_id={} index_id={} limit={} shard_count={} match_expr_size={} " + "query_type={} start_ts={}", + tici_scan.getKeyspaceID(), + tici_scan.getTableId(), + tici_scan.getIndexId(), + tici_scan.getLimit(), + tici_scan.getShardInfos().shard_info_list.size(), + tici_scan.getMatchExpr().size(), + tipb::FTSQueryType_Name(executor->idx_scan().fts_query_info().query_type()), + context.getSettingsRef().read_tso); + } pushBack(PhysicalTiCIScan::build(executor_id, log, tici_scan)); dagContext().table_scan_executor_id = executor_id; } diff --git a/dbms/src/Storages/StorageTantivy.cpp b/dbms/src/Storages/StorageTantivy.cpp index fd066742c61..1ff75fd267c 100644 --- a/dbms/src/Storages/StorageTantivy.cpp +++ b/dbms/src/Storages/StorageTantivy.cpp @@ -85,21 +85,41 @@ void StorageTantivy::read( auto shards_snapshot = std::move(*local_shards_snapshot); local_shards_snapshot.reset(); - auto tici_task_pool = std::make_shared( - log, - tici_scan.getKeyspaceID(), - tici_scan.getTableId(), - tici_scan.getIndexId(), - local_read, - return_columns, - tici_scan.getLimit(), - tici_scan.getSortColumnIds(), - tici_scan.getSortColumnAsc(), - context.getSettingsRef().read_tso, - tici_scan.getMatchExpr(), - tici_scan.isCount(), - context.getTimezoneInfo(), - std::move(shards_snapshot)); + std::shared_ptr tici_task_pool; + if (tici_scan.getQueryMode() == TiCIQueryMode::Vector) + { + auto vector_state = TS::TiCIReadTaskPool::buildVectorState( + tici_scan.getVectorQueryInfo(), + context.getTimezoneInfo()); + tici_task_pool = std::make_shared( + log, + tici_scan.getKeyspaceID(), + tici_scan.getTableId(), + tici_scan.getIndexId(), + local_read, + return_columns, + context.getSettingsRef().read_tso, + std::move(vector_state), + std::move(shards_snapshot)); + } + else + { + tici_task_pool = std::make_shared( + log, + tici_scan.getKeyspaceID(), + tici_scan.getTableId(), + tici_scan.getIndexId(), + local_read, + return_columns, + tici_scan.getLimit(), + tici_scan.getSortColumnIds(), + tici_scan.getSortColumnAsc(), + context.getSettingsRef().read_tso, + tici_scan.getMatchExpr(), + tici_scan.isCount(), + context.getTimezoneInfo(), + std::move(shards_snapshot)); + } num_streams = std::max(1, std::min(num_streams, local_read.size())); // local read diff --git a/dbms/src/Storages/Tantivy/TantivyInputStream.h b/dbms/src/Storages/Tantivy/TantivyInputStream.h index df593e218ea..38fa2af3e73 100644 --- a/dbms/src/Storages/Tantivy/TantivyInputStream.h +++ b/dbms/src/Storages/Tantivy/TantivyInputStream.h @@ -19,14 +19,18 @@ #include #include #include +#include #include #include #include #include #include +#include +#include #include #include #include +#include #include #include #include @@ -46,6 +50,17 @@ inline UInt64 convertPackedU64WithTimezone(UInt64 from_time, const TimezoneInfo return result_time; } +/// Holds vector query parameters extracted from TiCIVectorQueryInfo proto. +struct VectorQueryState +{ + Int64 col_id; + Int32 distance_metric; + UInt32 top_k; + std::vector query_vector; + bool has_filter; + ::Expr filter_expr; // converted from tipb filter_expr, empty if no filter +}; + class TantivyInputStream : public IProfilingBlockInputStream { static constexpr auto NAME = "TantivyInputStream"; @@ -53,6 +68,7 @@ class TantivyInputStream : public IProfilingBlockInputStream static constexpr auto version_column_name = "column_-1024"; public: + // FTS constructor TantivyInputStream( LoggerPtr log_, UInt32 keyspace_id_, @@ -77,11 +93,37 @@ class TantivyInputStream : public IProfilingBlockInputStream , sort_column_ids(sort_column_ids_) , sort_column_asc(sort_column_asc_) , read_ts(read_ts_) + , query_mode(TiCIQueryMode::FTS) , match_expr(match_expr_) , is_count(is_count) , shards_snapshot(std::move(shards_snapshot_)) {} + // Vector constructor + TantivyInputStream( + LoggerPtr log_, + UInt32 keyspace_id_, + Int64 table_id_, + Int64 index_id_, + ShardInfo query_shard_info_, + NamesAndTypes return_columns_, + UInt64 read_ts_, + VectorQueryState vector_state_, + std::shared_ptr> shards_snapshot_) + : log(log_) + , keyspace_id(keyspace_id_) + , table_id(table_id_) + , index_id(index_id_) + , query_shard_info(query_shard_info_) + , return_columns(return_columns_) + , limit(vector_state_.top_k) + , read_ts(read_ts_) + , query_mode(TiCIQueryMode::Vector) + , vector_state(std::move(vector_state_)) + , is_count(false) + , shards_snapshot(std::move(shards_snapshot_)) + {} + String getName() const override { return NAME; } Block getHeader() const override { return header; } @@ -92,12 +134,51 @@ class TantivyInputStream : public IProfilingBlockInputStream { return {}; } - Block ret = readFromS3(is_count); + Block ret = (query_mode == TiCIQueryMode::Vector) ? readVector() : readFromS3(is_count); done = true; return ret; } protected: + Block readVector() + { + auto return_fields = getFields(return_columns); + auto shard_info = query_shard_info; + LOG_DEBUG(log, "vector shard info: {}", shard_info.toString()); + auto key_ranges = getKeyRanges(shard_info.key_ranges); + + rust::Vec query_vec; + query_vec.reserve(vector_state.query_vector.size()); + for (auto v : vector_state.query_vector) + query_vec.push_back(v); + + VectorSearchParam vsp{ + .limit = static_cast(vector_state.top_k), + .col_id = vector_state.col_id, + .distance_metric = vector_state.distance_metric, + .query_vector = std::move(query_vec), + .has_filter = vector_state.has_filter, + }; + + RUNTIME_CHECK(shards_snapshot != nullptr); + SearchResult search_result = search_vector( + **shards_snapshot, + { + .keyspace_id = keyspace_id, + .table_id = table_id, + .index_id = index_id, + .shard_id = shard_info.shard_id, + .shard_epoch = shard_info.shard_epoch, + }, + key_ranges, + return_fields, + vector_state.filter_expr, + vsp, + read_ts); + + return buildBlockFromResult(search_result); + } + Block readFromS3(bool is_count) { auto return_fields = getFields(return_columns); @@ -149,6 +230,47 @@ class TantivyInputStream : public IProfilingBlockInputStream return res; } + return buildBlockFromResult(search_result); + } + +private: + static bool isVectorFloat32Type(const DataTypePtr & type) + { + const auto * type_array = typeid_cast(type.get()); + return type_array != nullptr && type_array->getNestedType()->isFloatingPoint() + && type_array->getNestedType()->getSizeOfValueInMemory() == sizeof(Float32); + } + + static Field decodeVectorFloat32Field(const String & raw_value, const String & column_name) + { + RUNTIME_CHECK_MSG( + raw_value.size() >= sizeof(UInt32), + "Malformed TiCI vector payload for column {}: payload is too short ({} bytes)", + column_name, + raw_value.size()); + const auto element_count = readLittleEndian(raw_value.data()); + const auto expected_size = sizeof(UInt32) + static_cast(element_count) * sizeof(Float32); + RUNTIME_CHECK_MSG( + raw_value.size() == expected_size, + "Malformed TiCI vector payload for column {}: expected {} bytes for {} elements, got {} bytes", + column_name, + expected_size, + element_count, + raw_value.size()); + + size_t cursor = 0; + auto field = DecodeVectorFloat32(cursor, raw_value); + RUNTIME_CHECK_MSG( + cursor == raw_value.size(), + "Malformed TiCI vector payload for column {}: {} trailing bytes remain", + column_name, + raw_value.size() - cursor); + return field; + } + + Block buildBlockFromResult(SearchResult & search_result) + { + Block res(return_columns); auto documents = search_result.rows; if (documents.empty()) { @@ -156,6 +278,7 @@ class TantivyInputStream : public IProfilingBlockInputStream } for (auto & name_and_type : return_columns) { + const auto nested_type = removeNullable(name_and_type.type); int idx = -1; for (size_t j = 0; j < documents[0].fieldValues.size(); j++) { @@ -176,6 +299,10 @@ class TantivyInputStream : public IProfilingBlockInputStream } continue; } + RUNTIME_CHECK_MSG( + !isVectorFloat32Type(nested_type), + "TiCI query did not materialize requested vector column {}", + name_and_type.name); for (size_t j = 0; j < documents.size(); j++) { // Insert default value for missing fields @@ -184,9 +311,10 @@ class TantivyInputStream : public IProfilingBlockInputStream continue; } + const auto * result_type = nested_type.get(); auto col = res.getByName(name_and_type.name).column->assumeMutable(); bool has_null = false; - if (removeNullable(name_and_type.type)->isStringOrFixedString()) + if (result_type->isStringOrFixedString()) { for (auto & doc : documents) { @@ -203,7 +331,7 @@ class TantivyInputStream : public IProfilingBlockInputStream } } } - if (removeNullable(name_and_type.type)->isInteger()) + else if (result_type->isInteger()) { for (auto & doc : documents) { @@ -219,7 +347,40 @@ class TantivyInputStream : public IProfilingBlockInputStream } } } - if (removeNullable(name_and_type.type)->isDateOrDateTime()) + else if (result_type->isFloatingPoint()) + { + for (auto & doc : documents) + { + const auto & field_value = doc.fieldValues[idx]; + if (field_value.is_null) + { + has_null = true; + col->insert(Field()); + } + else + { + col->insert(Field(field_value.float_value)); + } + } + } + else if (isVectorFloat32Type(nested_type)) + { + for (auto & doc : documents) + { + const auto & field_value = doc.fieldValues[idx]; + if (field_value.is_null) + { + has_null = true; + col->insert(Field()); + } + else + { + const String raw_value(field_value.string_value.begin(), field_value.string_value.end()); + col->insert(decodeVectorFloat32Field(raw_value, name_and_type.name)); + } + } + } + else if (result_type->isDateOrDateTime()) { for (auto & doc : documents) { @@ -236,6 +397,14 @@ class TantivyInputStream : public IProfilingBlockInputStream } } } + else + { + RUNTIME_CHECK_MSG( + false, + "Unsupported TiCI result column type {} for column {}", + nested_type->getName(), + name_and_type.name); + } if (has_null) { RUNTIME_CHECK_MSG( @@ -247,7 +416,6 @@ class TantivyInputStream : public IProfilingBlockInputStream return res; } -private: Block header; bool done = false; LoggerPtr log; @@ -260,7 +428,9 @@ class TantivyInputStream : public IProfilingBlockInputStream std::vector sort_column_ids; std::vector sort_column_asc; UInt64 read_ts; - ::Expr match_expr; + TiCIQueryMode query_mode; + ::Expr match_expr; // FTS mode + VectorQueryState vector_state; // Vector mode bool is_count; std::shared_ptr> shards_snapshot; diff --git a/dbms/src/Storages/Tantivy/TiCIReadTaskPool.h b/dbms/src/Storages/Tantivy/TiCIReadTaskPool.h index 7d6df3745c1..438adaf9c6c 100644 --- a/dbms/src/Storages/Tantivy/TiCIReadTaskPool.h +++ b/dbms/src/Storages/Tantivy/TiCIReadTaskPool.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include #include #include @@ -27,6 +30,7 @@ struct TiCIReadTask : shard_info(shard_info_) {} + // FTS init void initInputStream( LoggerPtr log_, UInt32 keyspace_id_, @@ -58,6 +62,30 @@ struct TiCIReadTask shards_snapshot_); } + // Vector init + void initInputStreamVector( + LoggerPtr log_, + UInt32 keyspace_id_, + Int64 table_id_, + Int64 index_id_, + ShardInfo query_shard_info_, + NamesAndTypes return_columns_, + UInt64 read_ts_, + VectorQueryState vector_state_, + const std::shared_ptr> & shards_snapshot_) + { + input_stream = std::make_shared( + log_, + keyspace_id_, + table_id_, + index_id_, + query_shard_info_, + return_columns_, + read_ts_, + std::move(vector_state_), + shards_snapshot_); + } + bool isInitialized() const { return input_stream != nullptr; } BlockInputStreamPtr getInputStream() const @@ -79,6 +107,7 @@ struct TiCIReadTaskPool public: using TiCIReadTasks = std::vector>; + // FTS constructor TiCIReadTaskPool( LoggerPtr log_, UInt32 keyspace_id_, @@ -103,6 +132,7 @@ struct TiCIReadTaskPool , sort_column_ids(sort_column_ids_) , sort_column_asc(sort_column_asc_) , read_ts(read_ts_) + , query_mode(TiCIQueryMode::FTS) , is_count(is_count) , shards_snapshot(std::make_shared>(std::move(shards_snapshot_))) { @@ -121,6 +151,43 @@ struct TiCIReadTaskPool LOG_DEBUG(log, "columns: [{}], match columns: {}", buf.toString(), cids); } + // Vector constructor + TiCIReadTaskPool( + LoggerPtr log_, + UInt32 keyspace_id_, + Int64 table_id_, + Int64 index_id_, + const ShardInfoList & shard_infos, + NamesAndTypes return_columns_, + UInt64 read_ts_, + VectorQueryState vector_state_, + rust::Box shards_snapshot_) + : log(log_) + , keyspace_id(keyspace_id_) + , table_id(table_id_) + , index_id(index_id_) + , return_columns(return_columns_) + , limit(vector_state_.top_k) + , read_ts(read_ts_) + , query_mode(TiCIQueryMode::Vector) + , vector_state(std::move(vector_state_)) + , is_count(false) + , shards_snapshot(std::make_shared>(std::move(shards_snapshot_))) + { + for (const auto & shard_info : shard_infos) + { + tasks.emplace_back(std::make_shared(shard_info)); + } + LOG_DEBUG( + log, + "vector query: col_id={} distance_metric={} top_k={} dim={} has_filter={}", + vector_state.col_id, + vector_state.distance_metric, + vector_state.top_k, + vector_state.query_vector.size(), + vector_state.has_filter); + } + TiCIReadTaskPtr getNextTask() { std::lock_guard lock(mutex); @@ -136,20 +203,36 @@ struct TiCIReadTaskPool RUNTIME_CHECK(task != nullptr); if (!task->isInitialized()) { - task->initInputStream( - log, - keyspace_id, - table_id, - index_id, - task->getShardInfo(), - return_columns, - limit, - sort_column_ids, - sort_column_asc, - read_ts, - match_expr, - is_count, - shards_snapshot); + if (query_mode == TiCIQueryMode::Vector) + { + task->initInputStreamVector( + log, + keyspace_id, + table_id, + index_id, + task->getShardInfo(), + return_columns, + read_ts, + vector_state, + shards_snapshot); + } + else + { + task->initInputStream( + log, + keyspace_id, + table_id, + index_id, + task->getShardInfo(), + return_columns, + limit, + sort_column_ids, + sort_column_asc, + read_ts, + match_expr, + is_count, + shards_snapshot); + } } return task->getInputStream(); } @@ -167,7 +250,9 @@ struct TiCIReadTaskPool std::vector sort_column_ids; std::vector sort_column_asc; UInt64 read_ts; - ::Expr match_expr; + TiCIQueryMode query_mode; + ::Expr match_expr; // FTS mode + VectorQueryState vector_state; // Vector mode bool is_count; std::shared_ptr> shards_snapshot; @@ -351,6 +436,78 @@ struct TiCIReadTaskPool } return {ret, cids}; } + +public: + /// Convert tipb filter expressions from TiCIVectorQueryInfo to a TiCI Expr. + /// Returns an empty Expr if there are no filter expressions. + static VectorQueryState buildVectorState( + const tipb::TiCIVectorQueryInfo & info, + const TimezoneInfo & timezone_info) + { + VectorQueryState state; + RUNTIME_CHECK_MSG(info.top_k() > 0, "TiCI vector query top_k must be greater than 0"); + RUNTIME_CHECK_MSG(info.column_id() != 0, "TiCI vector query column_id must not be 0"); + RUNTIME_CHECK_MSG(info.dimension() > 0, "TiCI vector query dimension must be greater than 0"); + RUNTIME_CHECK_MSG( + info.distance_metric() == tipb::VectorDistanceMetric::L2 + || info.distance_metric() == tipb::VectorDistanceMetric::COSINE, + "Unsupported TiCI vector distance metric: {}", + static_cast(info.distance_metric())); + + state.col_id = info.column_id(); + state.distance_metric = static_cast(info.distance_metric()); + state.top_k = info.top_k(); + + // TiDB currently serializes VectorFloat32 with a 4-byte little-endian + // dimension prefix. Keep accepting the original raw-float payload too + // so TiFlash stays compatible with both callers during rollout. + const auto & qv = info.query_vector(); + RUNTIME_CHECK_MSG(!qv.empty(), "TiCI vector query_vector must not be empty"); + RUNTIME_CHECK_MSG( + qv.size() % sizeof(float) == 0, + "Malformed TiCI query_vector payload: {} bytes is not a multiple of {}", + qv.size(), + sizeof(float)); + const auto expected_bytes = static_cast(info.dimension()) * sizeof(float); + size_t qv_offset = 0; + size_t qv_bytes = qv.size(); + if (qv.size() == expected_bytes + sizeof(UInt32)) + { + UInt32 encoded_dimension = 0; + std::memcpy(&encoded_dimension, qv.data(), sizeof(UInt32)); + if (encoded_dimension == info.dimension()) + { + qv_offset = sizeof(UInt32); + qv_bytes = expected_bytes; + } + } + RUNTIME_CHECK_MSG( + qv_bytes == expected_bytes, + "TiCI query_vector length mismatch: expected {} bytes for dimension {}, got {} bytes", + expected_bytes, + info.dimension(), + qv.size()); + state.query_vector.reserve(info.dimension()); + for (size_t offset = qv_offset; offset < qv.size(); offset += sizeof(float)) + { + float value = 0; + std::memcpy(&value, qv.data() + offset, sizeof(float)); + RUNTIME_CHECK_MSG( + std::isfinite(value), + "TiCI query_vector contains non-finite value at offset {}", + offset / sizeof(float)); + state.query_vector.push_back(value); + } + + // Convert filter expressions. + state.has_filter = (info.filter_expr_size() > 0); + if (state.has_filter) + { + auto [expr, _cids] = tipbToTiCIExpr(info.filter_expr(), timezone_info); + state.filter_expr = std::move(expr); + } + return state; + } }; using TiCIReadTaskPoolPtr = std::shared_ptr; diff --git a/docs/tici-hybrid-vector-plan.md b/docs/tici-hybrid-vector-plan.md new file mode 100644 index 00000000000..7571e2bd74f --- /dev/null +++ b/docs/tici-hybrid-vector-plan.md @@ -0,0 +1,98 @@ +# TiFlash Plan: Hybrid Vector Query on TiCI + +## Goal + +Execute hybrid vector queries through the TiCI read path instead of the legacy ANN path. + +Current rollout scope: + +- vector-only TiCI queries first +- pushed-down `filter + vector` deferred to a later phase +- runtime validation should currently avoid add-index-on-existing-data. + `import into` / backfill for hybrid-vector data is not adapted yet; use + empty-table DDL plus CDC writes for e2e. +- the local validation harness should now reuse the playground-managed TiCDC + changefeed instead of creating a second one manually. + +TiFlash should support: + +- fulltext / expression TiCI queries +- vector-only TiCI queries +- vector + pushed-down filter TiCI queries later + +## Current State + +- The current TiCI path already handles `FULLTEXT INDEX`. +- The same expression path already carries hybrid inverted/scalar predicates. +- TiCI vector search is now wired in the TiFlash local/remote TiCI path. +- The TiCI executor path now accepts either `IndexScan.fts_query_info` or `IndexScan.tici_vector_query_info`. +- The legacy ANN path still exists in DeltaMerge, but it is not the target for hybrid vector. +- Current upstream rollout does not populate `filter_expr` for TiCI vector queries yet. +- Local macOS build validation passed on `2026-03-19` in `cmake-build-codex-release`; the `dbms/src/Server/tiflash` binary was built successfully. +- Runtime smoke/e2e validation passed on `2026-03-20` with + `playground:v1.16.2-feature.fts`, `upstream/vector@db0a4054`, and the + playground-managed changefeed. + +## Chosen Direction + +Treat TiCI scan as two execution modes: + +- expression/fulltext mode + - current `FTSQueryInfo` + - calls Rust `search(...)` +- vector mode + - new `TiCIVectorQueryInfo` + - calls Rust `search_vector(...)` + +For vector mode, `filter_expr` stays part of the long-term payload shape, but the current rollout only requires the vector-only subset. + +## Prerequisites + +- ✅ TiCI `search_vector` FFI ready on `upstream/vector@db0a4054` +- ✅ tipb `TiCIVectorQueryInfo` proto merged (commit `a25a67b`) +- ✅ TiDB planner populates `TiCIVectorQueryInfo` on `tipb.IndexScan` (PR #67103) +- ✅ Update `contrib/tici` submodule to vector-capable upstream + +## TiFlash Work Items + +1. ✅ Generalize TiCI scan parsing. (PR7) + - `TiCIScan` parses either `FTSQueryInfo` or `TiCIVectorQueryInfo` via `TiCIQueryMode` enum. + - `PhysicalPlan::buildTiCIScan()` accepts both modes. + +2. ✅ Generalize TiCI read task creation. (PR7) + - `TiCIReadTaskPool` has separate FTS/Vector constructors. + - `VectorQueryState` holds col_id, distance_metric, query_vector, and optional filter_expr. + - Reuses `tipbToTiCIExpr` for vector filter conversion. + +3. ✅ Add vector FFI call path. (PR7) + - `TantivyInputStream` branches to `search(...)` or `search_vector(...)`. + - `VectorSearchParam` uses `(col_id, distance_metric)` matching TiCI FFI. + +4. ✅ Basic hardening for vector-only rollout. + - Validate `top_k`, `column_id`, `dimension`, metric, and `query_vector` byte length before FFI call. + - Reject non-finite query vector values early in TiFlash. + - Decode `FLOAT/DOUBLE` result fields correctly. + - Decode `Array(Float32)` vector payloads when TiCI materializes them. + - Fail fast if a TiCI vector query requests a vector column but TiCI does not return the payload. + +5. 🔲 Add tests / integration validation. (PR8) + - local TiCI vector query + - remote TiCI vector query + - malformed vector payload rejection + - returned float/vector column materialization + - vector + filter query later + +## Performance/Stability Requirements + +- Avoid reading/materializing full rows before TiCI shard-level pruning. +- Keep top-k reduction shard-local as long as possible. +- Ensure remote-read behavior matches local-read semantics. +- Reject invalid or ambiguous payloads early. +- Add metrics for TiCI vector local/remote reads, filter selectivity, and shard query latency. + +## Expected PR Split For TiFlash + +Recommended TiFlash split: 2 PRs. + +1. Query mode, executor, and FFI wiring. +2. Hardening, tests, and optional later filter pushdown.