diff --git a/CHANGELOG.md b/CHANGELOG.md index 70ec3ae510..ac51ceafdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,9 +21,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.19...2.x) ### Features ### Enhancements +- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378] ### Bug Fixes * Fix derived source for binary and byte vectors [#2533](https://github.com/opensearch-project/k-NN/pull/2533/) * Fix the put mapping issue for already created index with flat mapper [#2542](https://github.com/opensearch-project/k-NN/pull/2542) +* Fixing the bug to prevent index.knn setting from being modified or removed on restore snapshot (#2445)[https://github.com/opensearch-project/k-NN/pull/2445] ### Infrastructure ### Documentation ### Maintenance diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java index 8df068b739..a22f456351 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java @@ -226,7 +226,7 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex // Add training data createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, DIMENSION); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -279,7 +279,7 @@ public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Ra // Add training data createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c33f3ea63c..dffe3009ed 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -283,7 +283,13 @@ public class KNNSettings { /** * This setting identifies KNN index. */ - public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting(KNN_INDEX, false, IndexScope, Final); + public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting( + KNN_INDEX, + false, + IndexScope, + Final, + UnmodifiableOnRestore + ); /** * index_thread_quantity - the parameter specifies how many threads the nms library should use to create the graph. diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index bfd908a099..822f0e2ca1 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -22,6 +22,12 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Function; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; /** * Abstract class for KNN methods. This class provides the common functionality for all KNN methods. @@ -108,6 +114,55 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( return PerDimensionProcessor.NOOP_PROCESSOR; } + protected Function doGetTrainingConfigValidationSetup() { + return (trainingConfigValidationInput) -> { + + KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext(); + Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); + + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + + // validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension + if (knnMethodContext != null && knnMethodConfigContext != null) { + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M) + && knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(ENCODER_PARAMETER_PQ_M) != 0) { + builder.valid(false); + return builder.build(); + } else { + builder.valid(true); + } + } + + // validate number of training points should be greater than minimum clustering criteria defined in faiss + if (knnMethodContext != null && trainingVectors != null) { + long minTrainingVectorCount = 1000; + + MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER); + + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST) + && encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { + + int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST)); + int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); + minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size)); + } + + if (trainingVectors < minTrainingVectorCount) { + builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); + return builder.build(); + } else { + builder.valid(true); + } + } + return builder.build(); + }; + } + protected VectorTransformer getVectorTransformer(SpaceType spaceType) { return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; } @@ -131,6 +186,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) .vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType())) + .trainingConfigValidationSetup(doGetTrainingConfigValidationSetup()) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index 1ff677cd61..9bef9e2e47 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -12,6 +12,7 @@ import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; +import java.util.function.Function; /** * Context a library gives to build one of its indices @@ -49,6 +50,12 @@ public interface KNNLibraryIndexingContext { */ PerDimensionProcessor getPerDimensionProcessor(); + /** + * + * @return Get function that validates training model parameters + */ + Function getTrainingConfigValidationSetup(); + /** * Get the vector transformer that will be used to transform the vector before indexing. * This will be applied at vector level once entire vector is parsed and validated. diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index 9822033b72..46b5cb215b 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -14,6 +14,7 @@ import java.util.Collections; import java.util.Map; +import java.util.function.Function; /** * Simple implementation of {@link KNNLibraryIndexingContext} @@ -29,6 +30,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext private Map parameters = Collections.emptyMap(); @Builder.Default private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY; + private Function trainingConfigValidationSetup; @Override public Map getLibraryParameters() { @@ -59,4 +61,9 @@ public PerDimensionValidator getPerDimensionValidator() { public PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + + @Override + public Function getTrainingConfigValidationSetup() { + return trainingConfigValidationSetup; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationInput.java b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationInput.java new file mode 100644 index 0000000000..5070173f68 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationInput.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * This object provides the input of the validation checks for training model inputs. + * The values in this object need to be dynamically set and calling code needs to handle + * the possibility that the values have not been set. + */ +@Setter +@Getter +@Builder +@AllArgsConstructor +public class TrainingConfigValidationInput { + private Long trainingVectorsCount; + private KNNMethodContext knnMethodContext; + private KNNMethodConfigContext knnMethodConfigContext; +} diff --git a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java new file mode 100644 index 0000000000..0cbe6cad5d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * This object provides the output of the validation checks for training model inputs. + * The values in this object need to be dynamically set and calling code needs to handle + * the possibility that the values have not been set. + */ +@Setter +@Getter +@Builder +@AllArgsConstructor +public class TrainingConfigValidationOutput { + private boolean valid; + private long minTrainingVectorCount; +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 3e59cc6621..f58c14c082 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -23,12 +23,18 @@ import org.opensearch.common.ValidationException; import org.opensearch.common.inject.Inject; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; import java.util.Map; +import java.util.function.Function; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; @@ -134,6 +140,29 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques trainingVectors = trainingModelRequest.getMaximumVectorCount(); } + KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext(); + + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + + Function validateTrainingConfig = knnLibraryIndexingContext + .getTrainingConfigValidationSetup(); + + TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder(); + + TrainingConfigValidationOutput validation = validateTrainingConfig.apply( + inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build() + ); + if (!validation.isValid()) { + ValidationException exception = new ValidationException(); + exception.addValidationError( + String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount()) + ); + listener.onFailure(exception); + return; + } + listener.onResponse( estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType()) ); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9906ab490b..bd2c883477 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -30,10 +30,14 @@ import org.opensearch.knn.index.engine.EngineResolver; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import java.io.IOException; +import java.util.function.Function; /** * Request to train and serialize a model @@ -283,6 +287,21 @@ public ActionRequestValidationException validate() { exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters"); } + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + Function validateTrainingConfig = knnLibraryIndexingContext + .getTrainingConfigValidationSetup(); + TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder(); + TrainingConfigValidationOutput validation = validateTrainingConfig.apply( + inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build() + ); + + // Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension + if (!validation.isValid()) { + exception = exception == null ? new ActionRequestValidationException() : exception; + exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions"); + } + // Validate training index exists IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex); if (indexMetadata == null) { diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 605a19660a..d3731e5aa9 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -200,9 +200,7 @@ public void run() { } catch (Exception e) { logger.error("Failed to run training job for model \"" + modelId + "\": ", e); modelMetadata.setState(ModelState.FAILED); - modelMetadata.setError( - "Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training." - ); + modelMetadata.setError("Failed to execute training. " + e.getMessage()); KNNCounter.TRAINING_ERRORS.increment(); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 4113579bf3..32ae18ea0b 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -304,7 +304,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN // training data needs to be at least equal to the number of centroids for PQ // which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ - int trainingDataCount = 256; + int trainingDataCount = 1100; SpaceType spaceType = SpaceType.L2; @@ -468,7 +468,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { // training data needs to be at least equal to the number of centroids for PQ // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ - int trainingDataCount = 256; + int trainingDataCount = 1100; SpaceType spaceType = SpaceType.L2; @@ -736,7 +736,7 @@ public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() { // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -960,7 +960,7 @@ public void testIVFSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1064,7 +1064,7 @@ public void testIVFSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenS // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1144,7 +1144,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( // training data needs to be at least equal to the number of centroids for PQ // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ - int trainingDataCount = 256; + int trainingDataCount = 1100; SpaceType spaceType = SpaceType.L2; @@ -1414,7 +1414,7 @@ public void testKNNQuery_withModelDifferentCombination_thenSuccess() { // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); // Call train API - IVF with nlists = 1 is brute force, but will require training @@ -1769,7 +1769,7 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { createKnnIndex(trainingIndexName, trainIndexMapping); - int trainingDataCount = 40; + int trainingDataCount = 1100; bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/RestoreSnapshotIT.java b/src/test/java/org/opensearch/knn/index/RestoreSnapshotIT.java new file mode 100644 index 0000000000..a2c48a2210 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/RestoreSnapshotIT.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.junit.Before; +import org.junit.Test; +import lombok.SneakyThrows; +import static org.hamcrest.Matchers.containsString; + +public class RestoreSnapshotIT extends KNNRestTestCase { + + private String index = "test-index";; + private String snapshot = "snapshot-" + index; + private String repository = "repo"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + setupSnapshotRestore(index, snapshot, repository); + } + + @Test + @SneakyThrows + public void testKnnSettingIsModifiable_whenRestore_thenSuccess() { + // valid restore + XContentBuilder restoreCommand = JsonXContent.contentBuilder().startObject(); + restoreCommand.field("indices", index); + restoreCommand.field("rename_pattern", index); + restoreCommand.field("rename_replacement", "restored-" + index); + restoreCommand.startObject("index_settings"); + { + restoreCommand.field("knn.model.index.number_of_shards", 1); + } + restoreCommand.endObject(); + restoreCommand.endObject(); + Request restoreRequest = new Request("POST", "/_snapshot/" + repository + "/" + snapshot + "/_restore"); + restoreRequest.addParameter("wait_for_completion", "true"); + restoreRequest.setJsonEntity(restoreCommand.toString()); + + final Response restoreResponse = client().performRequest(restoreRequest); + assertEquals(200, restoreResponse.getStatusLine().getStatusCode()); + } + + @Test + @SneakyThrows + public void testKnnSettingIsUnmodifiable_whenRestore_thenFailure() { + // invalid restore + XContentBuilder restoreCommand = JsonXContent.contentBuilder().startObject(); + restoreCommand.field("indices", index); + restoreCommand.field("rename_pattern", index); + restoreCommand.field("rename_replacement", "restored-" + index); + restoreCommand.startObject("index_settings"); + { + restoreCommand.field("index.knn", false); + } + restoreCommand.endObject(); + restoreCommand.endObject(); + Request restoreRequest = new Request("POST", "/_snapshot/" + repository + "/" + snapshot + "/_restore"); + restoreRequest.addParameter("wait_for_completion", "true"); + restoreRequest.setJsonEntity(restoreCommand.toString()); + final ResponseException error = expectThrows(ResponseException.class, () -> client().performRequest(restoreRequest)); + assertThat(error.getMessage(), containsString("cannot modify UnmodifiableOnRestore setting [index.knn]" + " on restore")); + } + + @Test + @SneakyThrows + public void testKnnSettingCanBeIgnored_whenRestore_thenSuccess() { + // valid restore + XContentBuilder restoreCommand = JsonXContent.contentBuilder().startObject(); + restoreCommand.field("indices", index); + restoreCommand.field("rename_pattern", index); + restoreCommand.field("rename_replacement", "restored-" + index); + restoreCommand.field("ignore_index_settings", "knn.model.index.number_of_shards"); + restoreCommand.endObject(); + Request restoreRequest = new Request("POST", "/_snapshot/" + repository + "/" + snapshot + "/_restore"); + restoreRequest.addParameter("wait_for_completion", "true"); + restoreRequest.setJsonEntity(restoreCommand.toString()); + final Response restoreResponse = client().performRequest(restoreRequest); + assertEquals(200, restoreResponse.getStatusLine().getStatusCode()); + } + + @Test + @SneakyThrows + public void testKnnSettingCannotBeIgnored_whenRestore_thenFailure() { + // invalid restore + XContentBuilder restoreCommand = JsonXContent.contentBuilder().startObject(); + restoreCommand.field("indices", index); + restoreCommand.field("rename_pattern", index); + restoreCommand.field("rename_replacement", "restored-" + index); + restoreCommand.field("ignore_index_settings", "index.knn"); + restoreCommand.endObject(); + Request restoreRequest = new Request("POST", "/_snapshot/" + repository + "/" + snapshot + "/_restore"); + restoreRequest.addParameter("wait_for_completion", "true"); + restoreRequest.setJsonEntity(restoreCommand.toString()); + final ResponseException error = expectThrows(ResponseException.class, () -> client().performRequest(restoreRequest)); + assertThat(error.getMessage(), containsString("cannot remove UnmodifiableOnRestore setting [index.knn] on restore")); + } +} diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 6e0f954a7a..14305dbb81 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -603,7 +603,7 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() { .toString(); createKnnIndex(INDEX_NAME, trainIndexMapping); - int trainingDataCount = 100; + int trainingDataCount = 1100; bulkIngestRandomByteVectors(INDEX_NAME, FIELD_NAME, trainingDataCount, dimension); XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index 6288e2076a..5a630b5d68 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -629,7 +629,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { int dimensions = randomIntBetween(2, 10); String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions); createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping); - bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions); + bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, 1100, dimensions); XContentBuilder methodBuilder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index ef2cee8f22..e05573609f 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -52,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase { private static final String TRAINING_INDEX_NAME = "training_index"; private static final String TRAINING_FIELD_NAME = "training_field"; - private static final int TRAINING_VECS = 20; + private static final int TRAINING_VECS = 1100; private static final int DIMENSION = 16; private static final int NUM_DOCS = 20; diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index e8da155892..1324616092 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -70,7 +70,7 @@ public class RestKNNStatsHandlerIT extends KNNRestTestCase { private static final String FIELD_LUCENE_NAME = "lucene_test_field"; private static final int DIMENSION = 4; private static int DOC_ID = 0; - private static final int NUM_DOCS = 10; + private static final int NUM_DOCS = 1100; private static final int DELAY_MILLI_SEC = 1000; private static final int NUM_OF_ATTEMPTS = 30; diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index dc6e2e0ee5..c1d0d9f9a7 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -97,28 +97,11 @@ public void testTrainModel_fail_notEnoughData() throws Exception { .endObject(); Map method = xContentBuilderToMap(builder); - Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); - - assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); - - // Grab the model id from the response - String trainResponseBody = EntityUtils.toString(trainResponse.getEntity()); - assertNotNull(trainResponseBody); - - Map trainResponseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), trainResponseBody).map(); - String modelId = (String) trainResponseMap.get(MODEL_ID); - assertNotNull(modelId); - - // Confirm that the model fails to create - Response getResponse = getModel(modelId, null); - String responseBody = EntityUtils.toString(getResponse.getEntity()); - assertNotNull(responseBody); - - Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); - - assertEquals(modelId, responseMap.get(MODEL_ID)); - - assertTrainingFails(modelId, 30, 1000); + ResponseException exception = expectThrows( + ResponseException.class, + () -> trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description") + ); + assertTrue(exception.getMessage().contains("Number of training points should be greater than")); } public void testTrainModel_fail_tooMuchData() throws Exception { @@ -132,7 +115,7 @@ public void testTrainModel_fail_tooMuchData() throws Exception { // Create a training index and randomly ingest data into it createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 20; // 20 * 16 * 4 ~= 10 kb + int trainingDataCount = 128; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); // Call the train API with this definition: @@ -491,7 +474,7 @@ public void testTrainModel_success_methodOverrideWithCompressionMode() throws Ex // Create a training index and randomly ingest data into it String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath); createKnnIndex(trainingIndexName, mapping); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension); // Call the train API with this definition: diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 0507788f0f..cfa362ca6d 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -36,6 +36,7 @@ import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -344,6 +345,55 @@ public void testTrainingIndexSize() { transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); } + public void testTrainingIndexSizeFailure() { + + String trainingIndexName = "training-index"; + int dimension = 133; + int vectorCount = 100; + + // Setup the request + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + null, + getDefaultKNNMethodContextForModel(), + dimension, + trainingIndexName, + "training-field", + null, + "description", + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + // Mock client to return the right number of docs + TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO); + SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + Client client = mock(Client.class); + doAnswer(invocationOnMock -> { + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + // Setup the action + ClusterService clusterService = mock(ClusterService.class); + TransportService transportService = mock(TransportService.class); + TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); + + ActionListener listener = ActionListener.wrap( + size -> size.intValue(), + e -> assertThat(e.getMessage(), containsString("Number of training points should be greater than")) + ); + + transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); + } + public void testTrainIndexSize_whenDataTypeIsBinary() { String trainingIndexName = "training-index"; int dimension = 8; diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 6fd3994349..fdffc91d02 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -621,11 +621,61 @@ public void testValidation_invalid_descriptionToLong() { ActionRequestValidationException exception = trainingModelRequest.validate(); assertNotNull(exception); List validationErrors = exception.validationErrors(); - logger.error("Validation errorsa " + validationErrors); + logger.error("Validation errors " + validationErrors); assertEquals(1, validationErrors.size()); assertTrue(validationErrors.get(0).contains("Description exceeds limit")); } + public void testValidation_invalid_mNotDivisibleByDimension() { + + // Setup the training request + String modelId = "test-model-id"; + int dimension = 10; + String trainingIndex = "test-training-index"; + String trainingField = "test-training-field"; + String trainingFieldModeId = "training-field-model-id"; + + Map parameters = Map.of("m", 3); + + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters); + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext); + + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null, + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + // Mock the model dao to return metadata for modelId to recognize it is a duplicate + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(null); + when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata); + + // Cluster service that wont produce validation exception + ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); + + // Initialize static components with the mocks + TrainingModelRequest.initialize(modelDao, clusterService); + + // Test that validation produces m not divisible by vector dimension error message + ActionRequestValidationException exception = trainingModelRequest.validate(); + assertNotNull(exception); + List validationErrors = exception.validationErrors(); + logger.error("Validation errors " + validationErrors); + assertEquals(2, validationErrors.size()); + assertTrue(validationErrors.get(1).contains("Training request ENCODER_PARAMETER_PQ_M")); + } + public void testValidation_valid_trainingIndexBuiltFromMethod() { // This cluster service will result in no validation exceptions diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index ae162401bb..9cff112710 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -67,7 +67,7 @@ public class RecallTestsIT extends KNNRestTestCase { private final static String TRAIN_FIELD_NAME = "train_field"; private final static String TEST_MODEL_ID = "test_model_id"; private final static int TEST_DIMENSION = 32; - private final static int DOC_COUNT = 500; + private final static int DOC_COUNT = 1100; private final static int QUERY_COUNT = 100; private final static int TEST_K = 100; private final static double PERFECT_RECALL = 1.0; diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 4706bd0009..8db9d67bc6 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -41,6 +41,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -217,7 +218,6 @@ public void testRun_success() throws IOException, ExecutionException { Model model = trainingJob.getModel(); assertNotNull(model); - assertEquals(ModelState.CREATED, model.getModelMetadata().getState()); // Simple test that creates the index from template and doesnt fail @@ -308,6 +308,10 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept Model model = trainingJob.getModel(); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + assertThat( + "Failed to load training data into memory. " + "Check if there is enough memory to perform the request.", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertFalse(model.getModelMetadata().getError().isEmpty()); } @@ -382,6 +386,10 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce Model model = trainingJob.getModel(); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + assertThat( + "Failed to allocate space in native memory for the model. " + "Check if there is enough memory to perform the request.", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertFalse(model.getModelMetadata().getError().isEmpty()); } @@ -435,7 +443,7 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep when(nativeMemoryAllocation.isClosed()).thenReturn(true); when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); - // Throw error on getting data + // Throw error on allocation is closed when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); TrainingJob trainingJob = new TrainingJob( @@ -443,7 +451,83 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep knnMethodContext, nativeMemoryCacheManager, trainingDataEntryContext, - mock(NativeMemoryEntryContext.AnonymousEntryContext.class), + modelContext, + knnMethodConfigContext, + "", + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + trainingJob.run(); + + Model model = trainingJob.getModel(); + assertThat( + "Failed to execute training. Unable to load training data into memory: allocation is already closed", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); + assertNotNull(model); + assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + } + + public void testRun_failure_closedModelAnonymousAllocation() throws ExecutionException { + // In this test, the model anonymous allocation should be closed. Then, run should fail and update the error of + // the model + String modelId = "test-model-id"; + + // Define the method setup for method that requires training + int nlists = 5; + int dimension = 16; + KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); + + String tdataKey = "t-data-key"; + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); + when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); + + // Setup model manager + NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); + + // Setup mock allocation for model that's closed + NativeMemoryAllocation modelAllocation = mock(NativeMemoryAllocation.class); + doAnswer(invocationOnMock -> null).when(modelAllocation).readLock(); + doAnswer(invocationOnMock -> null).when(modelAllocation).readUnlock(); + when(modelAllocation.isClosed()).thenReturn(true); + + String modelKey = "model-test-key"; + NativeMemoryEntryContext.AnonymousEntryContext modelContext = mock(NativeMemoryEntryContext.AnonymousEntryContext.class); + when(modelContext.getKey()).thenReturn(modelKey); + + // Throw error on allocation is closed + when(nativeMemoryCacheManager.get(modelContext, false)).thenReturn(modelAllocation); + doAnswer(invocationOnMock -> null).when(nativeMemoryCacheManager).invalidate(modelKey); + + // Setup mock allocation thats not closed + NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); + doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readLock(); + doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readUnlock(); + when(nativeMemoryAllocation.isClosed()).thenReturn(false); + when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); + + when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); + + TrainingJob trainingJob = new TrainingJob( + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, knnMethodConfigContext, "", "test-node", @@ -454,6 +538,10 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep trainingJob.run(); Model model = trainingJob.getModel(); + assertThat( + "Failed to execute training. Unable to reserve memory for model: allocation is already closed", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); }