Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,30 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Bits;
import org.opensearch.common.UUIDs;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.codec.nativeindex.AbstractNativeEnginesKnnVectorsReader;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.util.WarmupUtil;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.faiss.FaissScorableByteVectorValues;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;
Expand Down Expand Up @@ -61,23 +64,34 @@ public NativeEngines990KnnVectorsReader(final SegmentReadState state, final Flat
}

/**
* Returns the {@link ByteVectorValues} for the given field.
* Attempts flat vectors reader first, then falls back to quantized vectors if available.
*
* @param field the vector field name
* @return {@link ByteVectorValues} for the field, never {@code null}
* @throws IOException if an I/O error occurs or no byte vectors are available for the field
* Returns a composite {@link FloatVectorValues} that bundles full-precision float vectors with
* quantized byte vectors when quantization is available. The composite delegates:
* <ul>
* <li>{@code vectorValue()} to full-precision floats (for merge/flush)</li>
* <li>{@code scorer()} quantizes the float query and delegates to quantized byte values</li>
* <li>{@code rescorer()} to full-precision floats (for full-fidelity rescoring)</li>
* </ul>
* Falls back to plain float vector values when quantization is not configured.
*/
@Override
public ByteVectorValues getByteVectorValues(final String field) throws IOException {
public FloatVectorValues getFloatVectorValues(final String field) throws IOException {
final FloatVectorValues rawFloatVectorValues = flatVectorsReader.getFloatVectorValues(field);
final FieldInfo fieldInfo = fieldInfos.fieldInfo(field);
if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
final ByteVectorValues quantizedVectorValues = getQuantizedVectorValues(fieldInfo);
if (quantizedVectorValues != null) {
return quantizedVectorValues;
if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 && hasQuantizationConfig(fieldInfo)) {
final VectorSearcher vectorSearcher = loadMemoryOptimizedSearcherIfRequired(fieldInfo);
if (vectorSearcher != null) {
final ByteVectorValues byteVectorValues = vectorSearcher.getByteVectorValues(rawFloatVectorValues.iterator());
return new QuantizedFloatVectorValues(rawFloatVectorValues, byteVectorValues, field);
}
log.warn("No quantized vectors found for field [{}]", field);
}
return rawFloatVectorValues;
}

/**
* Returns the {@link ByteVectorValues} for the given field by delegating to the flat vectors reader.
*/
@Override
public ByteVectorValues getByteVectorValues(final String field) throws IOException {
return flatVectorsReader.getByteVectorValues(field);
}

Expand Down Expand Up @@ -110,18 +124,7 @@ public ByteVectorValues getByteVectorValues(final String field) throws IOExcepti
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
// TODO: This is a temporary hack where we are using KNNCollector to initialize the quantization state.
if (knnCollector instanceof QuantizationConfigKNNCollector) {
String cacheKey = quantizationStateCacheKeyPerField.get(field);
FieldInfo fieldInfo = fieldInfos.fieldInfo(field);
QuantizationState quantizationState = QuantizationStateCacheManager.getInstance()
.getQuantizationState(
new QuantizationStateReadConfig(
segmentReadState,
QuantizationService.getInstance().getQuantizationParams(fieldInfo),
field,
cacheKey
)
);
((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState);
((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(getQuantizationState(field));
return;
}

Expand Down Expand Up @@ -244,21 +247,6 @@ private static List<String> getVectorCacheKeysFromSegmentReaderState(SegmentRead
return cacheKeys;
}

/**
* Retrieves quantized byte vectors from Faiss memory-optimized searcher.
*
* @param fieldInfo the field to retrieve vectors for
* @return quantized byte vectors, or null if not available
* @throws IOException if an I/O error occurs
*/
private ByteVectorValues getQuantizedVectorValues(@NonNull final FieldInfo fieldInfo) throws IOException {
if (hasQuantizationConfig(fieldInfo) == false) {
return null;
}
final VectorSearcher vectorSearcher = loadMemoryOptimizedSearcherIfRequired(fieldInfo);
return vectorSearcher != null ? vectorSearcher.getByteVectorValues(getFloatVectorValues(fieldInfo.getName()).iterator()) : null;
}

/**
* Warms up the on-disk data for the given field by loading the HNSW graph and flat vectors
* into the OS page cache.
Expand All @@ -284,4 +272,97 @@ public void warmUp(final String fieldName) throws IOException {
memoryOptimizedSearcher.warmUp();
}
}

private QuantizationState getQuantizationState(final String field) throws IOException {
final String cacheKey = quantizationStateCacheKeyPerField.get(field);
return QuantizationStateCacheManager.getInstance()
.getQuantizationState(
new QuantizationStateReadConfig(
segmentReadState,
QuantizationService.getInstance().getQuantizationParams(fieldInfos.fieldInfo(field)),
field,
cacheKey
)
);
}

/**
* A composite {@link FloatVectorValues} that bundles full-precision float vectors with
* quantized byte vectors, following the same pattern as Lucene's {@code ScalarQuantizedVectorValues}.
* <ul>
* <li>{@link #vectorValue(int)} returns full-precision floats (for merge/flush)</li>
* <li>{@link #scorer(float[])} quantizes the query and delegates to quantized byte values</li>
* <li>{@link #rescorer(float[])} delegates to raw float values (for full-fidelity rescoring)</li>
* </ul>
*/
private final class QuantizedFloatVectorValues extends FloatVectorValues {
private final FloatVectorValues rawFloatVectorValues;
private final ByteVectorValues quantizedByteVectorValues;
private final String fieldName;

QuantizedFloatVectorValues(FloatVectorValues rawFloatVectorValues, ByteVectorValues quantizedByteVectorValues, String fieldName) {
this.rawFloatVectorValues = rawFloatVectorValues;
this.quantizedByteVectorValues = quantizedByteVectorValues;
this.fieldName = fieldName;
}

@Override
public int dimension() {
return rawFloatVectorValues.dimension();
}

@Override
public int size() {
return rawFloatVectorValues.size();
}

@Override
public float[] vectorValue(int ord) throws IOException {
return rawFloatVectorValues.vectorValue(ord);
}

@Override
public QuantizedFloatVectorValues copy() throws IOException {
return new QuantizedFloatVectorValues(rawFloatVectorValues.copy(), quantizedByteVectorValues.copy(), fieldName);
}

@Override
public DocIndexIterator iterator() {
return rawFloatVectorValues.iterator();
}

@Override
public int ordToDoc(int ord) {
return rawFloatVectorValues.ordToDoc(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return rawFloatVectorValues.getAcceptOrds(acceptDocs);
}

@SuppressWarnings("unchecked")
@Override
public VectorScorer scorer(float[] target) throws IOException {
if (quantizedByteVectorValues instanceof FaissScorableByteVectorValues scorableByteVectorValues
&& FieldInfoExtractor.isAdc(fieldInfos.fieldInfo(fieldName))) {
// ADC: the FlatVectorsScorer handles float-vs-byte scoring asymmetrically
return scorableByteVectorValues.scorer(target);
}
// Non-ADC: quantize the float query to bytes, then score byte-vs-byte
final QuantizationState quantizationState = getQuantizationState(fieldName);
final QuantizationService quantizationService = QuantizationService.getInstance();
final byte[] quantizedQuery = (byte[]) quantizationService.quantize(
quantizationState,
target,
quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams())
);
return quantizedByteVectorValues.scorer(quantizedQuery);
}

@Override
public VectorScorer rescorer(float[] target) throws IOException {
return rawFloatVectorValues.rescorer(target);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.scorers.VectorScorerMode;
import org.opensearch.knn.index.query.scorers.VectorScorers;
Expand Down Expand Up @@ -384,46 +382,11 @@ private VectorScorer createVectorScorer(
parentBitSet
);
}

// Float vector path
final SegmentLevelQuantizationInfo quantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, context.getField());

if (quantizationInfo == null || scorerMode == VectorScorerMode.RESCORE) {
return VectorScorers.createScorer(
iteratorValues,
context.getFloatQueryVector(),
scorerMode,
spaceType,
fieldInfo,
context.getMatchedDocsIterator(),
parentBitSet
);
}

// Quantized path — need byte vector values
final KNNVectorValues<?> quantizedValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader, true);
final KNNVectorValuesIterator.DocIdsIteratorValues quantizedIteratorValues =
(KNNVectorValuesIterator.DocIdsIteratorValues) quantizedValues.getVectorValuesIterator();

if (SegmentLevelQuantizationUtil.isAdcEnabled(quantizationInfo)) {
SegmentLevelQuantizationUtil.transformVectorWithADC(context.getFloatQueryVector(), quantizationInfo, spaceType);
return VectorScorers.createScorer(
quantizedIteratorValues,
context.getFloatQueryVector(),
scorerMode,
spaceType,
fieldInfo,
context.getMatchedDocsIterator(),
parentBitSet
);
}

final byte[] quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(context.getFloatQueryVector(), quantizationInfo);
return VectorScorers.createScorer(
quantizedIteratorValues,
quantizedQueryVector,
iteratorValues,
context.getFloatQueryVector(),
scorerMode,
SpaceType.HAMMING,
spaceType,
fieldInfo,
context.getMatchedDocsIterator(),
parentBitSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,6 @@ public static <T> KNNVectorValues<T> getVectorValues(
);
} else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(fieldInfo.getName());
// Quantized search path: retrieve quantized byte vectors from codec.
if (shouldRetrieveQuantizedVectors) {
// Bypasses leafReader.getByteVectorValues() which enforces BYTE encoding check.
// This will call getByteVectorValues from NativeEngines990KnnVectorsReader at the end.
final ByteVectorValues byteVectorValues = leafReader.getVectorReader().getByteVectorValues(fieldInfo.getName());
return getVectorValues(
VectorDataType.BINARY, // retrieve binary data from reader
new KNNVectorValuesIterator.DocIdsIteratorValues(floatVectorValues.iterator(), byteVectorValues)
);
}
return getVectorValues(
FieldInfoExtractor.extractVectorDataType(fieldInfo),
new KNNVectorValuesIterator.DocIdsIteratorValues(floatVectorValues)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ public DocIndexIterator iterator() {

// ---- Scorer ----

/**
* Returns a {@link VectorScorer} for a float query against byte vectors.
* Used by ADC (Asymmetric Distance Computation) where the float query is scored
* directly against quantized byte document vectors via the {@link FlatVectorsScorer}.
*/
public VectorScorer scorer(float[] target) throws IOException {
if (size() == 0) return null;

final FaissScorableByteVectorValues scorerCopy = copy();
final RandomVectorScorer rvs = flatVectorsScorer.getRandomVectorScorer(similarityFunction, scorerCopy, target);
final DocIndexIterator iterator = scorerCopy.iterator();

return new VectorScorer() {
@Override
public float score() throws IOException {
return rvs.score(iterator.index());
}

@Override
public DocIdSetIterator iterator() {
return iterator;
}

@Override
public Bulk bulk(final DocIdSetIterator matchingDocs) {
return Bulk.fromRandomScorerSparse(rvs, iterator, matchingDocs);
}
};
}

/**
* Returns a {@link VectorScorer} for {@code target}, or {@code null} for an empty index.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.lucene.store.IndexInput;
import org.junit.Assert;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.codec.nativeindex.AbstractNativeEnginesKnnVectorsReader;
Expand All @@ -31,6 +30,8 @@
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;

import java.io.IOException;
import java.lang.reflect.Field;
Expand Down Expand Up @@ -77,7 +78,7 @@ public void testFlatVectorReaderIsCalled_whenNoQuantization() throws IOException
verify(flatVectorsReader).getByteVectorValues("field1");
}

public void testBinaryVectorValuesIsCalled_whenQuantizationIsAvailable_thenSuccess() throws IOException {
public void testCompositeFloatVectorValues_whenQuantizationIsAvailable_thenSuccess() throws IOException {
FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("field1")
.fieldNumber(0)
.addAttribute(KNN_ENGINE, KNNEngine.FAISS.getName())
Expand All @@ -99,14 +100,22 @@ public void testBinaryVectorValuesIsCalled_whenQuantizationIsAvailable_thenSucce
when(mockFloatValues.iterator()).thenReturn(mock(KnnVectorValues.DocIndexIterator.class));
final FlatVectorsReader flatVectorsReader = mock(FlatVectorsReader.class);
when(flatVectorsReader.getFloatVectorValues("field1")).thenReturn(mockFloatValues);
try (MockedStatic<KNNEngine> mockedStatic = mockStatic(KNNEngine.class)) {
try (
MockedStatic<KNNEngine> mockedStatic = mockStatic(KNNEngine.class);
MockedStatic<QuantizationStateCacheManager> mockedCacheManager = mockStatic(QuantizationStateCacheManager.class)
) {
mockedStatic.when(() -> KNNEngine.getEngine(any())).thenReturn(mockFaiss);
final Set<String> filesInSegment = Set.of("_0_165_field1.faiss");
mockedStatic.when(KNNEngine::getEnginesThatCreateCustomSegmentFiles).thenReturn(ImmutableSet.of(mockFaiss));

QuantizationStateCacheManager mockCacheManager = mock(QuantizationStateCacheManager.class);
when(mockCacheManager.getQuantizationState(any())).thenReturn(mock(QuantizationState.class));
mockedCacheManager.when(QuantizationStateCacheManager::getInstance).thenReturn(mockCacheManager);

final Set<String> filesInSegment = Set.of("_0_165_field1.faiss");
NativeEngines990KnnVectorsReader reader = createReader(fieldInfos, filesInSegment, flatVectorsReader);
reader.getByteVectorValues("field1");
FloatVectorValues result = reader.getFloatVectorValues("field1");
assertNotSame(mockFloatValues, result);
verify(mockSearcher).getByteVectorValues(any());
verify(flatVectorsReader, Mockito.never()).getByteVectorValues("field1");
}
}

Expand Down
Loading