diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index d4222dc8d2..fae1184f88 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -47,11 +47,14 @@ public class Faiss extends NativeLibrary { // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + Function>builder() + .put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1) + .put(SpaceType.COSINESIMIL, score -> 2 - 2 * score) + .build(); private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build(); + Function>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build(); // Package private so that the method resolving logic can access the methods final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index e1a6164338..8604d55062 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -675,7 +675,7 @@ protected void validatePreparse() { protected abstract VectorValidator getVectorValidator(); /** - * Getter for per dimension validator during vector parsing + * Getter for per dimension validator during vector parsing, and before any transformation * * @return PerDimensionValidator */ @@ -688,6 +688,11 @@ protected void validatePreparse() { */ protected abstract PerDimensionProcessor getPerDimensionProcessor(); + /** + * Getter for vector transformer after vector parsing and validation + * + * @return VectorTransformer + */ protected abstract VectorTransformer getVectorTransformer(); protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { @@ -700,8 +705,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - final byte[] transformedArray = getVectorTransformer().transform(array); - context.doc().addAll(getFieldsForByteVector(transformedArray)); + getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -710,8 +715,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - final float[] transformedArray = getVectorTransformer().transform(array); - context.doc().addAll(getFieldsForFloatVector(transformedArray)); + getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index a0832c1d08..964934d1d6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -17,12 +17,14 @@ import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; @@ -99,4 +101,32 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) Mode mode = knnMappingConfig.getMode(); return compressionLevel.getDefaultRescoreContext(mode, dimension); } + + /** + * Transforms a query vector based on the configured vector data type and KNN method context. + * This method only performs transformations on FLOAT type vectors, leaving other types unchanged. + * + * @param vector The float array to be transformed in place. The transformation will modify + * the original array values directly. + * @throws IllegalStateException if the KNN method context is not properly configured + * or is missing from the mapping configuration + * + * The transformation process: + * 1. Checks if the vector is of FLOAT type + * 2. Retrieves the KNN method context from mapping configuration + * 3. Applies the appropriate vector transformation based on the method context + * + * If the vector is not of FLOAT type, this method returns without performing any transformation. + */ + + public void transformQueryVector(float[] vector) { + if (VectorDataType.FLOAT != vectorDataType) { + return; + } + final Optional knnMethodContext = knnMappingConfig.getKnnMethodContext(); + if (knnMethodContext.isEmpty()) { + throw new IllegalStateException("KNN method context is not set"); + } + VectorTransformerFactory.getVectorTransformer(knnMethodContext.get()).transform(vector); + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java index 6a9642435a..a348cd1b88 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -2,30 +2,25 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.mapper; import org.apache.lucene.util.VectorUtil; /** - * Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures - * that the vector's magnitude becomes 1 while preserving its directional properties. + * Normalizes vectors using L2 (Euclidean) normalization, ensuring the vector's + * magnitude becomes 1 while preserving its directional properties. */ public class NormalizeVectorTransformer implements VectorTransformer { - /** - * Transforms the input vector into unit vector by applying L2 normalization. - * - * @param vector The input vector to be normalized. Must not be null. - * @return A new float array containing the L2-normalized version of the input vector. - * Each component is divided by the Euclidean norm of the vector. - * @throws IllegalArgumentException if the input vector is null, empty, or a zero vector - */ @Override - public float[] transform(float[] vector) { + public void transform(float[] vector) { + validateVector(vector); + VectorUtil.l2normalize(vector); + } + + private void validateVector(float[] vector) { if (vector == null || vector.length == 0) { throw new IllegalArgumentException("Vector cannot be null or empty"); } - return VectorUtil.l2normalize(vector); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java index f02df13ef9..ac6a9b1ac4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java @@ -2,11 +2,8 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.mapper; -import java.util.Arrays; - /** * Defines operations for transforming vectors in the k-NN search context. * Implementations can modify vectors while preserving their dimensional properties @@ -15,50 +12,31 @@ public interface VectorTransformer { /** - * Transforms a float vector into a new vector of the same type. - * - * Example: - *
{@code
-     * float[] input = {1.0f, 2.0f, 3.0f};
-     * float[] transformed = transformer.transform(input);
-     * }
+ * Transforms a float vector in place. * * @param vector The input vector to transform (must not be null) - * @return The transformed vector * @throws IllegalArgumentException if the input vector is null */ - default float[] transform(final float[] vector) { + default void transform(final float[] vector) { if (vector == null) { throw new IllegalArgumentException("Input vector cannot be null"); } - return Arrays.copyOf(vector, vector.length); } /** - * Transforms a byte vector into a new vector of the same type. - * - * Example: - *
{@code
-     * byte[] input = {1, 2, 3};
-     * byte[] transformed = transformer.transform(input);
-     * }
+ * Transforms a byte vector in place. * * @param vector The input vector to transform (must not be null) - * @return The transformed vector * @throws IllegalArgumentException if the input vector is null */ - default byte[] transform(final byte[] vector) { + default void transform(final byte[] vector) { if (vector == null) { throw new IllegalArgumentException("Input vector cannot be null"); } - // return copy of vector to avoid side effects - return Arrays.copyOf(vector, vector.length); - } /** * A no-operation transformer that returns vector values unchanged. - * This constant can be used when no transformation is needed. */ VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { }; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index c2998df6c6..ce23ed9bd9 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -12,7 +12,6 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.util.VectorUtil; import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; @@ -429,6 +428,7 @@ protected Query doToQuery(QueryShardContext context) { SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); + knnVectorFieldType.transformQueryVector(vector); VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); updateQueryStats(vectorQueryType); @@ -542,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .k(this.k) @@ -612,13 +612,7 @@ private void updateQueryStats(VectorQueryType vectorQueryType) { } } - private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) { - - // Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize - // query vector before applying search - if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) { - return VectorUtil.l2normalize(this.vector); - } + private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) { if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) { return this.vector; } diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java index 532985232a..1c4237d7b8 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -21,13 +21,13 @@ public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { public void testNormalizeTransformer_withValidVector_thenSuccess() { float[] input = { -3.0f, 4.0f }; - float[] normalized = transformer.transform(input); + transformer.transform(input); - assertEquals(-0.6f, normalized[0], DELTA); - assertEquals(0.8f, normalized[1], DELTA); + assertEquals(-0.6f, input[0], DELTA); + assertEquals(0.8f, input[1], DELTA); // Verify the magnitude is 1 - assertEquals(1.0f, calculateMagnitude(normalized), DELTA); + assertEquals(1.0f, calculateMagnitude(input), DELTA); } private float calculateMagnitude(float[] vector) {