diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index c9de889546..9768e56f79 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -109,7 +109,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( } protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { - return VectorTransformerFactory.getVectorTransformer(knnMethodContext); + return VectorTransformerFactory.getVectorTransformer(knnMethodContext.getKnnEngine(), knnMethodContext.getSpaceType()); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 537317729a..461c6f7c8c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -139,14 +139,15 @@ public void transformQueryVector(float[] vector) { } final Optional knnMethodContext = knnMappingConfig.getKnnMethodContext(); if (knnMethodContext.isPresent()) { - VectorTransformerFactory.getVectorTransformer(knnMethodContext.get()).transform(vector); + KNNMethodContext context = knnMethodContext.get(); + VectorTransformerFactory.getVectorTransformer(context.getKnnEngine(), context.getSpaceType()).transform(vector); return; } final Optional modelId = knnMappingConfig.getModelId(); if (modelId.isPresent()) { ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); final ModelMetadata metadata = modelDao.getMetadata(modelId.get()); - VectorTransformerFactory.getVectorTransformer(metadata).transform(vector); + VectorTransformerFactory.getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()).transform(vector); return; } throw new IllegalStateException("Either KNN method context or Model Id should be configured"); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index fc7638fd01..879706aa8a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -218,7 +218,7 @@ private void initVectorTransformer() { // Need to handle BWC case if (knnMethodContext == null || knnMethodConfigContext == null) { log.debug("Method Context not available - falling back to Model Metadata to determine VectorTransformer instance"); - vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata); + vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType()); return; } diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java index 4df5daa329..9d23ff4d20 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java @@ -9,8 +9,6 @@ import lombok.NoArgsConstructor; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.indices.ModelMetadata; /** * Factory class responsible for creating appropriate vector transformers based on the KNN method context. @@ -25,23 +23,6 @@ public final class VectorTransformerFactory { private final static VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { }; - /** - * Returns a vector transformer based on the provided KNN method context. - * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer - * since FAISS doesn't natively support cosine space type. For all other cases, - * returns a no-operation transformer. - * - * @param context The KNN method context containing engine and space type information - * @return VectorTransformer An appropriate vector transformer instance - * @throws IllegalArgumentException if the context parameter is null - */ - public static VectorTransformer getVectorTransformer(final KNNMethodContext context) { - if (context == null) { - throw new IllegalArgumentException("KNNMethod context cannot be null"); - } - return getVectorTransformer(context.getKnnEngine(), context.getSpaceType()); - } - /** * Returns a vector transformer instance for vector transformations. * This method provides access to the default no-operation vector transformer @@ -56,28 +37,6 @@ public static VectorTransformer getVectorTransformer() { return NOOP_VECTOR_TRANSFORMER; } - /** - * Creates a VectorTransformer based on the provided model metadata. - * - * @param metadata The model metadata containing KNN engine and space type configuration. - * This parameter must not be null. - * @return A VectorTransformer instance configured according to the model metadata - * @throws IllegalArgumentException if metadata is null - * - * The factory determines the appropriate transformer implementation based on: - * - The KNN engine (e.g., FAISS, NMSLIB) - * - The space type (e.g., L2, COSINE) - * - * The returned transformer can be used to modify vectors in-place according to - * the specified engine and space type requirements. - */ - public static VectorTransformer getVectorTransformer(final ModelMetadata metadata) { - if (metadata == null) { - throw new IllegalArgumentException("ModelMetadata cannot be null"); - } - return getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()); - } - /** * Returns a vector transformer based on the provided KNN engine and space type. * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer @@ -88,7 +47,7 @@ public static VectorTransformer getVectorTransformer(final ModelMetadata metadat * @param spaceType The space type * @return VectorTransformer An appropriate vector transformer instance */ - private static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { + public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : getVectorTransformer(); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java index ef666134ce..d93a836a15 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java @@ -8,61 +8,19 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.indices.ModelMetadata; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class VectorTransformerFactoryTests extends KNNTestCase { - public void testGetVectorTransformer_withNullModelMetadata() { - // Test case for null context - assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer((ModelMetadata) null)); - } - - public void testAllSpaceTypes_usingModelMetadata_withFaiss() { - for (SpaceType spaceType : SpaceType.values()) { - ModelMetadata metaData = mock(ModelMetadata.class); - when(metaData.getKnnEngine()).thenReturn(KNNEngine.FAISS); - when(metaData.getSpaceType()).thenReturn(spaceType); - VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(metaData); - validateTransformer(spaceType, KNNEngine.FAISS, transformer); - } - } - - public void testAllEngines_usingModelMetadata_withCosine() { - for (KNNEngine engine : KNNEngine.values()) { - ModelMetadata metaData = mock(ModelMetadata.class); - when(metaData.getKnnEngine()).thenReturn(engine); - when(metaData.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); - VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(metaData); - validateTransformer(SpaceType.COSINESIMIL, engine, transformer); - } - } - - public void testGetVectorTransformer_withNullContext() { - // Test case for null context - assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer((KNNMethodContext) null)); - } - - public void testAllSpaceTypes_usingContext_withFaiss() { + public void testAllSpaceTypes_withFaiss() { for (SpaceType spaceType : SpaceType.values()) { - KNNMethodContext context = mock(KNNMethodContext.class); - when(context.getKnnEngine()).thenReturn(KNNEngine.FAISS); - when(context.getSpaceType()).thenReturn(spaceType); - VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); validateTransformer(spaceType, KNNEngine.FAISS, transformer); } } - public void testAllEngines_usingContext_withCosine() { - // Test all engines with COSINESIMIL space type + public void testAllEngines_withCosine() { for (KNNEngine engine : KNNEngine.values()) { - KNNMethodContext context = mock(KNNMethodContext.class); - when(context.getKnnEngine()).thenReturn(engine); - when(context.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); - VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(engine, SpaceType.COSINESIMIL); validateTransformer(SpaceType.COSINESIMIL, engine, transformer); } }