Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),


### Bug Fixes
* Preserve original (unnormalized) vectors in doc values for Faiss + cosine similarity so that derived source returns the user-indexed vector [#3083](https://github.com/opensearch-project/k-NN/issues/3083)

### Refactoring

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
Expand Down Expand Up @@ -67,14 +71,26 @@ private boolean isKNNBinaryFieldRequired(FieldInfo field) {

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) throws IOException {
final VectorDataType vectorDataType = extractVectorDataType(field);
final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field));
KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field));

// Apply the field-configured VectorTransformer to the stream of vectors fed into the native
// builder. For Faiss + cosine this yields a normalized stream while BinaryDocValues keep the
// original vectors untouched. For all other combinations the factory returns a no-op and
// knnVectorValues is passed through unchanged. See issue #3083.
final VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(field, false);
if (knnVectorValues instanceof KNNFloatVectorValues floatVectorValues) {
knnVectorValues = transformer.wrap(floatVectorValues);
} else if (knnVectorValues instanceof KNNByteVectorValues byteVectorValues) {
knnVectorValues = transformer.wrap(byteVectorValues);
}

final KNNVectorValues<?> finalVectorValues = knnVectorValues;
// For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total
// live docs
if (isMerge) {
NativeIndexWriter.getWriter(field, state).mergeIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs());
NativeIndexWriter.getWriter(field, state).mergeIndex(() -> finalVectorValues, (int) finalVectorValues.totalLiveDocs());
} else {
NativeIndexWriter.getWriter(field, state).flushIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs());
NativeIndexWriter.getWriter(field, state).flushIndex(() -> finalVectorValues, (int) finalVectorValues.totalLiveDocs());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,9 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
getVectorTransformer().transform(array, true);
// Intentionally NOT normalizing the vector here even for Faiss+cosine. Normalization for Faiss
// cosine is deferred to the native index build path so that BinaryDocValues keeps the original
// (unnormalized) vectors. See issue #3083 for context.
context.doc().addAll(getFieldsForFloatVector(array, isDerivedEnabled(context)));
} else {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package org.opensearch.knn.index.mapper;

import org.apache.lucene.util.VectorUtil;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.NormalizingKNNFloatVectorValues;

import java.util.Arrays;

Expand Down Expand Up @@ -38,6 +40,15 @@ public void transform(byte[] vector) {
throw new UnsupportedOperationException("Byte array normalization is not supported");
}

/**
* Returns a {@link NormalizingKNNFloatVectorValues} that L2-normalizes vectors on demand
* without mutating the underlying data.
*/
@Override
public KNNFloatVectorValues wrap(final KNNFloatVectorValues delegate) {
return new NormalizingKNNFloatVectorValues(delegate);
}

private void validateVector(float[] vector) {
if (vector == null || vector.length == 0) {
throw new IllegalArgumentException("Vector cannot be null or empty");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
*/
package org.opensearch.knn.index.mapper;

import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;

/**
* Defines operations for transforming vectors in the k-NN search context.
* Implementations can modify vectors while preserving their dimensional properties
Expand Down Expand Up @@ -35,4 +38,29 @@ default void transform(final byte[] vector) {
throw new IllegalArgumentException("Input vector cannot be null");
}
}

/**
* Wraps a {@link KNNFloatVectorValues} stream so that each vector returned by {@code getVector()}
* is transformed on demand. Default implementation is a pass-through.
*
* <p>Used by codec-layer components (e.g. {@code KNN80DocValuesConsumer}) to apply vector
* transformations to the stream of vectors fed into native index builders, without mutating
* the original vectors stored in {@code BinaryDocValues}.
*
* @param delegate the underlying stream of float vectors
* @return a stream that applies the transformation on the fly; returns {@code delegate} unchanged
* for no-op implementations
*/
default KNNFloatVectorValues wrap(final KNNFloatVectorValues delegate) {
return delegate;
}

/**
* Wraps a {@link KNNByteVectorValues} stream. Default implementation is a pass-through.
* Kept symmetric with {@link #wrap(KNNFloatVectorValues)} so callers can apply transformations
* uniformly regardless of the vector element type.
*/
default KNNByteVectorValues wrap(final KNNByteVectorValues delegate) {
return delegate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,22 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand Down Expand Up @@ -55,6 +67,96 @@ public static VectorTransformer getVectorTransformer(
return shouldNormalizeVector(knnEngine, spaceType, methodComponentContext) ? DEFAULT_VECTOR_TRANSFORMER : NOOP_VECTOR_TRANSFORMER;
}

/**
* Returns the {@link VectorTransformer} for a field by reading the required metadata directly
* from the {@link FieldInfo}. Intended for codec-layer callers that do not have access to a
* resolved {@link MethodComponentContext}.
*
* <p>Space type resolution:
* <ul>
* <li>Primary source: {@link FieldInfo#getAttribute(String)} with key {@link KNNConstants#SPACE_TYPE}.</li>
* <li>Fallback: {@link ModelMetadata#getSpaceType()} via {@link ModelUtil#getModelMetadata(String)} when
* the field is model-based. Reads cluster state (no network I/O).</li>
* </ul>
*
* <p>{@link MethodComponentContext} resolution:
* <ul>
* <li>Primary source: {@link FieldInfo#getAttribute(String)} with key {@link KNNConstants#PARAMETERS}.
* The attribute value is the JSON serialization of
* {@code KNNLibraryIndexingContext.getLibraryParameters()} and is parsed back into a
* {@link MethodComponentContext} so that Lucene-specific normalization conditions
* (flat / SQ 1-bit) are evaluated identically to the query-time path.</li>
* <li>Fallback: {@link ModelMetadata#getMethodComponentContext()} for model-based fields.</li>
* <li>If neither is available (e.g. legacy fields without the PARAMETERS attribute), {@code null}
* is used and Lucene-specific conditions are skipped.</li>
* </ul>
*
* @param fieldInfo field metadata from the Lucene segment
*
* @return a {@link VectorTransformer}, possibly {@link #NOOP_VECTOR_TRANSFORMER}
*/
public static VectorTransformer getVectorTransformer(final FieldInfo fieldInfo, boolean resolveMethodComponentContext) {
final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
final SpaceType spaceType = resolveSpaceTypeFromFieldInfo(fieldInfo);
if (spaceType == null) {
return NOOP_VECTOR_TRANSFORMER;
}
if(resolveMethodComponentContext) {
return getVectorTransformer(engine, spaceType, resolveMethodComponentContextFromFieldInfo(fieldInfo));
}
return getVectorTransformer(engine, spaceType, null);
}

private static SpaceType resolveSpaceTypeFromFieldInfo(final FieldInfo fieldInfo) {
final String spaceTypeStr = fieldInfo.getAttribute(KNNConstants.SPACE_TYPE);
if (StringUtils.isNotEmpty(spaceTypeStr)) {
return SpaceType.getSpace(spaceTypeStr);
}
final String modelId = fieldInfo.getAttribute(KNNConstants.MODEL_ID);
if (StringUtils.isNotEmpty(modelId)) {
final ModelMetadata metadata = ModelUtil.getModelMetadata(modelId);
return metadata != null ? metadata.getSpaceType() : null;
}
return null;
}

private static MethodComponentContext resolveMethodComponentContextFromFieldInfo(final FieldInfo fieldInfo) {
final String parametersString = fieldInfo.getAttribute(KNNConstants.PARAMETERS);
if (StringUtils.isNotEmpty(parametersString)) {
try {
final Map<String, Object> parsed = XContentHelper.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(parametersString),
MediaTypeRegistry.getDefaultMediaType()
).map();
// The JSON written by EngineFieldMapper contains top-level keys beyond NAME/PARAMETERS
// (e.g. space_type, vector_data_type, index_description). MethodComponentContext.parse
// rejects unknown keys, so we narrow down to the two fields it understands before parsing.
final Map<String, Object> methodMap = new HashMap<>();
if (parsed.containsKey(KNNConstants.NAME)) {
methodMap.put(KNNConstants.NAME, parsed.get(KNNConstants.NAME));
}
if (parsed.containsKey(KNNConstants.PARAMETERS)) {
methodMap.put(KNNConstants.PARAMETERS, parsed.get(KNNConstants.PARAMETERS));
}
if (!methodMap.isEmpty()) {
return MethodComponentContext.parse(methodMap);
}
} catch (Exception e) {
// If parsing fails for any reason, fall through to other resolution paths.
}
}
final String modelId = fieldInfo.getAttribute(KNNConstants.MODEL_ID);
if (StringUtils.isNotEmpty(modelId)) {
final ModelMetadata metadata = ModelUtil.getModelMetadata(modelId);
if (metadata != null && metadata.getMethodComponentContext() != null) {
return metadata.getMethodComponentContext();
}
}
return null;
}

private static boolean shouldNormalizeVector(
final KNNEngine knnEngine,
final SpaceType spaceType,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.util.VectorUtil;

import java.io.IOException;
import java.util.Arrays;

/**
* Wraps a {@link KNNFloatVectorValues} and returns L2-normalized copies of the vectors.
*
* <p>This is used by the Faiss native index build path when the space type is cosine similarity.
* Faiss does not natively support cosine similarity; instead we convert cosine to inner product
* and feed unit vectors to the index. Storing the original (unnormalized) vectors in doc values
* while only normalizing at build time preserves the original data for downstream consumers such
* as derived source reconstruction.
*/
public class NormalizingKNNFloatVectorValues extends KNNFloatVectorValues {

private final KNNFloatVectorValues delegate;

public NormalizingKNNFloatVectorValues(final KNNFloatVectorValues delegate) {
super(delegate.getVectorValuesIterator());
this.delegate = delegate;
}

@Override
public float[] getVector() throws IOException {
final float[] original = delegate.getVector();
// Keep local caches consistent with the delegate (dimension/bytesPerVector are populated on first getVector()).
this.dimension = delegate.dimension;
this.bytesPerVector = delegate.bytesPerVector;
final float[] copy = Arrays.copyOf(original, original.length);
VectorUtil.l2normalize(copy);
return copy;
}

@Override
public float[] conditionalCloneVector() throws IOException {
// getVector() already returns a fresh copy, so no further clone is needed.
return getVector();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
package org.opensearch.knn.index.mapper;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.NormalizingKNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;

import java.util.List;

public class NormalizeVectorTransformerTests extends KNNTestCase {
private final NormalizeVectorTransformer transformer = new NormalizeVectorTransformer();
Expand Down Expand Up @@ -52,6 +58,19 @@ public void testNormalizeTransformer_noInplaceUpdate_withValidVector_thenSuccess
assertEquals(1.0f, calculateMagnitude(transformedVector), DELTA);
}

public void testWrap_returnsNormalizingKNNFloatVectorValues() throws Exception {
KNNVectorValues<?> delegate = TestVectorValues.createKNNFloatVectorValues(List.of(new float[] { -3.0f, 4.0f }));
KNNFloatVectorValues wrapped = transformer.wrap((KNNFloatVectorValues) delegate);

assertTrue(wrapped instanceof NormalizingKNNFloatVectorValues);

wrapped.nextDoc();
float[] normalized = wrapped.getVector();
assertEquals(-0.6f, normalized[0], DELTA);
assertEquals(0.8f, normalized[1], DELTA);
assertEquals(1.0f, calculateMagnitude(normalized), DELTA);
}

private float calculateMagnitude(float[] vector) {
float magnitude = 0.0f;
for (float value : vector) {
Expand Down
Loading
Loading