Skip to content

Commit

Permalink
Remove model meta data and context from factory
Browse files Browse the repository at this point in the history
Signed-off-by: Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 16, 2025
1 parent 6ed6fd0 commit afa94b6
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
}

protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) {
return VectorTransformerFactory.getVectorTransformer(knnMethodContext);
return VectorTransformerFactory.getVectorTransformer(knnMethodContext.getKnnEngine(), knnMethodContext.getSpaceType());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,15 @@ public void transformQueryVector(float[] vector) {
}
final Optional<KNNMethodContext> knnMethodContext = knnMappingConfig.getKnnMethodContext();
if (knnMethodContext.isPresent()) {
VectorTransformerFactory.getVectorTransformer(knnMethodContext.get()).transform(vector);
KNNMethodContext context = knnMethodContext.get();
VectorTransformerFactory.getVectorTransformer(context.getKnnEngine(), context.getSpaceType()).transform(vector);
return;
}
final Optional<String> modelId = knnMappingConfig.getModelId();
if (modelId.isPresent()) {
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata metadata = modelDao.getMetadata(modelId.get());
VectorTransformerFactory.getVectorTransformer(metadata).transform(vector);
VectorTransformerFactory.getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()).transform(vector);
return;
}
throw new IllegalStateException("Either KNN method context or Model Id should be configured");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ private void initVectorTransformer() {
// Need to handle BWC case
if (knnMethodContext == null || knnMethodConfigContext == null) {
log.debug("Method Context not available - falling back to Model Metadata to determine VectorTransformer instance");
vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata);
vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType());
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.indices.ModelMetadata;

/**
* Factory class responsible for creating appropriate vector transformers based on the KNN method context.
Expand All @@ -25,23 +23,6 @@ public final class VectorTransformerFactory {
private final static VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() {
};

/**
* Returns a vector transformer based on the provided KNN method context.
* For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer
* since FAISS doesn't natively support cosine space type. For all other cases,
* returns a no-operation transformer.
*
* @param context The KNN method context containing engine and space type information
* @return VectorTransformer An appropriate vector transformer instance
* @throws IllegalArgumentException if the context parameter is null
*/
public static VectorTransformer getVectorTransformer(final KNNMethodContext context) {
if (context == null) {
throw new IllegalArgumentException("KNNMethod context cannot be null");
}
return getVectorTransformer(context.getKnnEngine(), context.getSpaceType());
}

/**
* Returns a vector transformer instance for vector transformations.
* This method provides access to the default no-operation vector transformer
Expand All @@ -56,28 +37,6 @@ public static VectorTransformer getVectorTransformer() {
return NOOP_VECTOR_TRANSFORMER;
}

/**
* Creates a VectorTransformer based on the provided model metadata.
*
* @param metadata The model metadata containing KNN engine and space type configuration.
* This parameter must not be null.
* @return A VectorTransformer instance configured according to the model metadata
* @throws IllegalArgumentException if metadata is null
*
* The factory determines the appropriate transformer implementation based on:
* - The KNN engine (e.g., FAISS, NMSLIB)
* - The space type (e.g., L2, COSINE)
*
* The returned transformer can be used to modify vectors in-place according to
* the specified engine and space type requirements.
*/
public static VectorTransformer getVectorTransformer(final ModelMetadata metadata) {
if (metadata == null) {
throw new IllegalArgumentException("ModelMetadata cannot be null");
}
return getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType());
}

/**
* Returns a vector transformer based on the provided KNN engine and space type.
* For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer
Expand All @@ -88,7 +47,7 @@ public static VectorTransformer getVectorTransformer(final ModelMetadata metadat
* @param spaceType The space type
* @return VectorTransformer An appropriate vector transformer instance
*/
private static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) {
public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) {
return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : getVectorTransformer();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,19 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.indices.ModelMetadata;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class VectorTransformerFactoryTests extends KNNTestCase {

public void testGetVectorTransformer_withNullModelMetadata() {
// Test case for null context
assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer((ModelMetadata) null));
}

public void testAllSpaceTypes_usingModelMetadata_withFaiss() {
for (SpaceType spaceType : SpaceType.values()) {
ModelMetadata metaData = mock(ModelMetadata.class);
when(metaData.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(metaData.getSpaceType()).thenReturn(spaceType);
VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(metaData);
validateTransformer(spaceType, KNNEngine.FAISS, transformer);
}
}

public void testAllEngines_usingModelMetadata_withCosine() {
for (KNNEngine engine : KNNEngine.values()) {
ModelMetadata metaData = mock(ModelMetadata.class);
when(metaData.getKnnEngine()).thenReturn(engine);
when(metaData.getSpaceType()).thenReturn(SpaceType.COSINESIMIL);
VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(metaData);
validateTransformer(SpaceType.COSINESIMIL, engine, transformer);
}
}

public void testGetVectorTransformer_withNullContext() {
// Test case for null context
assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer((KNNMethodContext) null));
}

public void testAllSpaceTypes_usingContext_withFaiss() {
public void testAllSpaceTypes_withFaiss() {
for (SpaceType spaceType : SpaceType.values()) {
KNNMethodContext context = mock(KNNMethodContext.class);
when(context.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(context.getSpaceType()).thenReturn(spaceType);
VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context);
VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType);
validateTransformer(spaceType, KNNEngine.FAISS, transformer);
}
}

public void testAllEngines_usingContext_withCosine() {
// Test all engines with COSINESIMIL space type
public void testAllEngines_withCosine() {
for (KNNEngine engine : KNNEngine.values()) {
KNNMethodContext context = mock(KNNMethodContext.class);
when(context.getKnnEngine()).thenReturn(engine);
when(context.getSpaceType()).thenReturn(SpaceType.COSINESIMIL);
VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context);
VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(engine, SpaceType.COSINESIMIL);
validateTransformer(SpaceType.COSINESIMIL, engine, transformer);
}
}
Expand Down

0 comments on commit afa94b6

Please sign in to comment.