diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 29fefbf9bb..d4333d2e91 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ # This should match the owning team set up in https://github.com/orgs/opensearch-project/teams -* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @ryanbogan @luyuncheng @shatejas @0ctopus13prime +* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @ryanbogan @luyuncheng @shatejas @0ctopus13prime @Vikasht34 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b65aea6e9a..f844161d0b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -38,7 +38,7 @@ jobs: Build-k-NN-Linux: strategy: matrix: - java: [21] + java: [21, 23] name: Build and Test k-NN Plugin on Linux runs-on: ubuntu-latest @@ -101,7 +101,7 @@ jobs: Build-k-NN-MacOS: strategy: matrix: - java: [ 21 ] + java: [21, 23] name: Build and Test k-NN Plugin on MacOS needs: Get-CI-Image-Tag @@ -144,7 +144,7 @@ jobs: Build-k-NN-Windows: strategy: matrix: - java: [ 21 ] + java: [21, 23] name: Build and Test k-NN Plugin on Windows needs: Get-CI-Image-Tag diff --git a/.github/workflows/backwards_compatibility_tests_workflow.yml b/.github/workflows/backwards_compatibility_tests_workflow.yml index b5335d8685..6c9f262727 100644 --- a/.github/workflows/backwards_compatibility_tests_workflow.yml +++ b/.github/workflows/backwards_compatibility_tests_workflow.yml @@ -35,7 +35,7 @@ jobs: matrix: java: [ 21 ] os: [ubuntu-latest] - bwc_version : [ "2.0.1", "2.1.0", "2.2.1", "2.3.0", "2.4.1", "2.5.0", "2.6.0", "2.7.0", "2.8.0", "2.9.0", "2.10.0", "2.11.0", "2.12.0", "2.13.0", "2.14.0", "2.15.0", "2.16.0", "2.17.0", "2.18.0", "2.19.0-SNAPSHOT"] + bwc_version : [ "2.0.1", "2.1.0", "2.2.1", "2.3.0", "2.4.1", "2.5.0", "2.6.0", "2.7.0", "2.8.0", "2.9.0", "2.10.0", "2.11.0", "2.12.0", "2.13.0", "2.14.0", "2.15.0", "2.16.0", "2.17.0", "2.18.0","2.19.0-SNAPSHOT", "2.20.0-SNAPSHOT"] opensearch_version : [ "3.0.0-SNAPSHOT" ] exclude: - os: windows-latest @@ -130,7 +130,7 @@ jobs: matrix: java: [ 21 ] os: [ubuntu-latest] - bwc_version: [ "2.19.0-SNAPSHOT" ] + bwc_version: [ "2.20.0-SNAPSHOT" ] opensearch_version: [ "3.0.0-SNAPSHOT" ] name: k-NN Rolling-Upgrade BWC Tests diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml index 644342c78b..eb55a29c35 100644 --- a/.github/workflows/maven-publish.yml +++ b/.github/workflows/maven-publish.yml @@ -32,4 +32,7 @@ jobs: export SONATYPE_PASSWORD=$(aws secretsmanager get-secret-value --secret-id maven-snapshots-password --query SecretString --output text) echo "::add-mask::$SONATYPE_USERNAME" echo "::add-mask::$SONATYPE_PASSWORD" + # For zip ./gradlew publishPluginZipPublicationToSnapshotsRepository + # For jar + ./gradlew publishNebulaPublicationToSnapshotsRepository diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cb98978d7..862073af22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] - Add a new build mode, `FAISS_OPT_LEVEL=avx512_spr`, which enables the use of advanced AVX-512 instructions introduced with Intel(R) Sapphire Rapids (#2404)[https://github.com/opensearch-project/k-NN/pull/2404] - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] +- Add derived source feature for vector fields (#2449)[https://github.com/opensearch-project/k-NN/pull/2449] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] @@ -30,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357] - Add WithFieldName implementation to KNNQueryBuilder (#2398)[https://github.com/opensearch-project/k-NN/pull/2398] - Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397] +- Enabled concurrent graph creation for Lucene engine with index thread qty settings(#2480)[https://github.com/opensearch-project/k-NN/pull/2480] - Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] @@ -44,6 +46,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399] * Fixing bug where mapping accepts both dimension and model-id (#2410)[https://github.com/opensearch-project/k-NN/pull/2410] * Fixing bug where cmake condition to detect flag was broken and fix build path for JNI (#2442)[https://github.com/opensearch-project/k-NN/pull/2442] +* Add version check for full field name validation (#2477)[https://github.com/opensearch-project/k-NN/pull/2477] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) @@ -56,4 +59,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Upgrade jsonpath from 2.8.0 to 2.9.0[2325](https://github.com/opensearch-project/k-NN/pull/2325) * Bump Faiss commit from 1f42e81 to 0cbc2a8 to accelerate hamming distance calculation using _mm512_popcnt_epi64 intrinsic and also add avx512-fp16 instructions to boost performance [#2381](https://github.com/opensearch-project/k-NN/pull/2381) * Enabled indices.breaker.total.use_real_memory setting via build.gradle for integTest Cluster to catch heap CB in local ITs and github CI actions [#2395](https://github.com/opensearch-project/k-NN/pull/2395/) +* Fixing Lucene912Codec Issue with BWC for Lucene 10.0.1 upgrade[#2429](https://github.com/opensearch-project/k-NN/pull/2429) ### Refactoring diff --git a/MAINTAINERS.md b/MAINTAINERS.md index fd18eee5cd..35ebf26730 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -6,7 +6,7 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Maintainer | GitHub ID | Affiliation | |-------------------------|-------------------------------------------------------|-------------| -| Doo Yong Kim | [0ctopus13prime](https://github.com/0ctopus13prime) | Amazon | +| Doo Yong Kim | [0ctopus13prime](https://github.com/0ctopus13prime) | Amazon | | Heemin Kim | [heemin32](https://github.com/heemin32) | Amazon | | Jack Mazanec | [jmazanec15](https://github.com/jmazanec15) | Amazon | | Junqiu Lei | [junqiu-lei](https://github.com/junqiu-lei) | Amazon | @@ -17,5 +17,6 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Tejas Shah | [shatejas](https://github.com/shatejas) | Amazon | | Vamshi Vijay Nakkirtha | [vamshin](https://github.com/vamshin) | Amazon | | Vijayan Balasubramanian | [VijayanB](https://github.com/VijayanB) | Amazon | +| Vikash Tiwari | [Vikasht34](https://github.com/Vikasht34) | Amazon | | Yuncheng Lu | [luyuncheng](https://github.com/luyuncheng) | Bytedance | diff --git a/build-tools/knnplugin-coverage.gradle b/build-tools/knnplugin-coverage.gradle index eb3582dabf..b589ddf3ae 100644 --- a/build-tools/knnplugin-coverage.gradle +++ b/build-tools/knnplugin-coverage.gradle @@ -6,7 +6,7 @@ apply plugin: 'jacoco' jacoco { - toolVersion = "0.8.10" + toolVersion = "0.8.12" } /** diff --git a/build.gradle b/build.gradle index b4fdeaf2c0..075c1e9ef2 100644 --- a/build.gradle +++ b/build.gradle @@ -15,8 +15,8 @@ buildscript { ext { // build.version_qualifier parameter applies to knn plugin artifacts only. OpenSearch version must be set // explicitly as 'opensearch.version' property, for instance opensearch.version=2.0.0-rc1-SNAPSHOT - opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") - version_qualifier = System.getProperty("build.version_qualifier", "") + opensearch_version = System.getProperty("opensearch.version", "3.0.0-alpha1-SNAPSHOT") + version_qualifier = System.getProperty("build.version_qualifier", "alpha1") opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") avx2_enabled = System.getProperty("avx2.enabled", "true") @@ -399,6 +399,7 @@ integTest { systemProperty("https", is_https) systemProperty("user", user) systemProperty("password", password) + systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can @@ -462,6 +463,7 @@ task integTestRemote(type: RestIntegTestTask) { systemProperty 'cluster.number_of_nodes', "${_numNodes}" systemProperty 'tests.security.manager', 'false' + systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) // Run tests with remote cluster only if rest case is defined if (System.getProperty("tests.rest.cluster") != null) { diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index b42703b944..893c61f477 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -5,8 +5,9 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=f2b9ed0faf8472cbe469255ae6c86eddb77076c75191741b4a462f33128dd419 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-all.zip +distributionSha256Sum=2ab88d6de2c23e6adae7363ae6e29cbdd2a709e992929b48b6530fd0c7133bd6 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-all.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index 29496cff90..c6caff306d 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -53,7 +53,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase { private static final int DELAY_MILLI_SEC = 1000; private static final int MIN_NUM_OF_MODELS = 2; private static final int K = 5; - private static final int NUM_DOCS = 10; + private static final int NUM_DOCS = 1001; private static final int NUM_DOCS_TEST_MODEL_INDEX = 100; private static final int NUM_DOCS_TEST_MODEL_INDEX_DEFAULT = 100; private static final int NUM_DOCS_TEST_MODEL_INDEX_FOR_NON_KNN_INDEX = 100; diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 4479099e82..14e95887c7 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -161,4 +161,8 @@ public class KNNConstants { public static final String MODE_PARAMETER = "mode"; public static final String COMPRESSION_LEVEL_PARAMETER = "compression_level"; + + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY = "knn-derived-source-enabled"; + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE = "true"; + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_FALSE_VALUE = "false"; } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index d60b500c10..035bddd814 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -44,6 +44,7 @@ import static org.opensearch.common.settings.Setting.Property.IndexScope; import static org.opensearch.common.settings.Setting.Property.NodeScope; import static org.opensearch.common.settings.Setting.Property.Final; +import static org.opensearch.common.settings.Setting.Property.UnmodifiableOnRestore; import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio; import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.getFeatureFlags; @@ -92,6 +93,7 @@ public class KNNSettings { public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; + public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; /** * Default setting values @@ -268,6 +270,14 @@ public class KNNSettings { Setting.Property.Dynamic ); + public static final Setting KNN_DERIVED_SOURCE_ENABLED_SETTING = Setting.boolSetting( + KNN_DERIVED_SOURCE_ENABLED, + false, + IndexScope, + Final, + UnmodifiableOnRestore + ); + /** * This setting identifies KNN index. */ @@ -511,6 +521,9 @@ private Setting getSetting(String key) { if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) { return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; } + if (KNN_DERIVED_SOURCE_ENABLED.equals(key)) { + return KNN_DERIVED_SOURCE_ENABLED_SETTING; + } throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -536,7 +549,8 @@ public List> getSettings() { KNN_FAISS_AVX512_SPR_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, + KNN_DERIVED_SOURCE_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -574,6 +588,14 @@ public static boolean isFaissAVX2Disabled() { } } + /** + * check this index enabled/disabled derived source + * @param settings Settings + */ + public static boolean isKNNDerivedSourceEnabled(Settings settings) { + return KNN_DERIVED_SOURCE_ENABLED_SETTING.get(settings); + } + public static boolean isFaissAVX512Disabled() { return Booleans.parseBoolean( Objects.requireNonNullElse( @@ -709,6 +731,14 @@ public void onIndexModule(IndexModule module) { }); } + /** + * Get the index thread quantity setting value from cluster setting. + * @return int + */ + public static int getIndexThreadQty() { + return KNNSettings.state().getSettingValue(KNN_ALGO_PARAM_INDEX_THREAD_QTY); + } + private static String percentageAsString(Integer percentage) { return percentage + "%"; } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 7053e6151d..eee2808a6b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -7,6 +7,7 @@ import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; @@ -45,22 +46,21 @@ public ScriptDocValues getScriptValues() { if (fieldInfo == null) { return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); } - - DocIdSetIterator values; + KnnVectorValues knnVectorValues; if (fieldInfo.hasVectorValues()) { switch (fieldInfo.getVectorEncoding()) { case FLOAT32: - values = reader.getFloatVectorValues(fieldName); + knnVectorValues = reader.getFloatVectorValues(fieldName); break; case BYTE: - values = reader.getByteVectorValues(fieldName); + knnVectorValues = reader.getByteVectorValues(fieldName); break; default: throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); } - } else { - values = DocValues.getBinary(reader, fieldName); + return KNNVectorScriptDocValues.create(knnVectorValues, fieldName, vectorDataType); } + DocIdSetIterator values = DocValues.getBinary(reader, fieldName); return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 55ff655167..b174959c9e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; @@ -32,9 +33,7 @@ public void setNextDocId(int docId) throws IOException { if (docId < lastDocID) { throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId); } - lastDocID = docId; - int curDocID = vectorValues.docID(); if (lastDocID > curDocID) { curDocID = vectorValues.advance(docId); @@ -81,12 +80,13 @@ public float[] get(int i) { * @return A KNNVectorScriptDocValues object based on the type of the values. * @throws IllegalArgumentException If the type of values is unsupported. */ - public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + public static KNNVectorScriptDocValues create(Object values, String fieldName, VectorDataType vectorDataType) { Objects.requireNonNull(values, "values must not be null"); - if (values instanceof ByteVectorValues) { - return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); - } else if (values instanceof FloatVectorValues) { + + if (values instanceof FloatVectorValues) { return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); } else if (values instanceof BinaryDocValues) { return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); } else { @@ -96,34 +96,53 @@ public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fi private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { private final ByteVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { - super(values, field, type); + super(values.iterator(), field, type); this.values = values; + this.iterator = super.vectorValues instanceof KnnVectorValues.DocIndexIterator + ? (KnnVectorValues.DocIndexIterator) super.vectorValues + : values.iterator(); } @Override protected float[] doGetValue() throws IOException { - byte[] bytes = values.vectorValue(); + int docId = this.iterator.index(); + if (docId == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) { + throw new IllegalStateException("No more ordinals to retrieve vector values."); + } + + // Use the correct method to retrieve the byte vector for the current ordinal + byte[] bytes = values.vectorValue(docId); float[] value = new float[bytes.length]; for (int i = 0; i < bytes.length; i++) { value[i] = (float) bytes[i]; } return value; } + } private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { private final FloatVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { - super(values, field, type); + super(values.iterator(), field, type); this.values = values; + this.iterator = super.vectorValues instanceof KnnVectorValues.DocIndexIterator + ? (KnnVectorValues.DocIndexIterator) super.vectorValues + : values.iterator(); } @Override protected float[] doGetValue() throws IOException { - return values.vectorValue(); + int ord = iterator.index(); // Fetch ordinal (index of vector) + if (ord == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) { + throw new IllegalStateException("No more ordinals to retrieve vector values."); + } + return values.vectorValue(ord); } } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index e97bd2dbf7..6b3649bb82 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -9,7 +9,7 @@ import lombok.Getter; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; @@ -100,7 +100,7 @@ public void freeNativeMemory(long memoryAddress) { @Override public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) { - return KnnVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction()); + return KnnFloatVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction()); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java new file mode 100644 index 0000000000..97848bb350 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import lombok.Builder; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.StoredFieldsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN9120Codec.DerivedSourceStoredFieldsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; + +/** + * KNN Codec that wraps the Lucene Codec which is part of Lucene 10.0.1 + */ + +public class KNN10010Codec extends FilterCodec { + + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_10_01_0; + private final KNNFormatFacade knnFormatFacade; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final StoredFieldsFormat storedFieldsFormat; + + private final MapperService mapperService; + + /** + * No arg constructor that uses Lucene99 as the delegate + */ + public KNN10010Codec() { + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null); + } + + /** + * Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec + * and a unique name to this ctor. + * + * @param delegate codec that will perform all operations this codec does not override + * @param knnVectorsFormat per field format for KnnVector + */ + @Builder + protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) { + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); + perFieldKnnVectorsFormat = knnVectorsFormat; + this.mapperService = mapperService; + this.storedFieldsFormat = getStoredFieldsFormat(); + } + + @Override + public DocValuesFormat docValuesFormat() { + return knnFormatFacade.docValuesFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return knnFormatFacade.compoundFormat(); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldKnnVectorsFormat; + } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + private StoredFieldsFormat getStoredFieldsFormat() { + DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { + if (segmentReadState.fieldInfos.hasVectorValues()) { + return knnVectorsFormat().fieldsReader(segmentReadState); + } + return null; + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasDocValues()) { + return docValuesFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasPostings()) { + return postingsFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState -> { + if (segmentReadState.fieldInfos.hasNorms()) { + return normsFormat().normsProducer(segmentReadState); + } + return null; + })); + return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index 24dbfb78b8..2d0ee349a4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -40,8 +40,8 @@ public KNN80CompoundFormat(CompoundFormat delegate) { } @Override - public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { - return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si, context), dir); + public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si) throws IOException { + return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si), dir); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java index 23c9f31051..192ab101db 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java @@ -13,17 +13,10 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.*; import java.io.IOException; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.index.SortedSetDocValues; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; @@ -66,6 +59,18 @@ public SortedSetDocValues getSortedSet(FieldInfo field) throws IOException { return delegate.getSortedSet(field); } + /** + * @param fieldInfo + * @return Returns a DocValuesSkipper for this field. The returned instance need not be thread-safe: + * it will only be used by a single thread. + * The return value is undefined if FieldInfo. docValuesSkipIndexType() returns DocValuesSkipIndexType. NONE. + * @throws IOException + */ + @Override + public DocValuesSkipper getSkipper(FieldInfo fieldInfo) throws IOException { + return delegate.getSkipper(fieldInfo); + } + @Override public void checkIntegrity() throws IOException { delegate.checkIntegrity(); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java new file mode 100644 index 0000000000..e60b82b2e3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsFormat; +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.opensearch.common.Nullable; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; + +@AllArgsConstructor +public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat { + + private final StoredFieldsFormat delegate; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + // IMPORTANT Do not rely on this for the reader, it will be null if SPI is used + @Nullable + private final MapperService mapperService; + + @Override + public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext) + throws IOException { + List derivedVectorFields = null; + for (FieldInfo fieldInfo : fieldInfos) { + if (DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE.equals(fieldInfo.attributes().get(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY))) { + // Lazily initialize the list of fields + if (derivedVectorFields == null) { + derivedVectorFields = new ArrayList<>(); + } + derivedVectorFields.add(fieldInfo); + } + } + // If no fields have it enabled, we can just short-circuit and return the delegate's fieldReader + if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { + return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); + } + return new DerivedSourceStoredFieldsReader( + delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), + derivedVectorFields, + derivedSourceReadersSupplier, + new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) + ); + } + + @Override + public StoredFieldsWriter fieldsWriter(Directory directory, SegmentInfo segmentInfo, IOContext ioContext) throws IOException { + StoredFieldsWriter delegateWriter = delegate.fieldsWriter(directory, segmentInfo, ioContext); + if (mapperService != null && KNNSettings.isKNNDerivedSourceEnabled(mapperService.getIndexSettings().getSettings())) { + List vectorFieldTypes = new ArrayList<>(); + for (MappedFieldType fieldType : mapperService.fieldTypes()) { + if (fieldType instanceof KNNVectorFieldType) { + vectorFieldTypes.add(fieldType.name()); + } + } + if (vectorFieldTypes.isEmpty() == false) { + return new DerivedSourceStoredFieldsWriter(delegateWriter, vectorFieldTypes); + } + } + return delegateWriter; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java new file mode 100644 index 0000000000..24900eb19e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -0,0 +1,142 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.StoredFieldVisitor; +import org.apache.lucene.util.IOUtils; +import org.opensearch.index.fieldvisitor.FieldsVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; + +import java.io.IOException; +import java.util.List; + +public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { + private final StoredFieldsReader delegate; + private final List derivedVectorFields; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final SegmentReadState segmentReadState; + private final boolean shouldInject; + + private final DerivedSourceVectorInjector derivedSourceVectorInjector; + + /** + * + * @param delegate delegate StoredFieldsReader + * @param derivedVectorFields List of fields that are derived source fields + * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param segmentReadState SegmentReadState for the segment + * @throws IOException in case of I/O error + */ + public DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState + ) throws IOException { + this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + } + + private DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + boolean shouldInject + ) throws IOException { + this.delegate = delegate; + this.derivedVectorFields = derivedVectorFields; + this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.segmentReadState = segmentReadState; + this.shouldInject = shouldInject; + this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); + } + + private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { + return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + } + + @Override + public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { + // If the visitor has explicitly indicated it does not need the fields, we should not inject them + boolean isVisitorNeedFields = true; + if (storedFieldVisitor instanceof FieldsVisitor) { + isVisitorNeedFields = derivedSourceVectorInjector.shouldInject( + ((FieldsVisitor) storedFieldVisitor).includes(), + ((FieldsVisitor) storedFieldVisitor).excludes() + ); + } + if (shouldInject && isVisitorNeedFields) { + delegate.document(docId, new DerivedSourceStoredFieldVisitor(storedFieldVisitor, docId, derivedSourceVectorInjector)); + return; + } + delegate.document(docId, storedFieldVisitor); + } + + @Override + public StoredFieldsReader clone() { + try { + return new DerivedSourceStoredFieldsReader( + delegate.clone(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + shouldInject + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public void close() throws IOException { + IOUtils.close(delegate, derivedSourceVectorInjector); + } + + /** + * For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise, + * on merge we will end up just writing the source to disk. We cant override + * {@link StoredFieldsReader#getMergeInstance()} because it is used elsewhere than just merging. + * + * @return Merged instance that wont inject by default + */ + private StoredFieldsReader cloneForMerge() { + try { + return new DerivedSourceStoredFieldsReader( + delegate.getMergeInstance(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + false + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise, + * on merge we will end up just writing the source to disk + * + * @param storedFieldsReader stored fields reader to wrap + * @return wrapped stored fields reader + */ + public static StoredFieldsReader wrapForMerge(StoredFieldsReader storedFieldsReader) { + if (storedFieldsReader instanceof DerivedSourceStoredFieldsReader) { + return ((DerivedSourceStoredFieldsReader) storedFieldsReader).cloneForMerge(); + } + return storedFieldsReader; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java new file mode 100644 index 0000000000..0c43f6a493 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.mapper.SourceFieldMapper; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +@RequiredArgsConstructor +public class DerivedSourceStoredFieldsWriter extends StoredFieldsWriter { + + private final StoredFieldsWriter delegate; + private final List vectorFieldTypes; + + @Override + public void startDocument() throws IOException { + delegate.startDocument(); + } + + @Override + public void writeField(FieldInfo fieldInfo, int i) throws IOException { + delegate.writeField(fieldInfo, i); + } + + @Override + public void writeField(FieldInfo fieldInfo, long l) throws IOException { + delegate.writeField(fieldInfo, l); + } + + @Override + public void writeField(FieldInfo fieldInfo, float v) throws IOException { + delegate.writeField(fieldInfo, v); + } + + @Override + public void writeField(FieldInfo fieldInfo, double v) throws IOException { + delegate.writeField(fieldInfo, v); + } + + @Override + public void writeField(FieldInfo info, DataInput value, int length) throws IOException { + delegate.writeField(info, value, length); + } + + @Override + public int merge(MergeState mergeState) throws IOException { + // We have to wrap these here to avoid storing the vectors during merge + for (int i = 0; i < mergeState.storedFieldsReaders.length; i++) { + mergeState.storedFieldsReaders[i] = DerivedSourceStoredFieldsReader.wrapForMerge(mergeState.storedFieldsReaders[i]); + } + return delegate.merge(mergeState); + } + + @Override + public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOException { + // Parse out the vectors from the source + if (Objects.equals(fieldInfo.name, SourceFieldMapper.NAME) && !vectorFieldTypes.isEmpty()) { + // Reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)), + true, + MediaTypeRegistry.JSON + ); + Map filteredSource = XContentMapValues.filter(null, vectorFieldTypes.toArray(new String[0])) + .apply(mapTuple.v2()); + BytesStreamOutput bStream = new BytesStreamOutput(); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(filteredSource); + builder.close(); + BytesReference bytesReference = bStream.bytes(); + delegate.writeField(fieldInfo, bytesReference.toBytesRef()); + return; + } + delegate.writeField(fieldInfo, bytesRef); + } + + @Override + public void writeField(FieldInfo fieldInfo, String s) throws IOException { + delegate.writeField(fieldInfo, s); + } + + @Override + public void finishDocument() throws IOException { + delegate.finishDocument(); + } + + @Override + public void finish(int i) throws IOException { + delegate.finish(i); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public long ramBytesUsed() { + return delegate.ramBytesUsed(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java index 2b3723439e..16fd2ad436 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java @@ -8,7 +8,8 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.opensearch.knn.index.KNNVectorSimilarityFunction; @@ -22,10 +23,10 @@ public class KNN9120BinaryVectorScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues + KnnVectorValues randomAccessVectorValues ) throws IOException { - if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { - return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues); + if (randomAccessVectorValues instanceof ByteVectorValues) { + return new BinaryRandomVectorScorerSupplier((ByteVectorValues) randomAccessVectorValues); } throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); } @@ -33,7 +34,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues, + KnnVectorValues randomAccessVectorValues, float[] queryVector ) throws IOException { throw new IllegalArgumentException("binary vectors do not support float[] targets"); @@ -42,20 +43,20 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues, + KnnVectorValues randomAccessVectorValues, byte[] queryVector ) throws IOException { - if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { - return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector); + if (randomAccessVectorValues instanceof ByteVectorValues) { + return new BinaryRandomVectorScorer((ByteVectorValues) randomAccessVectorValues, queryVector); } throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); } static class BinaryRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues.Bytes vectorValues; + private final ByteVectorValues vectorValues; private final byte[] queryVector; - BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + BinaryRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) { this.queryVector = query; this.vectorValues = vectorValues; } @@ -82,11 +83,11 @@ public Bits getAcceptOrds(Bits acceptDocs) { } static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - protected final RandomAccessVectorValues.Bytes vectorValues; - protected final RandomAccessVectorValues.Bytes vectorValues1; - protected final RandomAccessVectorValues.Bytes vectorValues2; + protected final ByteVectorValues vectorValues; + protected final ByteVectorValues vectorValues1; + protected final ByteVectorValues vectorValues2; - public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException { + public BinaryRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; this.vectorValues1 = vectorValues.copy(); this.vectorValues2 = vectorValues.copy(); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index a370197ecc..5e40faf1ab 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -11,9 +11,12 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.StoredFieldsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; /** * KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12 @@ -22,12 +25,15 @@ public class KNN9120Codec extends FilterCodec { private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final StoredFieldsFormat storedFieldsFormat; + + private final MapperService mapperService; /** * No arg constructor that uses Lucene99 as the delegate */ public KNN9120Codec() { - this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null); } /** @@ -38,10 +44,12 @@ public KNN9120Codec() { * @param knnVectorsFormat per field format for KnnVector */ @Builder - protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) { super(VERSION.getCodecName(), delegate); knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; + this.mapperService = mapperService; + this.storedFieldsFormat = getStoredFieldsFormat(); } @Override @@ -58,4 +66,36 @@ public CompoundFormat compoundFormat() { public KnnVectorsFormat knnVectorsFormat() { return perFieldKnnVectorsFormat; } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + private StoredFieldsFormat getStoredFieldsFormat() { + DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { + if (segmentReadState.fieldInfos.hasVectorValues()) { + return knnVectorsFormat().fieldsReader(segmentReadState); + } + return null; + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasDocValues()) { + return docValuesFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasPostings()) { + return postingsFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState -> { + if (segmentReadState.fieldInfos.hasNorms()) { + return normsFormat().normsProducer(segmentReadState); + } + return null; + })); + return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java index 6e8fc767ec..afebae2e6f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java @@ -7,18 +7,22 @@ import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.opensearch.common.collect.Tuple; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; import org.opensearch.knn.index.engine.KNNEngine; import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Class provides per field format implementation for Lucene Knn vector type */ public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { - private static final int NUM_MERGE_WORKERS = 1; + private static final Tuple DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE = Tuple.tuple(1, null); public KNN9120PerFieldKnnVectorsFormat(final Optional mapperService) { super( @@ -27,37 +31,67 @@ public KNN9120PerFieldKnnVectorsFormat(final Optional mapperServi Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, Lucene99HnswVectorsFormat::new, knnVectorsFormatParams -> { + final Tuple mergeThreadCountAndExecutorService = getMergeThreadCountAndExecutorService(); // There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that // changes in the future. if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { return new KNN9120HnswBinaryVectorsFormat( knnVectorsFormatParams.getMaxConnections(), - knnVectorsFormatParams.getBeamWidth() + knnVectorsFormatParams.getBeamWidth(), + // number of merge threads + mergeThreadCountAndExecutorService.v1(), + // executor service + mergeThreadCountAndExecutorService.v2() ); } else { - return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth()); + return new Lucene99HnswVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth(), + // number of merge threads + mergeThreadCountAndExecutorService.v1(), + // executor service + mergeThreadCountAndExecutorService.v2() + ); } }, - knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - knnScalarQuantizedVectorsFormatParams.getBeamWidth(), - NUM_MERGE_WORKERS, - knnScalarQuantizedVectorsFormatParams.getBits(), - knnScalarQuantizedVectorsFormatParams.isCompressFlag(), - knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), - null - ) + knnScalarQuantizedVectorsFormatParams -> { + final Tuple mergeThreadCountAndExecutorService = getMergeThreadCountAndExecutorService(); + return new Lucene99HnswScalarQuantizedVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + // Number of merge threads + mergeThreadCountAndExecutorService.v1(), + knnScalarQuantizedVectorsFormatParams.getBits(), + knnScalarQuantizedVectorsFormatParams.isCompressFlag(), + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + // Executor service + mergeThreadCountAndExecutorService.v2() + ); + } ); } - @Override /** * This method returns the maximum dimension allowed from KNNEngine for Lucene codec * * @param fieldName Name of the field, ignored * @return Maximum constant dimension set by KNNEngine */ + @Override public int getMaxDimensions(String fieldName) { return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE); } + + private static Tuple getMergeThreadCountAndExecutorService() { + // To ensure that only once we are fetching the settings per segment, we are fetching the num threads once while + // creating the executors + int mergeThreadCount = KNNSettings.getIndexThreadQty(); + // We need to return null whenever the merge threads are <=1, as lucene assumes that if number of threads are 1 + // then we should be giving a null value of the executor + if (mergeThreadCount <= 1) { + return DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE; + } else { + return Tuple.tuple(mergeThreadCount, Executors.newFixedThreadPool(mergeThreadCount)); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index d9b73d621f..c6b6c6268f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -54,7 +54,7 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); - try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { + try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.DEFAULT)) { CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index dd326123e5..17304c1462 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -20,6 +20,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.KNNEngine; import java.io.IOException; @@ -71,6 +72,15 @@ public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOExce return new NativeEngines990KnnVectorsReader(state, flatVectorsFormat.fieldsReader(state)); } + /** + * @param s + * @return + */ + @Override + public int getMaxDimensions(String s) { + return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE); + } + @Override public String toString() { return "NativeEngines99KnnVectorsFormat(name=" diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index efabc3a70f..2366a6d579 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -199,14 +199,6 @@ public void close() throws IOException { } } - /** - * Return the memory usage of this object in bytes. Negative values are illegal. - */ - @Override - public long ramBytesUsed() { - return flatVectorsReader.ramBytesUsed(); - } - private void loadCacheKeyMap() { quantizationStateCacheKeyPerField = new HashMap<>(); for (FieldInfo fieldInfo : segmentReadState.fieldInfos) { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 4343c845b0..0f03170c25 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -13,9 +13,11 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; -import org.apache.lucene.codecs.lucene912.Lucene912Codec; +import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; @@ -126,11 +128,27 @@ public enum KNNCodecVersion { (userCodec, mapperService) -> KNN9120Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .mapperService(mapperService) .build(), KNN9120Codec::new + ), + V_10_01_0( + "KNN10010Codec", + new Lucene101Codec(), + new KNN9120PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN10010Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .mapperService(mapperService) + .build(), + KNN10010Codec::new ); - private static final KNNCodecVersion CURRENT = V_9_12_0; + private static final KNNCodecVersion CURRENT = V_10_01_0; private final String codecName; private final Codec defaultCodecDelegate; diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java new file mode 100644 index 0000000000..123b718a46 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.SegmentReadState; + +import java.io.IOException; + +@FunctionalInterface +public interface DerivedSourceReaderSupplier { + R apply(SegmentReadState segmentReadState) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java new file mode 100644 index 0000000000..c7e472e601 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.NormsProducer; +import org.apache.lucene.util.IOUtils; +import org.opensearch.common.Nullable; + +import java.io.Closeable; +import java.io.IOException; + +/** + * Class holds the readers necessary to implement derived source. Important to note that if a segment does not have + * any of these fields, the values will be null. Caller needs to check if these are null before using. + */ +@RequiredArgsConstructor +@Getter +public class DerivedSourceReaders implements Closeable { + @Nullable + private final KnnVectorsReader knnVectorsReader; + @Nullable + private final DocValuesProducer docValuesProducer; + @Nullable + private final FieldsProducer fieldsProducer; + @Nullable + private final NormsProducer normsProducer; + + @Override + public void close() throws IOException { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java new file mode 100644 index 0000000000..2dafa3af94 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.NormsProducer; +import org.apache.lucene.index.SegmentReadState; + +import java.io.IOException; + +/** + * Class encapsulates the suppliers to give the {@link DerivedSourceReaders} from particular formats needed to implement + * derived source. More specifically, given a {@link org.apache.lucene.index.SegmentReadState}, this class will provide + * the correct format reader for that segment. + */ +@RequiredArgsConstructor +public class DerivedSourceReadersSupplier { + private final DerivedSourceReaderSupplier knnVectorsReaderSupplier; + private final DerivedSourceReaderSupplier docValuesProducerSupplier; + private final DerivedSourceReaderSupplier fieldsProducerSupplier; + private final DerivedSourceReaderSupplier normsProducer; + + /** + * Get the readers for the segment + * + * @param state SegmentReadState + * @return DerivedSourceReaders + * @throws IOException in case of I/O error + */ + public DerivedSourceReaders getReaders(SegmentReadState state) throws IOException { + return new DerivedSourceReaders( + knnVectorsReaderSupplier.apply(state), + docValuesProducerSupplier.apply(state), + fieldsProducerSupplier.apply(state), + normsProducer.apply(state) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java new file mode 100644 index 0000000000..9610eff683 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.AllArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.index.mapper.SourceFieldMapper; + +import java.io.IOException; + +/** + * Custom {@link StoredFieldVisitor} that wraps an upstream delegate visitor in order to transparently inject derived + * source vector fields into the document. After the source is modified, it is forwarded to the delegate. + */ +@AllArgsConstructor +public class DerivedSourceStoredFieldVisitor extends StoredFieldVisitor { + + private final StoredFieldVisitor delegate; + private final Integer documentId; + private final DerivedSourceVectorInjector derivedSourceVectorInjector; + + @Override + public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { + if (fieldInfo.name.equals(SourceFieldMapper.NAME)) { + delegate.binaryField(fieldInfo, derivedSourceVectorInjector.injectVectors(documentId, value)); + return; + } + delegate.binaryField(fieldInfo, value); + } + + @Override + public Status needsField(FieldInfo fieldInfo) throws IOException { + return delegate.needsField(fieldInfo); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java new file mode 100644 index 0000000000..d3b1fe8469 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.util.IOUtils; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This class is responsible for injecting vectors into the source of a document. From a high level, it uses alternative + * format readers and information about the fields to inject vectors into the source. + */ +@Log4j2 +public class DerivedSourceVectorInjector implements Closeable { + + private final DerivedSourceReaders derivedSourceReaders; + private final List perFieldDerivedVectorInjectors; + private final Set fieldNames; + + /** + * Constructor for DerivedSourceVectorInjector. + * + * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param segmentReadState Segment read state + * @param fieldsToInjectVector List of fields to inject vectors into + */ + public DerivedSourceVectorInjector( + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + List fieldsToInjectVector + ) throws IOException { + this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + this.perFieldDerivedVectorInjectors = new ArrayList<>(); + this.fieldNames = new HashSet<>(); + for (FieldInfo fieldInfo : fieldsToInjectVector) { + this.perFieldDerivedVectorInjectors.add( + PerFieldDerivedVectorInjectorFactory.create(fieldInfo, derivedSourceReaders, segmentReadState) + ); + this.fieldNames.add(fieldInfo.name); + } + } + + /** + * Given a docId and the source of that doc as bytes, add all the necessary vector fields into the source. + * + * @param docId doc id of the document + * @param sourceAsBytes source of document as bytes + * @return byte array of the source with the vector fields added + * @throws IOException if there is an issue reading from the formats + */ + public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException { + // Reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 + // Deserialize the source into a modifiable map + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + // Have to create a copy of the map here to ensure that is mutable + Map sourceAsMap = new HashMap<>(mapTuple.v2()); + + // For each vector field, add in the source. The per field injectors are responsible for skipping if + // the field is not present. + for (PerFieldDerivedVectorInjector vectorInjector : perFieldDerivedVectorInjectors) { + vectorInjector.inject(docId, sourceAsMap); + } + + // At this point, we can serialize the modified source map + // Setting to 1024 based on + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourcePhase.java#L106 + BytesStreamOutput bStream = new BytesStreamOutput(1024); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(sourceAsMap); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } + + /** + * Whether or not to inject vectors based on what fields are explicitly required + * + * @param includes List of fields that are required to be injected + * @param excludes List of fields that are not required to be injected + * @return true if vectors should be injected, false otherwise + */ + public boolean shouldInject(String[] includes, String[] excludes) { + // If any of the vector fields are explicitly required we should inject + if (includes != null && includes != Strings.EMPTY_ARRAY) { + for (String includedField : includes) { + if (fieldNames.contains(includedField)) { + return true; + } + } + } + + // If all of the vector fields are explicitly excluded we should not inject + if (excludes != null && excludes != Strings.EMPTY_ARRAY) { + int excludedVectorFieldCount = 0; + for (String excludedField : excludes) { + if (fieldNames.contains(excludedField)) { + excludedVectorFieldCount++; + } + } + // Inject if we havent excluded all of the fields + return excludedVectorFieldCount < fieldNames.size(); + } + return true; + } + + @Override + public void close() throws IOException { + IOUtils.close(derivedSourceReaders); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..7e28156703 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -0,0 +1,276 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BytesRef; +import org.opensearch.index.mapper.FieldNamesFieldMapper; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +@Log4j2 +@AllArgsConstructor +public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + + private final FieldInfo childFieldInfo; + private final DerivedSourceReaders derivedSourceReaders; + private final SegmentReadState segmentReadState; + + @Override + public void inject(int parentDocId, Map sourceAsMap) throws IOException { + // If the parent has the field, then it is just an object field. + int lowestDocIdForFieldWithParentAsOffset = getLowestDocIdForField(childFieldInfo.name, parentDocId); + if (lowestDocIdForFieldWithParentAsOffset == parentDocId) { + injectObject(parentDocId, sourceAsMap); + return; + } + + // Setup the iterator. Return if no parent + String childFieldName = ParentChildHelper.getChildField(childFieldInfo.name); + String parentFieldName = ParentChildHelper.getParentField(childFieldInfo.name); + if (parentFieldName == null) { + return; + } + NestedPerFieldParentToDocIdIterator nestedPerFieldParentToDocIdIterator = new NestedPerFieldParentToDocIdIterator( + childFieldInfo, + segmentReadState, + derivedSourceReaders, + parentDocId + ); + + if (nestedPerFieldParentToDocIdIterator.numChildren() == 0) { + return; + } + + // Initializes the parent field so that there is a list to put each of the children + Object originalParentValue = sourceAsMap.get(parentFieldName); + List> reconstructedSource; + if (originalParentValue instanceof Map) { + reconstructedSource = new ArrayList<>(List.of((Map) originalParentValue)); + } else { + reconstructedSource = (List>) originalParentValue; + } + + // Contains the docIds of existing objects in the map in order. This is used to help figure out the best play + // to put back the vectors + List positions = mapObjectsToPositionInNestedList( + reconstructedSource, + nestedPerFieldParentToDocIdIterator.firstChild(), + parentDocId + ); + + // Finally, inject children for the document into the source. This code is non-trivial because filtering out + // the vectors during write could mean that children docs disappear from the source. So, to properly put + // everything back, we need to figure out where the existing fields in the original map to + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( + childFieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + int offsetPositionsIndex = 0; + while (nestedPerFieldParentToDocIdIterator.nextChild() != NO_MORE_DOCS) { + // If the child does not have a vector, vectValues advance will advance past child to the next matching + // docId. So, we need to ensure that doing this does not pass the parent docId. + if (nestedPerFieldParentToDocIdIterator.childId() > vectorValues.docId()) { + vectorValues.advance(nestedPerFieldParentToDocIdIterator.childId()); + } + if (vectorValues.docId() != nestedPerFieldParentToDocIdIterator.childId()) { + continue; + } + + int docId = nestedPerFieldParentToDocIdIterator.childId(); + boolean isInsert = true; + int position = positions.size(); // by default we insert it at the end + for (int i = offsetPositionsIndex; i < positions.size(); i++) { + if (docId < positions.get(i)) { + position = i; + break; + } + if (docId == positions.get(i)) { + isInsert = false; + position = i; + break; + } + } + + if (isInsert) { + reconstructedSource.add(position, new HashMap<>()); + positions.add(position, docId); + } + reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector()); + offsetPositionsIndex = position + 1; + } + sourceAsMap.put(parentFieldName, reconstructedSource); + } + + private void injectObject(int docId, Map sourceAsMap) throws IOException { + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( + childFieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + if (vectorValues.docId() != docId && vectorValues.advance(docId) != docId) { + return; + } + String[] fields = ParentChildHelper.splitPath(childFieldInfo.name); + Map currentMap = sourceAsMap; + for (int i = 0; i < fields.length - 1; i++) { + String field = fields[i]; + currentMap = (Map) currentMap.computeIfAbsent(field, k -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], vectorValues.getVector()); + } + + /** + * Given a list of maps, map each map to a position in the nested list. This is used to help figure out where to put + * the vectors back in the source. + * + * @param originals list of maps + * @param firstChild first child docId + * @param parent parent docId + * @return list of positions in the nested list + * @throws IOException if there is an issue reading from the formats + */ + private List mapObjectsToPositionInNestedList(List> originals, int firstChild, int parent) + throws IOException { + List positions = new ArrayList<>(); + int offset = firstChild; + for (Map docWithFields : originals) { + int fieldMapping = mapToDocId(docWithFields, offset, parent); + assert fieldMapping != -1; + positions.add(fieldMapping); + offset = fieldMapping + 1; + } + return positions; + } + + /** + * Given a doc as a map and the offset it has to be, find the ordinal of the first field that is greater than the + * offset. + * + * @param doc doc to find the ordinal for + * @param offset offset to start searching from + * @return id of the first field that is greater than the offset + * @throws IOException if there is an issue reading from the formats + */ + private int mapToDocId(Map doc, int offset, int parent) throws IOException { + // For all the fields, we look for the first doc that matches any of the fields. + int position = NO_MORE_DOCS; + for (String key : doc.keySet()) { + position = getLowestDocIdForField(ParentChildHelper.constructSiblingField(childFieldInfo.name, key), offset); + if (position < parent) { + break; + } + } + + // Advancing past the parent means something went wrong + assert position < parent; + return position; + } + + /** + * Get the lowest docId for a field that is greater than the offset. + * + * @param fieldToMatch field to find the lowest docId for + * @param offset offset to start searching from + * @return lowest docId for the field that is greater than the offset. Returns {@link DocIdSetIterator#NO_MORE_DOCS} if doc cannot be found + * @throws IOException if there is an issue reading from the formats + */ + private int getLowestDocIdForField(String fieldToMatch, int offset) throws IOException { + // This method implementation is inspired by the FieldExistsQuery in Lucene and the FieldNamesMapper in + // Opensearch. We first mimic the logic in the FieldExistsQuery in order to identify the docId of the nested + // doc. If that fails, we rely on + // References: + // 1. https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java#L170-L218. + // 2. + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/mapper/FieldMapper.java#L316-L324 + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(fieldToMatch); + + if (fieldInfo == null) { + return NO_MORE_DOCS; + } + + DocIdSetIterator iterator = null; + if (fieldInfo.hasNorms() && derivedSourceReaders.getNormsProducer() != null) { // the field indexes norms + iterator = derivedSourceReaders.getNormsProducer().getNorms(fieldInfo); + } else if (fieldInfo.getVectorDimension() != 0 && derivedSourceReaders.getKnnVectorsReader() != null) { // the field indexes vectors + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + iterator = derivedSourceReaders.getKnnVectorsReader().getFloatVectorValues(fieldInfo.name).iterator(); + break; + case BYTE: + iterator = derivedSourceReaders.getKnnVectorsReader().getByteVectorValues(fieldInfo.name).iterator(); + break; + } + } else if (fieldInfo.getDocValuesType() != DocValuesType.NONE && derivedSourceReaders.getDocValuesProducer() != null) { // the field + // indexes + // doc + // values + switch (fieldInfo.getDocValuesType()) { + case NUMERIC: + iterator = derivedSourceReaders.getDocValuesProducer().getNumeric(fieldInfo); + break; + case BINARY: + iterator = derivedSourceReaders.getDocValuesProducer().getBinary(fieldInfo); + break; + case SORTED: + iterator = derivedSourceReaders.getDocValuesProducer().getSorted(fieldInfo); + break; + case SORTED_NUMERIC: + iterator = derivedSourceReaders.getDocValuesProducer().getSortedNumeric(fieldInfo); + break; + case SORTED_SET: + iterator = derivedSourceReaders.getDocValuesProducer().getSortedSet(fieldInfo); + break; + case NONE: + default: + throw new AssertionError(); + } + } + if (iterator != null) { + return iterator.advance(offset); + } + + // Check the field names field type for matches + if (derivedSourceReaders.getFieldsProducer() == null) { + return NO_MORE_DOCS; + } + Terms terms = derivedSourceReaders.getFieldsProducer().terms(FieldNamesFieldMapper.NAME); + if (terms == null) { + return NO_MORE_DOCS; + } + TermsEnum fieldNameFieldsTerms = terms.iterator(); + BytesRef fieldToMatchRef = new BytesRef(fieldInfo.name); + PostingsEnum postingsEnum = null; + while (fieldNameFieldsTerms.next() != null) { + BytesRef currentTerm = fieldNameFieldsTerms.term(); + if (currentTerm.bytesEquals(fieldToMatchRef)) { + postingsEnum = fieldNameFieldsTerms.postings(null); + break; + } + } + if (postingsEnum == null) { + return NO_MORE_DOCS; + } + return postingsEnum.advance(offset); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java new file mode 100644 index 0000000000..d2bc1a32fd --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Iterator over the children documents of a particular parent + */ +public class NestedPerFieldParentToDocIdIterator { + + private final FieldInfo childFieldInfo; + private final SegmentReadState segmentReadState; + private final DerivedSourceReaders derivedSourceReaders; + private final int parentDocId; + private final int previousParentDocId; + private final List children; + private int currentChild; + + /** + * + * @param childFieldInfo FieldInfo for the child field + * @param segmentReadState SegmentReadState for the segment + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + * @param parentDocId Parent docId of the parent + * @throws IOException if there is an error reading the parent docId + */ + public NestedPerFieldParentToDocIdIterator( + FieldInfo childFieldInfo, + SegmentReadState segmentReadState, + DerivedSourceReaders derivedSourceReaders, + int parentDocId + ) throws IOException { + this.childFieldInfo = childFieldInfo; + this.segmentReadState = segmentReadState; + this.derivedSourceReaders = derivedSourceReaders; + this.parentDocId = parentDocId; + this.previousParentDocId = previousParent(); + this.children = getChildren(); + this.currentChild = -1; + } + + /** + * For the given parent get its first child offset + * + * @return the first child offset. If there are no children, just return NO_MORE_DOCS + */ + public int firstChild() { + if (parentDocId - previousParentDocId == 1) { + return NO_MORE_DOCS; + } + return previousParentDocId + 1; + } + + /** + * Get the next child for this parent + * + * @return the next child docId. If this has not been set, return -1. If there are no more children, return + * NO_MORE_DOCS + */ + public int nextChild() { + currentChild++; + if (currentChild >= children.size()) { + return NO_MORE_DOCS; + } + return children.get(currentChild); + } + + /** + * Get the current child for this parent + * + * @return the current child docId. If this has not been set, return -1 + */ + public int childId() { + return children.get(currentChild); + } + + /** + * + * @return the number of children for this parent + */ + public int numChildren() { + return children.size(); + } + + /** + * For parentDocId of this class, find the one just before it to be used for matching children. + * + * @return the parent docId just before the parentDocId. -1 if none exist + * @throws IOException if there is an error reading the parent docId + */ + private int previousParent() throws IOException { + // TODO: In the future this needs to be generalized to handle multiple levels of nesting + // For now, for non-nested docs, the primary_term field can be used to identify root level docs. For reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/search/fetch/subphase/SeqNoPrimaryTermPhase.java#L72 + // https://github.com/opensearch-project/OpenSearch/blob/3032bef54d502836789ea438f464ae0b1ba978b2/server/src/main/java/org/opensearch/index/mapper/SeqNoFieldMapper.java#L206-L230 + // We use it here to identify the previous parent to the current parent to get a range on the children documents + FieldInfo seqTermsFieldInfo = segmentReadState.fieldInfos.fieldInfo("_primary_term"); + NumericDocValues numericDocValues = derivedSourceReaders.getDocValuesProducer().getNumeric(seqTermsFieldInfo); + int previousParentDocId = -1; + while (numericDocValues.nextDoc() != NO_MORE_DOCS) { + if (numericDocValues.docID() >= parentDocId) { + break; + } + previousParentDocId = numericDocValues.docID(); + } + return previousParentDocId; + } + + /** + * Get all the children that match the parent path for the _nested_field + * + * @return list of children that match the parent path + * @throws IOException if there is an error reading the children + */ + private List getChildren() throws IOException { + if (this.parentDocId - this.previousParentDocId <= 1) { + return Collections.emptyList(); + } + + // First, we need to get the currect PostingsEnum for the key as _nested_path and the value the actual parent + // path. + String childField = childFieldInfo.name; + String parentField = ParentChildHelper.getParentField(childField); + + Terms terms = derivedSourceReaders.getFieldsProducer().terms("_nested_path"); + if (terms == null) { + return Collections.emptyList(); + } + TermsEnum nestedFieldsTerms = terms.iterator(); + BytesRef childPathRef = new BytesRef(parentField); + PostingsEnum postingsEnum = null; + while (nestedFieldsTerms.next() != null) { + BytesRef currentTerm = nestedFieldsTerms.term(); + if (currentTerm.bytesEquals(childPathRef)) { + postingsEnum = nestedFieldsTerms.postings(null); + break; + } + } + + // Next, get all the children that match this parent path. If none were found, return an empty list + if (postingsEnum == null) { + return Collections.emptyList(); + } + List children = new ArrayList<>(); + postingsEnum.advance(previousParentDocId + 1); + while (postingsEnum.docID() != NO_MORE_DOCS && postingsEnum.docID() < parentDocId) { + if (postingsEnum.freq() > 0) { + children.add(postingsEnum.docID()); + } + postingsEnum.nextDoc(); + } + + return children; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java new file mode 100644 index 0000000000..ae249b1b39 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +/** + * Helper class for working with nested fields. + */ +public class ParentChildHelper { + + /** + * Given a nested field path, return the path of the parent field. For instance if the field is "parent.to.child", + * this would return "parent.to". + * + * @param field nested field path + * @return parent field path without the child. Null if no parent exists + */ + public static String getParentField(String field) { + if (field == null) { + return null; + } + int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } + return field.substring(0, lastDot); + } + + /** + * Given a nested field path, return the child field. For instance if the field is "parent.to.child", this would + * return "child". + * + * @param field nested field path + * @return child field path without the parent path. Null if no child exists + */ + public static String getChildField(String field) { + if (field == null) { + return null; + } + int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } + return field.substring(lastDot + 1); + } + + /** + * Construct a sibling field path. For instance, if the field is "parent.to.child" and the sibling is "sibling", this + * would return "parent.to.sibling". + * + * @param field nested field path + * @param sibling sibling field + * @return sibling field path + */ + public static String constructSiblingField(String field, String sibling) { + String parent = getParentField(field); + if (parent == null) { + return sibling; + } + return parent + "." + sibling; + } + + /** + * Split a nested field path into an array of strings. For instance, if the field is "parent.to.child", this would + * return ["parent", "to", "child"]. + * + * @param field nested field path + * @return array of strings representing the nested field path + */ + public static String[] splitPath(String field) { + return field.split("\\."); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..b0bc5930c3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import java.io.IOException; +import java.util.Map; + +/** + * Interface for injecting derived vectors into a source map per field. + */ +public interface PerFieldDerivedVectorInjector { + + /** + * Injects the derived vector for this field into the sourceAsMap. Implementing classes must handle the case where + * a document does not have a value for their field. + * + * @param docId Document ID + * @param sourceAsMap Source as map + * @throws IOException if there is an issue reading from the formats + */ + void inject(int docId, Map sourceAsMap) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java new file mode 100644 index 0000000000..d31d000837 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; + +/** + * Factory for creating {@link PerFieldDerivedVectorInjector} instances. + */ +class PerFieldDerivedVectorInjectorFactory { + + /** + * Create a {@link PerFieldDerivedVectorInjector} instance based on information in field info. + * + * @param fieldInfo FieldInfo for the field to create the injector for + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + * @return PerFieldDerivedVectorInjector instance + */ + public static PerFieldDerivedVectorInjector create( + FieldInfo fieldInfo, + DerivedSourceReaders derivedSourceReaders, + SegmentReadState segmentReadState + ) { + // Nested case + if (ParentChildHelper.getParentField(fieldInfo.name) != null) { + return new NestedPerFieldDerivedVectorInjector(fieldInfo, derivedSourceReaders, segmentReadState); + } + + // Non-nested case + return new RootPerFieldDerivedVectorInjector(fieldInfo, derivedSourceReaders); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..430fd24ae1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.CheckedSupplier; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.Map; + +/** + * {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields). + */ +class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + + private final FieldInfo fieldInfo; + private final CheckedSupplier, IOException> vectorValuesSupplier; + + /** + * Constructor for RootPerFieldDerivedVectorInjector. + * + * @param fieldInfo FieldInfo for the field to create the injector for + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + */ + public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReaders derivedSourceReaders) { + this.fieldInfo = fieldInfo; + this.vectorValuesSupplier = () -> KNNVectorValuesFactory.getVectorValues( + fieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + } + + @Override + public void inject(int docId, Map sourceAsMap) throws IOException { + KNNVectorValues vectorValues = vectorValuesSupplier.get(); + if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) { + sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector()); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 7078645e54..de535c39e8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -226,7 +226,7 @@ private Map getParameters(FieldInfo fieldInfo, VectorDataType ve maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes); // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); return parameters; } @@ -258,7 +258,7 @@ private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException { Map parameters = new HashMap<>(); - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID)); parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob()); if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo) != QuantizationConfig.EMPTY) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 9f1ebcf018..68ea25a1fc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -14,6 +14,9 @@ import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; + /** * Mapper used when you dont want to build an underlying KNN struct - you just want to * store vectors as doc values @@ -32,7 +35,8 @@ public static FlatVectorFieldMapper createFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -49,7 +53,8 @@ public static FlatVectorFieldMapper createFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -62,7 +67,8 @@ private FlatVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( simpleName, @@ -73,13 +79,17 @@ private FlatVectorFieldMapper( stored, hasDocValues, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); // setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created. this.useLuceneBasedVectorField = false; this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.setDocValuesType(DocValuesType.BINARY); + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } this.fieldType.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7a0ee00af7..338a913d5f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -58,6 +58,7 @@ import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createKNNMethodContextFromLegacy; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.useFullFieldNameValidation; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfCircuitBreakerIsNotTriggered; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfKNNPluginEnabled; import static org.opensearch.knn.index.mapper.ModelFieldMapper.UNSET_MODEL_DIMENSION_IDENTIFIER; @@ -93,6 +94,7 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { */ public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; + protected final boolean isDerivedSourceEnabled; protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); @@ -200,13 +202,15 @@ public Builder( ModelDao modelDao, Version indexCreatedVersion, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalParameters + OriginalMappingParameters originalParameters, + boolean isDerivedSourceEnabled ) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; this.knnMethodConfigContext = knnMethodConfigContext; this.originalParameters = originalParameters; + this.isDerivedSourceEnabled = isDerivedSourceEnabled; } @Override @@ -237,7 +241,9 @@ protected Explicit ignoreMalformed(BuilderContext context) { @Override public KNNVectorFieldMapper build(BuilderContext context) { - validateFullFieldName(context); + if (useFullFieldNameValidation(indexCreatedVersion)) { + validateFullFieldName(context); + } final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); final CopyTo copyToBuilder = copyTo.build(); @@ -258,7 +264,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { modelDao, indexCreatedVersion, originalParameters, - knnMethodConfigContext + knnMethodConfigContext, + isDerivedSourceEnabled ); } @@ -280,7 +287,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -301,7 +309,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { metaValue, knnMethodConfigContext, createLuceneFieldMapperInput, - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -315,7 +324,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.getValue(), hasDocValues.getValue(), - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -363,7 +373,8 @@ public Mapper.Builder parse(String name, Map node, ParserCont modelDaoSupplier.get(), parserContext.indexVersionCreated(), null, - null + null, + KNNSettings.isKNNDerivedSourceEnabled(parserContext.getSettings()) ); builder.parse(name, parserContext, node); builder.setOriginalParameters(new OriginalMappingParameters(builder)); @@ -578,6 +589,7 @@ static boolean useKNNMethodContextFromLegacy(Builder builder, Mapper.TypeParser. // values of KNN engine Algorithms hyperparameters. protected Version indexCreatedVersion; protected Explicit ignoreMalformed; + protected final boolean isDerivedSourceEnabled; protected boolean stored; protected boolean hasDocValues; protected VectorDataType vectorDataType; @@ -598,7 +610,8 @@ public KNNVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; @@ -608,6 +621,7 @@ public KNNVectorFieldMapper( updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; this.originalMappingParameters = originalMappingParameters; + this.isDerivedSourceEnabled = isDerivedSourceEnabled; } public KNNVectorFieldMapper clone() { @@ -840,7 +854,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { modelDao, indexCreatedVersion, knnMethodConfigContext, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ).init(this); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index a9ca56d4b5..1240098191 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -146,6 +146,16 @@ static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) { return indexCreatedVersion.onOrAfter(Version.V_2_17_0); } + /** + * Determines if full field name validation should be applied based on the index creation version. + * + * @param indexCreatedVersion The version when the index was created + * @return true if the index version is 2.17.0 or later, false otherwise + */ + static boolean useFullFieldNameValidation(final Version indexCreatedVersion) { + return indexCreatedVersion != null && indexCreatedVersion.onOrAfter(Version.V_2_17_0); + } + public static SpaceType getSpaceType(final Settings indexSettings) { String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); if (spaceType == null) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 461c6f7c8c..b0bead693c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -8,7 +8,7 @@ import lombok.Getter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; import org.opensearch.index.fielddata.IndexFieldData; @@ -81,7 +81,7 @@ public String typeName() { @Override public Query existsQuery(QueryShardContext context) { - return new DocValuesFieldExistsQuery(name()); + return new FieldExistsQuery(name()); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4ceb9b4b23..49cd02d5b0 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -27,6 +27,8 @@ import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; @@ -48,7 +50,8 @@ static LuceneFieldMapper createFieldMapper( Map metaValue, KNNMethodConfigContext knnMethodConfigContext, CreateLuceneFieldMapperInput createLuceneFieldMapperInput, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -82,14 +85,21 @@ public Version getIndexCreatedVersion() { } ); - return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext, originalMappingParameters); + return new LuceneFieldMapper( + mappedFieldType, + createLuceneFieldMapperInput, + knnMethodConfigContext, + originalMappingParameters, + isDerivedSourceEnabled + ); } private LuceneFieldMapper( final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( input.getName(), @@ -100,7 +110,8 @@ private LuceneFieldMapper( input.isStored(), input.isHasDocValues(), knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); @@ -117,6 +128,12 @@ private LuceneFieldMapper( this.vectorFieldType = null; } + if (isDerivedSourceEnabled) { + this.fieldType = new FieldType(this.fieldType); + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + this.fieldType.freeze(); + } + KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() .getKNNLibraryIndexingContext(resolvedKnnMethodContext, knnMethodConfigContext); this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 814bc4f639..a2635b1953 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -51,7 +53,8 @@ public static MethodFieldMapper createFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { KNNMethodContext knnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); @@ -104,7 +107,8 @@ public Version getIndexCreatedVersion() { stored, hasDocValues, knnMethodConfigContext, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -117,7 +121,8 @@ private MethodFieldMapper( boolean stored, boolean hasDocValues, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( @@ -129,7 +134,8 @@ private MethodFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); @@ -151,7 +157,9 @@ private MethodFieldMapper( this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); - + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } try { this.fieldType.putAttribute( PARAMETERS, diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index d472090fc3..ae912aa415 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -28,6 +28,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; @@ -60,7 +62,8 @@ public static ModelFieldMapper createFieldMapper( ModelDao modelDao, Version indexCreatedVersion, OriginalMappingParameters originalMappingParameters, - KNNMethodConfigContext knnMethodConfigContext + KNNMethodConfigContext knnMethodConfigContext, + boolean isDerivedSourceEnabled ) { final KNNMethodContext knnMethodContext = originalMappingParameters.getKnnMethodContext(); @@ -134,7 +137,8 @@ private void initFromModelMetadata() { hasDocValues, modelDao, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -148,7 +152,8 @@ private ModelFieldMapper( boolean hasDocValues, ModelDao modelDao, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( simpleName, @@ -159,7 +164,8 @@ private ModelFieldMapper( stored, hasDocValues, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); @@ -174,6 +180,9 @@ private ModelFieldMapper( this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 99962d3074..26b26ec11f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -28,7 +28,7 @@ public class KNNScorer extends Scorer { private final float boost; public KNNScorer(Weight weight, DocIdSetIterator docIdsIter, Map scores, float boost) { - super(weight); + super(); this.docIdsIter = docIdsIter; this.scores = scores; this.boost = boost; @@ -60,40 +60,44 @@ public int docID() { /** * Returns the Empty Scorer implementation. We use this scorer to short circuit the actual search when it is not * required. - * @param knnWeight {@link KNNWeight} * @return {@link KNNScorer} */ - public static Scorer emptyScorer(KNNWeight knnWeight) { - return new Scorer(knnWeight) { - private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); - - @Override - public DocIdSetIterator iterator() { - return docIdsIter; - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return 0; - } - - @Override - public float score() throws IOException { - assert docID() != DocIdSetIterator.NO_MORE_DOCS; - return 0; - } - - @Override - public int docID() { - return docIdsIter.docID(); - } - - @Override - public boolean equals(Object obj) { - if (!(obj instanceof Scorer)) return false; - return getWeight().equals(((Scorer) obj).getWeight()); - } - }; - + public static Scorer emptyScorer() { + return EMPTY_SCORER_INSTANCE; } + + private static final Scorer EMPTY_SCORER_INSTANCE = new Scorer() { + private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); + + @Override + public DocIdSetIterator iterator() { + return docIdsIter; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return 0; + } + + @Override + public float score() throws IOException { + assert docID() != DocIdSetIterator.NO_MORE_DOCS; + return 0; + } + + @Override + public int docID() { + return docIdsIter.docID(); + } + + @Override + public boolean equals(Object obj) { + return this == obj; // Singleton ensures only one instance exists + } + + @Override + public int hashCode() { + return System.identityHashCode(this); // Consistent hash for singleton + } + }; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad1..dc12fd473c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilteredDocIdSetIterator; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; @@ -110,13 +111,27 @@ public Explanation explain(LeafReaderContext context, int doc) { } @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - final Map docIdToScoreMap = searchLeaf(context, knnQuery.getK()).getResult(); - if (docIdToScoreMap.isEmpty()) { - return KNNScorer.emptyScorer(this); - } - final int maxDoc = Collections.max(docIdToScoreMap.keySet()) + 1; - return new KNNScorer(this, ResultUtil.resultMapToDocIds(docIdToScoreMap, maxDoc), docIdToScoreMap, boost); + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return new ScorerSupplier() { + long cost = -1L; + + @Override + public Scorer get(long leadCost) throws IOException { + final Map docIdToScoreMap = searchLeaf(context, knnQuery.getK()).getResult(); + cost = docIdToScoreMap.size(); + if (docIdToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(); + } + final int maxDoc = Collections.max(docIdToScoreMap.keySet()) + 1; + return new KNNScorer(KNNWeight.this, ResultUtil.resultMapToDocIds(docIdToScoreMap, maxDoc), docIdToScoreMap, boost); + } + + @Override + public long cost() { + // Estimate the cost of the scoring operation, if applicable. + return cost == -1L ? knnQuery.getK() : cost; + } + }; } /** diff --git a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java index f38cc96c64..0aee8dd8b7 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java @@ -13,7 +13,9 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import org.opensearch.knn.index.query.KNNScorer; import java.io.IOException; import java.util.Arrays; @@ -21,9 +23,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -/** - * This is the same as {@link org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery} - */ final class DocAndScoreQuery extends Query { private final int k; @@ -62,92 +61,104 @@ public int count(LeafReaderContext context) { } @Override - public Scorer scorer(LeafReaderContext context) { - if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { - return null; - } - return new Scorer(this) { - final int lower = segmentStarts[context.ord]; - final int upper = segmentStarts[context.ord + 1]; - int upTo = -1; - + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return new ScorerSupplier() { @Override - public DocIdSetIterator iterator() { - return new DocIdSetIterator() { + public Scorer get(long leadCost) throws IOException { + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return KNNScorer.emptyScorer(); + } + return new Scorer() { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + @Override - public int docID() { - return docIdNoShadow(); + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return docIdNoShadow(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return docIdNoShadow(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; } @Override - public int nextDoc() { - if (upTo == -1) { - upTo = lower; - } else { - ++upTo; + public float getMaxScore(int docId) { + docId += context.docBase; + float maxScore = 0; + for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { + maxScore = Math.max(maxScore, scores[idx]); } - return docIdNoShadow(); + return maxScore * boost; } @Override - public int advance(int target) throws IOException { - return slowAdvance(target); + public float score() { + return scores[upTo] * boost; } @Override - public long cost() { - return upper - lower; + public int advanceShallow(int docid) { + int start = Math.max(upTo, lower); + int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); + if (docidIndex < 0) { + docidIndex = -1 - docidIndex; + } + if (docidIndex >= upper) { + return NO_MORE_DOCS; + } + return docs[docidIndex]; } - }; - } - - @Override - public float getMaxScore(int docId) { - docId += context.docBase; - float maxScore = 0; - for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { - maxScore = Math.max(maxScore, scores[idx]); - } - return maxScore * boost; - } - @Override - public float score() { - return scores[upTo] * boost; - } + /** + * move the implementation of docID() into a differently-named method so we can call it + * from DocIDSetIterator.docID() even though this class is anonymous + * + * @return the current docid + */ + private int docIdNoShadow() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo] - context.docBase; + } - @Override - public int advanceShallow(int docid) { - int start = Math.max(upTo, lower); - int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); - if (docidIndex < 0) { - docidIndex = -1 - docidIndex; - } - if (docidIndex >= upper) { - return NO_MORE_DOCS; - } - return docs[docidIndex]; - } + @Override + public int docID() { + return docIdNoShadow(); + } + }; - /** - * move the implementation of docID() into a differently-named method so we can call it - * from DocIDSetIterator.docID() even though this class is anonymous - * - * @return the current docid - */ - private int docIdNoShadow() { - if (upTo == -1) { - return -1; - } - if (upTo >= upper) { - return NO_MORE_DOCS; - } - return docs[upTo] - context.docBase; } @Override - public int docID() { - return docIdNoShadow(); + public long cost() { + // Estimate the cost of the scoring operation, if applicable. + return docs.length == 0 ? k : docs.length; } }; } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java index 5da093fd54..b113509cfc 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -9,6 +9,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import java.io.IOException; import java.util.Arrays; @@ -34,7 +35,7 @@ public byte[] getVector() throws IOException { @Override public byte[] conditionalCloneVector() throws IOException { byte[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + if (vectorValuesIterator.getDocIdSetIterator() instanceof KnnVectorValues.DocIndexIterator) { return Arrays.copyOf(vector, vector.length); } return vector; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java index 1ebc509707..374adea208 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -9,6 +9,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import java.io.IOException; import java.util.Arrays; @@ -34,7 +35,7 @@ public byte[] getVector() throws IOException { @Override public byte[] conditionalCloneVector() throws IOException { byte[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + if (vectorValuesIterator.getDocIdSetIterator() instanceof KnnVectorValues.DocIndexIterator) { return Arrays.copyOf(vector, vector.length); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index dffdd8f0d9..ad9f32b77f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -8,6 +8,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import java.io.IOException; import java.util.Arrays; @@ -32,7 +33,7 @@ public float[] getVector() throws IOException { @Override public float[] conditionalCloneVector() throws IOException { float[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { + if (vectorValuesIterator.getDocIdSetIterator() instanceof KnnVectorValues.DocIndexIterator) { return Arrays.copyOf(vector, vector.length); } return vector; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e2172..835425b2a2 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,11 +5,14 @@ package org.opensearch.knn.index.vectorvalues; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.VectorDataType; @@ -26,9 +29,13 @@ public final class KNNVectorValuesFactory { * Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and {@link VectorDataType} * * @param vectorDataType {@link VectorDataType} - * @param docIdSetIterator {@link DocIdSetIterator} + * @param knnVectorValues {@link KnnVectorValues} * @return {@link KNNVectorValues} */ + public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final KnnVectorValues knnVectorValues) { + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(knnVectorValues)); + } + public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) { return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator)); } @@ -57,19 +64,60 @@ public static KNNVectorValues getVectorValues( */ public static KNNVectorValues getVectorValues(final FieldInfo fieldInfo, final LeafReader leafReader) throws IOException { final DocIdSetIterator docIdSetIterator; - if (fieldInfo.hasVectorValues()) { + if (!fieldInfo.hasVectorValues()) { + docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName()); + final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); + return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); + } + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + return getVectorValues( + FieldInfoExtractor.extractVectorDataType(fieldInfo), + new KNNVectorValuesIterator.DocIdsIteratorValues(leafReader.getByteVectorValues(fieldInfo.getName())) + ); + } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { + return getVectorValues( + FieldInfoExtractor.extractVectorDataType(fieldInfo), + new KNNVectorValuesIterator.DocIdsIteratorValues(leafReader.getFloatVectorValues(fieldInfo.getName())) + ); + } else { + throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); + } + } + + /** + * Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader} + * + * @param fieldInfo {@link FieldInfo} + * @param docValuesProducer {@link DocValuesProducer} + * @param knnVectorsReader {@link KnnVectorsReader} + * @return {@link KNNVectorValues} + */ + public static KNNVectorValues getVectorValues( + final FieldInfo fieldInfo, + final DocValuesProducer docValuesProducer, + final KnnVectorsReader knnVectorsReader + ) throws IOException { + if (fieldInfo.hasVectorValues() && knnVectorsReader != null) { + final KnnVectorValues knnVectorValues; if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { - docIdSetIterator = leafReader.getByteVectorValues(fieldInfo.getName()); + knnVectorValues = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { - docIdSetIterator = leafReader.getFloatVectorValues(fieldInfo.getName()); + knnVectorValues = knnVectorsReader.getFloatVectorValues(fieldInfo.getName()); } else { throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); } + return getVectorValues( + FieldInfoExtractor.extractVectorDataType(fieldInfo), + new KNNVectorValuesIterator.DocIdsIteratorValues(knnVectorValues) + ); + } else if (docValuesProducer != null) { + return getVectorValues( + FieldInfoExtractor.extractVectorDataType(fieldInfo), + new KNNVectorValuesIterator.DocIdsIteratorValues(docValuesProducer.getBinary(fieldInfo)) + ); } else { - docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName()); + throw new IllegalArgumentException("Field does not have vector values and DocValues"); } - final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); - return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); } @SuppressWarnings("unchecked") diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java index 4f1445c1cb..bf9c0bef3a 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java @@ -5,19 +5,19 @@ package org.opensearch.knn.index.vectorvalues; +import lombok.Getter; import lombok.NonNull; +import lombok.Setter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.KnnVectorValues; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import java.io.IOException; -import java.util.List; import java.util.Map; -import java.util.function.Function; /** * An abstract class that provides an iterator to iterate over KNNVectors, as KNNVectors are stored as different @@ -71,18 +71,28 @@ public interface KNNVectorValuesIterator { * {@link DocIdSetIterator} interface. Example: {@link BinaryDocValues}, {@link FloatVectorValues} etc. */ class DocIdsIteratorValues implements KNNVectorValuesIterator { - protected DocIdSetIterator docIdSetIterator; - private static final List> VALID_ITERATOR_INSTANCE = List.of( - (itr) -> itr instanceof BinaryDocValues, - (itr) -> itr instanceof FloatVectorValues, - (itr) -> itr instanceof ByteVectorValues - ); - - DocIdsIteratorValues(@NonNull final DocIdSetIterator docIdSetIterator) { - validateIteratorType(docIdSetIterator); + private final DocIdSetIterator docIdSetIterator; + private KnnVectorValues knnVectorValues = null; // Added reference to KnnVectorValues + @Getter + @Setter + private int lastOrd = -1; + @Getter + @Setter + private Object lastAccessedVector = null; + + DocIdsIteratorValues(@NonNull final KnnVectorValues knnVectorValues) { + this.docIdSetIterator = knnVectorValues.iterator(); + this.knnVectorValues = knnVectorValues; + } + + DocIdsIteratorValues(final DocIdSetIterator docIdSetIterator) { this.docIdSetIterator = docIdSetIterator; } + public KnnVectorValues getKnnVectorValues() { + return knnVectorValues; + } + @Override public int docId() { return docIdSetIterator.docID(); @@ -107,7 +117,7 @@ public DocIdSetIterator getDocIdSetIterator() { public long liveDocs() { if (docIdSetIterator instanceof BinaryDocValues) { return KNNCodecUtil.getTotalLiveDocsCount((BinaryDocValues) docIdSetIterator); - } else if (docIdSetIterator instanceof FloatVectorValues || docIdSetIterator instanceof ByteVectorValues) { + } else if (docIdSetIterator instanceof KnnVectorValues.DocIndexIterator) { return docIdSetIterator.cost(); } throw new IllegalArgumentException( @@ -119,18 +129,6 @@ public long liveDocs() { public VectorValueExtractorStrategy getVectorExtractorStrategy() { return new VectorValueExtractorStrategy.DISIVectorExtractor(); } - - private void validateIteratorType(final DocIdSetIterator docIdSetIterator) { - VALID_ITERATOR_INSTANCE.stream() - .map(v -> v.apply(docIdSetIterator)) - .filter(Boolean::booleanValue) - .findFirst() - .orElseThrow( - () -> new IllegalArgumentException( - "DocIdSetIterator present is not of valid type. Valid types are: BinaryDocValues, FloatVectorValues and ByteVectorValues" - ) - ); - } } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 07db4e7f64..7aafae308f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -8,6 +8,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; @@ -20,7 +21,7 @@ /** * Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator} */ -interface VectorValueExtractorStrategy { +public interface VectorValueExtractorStrategy { /** * Extract a float vector from KNNVectorValuesIterator. @@ -69,31 +70,54 @@ class DISIVectorExtractor implements VectorValueExtractorStrategy { @Override public T extract(final VectorDataType vectorDataType, final KNNVectorValuesIterator vectorValuesIterator) throws IOException { final DocIdSetIterator docIdSetIterator = vectorValuesIterator.getDocIdSetIterator(); - switch (vectorDataType) { - case FLOAT: - if (docIdSetIterator instanceof BinaryDocValues) { - final BinaryDocValues values = (BinaryDocValues) docIdSetIterator; - return (T) getFloatVectorFromByteRef(values.binaryValue()); - } else if (docIdSetIterator instanceof FloatVectorValues) { - return (T) ((FloatVectorValues) docIdSetIterator).vectorValue(); - } - throw new IllegalArgumentException( - "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and FloatVectorValues" - ); - case BYTE: - case BINARY: - if (docIdSetIterator instanceof BinaryDocValues) { - final BinaryDocValues values = (BinaryDocValues) docIdSetIterator; - final BytesRef bytesRef = values.binaryValue(); - return (T) ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length); - } else if (docIdSetIterator instanceof ByteVectorValues) { - return (T) ((ByteVectorValues) docIdSetIterator).vectorValue(); - } - throw new IllegalArgumentException( - "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and ByteVectorValues" - ); + + if (docIdSetIterator instanceof BinaryDocValues) { + return extractFromBinaryDocValues(vectorDataType, (BinaryDocValues) docIdSetIterator); + } else if (docIdSetIterator instanceof KnnVectorValues.DocIndexIterator) { + return extractFromKnnVectorValues( + vectorDataType, + (KNNVectorValuesIterator.DocIdsIteratorValues) vectorValuesIterator, + (KnnVectorValues.DocIndexIterator) docIdSetIterator + ); + } else { + throw new IllegalArgumentException( + "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and KnnVectorValues.DocIndexIterator" + ); + } + } + + private T extractFromBinaryDocValues(VectorDataType vectorDataType, BinaryDocValues values) throws IOException { + BytesRef bytesRef = values.binaryValue(); + if (vectorDataType == VectorDataType.FLOAT) { + return (T) getFloatVectorFromByteRef(bytesRef); + } else if (vectorDataType == VectorDataType.BYTE || vectorDataType == VectorDataType.BINARY) { + return (T) ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length); + } + throw new IllegalArgumentException("Invalid vector data type for BinaryDocValues"); + } + + private T extractFromKnnVectorValues( + VectorDataType vectorDataType, + KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues, + KnnVectorValues.DocIndexIterator docIdSetIterator + ) throws IOException { + int ord = docIdSetIterator.index(); + if (ord == docIdsIteratorValues.getLastOrd()) { + return (T) docIdsIteratorValues.getLastAccessedVector(); } - throw new IllegalArgumentException("Valid Vector data type not passed to extract vector from DISIVectorExtractor strategy"); + docIdsIteratorValues.setLastOrd(ord); + + if (vectorDataType == VectorDataType.FLOAT) { + FloatVectorValues knnVectorValues = (FloatVectorValues) docIdsIteratorValues.getKnnVectorValues(); + docIdsIteratorValues.setLastAccessedVector(knnVectorValues.vectorValue(ord)); + } else if (vectorDataType == VectorDataType.BYTE || vectorDataType == VectorDataType.BINARY) { + ByteVectorValues byteVectorValues = (ByteVectorValues) docIdsIteratorValues.getKnnVectorValues(); + docIdsIteratorValues.setLastAccessedVector(byteVectorValues.vectorValue(ord)); + } else { + throw new IllegalArgumentException("Invalid vector data type for KnnVectorValues"); + } + + return (T) docIdsIteratorValues.getLastAccessedVector(); } private float[] getFloatVectorFromByteRef(final BytesRef bytesRef) { diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index d0abe86129..387a23587d 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -35,7 +35,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.client.Client; import org.opensearch.cluster.health.ClusterHealthStatus; import org.opensearch.cluster.health.ClusterIndexHealth; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 78f3769c59..4ad6227d59 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -127,7 +127,7 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques searchSourceBuilder.terminateAfter(DEFAULT_TERMINATE_AFTER); client.search(countRequest, ActionListener.wrap(searchResponse -> { - long trainingVectors = searchResponse.getHits().getTotalHits().value; + long trainingVectors = searchResponse.getHits().getTotalHits().value(); // If there are more docs in the index than what the user wants to use for training, take the min if (trainingModelRequest.getMaximumVectorCount() < trainingVectors) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java index 216efa78e8..d374b4610f 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java @@ -6,7 +6,7 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionType; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.core.common.io.stream.Writeable; /** diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java index 887f5d7a2b..2cbda7b2e8 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java @@ -7,7 +7,7 @@ import lombok.Getter; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.support.master.AcknowledgedRequest; +import org.opensearch.action.support.clustermanager.AcknowledgedRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java index 7d5750c2be..d9a26e1e04 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java @@ -10,7 +10,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.action.support.clustermanager.TransportClusterManagerNodeAction; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.ClusterStateTaskConfig; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java index 756a32575c..dbc2d6c7f5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java @@ -12,7 +12,7 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionType; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.core.common.io.stream.Writeable; /** diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java index af063ad271..56aac2510e 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java @@ -12,7 +12,7 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.support.master.AcknowledgedRequest; +import org.opensearch.action.support.clustermanager.AcknowledgedRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.indices.ModelMetadata; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java index ec909f443f..01e4cbf36f 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java @@ -15,7 +15,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.action.support.clustermanager.TransportClusterManagerNodeAction; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.ClusterStateTaskConfig; diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index b479192e8d..605a19660a 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -179,10 +179,7 @@ public void run() { .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); Map trainParameters = libraryIndexingContext.getLibraryParameters(); - trainParameters.put( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + trainParameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); if (libraryIndexingContext.getQuantizationConfig() != QuantizationConfig.EMPTY) { trainParameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index 7a8916981e..e0ed615f70 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -8,4 +8,5 @@ org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec org.opensearch.knn.index.codec.KNN950Codec.KNN950Codec org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec +org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec org.opensearch.knn.index.codec.KNN990Codec.UnitTestCodec diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index 24990dd364..961164a766 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -205,6 +205,17 @@ public void testGetFaissAVX2DisabledSettingValueFromConfig_enableSetting_thenVal assertEquals(expectedKNNFaissAVX2Disabled, actualKNNFaissAVX2Disabled); } + @SneakyThrows + public void testGetIndexThreadQty_WithDifferentValues_thenSuccess() { + Node mockNode = createMockNode(Map.of(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY, 3)); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + KNNSettings.state().setClusterService(clusterService); + int threadQty = KNNSettings.getIndexThreadQty(); + mockNode.close(); + assertEquals(3, threadQty); + } + private Node createMockNode(Map configSettings) throws IOException { Path configDir = createTempDir(); File configFile = configDir.resolve("opensearch.yml").toFile(); diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index cbe11dd6b9..6f1998faf4 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -5,11 +5,11 @@ package org.opensearch.knn.index; +import org.apache.lucene.util.BytesRef; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; @@ -18,6 +18,7 @@ import org.apache.lucene.store.Directory; import org.opensearch.index.fielddata.ScriptDocValues; import org.junit.Before; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import java.io.IOException; @@ -42,12 +43,8 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, new float[] { 1.0f, 2.0f }, new FieldType()).binaryValue() - ) - ); + byte[] vectorBinary = KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(new float[] { 1.0f, 2.0f }); + knnDocument.add(new BinaryDocValuesField(MOCK_INDEX_FIELD_NAME, new BytesRef(vectorBinary))); knnDocument.add(new NumericDocValuesField(MOCK_NUMERIC_INDEX_FIELD_NAME, 1000)); writer.addDocument(knnDocument); writer.commit(); diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 66e2893c0e..8817fae035 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -8,23 +8,18 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.*; +import org.apache.lucene.util.BytesRef; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.Directory; import org.junit.Assert; import org.junit.Before; +import org.junit.After; +import org.junit.Test; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import java.io.IOException; @@ -33,7 +28,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 }; - private KNNVectorScriptDocValues scriptDocValues; + private Directory directory; private DirectoryReader reader; @@ -41,71 +36,116 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); - createKNNVectorDocument(directory, valuesClass); - reader = DirectoryReader.open(directory); - LeafReader leafReader = reader.getContext().leaves().get(0).reader(); - DocIdSetIterator vectorValues; - if (BinaryDocValues.class.equals(valuesClass)) { - vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); - } else if (ByteVectorValues.class.equals(valuesClass)) { - vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); - } else { - vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); - } - - scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { - IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); - IndexWriter writer = new IndexWriter(directory, conf); - Document knnDocument = new Document(); - Field field; - if (BinaryDocValues.class.equals(valuesClass)) { - field = new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ); - } else if (ByteVectorValues.class.equals(valuesClass)) { - field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); - } else { - field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + @After + public void tearDown() throws Exception { + super.tearDown(); + if (reader != null) { + reader.close(); } + if (directory != null) { + directory.close(); + } + } - knnDocument.add(field); - writer.addDocument(knnDocument); - writer.commit(); - writer.close(); + /** Test for Float Vector Values */ + @Test + public void testFloatVectorValues() throws IOException { + createKNNVectorDocument(directory, FloatVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + // Separate scriptDocValues instance for this test + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); } - @Override - public void tearDown() throws Exception { - super.tearDown(); - reader.close(); - directory.close(); + /** Test for Byte Vector Values */ + @Test + public void testByteVectorValues() throws IOException { + createKNNVectorDocument(directory, ByteVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.BYTE + ); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(new float[] { SAMPLE_BYTE_VECTOR_DATA[0], SAMPLE_BYTE_VECTOR_DATA[1] }, scriptDocValues.getValue(), 0.1f); } - public void testGetValue() throws IOException { + /** Test for Binary Vector Values */ + @Test + public void testBinaryVectorValues() throws IOException { + createKNNVectorDocument(directory, BinaryDocValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getBinaryDocValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.BINARY + ); + scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + Assert.assertNotNull(scriptDocValues.getValue()); // Just checking it's non-null } - // Test getValue without calling setNextDocId + /** Ensure getValue() fails without setNextDocId */ + @Test public void testGetValueFails() throws IOException { + createKNNVectorDocument(directory, FloatVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + expectThrows(IllegalStateException.class, () -> scriptDocValues.getValue()); } + /** Ensure size() returns expected values */ + @Test public void testSize() throws IOException { + createKNNVectorDocument(directory, FloatVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + Assert.assertEquals(0, scriptDocValues.size()); scriptDocValues.setNextDocId(0); Assert.assertEquals(1, scriptDocValues.size()); } - public void testGet() throws IOException { - expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + /** Ensure get() throws UnsupportedOperationException */ + @Test + public void testGet() { + expectThrows(UnsupportedOperationException.class, () -> { + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + scriptDocValues.get(0); + }); } + /** Test unsupported values type */ + @Test public void testUnsupportedValues() throws IOException { expectThrows( IllegalArgumentException.class, @@ -113,10 +153,31 @@ public void testUnsupportedValues() throws IOException { ); } + /** Ensure empty values case */ + @Test public void testEmptyValues() throws IOException { KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); assertEquals(0, values.size()); - scriptDocValues.setNextDocId(0); - assertEquals(0, values.size()); + } + + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + Field field; + + if (BinaryDocValues.class.equals(valuesClass)) { + byte[] vectorBinary = KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(SAMPLE_VECTOR_DATA); + field = new BinaryDocValuesField(MOCK_INDEX_FIELD_NAME, new BytesRef(vectorBinary)); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 73af608c1e..3753fdef22 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -8,7 +8,6 @@ import lombok.SneakyThrows; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -18,6 +17,7 @@ import org.apache.lucene.util.BytesRef; import org.junit.Assert; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import java.io.IOException; @@ -82,12 +82,8 @@ private void createKNNFloatVectorDocument(Directory directory) throws IOExceptio IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - MOCK_FLOAT_INDEX_FIELD_NAME, - new VectorField(MOCK_FLOAT_INDEX_FIELD_NAME, SAMPLE_FLOAT_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + BytesRef bytesRef = new BytesRef(KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(SAMPLE_FLOAT_VECTOR_DATA)); + knnDocument.add(new BinaryDocValuesField(MOCK_FLOAT_INDEX_FIELD_NAME, bytesRef)); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -97,12 +93,7 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - MOCK_BYTE_INDEX_FIELD_NAME, - new VectorField(MOCK_BYTE_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + knnDocument.add(new BinaryDocValuesField(MOCK_BYTE_INDEX_FIELD_NAME, new BytesRef(SAMPLE_BYTE_VECTOR_DATA))); writer.addDocument(knnDocument); writer.commit(); writer.close(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java new file mode 100644 index 0000000000..2170276136 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecTestCase; +import org.opensearch.knn.index.codec.KNNCodecVersion; + +import java.util.Optional; +import java.util.function.Function; + +public class KNN10010CodecTests extends KNNCodecTestCase { + + @SneakyThrows + public void testMultiFieldsKnnIndex() { + testMultiFieldsKnnIndex(KNN10010Codec.builder().delegate(KNNCodecVersion.V_10_01_0.getDefaultCodecDelegate()).build()); + } + + @SneakyThrows + public void testBuildFromModelTemplate() { + testBuildFromModelTemplate(KNN10010Codec.builder().delegate(KNNCodecVersion.V_10_01_0.getDefaultCodecDelegate()).build()); + } + + // Ensure that the codec is able to return the correct per field knn vectors format for codec + public void testCodecSetsCustomPerFieldKnnVectorsFormat() { + final Codec codec = new KNN10010Codec(); + assertTrue(codec.knnVectorsFormat() instanceof KNN9120PerFieldKnnVectorsFormat); + } + + // IMPORTANT: When this Codec is moved to a backwards Codec, this test needs to be removed, because it attempts to + // write with a read-only codec, which will fail + @SneakyThrows + public void testKnnVectorIndex() { + Function perFieldKnnVectorsFormatProvider = ( + mapperService) -> new KNN9120PerFieldKnnVectorsFormat(Optional.of(mapperService)); + + Function knnCodecProvider = (knnVectorFormat) -> KNN10010Codec.builder() + .delegate(KNNCodecVersion.V_10_01_0.getDefaultCodecDelegate()) + .knnVectorsFormat(knnVectorFormat) + .build(); + + testKnnVectorIndex(knnCodecProvider, perFieldKnnVectorsFormatProvider); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java index 6001a97290..b9aca7620a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java @@ -47,9 +47,9 @@ public static void closeStaticVariables() throws IOException { public void testGetCompoundReader() throws IOException { CompoundDirectory dir = mock(CompoundDirectory.class); CompoundFormat delegate = mock(CompoundFormat.class); - when(delegate.getCompoundReader(null, null, null)).thenReturn(dir); + when(delegate.getCompoundReader(null, null)).thenReturn(dir); KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); - CompoundDirectory knnDir = knn80CompoundFormat.getCompoundReader(null, null, null); + CompoundDirectory knnDir = knn80CompoundFormat.getCompoundReader(null, null); assertTrue(knnDir instanceof KNN80CompoundDirectory); assertEquals(dir, ((KNN80CompoundDirectory) knnDir).getDelegate()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java new file mode 100644 index 0000000000..2953539adb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase { + + @SneakyThrows + public void testWriteField() { + StoredFieldsWriter delegate = mock(StoredFieldsWriter.class); + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(); + List fields = List.of("test"); + + DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields); + + Map source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value"); + BytesStreamOutput bStream = new BytesStreamOutput(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source); + builder.close(); + byte[] originalBytes = bStream.bytes().toBytesRef().bytes; + byte[] shiftedBytes = new byte[originalBytes.length + 2]; + System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length); + derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 90ed18d0df..86ed6b3aee 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -24,21 +24,7 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.NoMergePolicy; -import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.SerialMergeScheduler; -import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.*; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; @@ -236,20 +222,24 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc } final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD); - floatVectorValues.nextDoc(); - assertArrayEquals(floatVector, floatVectorValues.vectorValue(), 0.0f); + floatVectorValues.iterator().nextDoc(); + assertArrayEquals(floatVector, floatVectorValues.vectorValue(floatVectorValues.iterator().index()), 0.0f); assertEquals(1, floatVectorValues.size()); assertEquals(3, floatVectorValues.dimension()); final ByteVectorValues byteVectorValues = leafReader.getByteVectorValues(BYTE_VECTOR_FIELD); - byteVectorValues.nextDoc(); - assertArrayEquals(byteVector, byteVectorValues.vectorValue()); + byteVectorValues.iterator().nextDoc(); + assertArrayEquals(byteVector, byteVectorValues.vectorValue(byteVectorValues.iterator().index())); assertEquals(1, byteVectorValues.size()); assertEquals(2, byteVectorValues.dimension()); final FloatVectorValues floatVectorValuesForBinaryQuantization = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); - floatVectorValuesForBinaryQuantization.nextDoc(); - assertArrayEquals(floatVectorForBinaryQuantization_1, floatVectorValuesForBinaryQuantization.vectorValue(), 0.0f); + floatVectorValuesForBinaryQuantization.iterator().nextDoc(); + assertArrayEquals( + floatVectorForBinaryQuantization_1, + floatVectorValuesForBinaryQuantization.vectorValue(floatVectorValuesForBinaryQuantization.iterator().index()), + 0.0f + ); assertEquals(2, floatVectorValuesForBinaryQuantization.size()); assertEquals(8, floatVectorValuesForBinaryQuantization.dimension()); @@ -296,8 +286,9 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce } final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); - floatVectorValues.nextDoc(); - assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(), 0.0f); + KnnVectorValues.DocIndexIterator docIndexIterator = floatVectorValues.iterator(); + docIndexIterator.nextDoc(); + assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(docIndexIterator.index()), 0.0f); assertEquals(1, floatVectorValues.size()); assertEquals(8, floatVectorValues.dimension()); indexReader.close(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 315693a653..18c9e96674 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -9,7 +9,6 @@ import com.google.common.collect.ImmutableSet; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; @@ -411,9 +410,9 @@ public void testKnnVectorIndex( iwc1.setMergeScheduler(new SerialMergeScheduler()); iwc1.setCodec(codec); writer = new RandomIndexWriter(random(), dir, iwc1); - final FieldType luceneFieldType1 = KnnVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN); + final FieldType luceneFieldType1 = KnnFloatVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN); float[] array1 = { 6.0f, 14.0f }; - KnnVectorField vectorField1 = new KnnVectorField(FIELD_NAME_TWO, array1, luceneFieldType1); + KnnFloatVectorField vectorField1 = new KnnFloatVectorField(FIELD_NAME_TWO, array1, luceneFieldType1); Document doc1 = new Document(); doc1.add(vectorField1); writer.addDocument(doc1); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index d6f22ca7f5..64c5371dbf 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -10,13 +10,7 @@ import lombok.Builder; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.index.*; import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; @@ -172,6 +166,8 @@ public FieldInfo build() { storePayloads, indexOptions, docValuesType, + DocValuesSkipIndexType.NONE, + dvGen, attributes, pointDimensionCount, @@ -191,7 +187,7 @@ public static void assertFileInCorrectLocation(SegmentWriteState state, String e } public static void assertValidFooter(Directory dir, String filename) throws IOException { - ChecksumIndexInput indexInput = dir.openChecksumInput(filename, IOContext.DEFAULT); + ChecksumIndexInput indexInput = dir.openChecksumInput(filename); indexInput.seek(indexInput.length() - CodecUtil.footerLength()); CodecUtil.checkFooter(indexInput); indexInput.close(); @@ -205,7 +201,7 @@ public static void assertLoadableByEngine( SpaceType spaceType, int dimension ) { - try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.LOAD)) { + try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long indexPtr = JNIService.loadIndex( indexInputWithBuffer, @@ -230,7 +226,7 @@ public static void assertBinaryIndexLoadableByEngine( int dimension, VectorDataType vectorDataType ) { - try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.LOAD)) { + try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long indexPtr = JNIService.loadIndex( indexInputWithBuffer, diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java new file mode 100644 index 0000000000..943146a914 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase { + + public void testBinaryField() throws Exception { + StoredFieldVisitor delegate = mock(StoredFieldVisitor.class); + doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any()); + DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class); + when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]); + DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor( + delegate, + 0, + derivedSourceVectorInjector + ); + + // When field is not _source, then do not call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null); + verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any()); + verify(delegate, times(1)).binaryField(any(), any()); + + // When field is not _source, then do call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null); + verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any()); + verify(delegate, times(2)).binaryField(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java new file mode 100644 index 0000000000..1fa4b9364a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; + +public class DerivedSourceVectorInjectorTests extends KNNTestCase { + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testInjectVectors() { + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + Map fieldToVector = Collections.unmodifiableMap(new HashMap<>() { + { + put("test1", new float[] { 1.0f, 2.0f, 3.0f }); + put("test2", new float[] { 4.0f, 5.0f, 6.0f, 7.0f }); + put("test3", new float[] { 7.0f, 8.0f, 9.0f, 1.0f, 3.0f, 4.0f }); + put("test4", null); + } + }); + + try (MockedStatic factory = Mockito.mockStatic(PerFieldDerivedVectorInjectorFactory.class)) { + factory.when(() -> PerFieldDerivedVectorInjectorFactory.create(any(), any(), any())).thenAnswer(invocation -> { + FieldInfo fieldInfo = invocation.getArgument(0); + return (PerFieldDerivedVectorInjector) (docId, sourceAsMap) -> { + float[] vector = fieldToVector.get(fieldInfo.name); + if (vector != null) { + sourceAsMap.put(fieldInfo.name, vector); + } + }; + }); + + DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + null, + fields + ); + + int docId = 2; + String existingFieldKey = "existingField"; + String existingFieldValue = "existingField"; + Map source = Map.of(existingFieldKey, existingFieldValue); + byte[] originalSourceBytes = mapToBytes(source); + byte[] modifiedSourceByttes = derivedSourceVectorInjector.injectVectors(docId, originalSourceBytes); + Map modifiedSource = bytesToMap(modifiedSourceByttes); + + assertEquals(existingFieldValue, modifiedSource.get(existingFieldKey)); + + assertArrayEquals(fieldToVector.get("test1"), toFloatArray((List) modifiedSource.get("test1")), 0.000001f); + assertArrayEquals(fieldToVector.get("test2"), toFloatArray((List) modifiedSource.get("test2")), 0.000001f); + assertArrayEquals(fieldToVector.get("test3"), toFloatArray((List) modifiedSource.get("test3")), 0.000001f); + assertFalse(modifiedSource.containsKey("test4")); + } + } + + @SneakyThrows + private byte[] mapToBytes(Map map) { + + BytesStreamOutput bStream = new BytesStreamOutput(1024); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(map); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } + + private float[] toFloatArray(List list) { + float[] array = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i).floatValue(); + } + return array; + } + + private Map bytesToMap(byte[] bytes) { + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + + return mapTuple.v2(); + } + + @SneakyThrows + public void testShouldInject() { + + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + try ( + DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + null, + fields + ) + ) { + assertTrue(vectorInjector.shouldInject(null, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null)); + assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" })); + assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" })); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java new file mode 100644 index 0000000000..0852224604 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; + +public class ParentChildHelperTests extends KNNTestCase { + + public void testGetParentField() { + assertEquals("parent.to", ParentChildHelper.getParentField("parent.to.child")); + assertEquals("parent", ParentChildHelper.getParentField("parent.to")); + assertNull(ParentChildHelper.getParentField("child")); + assertNull(ParentChildHelper.getParentField("")); + assertNull(ParentChildHelper.getParentField(null)); + } + + public void testGetChildField() { + assertEquals("child", ParentChildHelper.getChildField("parent.to.child")); + assertNull(ParentChildHelper.getChildField(null)); + assertNull(ParentChildHelper.getChildField("child")); + } + + public void testConstructSiblingField() { + assertEquals("parent.to.sibling", ParentChildHelper.constructSiblingField("parent.to.child", "sibling")); + assertEquals("sibling", ParentChildHelper.constructSiblingField("parent", "sibling")); + } + + public void testSplitPath() { + String[] path = ParentChildHelper.splitPath("parent.to.child"); + assertEquals(3, path.length); + assertEquals("parent", path[0]); + assertEquals("to", path[1]); + assertEquals("child", path[2]); + + path = ParentChildHelper.splitPath("parent"); + assertEquals(1, path.length); + assertEquals("parent", path[0]); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java new file mode 100644 index 0000000000..f117db8c85 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +public class PerFieldDerivedVectorInjectorFactoryTests extends KNNTestCase { + public void testCreate() { + // Non-nested case + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), + new DerivedSourceReaders(null, null, null, null), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof RootPerFieldDerivedVectorInjector); + + // Nested case + perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("parent.test").build(), + new DerivedSourceReaders(null, null, null, null), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof NestedPerFieldDerivedVectorInjector); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java new file mode 100644 index 0000000000..3e015b09c8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.search.DocIdSetIterator; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; +import org.opensearch.knn.index.vectorvalues.VectorValueExtractorStrategy; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; + +public class RootPerFieldDerivedVectorInjectorTests extends KNNTestCase { + public static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + @SneakyThrows + public void testInject() { + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder(FIELD_NAME).build(); + try (MockedStatic mockedKnnVectorValues = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + mockedKnnVectorValues.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, null, null)) + .thenReturn(new KNNVectorValues(new KNNVectorValuesIterator() { + @Override + public int docId() { + return 0; + } + + @Override + public int advance(int docId) { + return 0; + } + + @Override + public int nextDoc() { + return 0; + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return null; + } + + @Override + public long liveDocs() { + return 0; + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return null; + } + }) { + + @Override + public float[] getVector() { + return TEST_VECTOR; + } + + @Override + public float[] conditionalCloneVector() { + return TEST_VECTOR; + } + }); + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = new RootPerFieldDerivedVectorInjector( + fieldInfo, + new DerivedSourceReaders(null, null, null, null) + ); + + Map source = new HashMap<>(); + perFieldDerivedVectorInjector.inject(0, source); + assertArrayEquals(TEST_VECTOR, (float[]) source.get(FIELD_NAME), 0.0001f); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 5d4629fecf..b0c3d2ffe7 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -123,7 +123,8 @@ public void testBuilder_getParameters() { modelDao, CURRENT, null, - new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null, SpaceType.UNDEFINED.getValue()) + new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null, SpaceType.UNDEFINED.getValue()), + false ); assertEquals(10, builder.getParameters().size()); @@ -357,7 +358,7 @@ public void testTypeParser_withSpaceTypeAndMode_thenSuccess() throws IOException public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null, false); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -1176,6 +1177,7 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT when(parseContext.parser()).thenReturn(createXContentParser(dataType)); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useFullFieldNameValidation(Mockito.any())).thenReturn(true); OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( dataType, @@ -1197,7 +1199,8 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT new Explicit<>(true, true), false, false, - originalMappingParameters + originalMappingParameters, + false ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -1220,6 +1223,7 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT ); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false); + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useFullFieldNameValidation(Mockito.any())).thenReturn(false); document = new ParseContext.Document(); contentPath = new ContentPath(); @@ -1236,7 +1240,8 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT new Explicit<>(true, true), false, false, - originalMappingParameters + originalMappingParameters, + false ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -1284,6 +1289,7 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy when(parseContext.parser()).thenReturn(createXContentParser(dataType)); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useFullFieldNameValidation(Mockito.any())).thenReturn(true); OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( VectorDataType.DEFAULT, @@ -1308,7 +1314,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy modelDao, CURRENT, originalMappingParameters, - knnMethodConfigContext + knnMethodConfigContext, + false ); modelFieldMapper.parseCreateField(parseContext); @@ -1331,6 +1338,7 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy ); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false); + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useFullFieldNameValidation(Mockito.any())).thenReturn(true); document = new ParseContext.Document(); contentPath = new ContentPath(); @@ -1350,7 +1358,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy modelDao, CURRENT, originalMappingParameters, - knnMethodConfigContext + knnMethodConfigContext, + false ); modelFieldMapper.parseCreateField(parseContext); @@ -1396,7 +1405,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { Collections.emptyMap(), knnMethodConfigContext, inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1455,7 +1465,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { Collections.emptyMap(), knnMethodConfigContext, inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1502,7 +1513,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .dimension(TEST_DIMENSION) .build(), inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1552,7 +1564,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .dimension(TEST_DIMENSION) .build(), inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1667,7 +1680,7 @@ public void testTypeParser_whenBinaryFaissHNSWWithSQ_thenException() throws IOEx public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null, false); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1716,7 +1729,7 @@ public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { // IllegalArgumentException should be thrown. Exception e = assertThrows(IllegalArgumentException.class, () -> { - new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null).build(builderContext); + new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null, false).build(builderContext); }); assertTrue(e.getMessage(), e.getMessage().contains("Vector field name must not include")); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index d0fa150a55..8bcf9fdbeb 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -77,4 +77,13 @@ public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0)); Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); } + + /** + * Test useFullFieldNameValidation method for different OpenSearch versions + */ + public void testUseFullFieldNameValidation() { + Assert.assertFalse(KNNVectorFieldMapperUtil.useFullFieldNameValidation(Version.V_2_16_0)); + Assert.assertTrue(KNNVectorFieldMapperUtil.useFullFieldNameValidation(Version.V_2_17_0)); + Assert.assertTrue(KNNVectorFieldMapperUtil.useFullFieldNameValidation(Version.V_2_18_0)); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 8011cc08ca..7a1da87814 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -99,6 +99,8 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT; public class KNNWeightTests extends KNNTestCase { private static final String FIELD_NAME = "target_field"; @@ -146,6 +148,12 @@ public static void setUpClass() throws Exception { knnSettingsMockedStatic.when(KNNSettings::getCircuitBreakerLimit).thenReturn(v); knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); knnSettingsMockedStatic.when(KNNSettings::isKNNPluginEnabled).thenReturn(true); + ByteSizeValue cacheSize = ByteSizeValue.parseBytesSizeValue("1024kb", QUANTIZATION_STATE_CACHE_SIZE_LIMIT); // Setting 1MB as an + // example + when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_SIZE_LIMIT))).thenReturn(cacheSize); + // Mock QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES setting + TimeValue mockTimeValue = TimeValue.timeValueMinutes(10); + when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES))).thenReturn(mockTimeValue); nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); @@ -371,7 +379,7 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() // When no knn fields are available , field info for vector field will be null when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); + assertEquals(KNNScorer.emptyScorer(), knnScorer); } @SneakyThrows @@ -415,7 +423,7 @@ public void testEmptyQueryResults() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); + assertEquals(KNNScorer.emptyScorer(), knnScorer); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index a3b8c69893..d75a9a7bc6 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -72,7 +72,7 @@ public void testResultMapToTopDocs() { } private void assertResultMapToTopDocs(Map perLeafResults, TopDocs topDocs, int k, int offset) { - assertEquals(k, topDocs.totalHits.value); + assertEquals(k, topDocs.totalHits.value()); float previousScore = Float.MAX_VALUE; for (ScoreDoc scoreDoc : topDocs.scoreDocs) { assertTrue(perLeafResults.containsKey(scoreDoc.doc - offset)); diff --git a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java index b32496138c..607699b56e 100644 --- a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java @@ -6,8 +6,12 @@ package org.opensearch.knn.index.query.common; import lombok.SneakyThrows; +import org.apache.lucene.document.Document; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -15,9 +19,14 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; +import org.apache.lucene.store.ByteBuffersDirectory; + +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.mockito.Mock; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; + import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; @@ -27,9 +36,6 @@ public class DocAndScoreQueryTests extends OpenSearchTestCase { private LeafReaderContext leaf1; @Mock private IndexSearcher indexSearcher; - @Mock - private IndexReader reader; - @Mock private IndexReaderContext readerContext; private DocAndScoreQuery objectUnderTest; @@ -39,9 +45,9 @@ public void setUp() throws Exception { super.setUp(); openMocks(this); + IndexReader reader = createTestIndexReader(); when(indexSearcher.getIndexReader()).thenReturn(reader); - when(reader.getContext()).thenReturn(readerContext); - when(readerContext.id()).thenReturn(1); + readerContext = reader.getContext(); } // Note: cannot test with multi leaf as there LeafReaderContext is readonly with no getters for some fields to mock @@ -50,7 +56,7 @@ public void testScorer() throws Exception { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id()); // When Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); @@ -85,7 +91,7 @@ public void testWeight() { Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); // When - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id()); Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); Explanation explanation = weight.explain(leaf1, 1); @@ -96,4 +102,13 @@ public void testWeight() { assertEquals(expectedExplanation, explanation); assertEquals(Explanation.noMatch("not in top 4"), weight.explain(leaf1, 9)); } + + private IndexReader createTestIndexReader() throws IOException { + ByteBuffersDirectory directory = new ByteBuffersDirectory(); + IndexWriter writer = new IndexWriter(directory, new IndexWriterConfig(new MockAnalyzer(random()))); + writer.addDocument(new Document()); + writer.close(); + return DirectoryReader.open(directory); + } + } diff --git a/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java b/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java index 55a110f6a8..ecda53e1b8 100644 --- a/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java @@ -7,8 +7,9 @@ import junit.framework.TestCase; import lombok.SneakyThrows; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; @@ -17,6 +18,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; import org.apache.lucene.util.Bits; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -49,17 +52,38 @@ public void setUp() throws Exception { @SneakyThrows public void testCreateWeight_whenCalled_thenSucceed() { - LeafReaderContext leafReaderContext1 = mock(LeafReaderContext.class); - LeafReaderContext leafReaderContext2 = mock(LeafReaderContext.class); - List leafReaderContexts = Arrays.asList(leafReaderContext1, leafReaderContext2); + Directory directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + + IndexReader reader = DirectoryReader.open(directory); - IndexReader indexReader = mock(IndexReader.class); - when(indexReader.leaves()).thenReturn(leafReaderContexts); + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + LeafReaderContext leaf1 = leaves.get(0); + LeafReaderContext leaf2 = leaves.get(1); Weight filterWeight = mock(Weight.class); IndexSearcher indexSearcher = mock(IndexSearcher.class); - when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(indexSearcher.getIndexReader()).thenReturn(reader); when(indexSearcher.getTaskExecutor()).thenReturn(taskExecutor); when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE_NO_SCORES), eq(1.0F))).thenReturn(filterWeight); @@ -97,10 +121,10 @@ public void testCreateWeight_whenCalled_thenSucceed() { when(finalQuery.createWeight(indexSearcher, scoreMode, boost)).thenReturn(expectedWeight); QueryUtils queryUtils = mock(QueryUtils.class); - when(queryUtils.doSearch(indexSearcher, leafReaderContexts, queryWeight)).thenReturn(perLeafResults); + when(queryUtils.doSearch(indexSearcher, reader.leaves(), queryWeight)).thenReturn(perLeafResults); when(queryUtils.createBits(any(), any())).thenReturn(queryFilterBits); when(queryUtils.getAllSiblings(any(), any(), any(), any())).thenReturn(allSiblings); - when(queryUtils.createDocAndScoreQuery(eq(indexReader), any())).thenReturn(finalQuery); + when(queryUtils.createDocAndScoreQuery(eq(reader), any())).thenReturn(finalQuery); // Run ExpandNestedDocsQuery query = new ExpandNestedDocsQuery(internalQuery, queryUtils); @@ -108,12 +132,12 @@ public void testCreateWeight_whenCalled_thenSucceed() { // Verify assertEquals(expectedWeight, finalWeigh); - verify(queryUtils).createBits(leafReaderContext1, filterWeight); - verify(queryUtils).createBits(leafReaderContext2, filterWeight); - verify(queryUtils).getAllSiblings(leafReaderContext1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); - verify(queryUtils).getAllSiblings(leafReaderContext2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).createBits(leaf1, filterWeight); + verify(queryUtils).createBits(leaf2, filterWeight); + verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); - verify(queryUtils).createDocAndScoreQuery(eq(indexReader), topDocsCaptor.capture()); + verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture()); TopDocs capturedTopDocs = topDocsCaptor.getValue(); assertEquals(topK.totalHits, capturedTopDocs.totalHits); for (int i = 0; i < topK.scoreDocs.length; i++) { diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 87c4a50148..82dc43bf35 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -6,10 +6,9 @@ package org.opensearch.knn.index.query.nativelib; import lombok.SneakyThrows; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexReaderContext; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; @@ -21,6 +20,9 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.util.Bits; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -38,6 +40,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -48,34 +51,26 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.openMocks; public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @Mock private IndexSearcher searcher; - @Mock private IndexReader reader; + private Directory directory; + private DirectoryReader directoryReader; @Mock private KNNQuery knnQuery; @Mock private KNNWeight knnWeight; @Mock private TaskExecutor taskExecutor; - @Mock private IndexReaderContext indexReaderContext; - @Mock private LeafReaderContext leaf1; - @Mock private LeafReaderContext leaf2; - @Mock private LeafReader leafReader1; - @Mock private LeafReader leafReader2; @Mock @@ -90,12 +85,10 @@ public void setUp() throws Exception { super.setUp(); openMocks(this); objectUnderTest = new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, false); - when(leaf1.reader()).thenReturn(leafReader1); - when(leaf2.reader()).thenReturn(leafReader2); - + reader = createTestIndexReader(); + indexReaderContext = reader.getContext(); when(searcher.getIndexReader()).thenReturn(reader); when(knnQuery.createWeight(searcher, scoreMode, 1)).thenReturn(knnWeight); - when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { List> callables = invocationOnMock.getArgument(0); @@ -105,11 +98,7 @@ public void setUp() throws Exception { } return results; }); - - when(reader.getContext()).thenReturn(indexReaderContext); - when(clusterService.state()).thenReturn(mock(ClusterState.class)); // Mock ClusterState - // Set ClusterService in KNNSettings KNNSettings.state().setClusterService(clusterService); when(knnQuery.getQueryVector()).thenReturn(new float[] { 1.0f, 2.0f, 3.0f }); // Example vector @@ -118,29 +107,71 @@ public void setUp() throws Exception { @SneakyThrows public void testMultiLeaf() { - // Given - List leaves = List.of(leaf1, leaf2); - when(reader.leaves()).thenReturn(leaves); + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + + // Initialize DirectoryReader and IndexSearcher + // Open the real DirectoryReader + DirectoryReader originalReader = DirectoryReader.open(directory); + // Define liveDocs for each segment + Bits liveDocs1 = new Bits() { + @Override + public boolean get(int index) { + return index != 1 && index != 2; // Document 1 and 2 are deleted + } + + @Override + public int length() { + return originalReader.leaves().get(0).reader().maxDoc(); + } + }; + + Bits liveDocs2 = null; // No deletions in the second segment + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); + // Given PerLeafResult leaf1Result = new PerLeafResult(null, new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); PerLeafResult leaf2Result = new PerLeafResult(null, new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))); when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leaf1Result); when(knnWeight.searchLeaf(leaf2, 4)).thenReturn(leaf2Result); - - // Making sure there is deleted docs in one of the segments - Bits liveDocs = mock(Bits.class); - when(leafReader1.getLiveDocs()).thenReturn(liveDocs); - when(leafReader2.getLiveDocs()).thenReturn(null); - - when(liveDocs.get(anyInt())).thenReturn(true); - when(liveDocs.get(2)).thenReturn(false); - when(liveDocs.get(1)).thenReturn(false); + when(searcher.getIndexReader()).thenReturn(reader); // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves when(knnQuery.getK()).thenReturn(4); - when(indexReaderContext.id()).thenReturn(1); - Map leaf1ResultLive = Map.of(0, 1.2f); TopDocs[] topDocs = { ResultUtil.resultMapToTopDocs(leaf1ResultLive, leaf1.docBase), @@ -157,9 +188,48 @@ public void testMultiLeaf() { @SneakyThrows public void testRescoreWhenShardLevelRescoringEnabled() { - // Given - List leaves = List.of(leaf1, leaf2); - when(reader.leaves()).thenReturn(leaves); + + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + Bits liveDocs1 = null; + + Bits liveDocs2 = null; + + DirectoryReader originalReader = DirectoryReader.open(directory); + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); int k = 2; PerLeafResult initialLeaf1Results = new PerLeafResult(null, new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f))); @@ -176,6 +246,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { when(knnWeight.searchLeaf(leaf2, firstPassK)).thenReturn(initialLeaf2Results); when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); + when(searcher.getIndexReader()).thenReturn(reader); try ( MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); @@ -205,11 +276,10 @@ public void testSingleLeaf() { int k = 4; float boost = 1; PerLeafResult leaf1Result = new PerLeafResult(null, new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); - List leaves = List.of(leaf1); - when(reader.leaves()).thenReturn(leaves); + List leaves = reader.leaves(); + leaf1 = leaves.get(0); when(knnWeight.searchLeaf(leaf1, k)).thenReturn(leaf1Result); when(knnQuery.getK()).thenReturn(k); - when(indexReaderContext.id()).thenReturn(1); TopDocs expectedTopDocs = ResultUtil.resultMapToTopDocs(leaf1Result.getResult(), leaf1.docBase); // When @@ -223,8 +293,8 @@ public void testSingleLeaf() { @SneakyThrows public void testNoMatch() { // Given - List leaves = List.of(leaf1); - when(reader.leaves()).thenReturn(leaves); + List leaves = reader.leaves(); + leaf1 = leaves.get(0); when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(PerLeafResult.EMPTY_RESULT); when(knnQuery.getK()).thenReturn(4); @@ -238,8 +308,47 @@ public void testNoMatch() { @SneakyThrows public void testRescore() { // Given - List leaves = List.of(leaf1, leaf2); - when(reader.leaves()).thenReturn(leaves); + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + Bits liveDocs1 = null; + + Bits liveDocs2 = null; + + DirectoryReader originalReader = DirectoryReader.open(directory); + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); int k = 2; int firstPassK = 100; @@ -249,7 +358,6 @@ public void testRescore() { Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(Map.of(1, 20f), 0); TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(Map.of(0, 21f), 4); - when(indexReaderContext.id()).thenReturn(1); when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); when(knnQuery.getK()).thenReturn(k); when(knnWeight.getQuery()).thenReturn(knnQuery); @@ -258,6 +366,7 @@ public void testRescore() { when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); + when(searcher.getIndexReader()).thenReturn(reader); try ( MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); @@ -286,8 +395,47 @@ public void testRescore() { @SneakyThrows public void testExpandNestedDocs() { - List leafReaderContexts = Arrays.asList(leaf1, leaf2); - when(reader.leaves()).thenReturn(leafReaderContexts); + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + Bits liveDocs1 = null; + + Bits liveDocs2 = null; + + DirectoryReader originalReader = DirectoryReader.open(directory); + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); Bits queryFilterBits = mock(Bits.class); PerLeafResult initialLeaf1Results = new PerLeafResult(queryFilterBits, new HashMap<>(Map.of(0, 19f, 1, 20f, 2, 17f, 3, 15f))); PerLeafResult initialLeaf2Results = new PerLeafResult(queryFilterBits, new HashMap<>(Map.of(0, 21f, 1, 18f, 2, 16f, 3, 14f))); @@ -296,9 +444,10 @@ public void testExpandNestedDocs() { Map exactSearchLeaf1Result = new HashMap<>(Map.of(1, 20f)); Map exactSearchLeaf2Result = new HashMap<>(Map.of(0, 21f)); - TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(exactSearchLeaf1Result, 0); - TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(exactSearchLeaf2Result, 0); + TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(exactSearchLeaf1Result, leaf1.docBase); + TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(exactSearchLeaf2Result, leaf2.docBase); TopDocs topK = TopDocs.merge(2, new TopDocs[] { topDocs1, topDocs2 }); + when(searcher.getIndexReader()).thenReturn(reader); int k = 2; when(knnQuery.getRescoreContext()).thenReturn(null); @@ -308,7 +457,8 @@ public void testExpandNestedDocs() { when(knnQuery.getParentsFilter()).thenReturn(parentFilter); when(knnWeight.searchLeaf(leaf1, k)).thenReturn(initialLeaf1Results); when(knnWeight.searchLeaf(leaf2, k)).thenReturn(initialLeaf2Results); - when(knnWeight.exactSearch(any(), any())).thenReturn(exactSearchLeaf1Result, exactSearchLeaf2Result); + when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(exactSearchLeaf1Result); + when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(exactSearchLeaf2Result); Weight filterWeight = mock(Weight.class); when(knnWeight.getFilterWeight()).thenReturn(filterWeight); @@ -350,4 +500,99 @@ public void testExpandNestedDocs() { assertEquals(2, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); assertEquals(DocIdSetIterator.NO_MORE_DOCS, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); } + + private IndexReader createTestIndexReader() throws IOException { + ByteBuffersDirectory directory = new ByteBuffersDirectory(); + IndexWriter writer = new IndexWriter(directory, new IndexWriterConfig(new MockAnalyzer(random()))); + writer.addDocument(new Document()); + writer.close(); + return DirectoryReader.open(directory); + } +} + +class CustomFilterDirectoryReader extends FilterDirectoryReader { + + private final Bits liveDocs1; + private final Bits liveDocs2; + + protected CustomFilterDirectoryReader(DirectoryReader in, Bits liveDocs1, Bits liveDocs2) throws IOException { + super(in, getWrapper(liveDocs1, liveDocs2)); + this.liveDocs1 = liveDocs1; + this.liveDocs2 = liveDocs2; + } + + @Override + protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException { + return new CustomFilterDirectoryReader(in, liveDocs1, liveDocs2); + } + + private static SubReaderWrapper getWrapper(Bits liveDocs1, Bits liveDocs2) { + return new SubReaderWrapper() { + @Override + public LeafReader wrap(LeafReader reader) { + if (reader.getContext().ord == 0) { // First segment + return new FilterLeafReader(reader) { + /** + * @return + */ + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } + + /** + * @return + */ + @Override + public CacheHelper getCoreCacheHelper() { + return null; + } + + @Override + public Bits getLiveDocs() { + return liveDocs1; + } + }; + } else if (reader.getContext().ord == 1) { // Second segment + return new FilterLeafReader(reader) { + /** + * @return + */ + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } + + /** + * @return + */ + @Override + public CacheHelper getCoreCacheHelper() { + return null; + } + + @Override + public Bits getLiveDocs() { + return liveDocs2; + } + }; + } else { + return reader; // Default case + } + } + }; + } + + // Remove the static modifier to fix the error + public static DirectoryReader wrap(DirectoryReader reader, Bits liveDocs1, Bits liveDocs2) throws IOException { + return new CustomFilterDirectoryReader(reader, liveDocs1, liveDocs2); + } + + /** + * @return + */ + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } } diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java index 0b631ab416..99cb383e50 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java @@ -8,7 +8,6 @@ import lombok.SneakyThrows; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.search.DocIdSetIterator; -import org.junit.Assert; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; @@ -37,7 +36,6 @@ public void testFloatVectorValues_whenValidInput_thenSuccess() { vectorsMap ); new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, floatArray, 8, dimension, false); - final TestVectorValues.PredefinedFloatVectorBinaryDocValues preDefinedFloatVectorValues = new TestVectorValues.PredefinedFloatVectorBinaryDocValues(floatArray); final KNNVectorValues knnFloatVectorValuesBinaryDocValues = KNNVectorValuesFactory.getVectorValues( @@ -101,13 +99,6 @@ public void testBinaryVectorValues_whenValidInput_thenSuccess() { new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, 3, dimension, false); } - public void testDocIdsIteratorValues_whenInvalidDisi_thenThrowException() { - Assert.assertThrows( - IllegalArgumentException.class, - () -> new KNNVectorValuesIterator.DocIdsIteratorValues(new TestVectorValues.NotBinaryDocValues()) - ); - } - private DocsWithFieldSet getDocIdSetIterator(int numberOfDocIds) { final DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); for (int i = 0; i < numberOfDocIds; i++) { diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java index 0f15d5240c..337ab6c489 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -5,14 +5,7 @@ package org.opensearch.knn.index.vectorvalues; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.*; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; @@ -116,6 +109,18 @@ public SortedSetDocValues getSortedSet(FieldInfo field) { return null; } + /** + * Returns a {@link DocValuesSkipper} for this field. The returned instance need not be + * thread-safe: it will only be used by a single thread. The return value is undefined if {@link + * FieldInfo#docValuesSkipIndexType()} returns {@link DocValuesSkipIndexType#NONE}. + * + * @param field + */ + @Override + public DocValuesSkipper getSkipper(FieldInfo field) throws IOException { + return null; + } + @Override public void checkIntegrity() { @@ -204,7 +209,6 @@ public int size() { return count; } - @Override public float[] vectorValue() throws IOException { // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an @@ -213,29 +217,49 @@ public float[] vectorValue() throws IOException { return vector; } + @Override + public float[] vectorValue(int ordId) throws IOException { + // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we + // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an + // experience similar to what we get in prod. + System.arraycopy(vectors.get(ordId), 0, vector, 0, dimension); + return vector; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + /** + * @return + * @throws IOException + */ + @Override + public FloatVectorValues copy() throws IOException { + return null; + } + @Override public VectorScorer scorer(float[] query) throws IOException { throw new UnsupportedOperationException("scorer not supported with PreDefinedFloatVectorValues"); } - @Override public int docID() { if (this.current > this.count) { - return FloatVectorValues.NO_MORE_DOCS; + return DocIndexIterator.NO_MORE_DOCS; } return this.current; } - @Override public int nextDoc() throws IOException { return advance(current + 1); } - @Override public int advance(int target) throws IOException { current = target; if (current >= count) { - current = NO_MORE_DOCS; + current = DocIndexIterator.NO_MORE_DOCS; } return current; } @@ -267,7 +291,6 @@ public int size() { return count; } - @Override public byte[] vectorValue() throws IOException { // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an @@ -276,29 +299,49 @@ public byte[] vectorValue() throws IOException { return vector; } + @Override + public byte[] vectorValue(int ordId) throws IOException { + // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we + // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an + // experience similar to what we get in prod. + System.arraycopy(vectors.get(ordId), 0, vector, 0, dimension); + return vector; + } + + /** + * @return + * @throws IOException + */ + @Override + public ByteVectorValues copy() throws IOException { + return null; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public VectorScorer scorer(byte[] query) throws IOException { throw new UnsupportedOperationException("scorer not supported with PreDefinedFloatVectorValues"); } - @Override public int docID() { if (this.current > this.count) { - return FloatVectorValues.NO_MORE_DOCS; + return DocIndexIterator.NO_MORE_DOCS; } return this.current; } - @Override public int nextDoc() throws IOException { return advance(current + 1); } - @Override public int advance(int target) throws IOException { current = target; if (current >= count) { - current = NO_MORE_DOCS; + current = DocIndexIterator.NO_MORE_DOCS; } return current; } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 1edb5cff20..e011ba57f4 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -29,7 +29,7 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.knn.KNNSingleNodeTestCase; diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java new file mode 100644 index 0000000000..ad5ef811b5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -0,0 +1,1252 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.primitives.Floats; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; + +/** + * Integration tests for derived source feature for vector fields. Currently, with derived source, there are + * a few gaps in functionality. Ignoring tests for now as feature is experimental. + */ +public class DerivedSourceIT extends KNNRestTestCase { + + private final static String NESTED_NAME = "test_nested"; + private final static String FIELD_NAME = "test_vector"; + private final int TEST_DIMENSION = 128; + private final int DOCS = 50; + + private static final Settings DERIVED_ENABLED_SETTINGS = Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, true) + .build(); + private static final Settings DERIVED_DISABLED_SETTINGS = Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, false) + .build(); + + /** + * Testing flat, single field base case with index configuration. The test will automatically skip adding fields for + * random documents to ensure it works robustly. To ensure correctness, we repeat same operations against an + * index without derived source enabled (baseline). + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * } + * } + * } + * } + */ + @SneakyThrows + public void testFlatBaseCase() { + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing multiple flat fields. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + */ + @SneakyThrows + public void testMultiFlatFields() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME + "1") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject(FIELD_NAME + "2") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject("text") + .field(TYPE, "text") + .endObject() + .endObject() + .endObject(); + String multiFieldMapping = builder.toString(); + + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndMultFields( + context.indexName, + context.vectorFieldNames.get(0), + context.vectorFieldNames.get(1), + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndMultFields( + context.indexName, + context.vectorFieldNames.get(0), + context.vectorFieldNames.get(1), + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing single nested doc per parent doc. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + */ + public void testNestedSingleDocBasic() { + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNested( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNested( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing single nested doc per parent doc. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testNestedMultiDocBasic() { + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f, + 5 + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f, + 5 + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Test object (non-nested field) + * Test + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } + * } + * } + * Baseline + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testObjectFieldTypes() { + String PATH_1_NAME = "path_1"; + String PATH_2_NAME = "path_2"; + + String objectFieldTypeMapping = XContentFactory.jsonBuilder() + .startObject() // 1-open + .startObject(PROPERTIES_FIELD) // 2-open + .startObject(FIELD_NAME + "1") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + + .startObject(PATH_1_NAME) + .startObject(PROPERTIES_FIELD) + + .startObject(FIELD_NAME + "2") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject(PATH_2_NAME) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME + "3") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsMultiFieldsWithSkips( + context.indexName, + context.vectorFieldNames, + List.of("text", PATH_1_NAME + "." + "text", PATH_1_NAME + "." + PATH_2_NAME + "." + "text"), + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsMultiFieldsWithSkips( + context.indexName, + context.vectorFieldNames, + List.of("text", PATH_1_NAME + "." + "text", PATH_1_NAME + "." + PATH_2_NAME + "." + "text"), + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Single method for running end to end tests for different index configurations for derived source. In general, + * flow of operations are + * + * @param indexConfigContexts {@link IndexConfigContext} + */ + @SneakyThrows + private void testDerivedSourceE2E(List indexConfigContexts) { + // Make sure there are 6 + assertEquals(6, indexConfigContexts.size()); + + // Prepare the indices by creating them and ingesting data into them + prepareOriginalIndices(indexConfigContexts); + + // Merging + testMerging(indexConfigContexts); + + // Update. Skipping update tests for nested docs for now. Will add in the future. + if (indexConfigContexts.get(0).isNested() == false) { + testUpdate(indexConfigContexts); + } + + // Delete + testDelete(indexConfigContexts); + + // Search + testSearch(indexConfigContexts); + + // Reindex + testReindex(indexConfigContexts); + } + + @SneakyThrows + private void prepareOriginalIndices(List indexConfigContexts) { + assertEquals(6, indexConfigContexts.size()); + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + createKnnIndex(derivedSourceEnabledContext.indexName, derivedSourceEnabledContext.settings, derivedSourceEnabledContext.mapping); + createKnnIndex(derivedSourceDisabledContext.indexName, derivedSourceDisabledContext.settings, derivedSourceDisabledContext.mapping); + derivedSourceEnabledContext.indexIngestor.accept(derivedSourceEnabledContext); + derivedSourceDisabledContext.indexIngestor.accept(derivedSourceDisabledContext); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + derivedSourceDisabledContext.indexName, + derivedSourceEnabledContext.indexName + ); + } + + @SneakyThrows + private void testMerging(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + + @SneakyThrows + private void testUpdate(List indexConfigContexts) { + // Random variables + int docWithVectorUpdate = DOCS - 4; + int docWithVectorRemoval = 1; + int docWithVectorUpdateFromAPI = 2; + int docWithUpdateByQuery = 7; + + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + float[] updateVector = randomFloatVector(derivedSourceDisabledContext.dimension); + + // Update via POST //_doc/ + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDoc( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdate), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDoc( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdate), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Sets the doc to an empty doc + setDocToEmpty(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorRemoval)); + setDocToEmpty(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorRemoval)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Use update API + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdateFromAPI), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdateFromAPI), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Update by query + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDocByQuery( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithUpdateByQuery), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDocByQuery( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithUpdateByQuery), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + + @SneakyThrows + private void testSearch(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + + // Default - all fields should be there + validateSearch(originalIndexNameDerivedSourceEnabled, derivedSourceEnabledContext.docCount, true, null, null); + + // Default - no fields should be there + validateSearch(originalIndexNameDerivedSourceEnabled, derivedSourceEnabledContext.docCount, false, null, null); + + // Exclude all vectors + validateSearch( + originalIndexNameDerivedSourceEnabled, + derivedSourceEnabledContext.docCount, + true, + null, + derivedSourceEnabledContext.vectorFieldNames + ); + + // Include all vectors + validateSearch( + originalIndexNameDerivedSourceEnabled, + derivedSourceEnabledContext.docCount, + true, + derivedSourceEnabledContext.vectorFieldNames, + null + ); + } + + @SneakyThrows + private void validateSearch(String indexName, int size, boolean isSourceEnabled, List includes, List excludes) { + // TODO: We need to figure out a way to enhance validation + QueryBuilder qb = new MatchAllQueryBuilder(); + Request request = new Request("POST", "/" + indexName + "/_search"); + + request.addParameter("size", Integer.toString(size)); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("query", qb); + if (isSourceEnabled == false) { + builder.field("_source", false); + } + if (includes != null) { + builder.startObject("_source"); + builder.startArray("includes"); + for (String include : includes) { + builder.value(include); + } + builder.endArray(); + builder.endObject(); + } + if (excludes != null) { + builder.startObject("_source"); + builder.startArray("excludes"); + for (String exclude : excludes) { + builder.value(exclude); + } + builder.endArray(); + builder.endObject(); + } + + builder.endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + List hits = parseSearchResponseHits(responseBody); + + assertNotEquals(0, hits.size()); + } + + @SneakyThrows + private void testDelete(List indexConfigContexts) { + int docToDelete = 8; + int docToDeleteByQuery = 11; + + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + + // Delete by API + deleteKnnDoc(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDelete)); + deleteKnnDoc(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDelete)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Delete by query + deleteKnnDocByQuery(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDeleteByQuery)); + deleteKnnDocByQuery(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDeleteByQuery)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + + @SneakyThrows + private void testReindex(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + IndexConfigContext reindexFromEnabledToEnabledContext = indexConfigContexts.get(2); + IndexConfigContext reindexFromEnabledToDisabledContext = indexConfigContexts.get(3); + IndexConfigContext reindexFromDisabledToEnabledContext = indexConfigContexts.get(4); + IndexConfigContext reindexFromDisabledToDisabledContext = indexConfigContexts.get(5); + + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + String reindexFromEnabledToEnabledIndexName = reindexFromEnabledToEnabledContext.indexName; + String reindexFromEnabledToDisabledIndexName = reindexFromEnabledToDisabledContext.indexName; + String reindexFromDisabledToEnabledIndexName = reindexFromDisabledToEnabledContext.indexName; + String reindexFromDisabledToDisabledIndexName = reindexFromDisabledToDisabledContext.indexName; + + createKnnIndex( + reindexFromEnabledToEnabledIndexName, + reindexFromEnabledToEnabledContext.getSettings(), + reindexFromEnabledToEnabledContext.getMapping() + ); + createKnnIndex( + reindexFromEnabledToDisabledIndexName, + reindexFromEnabledToDisabledContext.getSettings(), + reindexFromEnabledToDisabledContext.getMapping() + ); + createKnnIndex( + reindexFromDisabledToEnabledIndexName, + reindexFromDisabledToEnabledContext.getSettings(), + reindexFromDisabledToEnabledContext.getMapping() + ); + createKnnIndex( + reindexFromDisabledToDisabledIndexName, + reindexFromDisabledToDisabledContext.getSettings(), + reindexFromDisabledToDisabledContext.getMapping() + ); + refreshAllIndices(); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToDisabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); + + // Need to forcemerge before comparison + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + assertIndexBigger(reindexFromEnabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertIndexBigger(reindexFromDisabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToDisabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToDisabledIndexName + ); + } + + @Builder + @Data + private static class IndexConfigContext { + String indexName; + List vectorFieldNames; + int dimension; + Settings settings; + String mapping; + boolean isNested; + int docCount; + CheckedConsumer indexIngestor; + } + + @SneakyThrows + private void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { + if (isExhaustive()) { + logger.info("Checking index bigger assertion because running in exhaustive mode"); + int expectedSmaller = indexSizeInBytes(expectedSmallerIndex); + int expectedBigger = indexSizeInBytes(expectedBiggerIndex); + assertTrue( + "Expected smaller index " + expectedSmaller + " was bigger than the expected bigger index:" + expectedBigger, + expectedSmaller < expectedBigger + ); + } else { + logger.info("Skipping index bigger assertion because not running in exhaustive mode"); + } + } + + private void assertDocsMatch(int docCount, String index1, String index2) { + for (int i = 0; i < docCount; i++) { + assertDocMatches(i + 1, index1, index2); + } + } + + @SneakyThrows + private void assertDocMatches(int docId, String index1, String index2) { + Map response1 = getKnnDoc(index1, String.valueOf(docId)); + Map response2 = getKnnDoc(index2, String.valueOf(docId)); + assertEquals("Docs do not match: " + docId, response1, response2); + } + + @SneakyThrows + private String createVectorNonNestedMappings(final int dimension) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } + + @SneakyThrows + private String createVectorNestedMappings(final int dimension) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(NESTED_NAME) + .field(TYPE, "nested") + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } +} diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index 42a4f2b956..6288e2076a 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -856,7 +856,7 @@ private void createIndexAndAssertScriptScore( dense, vectorDataType ); - final float[] dummyVector = new float[1]; + float[] dummyVector = new float[1]; dataset.forEach((k, v) -> { final float[] vector = (v != null) ? v.getVector() : dummyVector; ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); @@ -886,4 +886,8 @@ private void createIndexAndAssertScriptScore( deleteKNNIndex(INDEX_NAME); } } + + private float[] dummyFloatArrayBasedOnDimension(int dimesion) { + return new float[dimesion]; + } } diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index e116ef3c60..37c00a104b 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -601,7 +601,7 @@ public void testLoadIndex_faiss_sqfp16_valid() { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -669,7 +669,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -877,7 +877,7 @@ public void testLoadIndex_nmslib_valid() throws IOException { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -938,7 +938,7 @@ public void testLoadIndex_nmslib_valid_with_stream() throws IOException { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -985,7 +985,7 @@ public void testLoadIndex_faiss_valid() throws IOException { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1024,7 +1024,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1065,7 +1065,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1108,7 +1108,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1137,7 +1137,7 @@ public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOEx assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1174,7 +1174,7 @@ public void testQueryIndex_faiss_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1310,7 +1310,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1425,7 +1425,7 @@ public void testQueryBinaryIndex_faiss_valid() { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1561,7 +1561,7 @@ public void testFree_nmslib_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1595,7 +1595,7 @@ public void testFree_faiss_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1800,7 +1800,7 @@ public void createIndexFromTemplate() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1862,7 +1862,7 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { String indexIVFPQPath = createFaissIVFPQIndex(directory, ivfNlist, pqM, pqCodeSize, SpaceType.L2); final long indexIVFPQIndexTest1; - try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); indexIVFPQIndexTest1 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, indexIVFPQIndexTest1); @@ -1871,7 +1871,7 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { throw e; } final long indexIVFPQIndexTest2; - try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); indexIVFPQIndexTest2 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, indexIVFPQIndexTest2); @@ -1891,7 +1891,7 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { JNIService.free(indexIVFPQIndexTest1, KNNEngine.FAISS); final long indexIVFPQIndexTest3; - try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); indexIVFPQIndexTest3 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, indexIVFPQIndexTest3); @@ -1919,7 +1919,7 @@ public void testIsIndexIVFPQL2() { Path tempDirPath = createTempDir(); try (Directory directory = newFSDirectory(tempDirPath)) { String faissIVFPQL2Index = createFaissIVFPQIndex(directory, 16, 16, 4, SpaceType.L2); - try (IndexInput indexInput = directory.openInput(faissIVFPQL2Index, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(faissIVFPQL2Index, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long faissIVFPQL2Address = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertTrue(JNIService.isSharedIndexStateRequired(faissIVFPQL2Address, KNNEngine.FAISS)); @@ -1927,7 +1927,7 @@ public void testIsIndexIVFPQL2() { } String faissIVFPQIPIndex = createFaissIVFPQIndex(directory, 16, 16, 4, SpaceType.INNER_PRODUCT); - try (IndexInput indexInput = directory.openInput(faissIVFPQIPIndex, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(faissIVFPQIPIndex, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long faissIVFPQIPAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertFalse(JNIService.isSharedIndexStateRequired(faissIVFPQIPAddress, KNNEngine.FAISS)); @@ -1935,7 +1935,7 @@ public void testIsIndexIVFPQL2() { } String faissHNSWIndex = createFaissHNSWIndex(directory, SpaceType.L2); - try (IndexInput indexInput = directory.openInput(faissHNSWIndex, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(faissHNSWIndex, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long faissHNSWAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertFalse(JNIService.isSharedIndexStateRequired(faissHNSWAddress, KNNEngine.FAISS)); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 2cc20c8f98..a8dd7b1c16 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -7,21 +7,21 @@ import java.util.Arrays; import java.util.Locale; +import org.apache.lucene.util.BytesRef; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.VectorField; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.store.Directory; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import java.io.IOException; import java.math.BigInteger; @@ -354,7 +354,8 @@ public void createKNNVectorDocument(final float[] content, final String fieldNam IndexWriter writer = new IndexWriter(directory, conf); conf.setMergePolicy(NoMergePolicy.INSTANCE); // prevent merges for this test Document knnDocument = new Document(); - knnDocument.add(new BinaryDocValuesField(fieldName, new VectorField(fieldName, content, new FieldType()).binaryValue())); + BytesRef vector = new BytesRef(KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(content)); + knnDocument.add(new BinaryDocValuesField(fieldName, vector)); writer.addDocument(knnDocument); writer.commit(); writer.close(); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index cac5c1b9c2..1bbb388fdf 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -9,7 +9,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index 48b93653f0..f94b661bbe 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -13,7 +13,7 @@ import org.opensearch.Version; import org.opensearch.core.action.ActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.io.stream.BytesStreamOutput; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 7dd1ec237d..ed2ffc54a6 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -26,6 +26,7 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.derivedsource.ParentChildHelper; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -69,6 +70,7 @@ import java.util.Objects; import java.util.Optional; import java.util.PriorityQueue; +import java.util.Random; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -168,6 +170,15 @@ public void cleanUpCache() throws Exception { clearCache(); } + /** + * Gives the ability for certain, more exhaustive checks, to be disabled by default + * + * @return If the test is running in exhaustive mode + */ + protected boolean isExhaustive() { + return Boolean.parseBoolean(System.getProperty("test.exhaustive", "false")); + } + /** * Create KNN Index with default settings */ @@ -271,6 +282,11 @@ protected Response performSearch(final String indexName, final String query, fin return response; } + protected List parseSearchResponseHits(String responseBody) throws IOException { + return (List) ((Map) createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map() + .get("hits")).get("hits"); + } + /** * Parse the response of KNN search into a List of KNNResults */ @@ -374,10 +390,6 @@ protected Double parseAggregationResponse(String responseBody, String aggregatio return Double.valueOf(String.valueOf(values.get("value"))); } - /** - * Parse the score from the KNN search response - */ - /** * Delete KNN index */ @@ -693,6 +705,28 @@ protected void addKnnDocWithNestedField(String index, String docId, String neste assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void addDocWithNestedNumericField(String index, String docId, String nestedFieldPath, long val) throws IOException { + String[] fieldParts = nestedFieldPath.split("\\."); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], val); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(builder.toString()); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Add a single KNN Doc to an index with multiple fields */ @@ -759,15 +793,90 @@ protected void addDocWithBinaryField(String index, String docId, String fieldNam */ protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + String parent = ParentChildHelper.getParentField(fieldName); + if (parent != null) { + builder.startObject(parent).field(fieldName, vector).endObject(); + } else { + builder.field(fieldName, vector); + } + builder.endObject(); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Update a KNN Doc using the POST /\/_update/\. Only the vector field will be updated. + */ + protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request("POST", "/" + index + "/_update/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("doc"); + String parent = ParentChildHelper.getParentField(fieldName); + if (parent != null) { + builder.startObject(parent).field(fieldName, vector).endObject(); + } else { + builder.field(fieldName, vector); + } + builder.endObject().endObject(); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void updateKnnDocByQuery(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request("POST", "/" + index + "/_update_by_query?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("term") + .field("id", docId) + .endObject() + .endObject() + .startObject("script") + .field("source", "ctx._source." + fieldName + " = params.newValue") + .field("lang", "painless") + .startObject("params") + .field("newValue", vector) + .endObject() + .endObject() + .endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void deleteKnnDocByQuery(String index, String docId) throws IOException { + // Put KNN mapping + Request request = new Request("POST", "/" + index + "/_delete_by_query?refresh"); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("term") + .field("id", docId) + .endObject() + .endObject() + .endObject(); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + /** + * Update a KNN Doc with a new vector for the given fieldName + */ + protected void setDocToEmpty(String index, String docId) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().endObject(); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Delete Knn Doc */ @@ -784,6 +893,7 @@ protected void deleteKnnDoc(String index, String docId) throws IOException { */ protected Map getKnnDoc(final String index, final String docId) throws Exception { final Request request = new Request("GET", "/" + index + "/_doc/" + docId); + request.addParameter("ignore", "404"); final Response response = client().performRequest(request); final Map responseMap = createParser( @@ -792,8 +902,8 @@ protected Map getKnnDoc(final String index, final String docId) ).map(); assertNotNull(responseMap); - assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); - assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE)); + // assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); + // assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE)); final Map docMap = (Map) responseMap.get(DOCUMENT_FIELD_SOURCE); @@ -816,6 +926,22 @@ protected void updateClusterSettings(String settingKey, Object value) throws Exc assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void reindex(String source, Object destination) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("source") + .field("index", source) + .endObject() + .startObject("dest") + .field("index", destination) + .endObject() + .endObject(); + Request request = new Request("POST", "_reindex"); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Return default index settings for index creation */ @@ -893,6 +1019,22 @@ protected Response executeKnnStatRequest(List nodeIds, List stat return response; } + @SneakyThrows + protected int indexSizeInBytes(String indexName) throws IOException { + Request request = new Request("GET", indexName + "/_stats" + "/store"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + String responseBody = EntityUtils.toString(response.getEntity()); + + @SuppressWarnings("unchecked") + Integer sizeInBytes = (Integer) ((Map) ((Map) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("_all")).get("primaries")).get("store")).get("size_in_bytes"); + + return sizeInBytes; + } + @SneakyThrows protected void doKnnWarmup(List indices) { Response response = knnWarmup(indices); @@ -1240,15 +1382,205 @@ public Map xContentBuilderToMap(XContentBuilder xContentBuilder) } public void bulkIngestRandomVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { + // TODO: Do better on this one + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); for (int i = 0; i < numVectors; i++) { - float[] vector = new float[dimension]; - for (int j = 0; j < dimension; j++) { - vector[j] = randomFloat(); + float[] vector = vectors[i]; + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + } + } + + public void bulkIngestRandomVectorsWithSkips(String indexName, String fieldName, int numVectors, int dimension, float skipProb) + throws IOException { + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + float[] vector = vectors[i]; + if (random.nextFloat() > skipProb) { + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + } else { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); } + } + } - addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + public void bulkIngestRandomVectorsWithSkipsAndMultFields( + String indexName, + String fieldName1, + String fieldName2, + String fieldName3, + int numVectors, + int dimension, + float skipProb + ) throws IOException { + float[][] vectors1 = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + float[][] vectors2 = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 8); + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + float[] vector1 = vectors1[i]; + float[] vector2 = vectors2[i]; + + boolean includeFieldOne = random.nextFloat() > skipProb; + boolean includeFieldTwo = random.nextFloat() > skipProb; + boolean includeFieldThree = random.nextFloat() > skipProb; + + if (includeFieldOne || includeFieldTwo || includeFieldThree) { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + if (includeFieldOne) { + xContentBuilder.field(fieldName1, vector1); + } + if (includeFieldTwo) { + xContentBuilder.field(fieldName2, vector2); + } + if (includeFieldThree) { + xContentBuilder.field(fieldName3, "test-test"); + } + xContentBuilder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), xContentBuilder.toString()); + } else { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); + } } + } + + @SneakyThrows + public void bulkIngestRandomVectorsMultiFieldsWithSkips( + String indexName, + List vectorFields, + List textFields, + int numVectors, + int dimension, + float skipProb + ) { + List vectors = new ArrayList<>(); + int seed = 1; + for (String ignored : vectorFields) { + vectors.add(TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, seed++)); + } + + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + + List includeVectorFields = new ArrayList<>(); + for (String ignored : vectorFields) { + includeVectorFields.add(random.nextFloat() > skipProb); + } + List includeTextFields = new ArrayList<>(); + for (String ignored : textFields) { + includeTextFields.add(random.nextFloat() > skipProb); + } + + // If all are skipped, just add a random field + if (includeVectorFields.stream().allMatch((t) -> !t) && includeTextFields.stream().allMatch((t) -> !t)) { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); + } else { + Map source = new HashMap<>(); + for (int j = 0; j < includeVectorFields.size(); j++) { + if (includeVectorFields.get(j)) { + String[] fields = ParentChildHelper.splitPath(vectorFields.get(j)); + Map currentMap = source; + for (int k = 0; k < fields.length - 1; k++) { + String field = fields[k]; + Object value = currentMap.get(field); + currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], vectors.get(j)[i]); + } + } + for (int j = 0; j < includeTextFields.size(); j++) { + if (includeTextFields.get(j)) { + String[] fields = ParentChildHelper.splitPath(textFields.get(j)); + Map currentMap = source; + for (int k = 0; k < fields.length - 1; k++) { + String field = fields[k]; + Object value = currentMap.get(field); + currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], "test-test"); + } + } + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + mapToBuilder(builder, source); + builder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); + } + } + } + + @SneakyThrows + void mapToBuilder(XContentBuilder xContentBuilder, Map source) { + for (Map.Entry entry : source.entrySet()) { + if (entry.getValue() instanceof Map) { + xContentBuilder.startObject(entry.getKey()); + mapToBuilder(xContentBuilder, (Map) entry.getValue()); + xContentBuilder.endObject(); + } else { + xContentBuilder.field(entry.getKey(), entry.getValue()); + } + } + } + + public void bulkIngestRandomVectorsWithSkipsAndNested( + String indexName, + String nestedFieldName, + String nestedNumericPath, + int numVectors, + int dimension, + float skipProb + ) throws IOException { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + indexName, + nestedFieldName, + nestedNumericPath, + numVectors, + dimension, + skipProb, + 1 + ); + } + + public void bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + String indexName, + String nestedFieldName, + String nestedNumericPath, + int numDocs, + int dimension, + float skipProb, + int maxDoc + ) throws IOException { + Random random = new Random(); + random.setSeed(2); + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numDocs * maxDoc, dimension, 1); + for (int i = 0; i < numDocs; i++) { + int nestedDocs = random.nextInt(maxDoc) + 1; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startArray(ParentChildHelper.getParentField(nestedFieldName)); + for (int j = 0; j < nestedDocs; j++) { + builder.startObject(); + if (random.nextFloat() > skipProb) { + builder.field(ParentChildHelper.getChildField(nestedFieldName), vectors[i + j]); + } else { + builder.field(ParentChildHelper.getChildField(nestedNumericPath), 1); + } + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); + } + } + public float[] randomFloatVector(int dimension) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = randomFloat(); + } + return vector; } /**