diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 147e260b92..5d90071e84 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -77,11 +77,6 @@ public float scoreTranslation(float rawScore) { return Math.max((2.0F - rawScore) / 2.0F, 0.0F); } - @Override - public float scoreToDistanceTranslation(float score) { - return score; - } - @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.COSINE; diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index d4222dc8d2..fae1184f88 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -47,11 +47,14 @@ public class Faiss extends NativeLibrary { // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + Function>builder() + .put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1) + .put(SpaceType.COSINESIMIL, score -> 2 - 2 * score) + .build(); private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build(); + Function>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build(); // Package private so that the method resolving logic can access the methods final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index e1a6164338..8604d55062 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -675,7 +675,7 @@ protected void validatePreparse() { protected abstract VectorValidator getVectorValidator(); /** - * Getter for per dimension validator during vector parsing + * Getter for per dimension validator during vector parsing, and before any transformation * * @return PerDimensionValidator */ @@ -688,6 +688,11 @@ protected void validatePreparse() { */ protected abstract PerDimensionProcessor getPerDimensionProcessor(); + /** + * Getter for vector transformer after vector parsing and validation + * + * @return VectorTransformer + */ protected abstract VectorTransformer getVectorTransformer(); protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { @@ -700,8 +705,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - final byte[] transformedArray = getVectorTransformer().transform(array); - context.doc().addAll(getFieldsForByteVector(transformedArray)); + getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -710,8 +715,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - final float[] transformedArray = getVectorTransformer().transform(array); - context.doc().addAll(getFieldsForFloatVector(transformedArray)); + getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index a0832c1d08..91a9801412 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -16,7 +16,9 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; @@ -99,4 +101,21 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) Mode mode = knnMappingConfig.getMode(); return compressionLevel.getDefaultRescoreContext(mode, dimension); } + + /** + * Transforms a query vector based on the specified engine and space type. + * Only performs transformations for FLOAT type vectors, leaving other types unchanged. + * + * @param vector The float array to be transformed in place. Must not be null. + * @param engine The KNN engine (e.g., FAISS, NMSLIB) to be used for the transformation + * @param spaceType The space type (e.g., L2, COSINE) that determines the transformation method + * @throws IllegalArgumentException if the vector is null, empty, or zero vector + */ + public void transformQueryVector(float[] vector, KNNEngine engine, SpaceType spaceType) { + if (vectorDataType != VectorDataType.FLOAT) { + return; + } + VectorTransformerFactory.getVectorTransformer(engine, spaceType).transform(vector); + } + } diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java index 6a9642435a..a348cd1b88 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -2,30 +2,25 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.mapper; import org.apache.lucene.util.VectorUtil; /** - * Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures - * that the vector's magnitude becomes 1 while preserving its directional properties. + * Normalizes vectors using L2 (Euclidean) normalization, ensuring the vector's + * magnitude becomes 1 while preserving its directional properties. */ public class NormalizeVectorTransformer implements VectorTransformer { - /** - * Transforms the input vector into unit vector by applying L2 normalization. - * - * @param vector The input vector to be normalized. Must not be null. - * @return A new float array containing the L2-normalized version of the input vector. - * Each component is divided by the Euclidean norm of the vector. - * @throws IllegalArgumentException if the input vector is null, empty, or a zero vector - */ @Override - public float[] transform(float[] vector) { + public void transform(float[] vector) { + validateVector(vector); + VectorUtil.l2normalize(vector); + } + + private void validateVector(float[] vector) { if (vector == null || vector.length == 0) { throw new IllegalArgumentException("Vector cannot be null or empty"); } - return VectorUtil.l2normalize(vector); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java index f02df13ef9..ac6a9b1ac4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java @@ -2,11 +2,8 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.mapper; -import java.util.Arrays; - /** * Defines operations for transforming vectors in the k-NN search context. * Implementations can modify vectors while preserving their dimensional properties @@ -15,50 +12,31 @@ public interface VectorTransformer { /** - * Transforms a float vector into a new vector of the same type. - * - * Example: - *
{@code
-     * float[] input = {1.0f, 2.0f, 3.0f};
-     * float[] transformed = transformer.transform(input);
-     * }
+ * Transforms a float vector in place. * * @param vector The input vector to transform (must not be null) - * @return The transformed vector * @throws IllegalArgumentException if the input vector is null */ - default float[] transform(final float[] vector) { + default void transform(final float[] vector) { if (vector == null) { throw new IllegalArgumentException("Input vector cannot be null"); } - return Arrays.copyOf(vector, vector.length); } /** - * Transforms a byte vector into a new vector of the same type. - * - * Example: - *
{@code
-     * byte[] input = {1, 2, 3};
-     * byte[] transformed = transformer.transform(input);
-     * }
+ * Transforms a byte vector in place. * * @param vector The input vector to transform (must not be null) - * @return The transformed vector * @throws IllegalArgumentException if the input vector is null */ - default byte[] transform(final byte[] vector) { + default void transform(final byte[] vector) { if (vector == null) { throw new IllegalArgumentException("Input vector cannot be null"); } - // return copy of vector to avoid side effects - return Arrays.copyOf(vector, vector.length); - } /** * A no-operation transformer that returns vector values unchanged. - * This constant can be used when no transformation is needed. */ VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { }; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index c2998df6c6..2fa3891a91 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -12,7 +12,6 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.util.VectorUtil; import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; @@ -429,6 +428,7 @@ protected Query doToQuery(QueryShardContext context) { SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); + knnVectorFieldType.transformQueryVector(vector, knnEngine, spaceType); VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); updateQueryStats(vectorQueryType); @@ -542,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .k(this.k) @@ -559,7 +559,7 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .radius(radius) @@ -612,13 +612,7 @@ private void updateQueryStats(VectorQueryType vectorQueryType) { } } - private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) { - - // Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize - // query vector before applying search - if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) { - return VectorUtil.l2normalize(this.vector); - } + private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) { if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) { return this.vector; } diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java index 532985232a..1c4237d7b8 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -21,13 +21,13 @@ public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { public void testNormalizeTransformer_withValidVector_thenSuccess() { float[] input = { -3.0f, 4.0f }; - float[] normalized = transformer.transform(input); + transformer.transform(input); - assertEquals(-0.6f, normalized[0], DELTA); - assertEquals(0.8f, normalized[1], DELTA); + assertEquals(-0.6f, input[0], DELTA); + assertEquals(0.8f, input[1], DELTA); // Verify the magnitude is 1 - assertEquals(1.0f, calculateMagnitude(normalized), DELTA); + assertEquals(1.0f, calculateMagnitude(input), DELTA); } private float calculateMagnitude(float[] vector) {