From f16f2250b082ce8e2c5eda23b38d27a89365fa13 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Mon, 30 Sep 2024 14:56:18 -0700 Subject: [PATCH 1/3] Refactor and Update unit test to include field with no live docs (#2167) Refactored if/else to reduce nesting. Added unit test when one of the field doesn't have live docs. Signed-off-by: Vijayan Balasubramanian --- CHANGELOG.md | 1 + .../NativeEngines990KnnVectorsWriter.java | 32 ++++++++-------- ...eEngines990KnnVectorsWriterFlushTests.java | 38 ++++++++++++------- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 67879bae79..5615509ded 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,3 +31,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Maintenance * Remove benchmarks folder from k-NN repo [#2127](https://github.com/opensearch-project/k-NN/pull/2127) ### Refactoring +* Minor refactoring and refactored some unit test [#2167](https://github.com/opensearch-project/k-NN/pull/2167) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 23cd2a4de0..2f22565c98 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -84,24 +84,24 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { final FieldInfo fieldInfo = field.getFieldInfo(); final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); int totalLiveDocs = field.getVectors().size(); - if (totalLiveDocs > 0) { - final Supplier> knnVectorValuesSupplier = () -> getVectorValues( - vectorDataType, - field.getDocsWithField(), - field.getVectors() - ); - final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - - StopWatch stopWatch = new StopWatch().start(); - writer.flushIndex(knnVectorValues, totalLiveDocs); - long time_in_millis = stopWatch.stop().totalTime().millis(); - KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); - log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); - } else { + if (totalLiveDocs == 0) { log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); + continue; } + final Supplier> knnVectorValuesSupplier = () -> getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + + StopWatch stopWatch = new StopWatch().start(); + writer.flushIndex(knnVectorValues, totalLiveDocs); + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index dbb5649085..9f74b2c104 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -32,8 +32,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static com.carrotsearch.randomizedtesting.RandomizedTest.$; @@ -44,6 +47,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -86,6 +90,7 @@ public static Collection data() { "Multi Field", List.of( Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), + Collections.emptyMap(), Map.of( 0, new float[] { 1, 2, 3, 4 }, @@ -105,18 +110,16 @@ public static Collection data() { @SneakyThrows public void testFlush() { // Given - List> expectedVectorValues = new ArrayList<>(); - IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final List> expectedVectorValues = vectorsPerField.stream().map(vectors -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - new ArrayList<>(vectorsPerField.get(i).values()) + new ArrayList<>(vectors.values()) ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( VectorDataType.FLOAT, randomVectorValues ); - expectedVectorValues.add(knnVectorValues); - - }); + return knnVectorValues; + }).collect(Collectors.toList()); try ( MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); @@ -172,15 +175,19 @@ public void testFlush() { IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + if (vectorsPerField.get(i).isEmpty()) { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } } catch (Exception e) { throw new RuntimeException(e); } }); - + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(expectedVectorValues.size()) + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } } @@ -264,16 +271,21 @@ public void testFlush_WithQuantization() { IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { - verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } } catch (Exception e) { throw new RuntimeException(e); } }); - + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(expectedVectorValues.size() * 2) + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) ); } } From 6d098cf5266160033eccc7ee2e90bd4e4c1c071c Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 30 Sep 2024 17:24:28 -0500 Subject: [PATCH 2/3] Fix Faiss efficient filter exact search using byte vector datatype (#2165) * Fix Faiss efficient filter exact search using byte vector datatype Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda --- .../knn/index/query/ExactSearcher.java | 21 ++- .../iterators/BinaryVectorIdsKNNIterator.java | 92 +++++++++++ .../iterators/ByteVectorIdsKNNIterator.java | 34 ++-- .../NestedBinaryVectorIdsKNNIterator.java | 77 +++++++++ .../NestedByteVectorIdsKNNIterator.java | 12 +- .../BinaryVectorIdsKNNIteratorTests.java | 97 +++++++++++ .../ByteVectorIdsKNNIteratorTests.java | 24 +-- ...NestedBinaryVectorIdsKNNIteratorTests.java | 91 ++++++++++ .../NestedByteVectorIdsKNNIteratorTests.java | 24 +-- .../knn/integ/FilteredSearchByteIT.java | 104 ++++++++++++ .../knn/integ/NestedSearchByteIT.java | 156 ++++++++++++++++++ 11 files changed, 690 insertions(+), 42 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java create mode 100644 src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java create mode 100644 src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java create mode 100644 src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java create mode 100644 src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 193cba8c16..8e5849abb6 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -20,12 +20,15 @@ import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; +import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.KNNIterator; import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; @@ -111,7 +114,7 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); if (isNestedRequired) { - return new NestedByteVectorIdsKNNIterator( + return new NestedBinaryVectorIdsKNNIterator( matchedDocs, knnQuery.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, @@ -119,13 +122,27 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea knnQuery.getParentsFilter().getBitSet(leafReaderContext) ); } - return new ByteVectorIdsKNNIterator( + return new BinaryVectorIdsKNNIterator( matchedDocs, knnQuery.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, spaceType ); } + + if (VectorDataType.BYTE == knnQuery.getVectorDataType()) { + final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + if (isNestedRequired) { + return new NestedByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } + return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType); + } final byte[] quantizedQueryVector; final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java new file mode 100644 index 0000000000..5bab5b573d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.opensearch.common.Nullable; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; + +/** + * Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene + * https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 + * + * The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided + */ +public class BinaryVectorIdsKNNIterator implements KNNIterator { + protected final BitSetIterator bitSetIterator; + protected final byte[] queryVector; + protected final KNNBinaryVectorValues binaryVectorValues; + protected final SpaceType spaceType; + protected float currentScore = Float.NEGATIVE_INFINITY; + protected int docId; + + public BinaryVectorIdsKNNIterator( + @Nullable final BitSet filterIdsBitSet, + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType + ) throws IOException { + this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.queryVector = queryVector; + this.binaryVectorValues = binaryVectorValues; + this.spaceType = spaceType; + // This cannot be moved inside nextDoc() method since it will break when we have nested field, where + // nextDoc should already be referring to next knnVectorValues + this.docId = getNextDocId(); + } + + public BinaryVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType) + throws IOException { + this(null, queryVector, binaryVectorValues, spaceType); + } + + /** + * Advance to the next doc and update score value with score of the next doc. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next doc id + */ + @Override + public int nextDoc() throws IOException { + + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + currentScore = computeScore(); + int currentDocId = docId; + docId = getNextDocId(); + return currentDocId; + } + + @Override + public float score() { + return currentScore; + } + + protected float computeScore() throws IOException { + final byte[] vector = binaryVectorValues.getVector(); + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + } + + protected int getNextDocId() throws IOException { + if (bitSetIterator == null) { + return binaryVectorValues.nextDoc(); + } + int nextDocID = this.bitSetIterator.nextDoc(); + // For filter case, advance vector values to corresponding doc id from filter bit set + if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { + binaryVectorValues.advance(nextDocID); + } + return nextDocID; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java index b1aea42847..0e80051637 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java @@ -10,7 +10,7 @@ import org.apache.lucene.util.BitSetIterator; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; @@ -22,30 +22,30 @@ */ public class ByteVectorIdsKNNIterator implements KNNIterator { protected final BitSetIterator bitSetIterator; - protected final byte[] queryVector; - protected final KNNBinaryVectorValues binaryVectorValues; + protected final float[] queryVector; + protected final KNNByteVectorValues byteVectorValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; protected int docId; public ByteVectorIdsKNNIterator( @Nullable final BitSet filterIdsBitSet, - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues byteVectorValues, final SpaceType spaceType ) throws IOException { this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; - this.binaryVectorValues = binaryVectorValues; + this.byteVectorValues = byteVectorValues; this.spaceType = spaceType; // This cannot be moved inside nextDoc() method since it will break when we have nested field, where // nextDoc should already be referring to next knnVectorValues this.docId = getNextDocId(); } - public ByteVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType) + public ByteVectorIdsKNNIterator(final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType) throws IOException { - this(null, queryVector, binaryVectorValues, spaceType); + this(null, queryVector, byteVectorValues, spaceType); } /** @@ -72,20 +72,30 @@ public float score() { } protected float computeScore() throws IOException { - final byte[] vector = binaryVectorValues.getVector(); + final byte[] vector = byteVectorValues.getVector(); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. - return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + + // The query vector of Faiss byte vector is a Float array because ScalarQuantizer accepts it as float array. + // To compute the score between this query vector and each vector in KNNByteVectorValues we are casting this query vector into byte + // array directly. + // This is safe to do so because float query vector already has validated byte values. Do not reuse this direct cast at any other + // place. + final byte[] byteQueryVector = new byte[queryVector.length]; + for (int i = 0; i < queryVector.length; i++) { + byteQueryVector[i] = (byte) queryVector[i]; + } + return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector); } protected int getNextDocId() throws IOException { if (bitSetIterator == null) { - return binaryVectorValues.nextDoc(); + return byteVectorValues.nextDoc(); } int nextDocID = this.bitSetIterator.nextDoc(); // For filter case, advance vector values to corresponding doc id from filter bit set if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { - binaryVectorValues.advance(nextDocID); + byteVectorValues.advance(nextDocID); } return nextDocID; } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java new file mode 100644 index 0000000000..97bf3517e7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.opensearch.common.Nullable; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; + +/** + * This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs. + * However, it dedupe docs per each parent doc + * of which ID is set in parentBitSet and only return best child doc with the highest score. + */ +public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator { + private final BitSet parentBitSet; + + public NestedBinaryVectorIdsKNNIterator( + @Nullable final BitSet filterIdsArray, + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet + ) throws IOException { + super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + this.parentBitSet = parentBitSet; + } + + public NestedBinaryVectorIdsKNNIterator( + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet + ) throws IOException { + super(null, queryVector, binaryVectorValues, spaceType); + this.parentBitSet = parentBitSet; + } + + /** + * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next best child doc id + */ + @Override + public int nextDoc() throws IOException { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + currentScore = Float.NEGATIVE_INFINITY; + int currentParent = parentBitSet.nextSetBit(docId); + int bestChild = -1; + + // In order to traverse all children for given parent, we have to use docId < parentId, because, + // kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child + // and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4] + // and parentBitSet will have [1,5] + // Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId + while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + float score = computeScore(); + if (score > currentScore) { + bestChild = docId; + currentScore = score; + } + docId = getNextDocId(); + } + + return bestChild; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java index 3c93ec888a..9644b620ff 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java @@ -9,7 +9,7 @@ import org.apache.lucene.util.BitSet; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; @@ -23,18 +23,18 @@ public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator { public NestedByteVectorIdsKNNIterator( @Nullable final BitSet filterIdsArray, - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues byteVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { - super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + super(filterIdsArray, queryVector, byteVectorValues, spaceType); this.parentBitSet = parentBitSet; } public NestedByteVectorIdsKNNIterator( - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues binaryVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java new file mode 100644 index 0000000000..6d5dffa98f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.mockito.stubbing.OngoingStubbing; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class BinaryVectorIdsKNNIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 1, 2, 3 }; + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + BinaryVectorIdsKNNIterator iterator = new BinaryVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); + for (int i = 0; i < filterIds.length; i++) { + assertEquals(filterIds[i], iterator.nextDoc()); + assertEquals(expectedScores.get(i), (Float) iterator.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } + + @SneakyThrows + public void testNextDoc_whenCalled_thenIterateAllDocsWithoutFilter() throws IOException { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final List dataVectors = Arrays.asList( + new byte[] { 11, 12, 13 }, + new byte[] { 14, 15, 16 }, + new byte[] { 17, 18, 19 }, + new byte[] { 20, 21, 22 }, + new byte[] { 23, 24, 25 } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn( + dataVectors.get(0), + dataVectors.get(1), + dataVectors.get(2), + dataVectors.get(3), + dataVectors.get(4) + ); + + // stub return value when nextDoc is called + OngoingStubbing stubbing = when(values.nextDoc()); + for (int i = 0; i < dataVectors.size(); i++) { + stubbing = stubbing.thenReturn(i); + } + // set last return to be Integer.MAX_VALUE to represent no more docs + stubbing.thenReturn(Integer.MAX_VALUE); + + // Execute and verify + BinaryVectorIdsKNNIterator iterator = new BinaryVectorIdsKNNIterator(queryVector, values, spaceType); + for (int i = 0; i < dataVectors.size(); i++) { + assertEquals(i, iterator.nextDoc()); + assertEquals(expectedScores.get(i), iterator.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + verify(values, never()).advance(anyInt()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java index 0b1b71286f..60169b95fb 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.util.FixedBitSet; import org.mockito.stubbing.OngoingStubbing; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; import java.util.Arrays; @@ -26,16 +26,17 @@ public class ByteVectorIdsKNNIteratorTests extends TestCase { @SneakyThrows - public void testNextDoc_whenCalled_thenIterateAllDocs() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1f, 2f, 3f }; final int[] filterIds = { 1, 2, 3 }; final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); @@ -48,15 +49,16 @@ public void testNextDoc_whenCalled_thenIterateAllDocs() { ByteVectorIdsKNNIterator iterator = new ByteVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); - assertEquals(expectedScores.get(i), iterator.score()); + assertEquals(expectedScores.get(i), (Float) iterator.score()); } assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); } @SneakyThrows public void testNextDoc_whenCalled_thenIterateAllDocsWithoutFilter() throws IOException { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; final List dataVectors = Arrays.asList( new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, @@ -65,10 +67,10 @@ public void testNextDoc_whenCalled_thenIterateAllDocsWithoutFilter() throws IOEx new byte[] { 23, 24, 25 } ); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn( dataVectors.get(0), dataVectors.get(1), diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java new file mode 100644 index 0000000000..a39a3b2e92 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class NestedBinaryVectorIdsKNNIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedBinaryVectorIdsKNNIterator iterator = new NestedBinaryVectorIdsKNNIterator( + filterBitSet, + queryVector, + values, + spaceType, + parentBitSet + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } + + @SneakyThrows + public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + when(values.nextDoc()).thenReturn(0, 2, 3, Integer.MAX_VALUE); + + // Execute and verify + NestedBinaryVectorIdsKNNIterator iterator = new NestedBinaryVectorIdsKNNIterator(queryVector, values, spaceType, parentBitSet); + assertEquals(0, iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(3, iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + verify(values, never()).advance(anyInt()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java index eff021234c..08c8597799 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.util.Arrays; import java.util.List; @@ -26,19 +26,20 @@ public class NestedByteVectorIdsKNNIteratorTests extends TestCase { @SneakyThrows public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; final int[] filterIds = { 0, 2, 3 }; // Parent id for 0 -> 1 // Parent id for 2, 3 -> 4 // In bit representation, it is 10010. In long, it is 18. final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); - final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 17, 18, 19 }, new byte[] { 14, 15, 16 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); @@ -64,18 +65,19 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { @SneakyThrows public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerParent() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; // Parent id for 0 -> 1 // Parent id for 2, 3 -> 4 // In bit representation, it is 10010. In long, it is 18. final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); - final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 17, 18, 19 }, new byte[] { 14, 15, 16 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); when(values.nextDoc()).thenReturn(0, 2, 3, Integer.MAX_VALUE); diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java new file mode 100644 index 0000000000..fe4dc7db9c --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonIndexMappingsBuilder; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class FilteredSearchByteIT extends KNNRestTestCase { + @After + public void cleanUp() { + try { + deleteKNNIndex(INDEX_NAME); + } catch (Exception e) { + log.error(e); + } + } + + @SneakyThrows + public void testFilteredSearchWithFaissHnswByte_whenDoingApproximateSearch_thenReturnCorrectResults() { + validateFilteredSearchWithFaissHnswByte(INDEX_NAME, false); + } + + @SneakyThrows + public void testFilteredSearchWithFaissHnswByte_whenDoingExactSearch_thenReturnCorrectResults() { + validateFilteredSearchWithFaissHnswByte(INDEX_NAME, true); + } + + private void validateFilteredSearchWithFaissHnswByte(final String indexName, final boolean doExactSearch) throws Exception { + String filterFieldName = "parking"; + createKnnByteIndex(indexName, FIELD_NAME, 3, KNNEngine.FAISS); + + for (byte i = 1; i < 4; i++) { + addKnnDocWithAttributes( + indexName, + Integer.toString(i), + FIELD_NAME, + new float[] { i, i, i }, + ImmutableMap.of(filterFieldName, i % 2 == 1 ? "true" : "false") + ); + } + refreshIndex(indexName); + forceMergeKnnIndex(indexName); + + // Set it as 0 for approximate search and 100(larger than number of filtered id) for exact search + updateIndexSettings( + indexName, + Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, doExactSearch ? 100 : 0) + ); + + Float[] queryVector = { 3f, 3f, 3f }; + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(3) + .filterFieldName(filterFieldName) + .filterValue("true") + .build() + .getQueryString(); + Response response = searchKNNIndex(indexName, query, 3); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + + private void createKnnByteIndex(final String indexName, final String fieldName, final int dimension, final KNNEngine knnEngine) + throws Exception { + KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() + .methodName(METHOD_HNSW) + .engine(knnEngine.getName()) + .build(); + + String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder() + .fieldName(fieldName) + .dimension(dimension) + .vectorDataType(VectorDataType.BYTE.getValue()) + .method(method) + .build() + .getIndexMapping(); + + createKnnIndex(indexName, knnIndexMapping); + } +} diff --git a/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java b/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java new file mode 100644 index 0000000000..7985d08a7d --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonIndexMappingsBuilder; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.NestedKnnDocBuilder; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class NestedSearchByteIT extends KNNRestTestCase { + @After + public void cleanUp() { + try { + deleteKNNIndex(INDEX_NAME); + } catch (Exception e) { + log.error(e); + } + } + + @SneakyThrows + public void testNestedSearchWithFaissHnswByte_whenKIsTwo_thenReturnTwoResults() { + String nestedFieldName = "nested"; + createKnnByteIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 2, KNNEngine.FAISS); + + int totalDocCount = 15; + for (byte i = 0; i < totalDocCount; i++) { + String doc = NestedKnnDocBuilder.create(nestedFieldName) + .addVectors(FIELD_NAME, new Byte[] { i, i }, new Byte[] { i, i }) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + Byte[] queryVector = { 14, 14 }; + String query = KNNJsonQueryBuilder.builder() + .nestedFieldName(nestedFieldName) + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(2) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, 2); + String entity = EntityUtils.toString(response.getEntity()); + + assertEquals(2, parseHits(entity)); + assertEquals(2, parseTotalSearchHits(entity)); + assertEquals("14", parseIds(entity).get(0)); + assertNotEquals("14", parseIds(entity).get(1)); + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 1, 1, 1 + * ], + * "k": 3, + * "filter": { + * "term": { + * "parking": "true" + * } + * } + * } + * } + * } + * } + * } + * } + * + */ + @SneakyThrows + public void testNestedSearchWithFaissHnswByte_whenDoingExactSearch_thenReturnCorrectResults() { + String nestedFieldName = "nested"; + String filterFieldName = "parking"; + createKnnByteIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 3, KNNEngine.FAISS); + + for (byte i = 1; i < 4; i++) { + String doc = NestedKnnDocBuilder.create(nestedFieldName) + .addVectors(FIELD_NAME, new Byte[] { i, i, i }, new Byte[] { i, i, i }, new Byte[] { i, i, i }) + .addTopLevelField(filterFieldName, i % 2 == 1 ? "true" : "false") + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Make it as an exact search by setting the threshold larger than size of filteredIds(6) + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 100)); + + Byte[] queryVector = { 3, 3, 3 }; + String query = KNNJsonQueryBuilder.builder() + .nestedFieldName(nestedFieldName) + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(3) + .filterFieldName(filterFieldName) + .filterValue("true") + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, 3); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + + private void createKnnByteIndexWithNestedField( + final String indexName, + final String nestedFieldName, + final String fieldName, + final int dimension, + final KNNEngine knnEngine + ) throws Exception { + KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() + .methodName(METHOD_HNSW) + .engine(knnEngine.getName()) + .build(); + + String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder() + .nestedFieldName(nestedFieldName) + .fieldName(fieldName) + .dimension(dimension) + .vectorDataType(VectorDataType.BYTE.getValue()) + .method(method) + .build() + .getIndexMapping(); + + createKnnIndex(indexName, knnIndexMapping); + } +} From 07f4df2015a0af5158859c42c98511b865005eac Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Mon, 30 Sep 2024 17:17:53 -0700 Subject: [PATCH 3/3] Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor (#2172) Signed-off-by: VIKASH TIWARI --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/KNNSettings.java | 36 ++++++++- .../knn/index/mapper/CompressionLevel.java | 31 ++++---- .../nativelib/NativeEngineKnnVectorQuery.java | 6 +- .../index/query/rescore/RescoreContext.java | 9 +++ .../knn/index/KNNSettingsTests.java | 35 +++++++++ .../index/mapper/CompressionLevelTests.java | 73 ++++++++++++------- .../NativeEngineKNNVectorQueryTests.java | 71 +++++++++++++++++- 8 files changed, 214 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5615509ded..6871074f09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Optimize reduceToTopK in ResultUtil by removing pre-filling and reducing peek calls [#2146](https://github.com/opensearch-project/k-NN/pull/2146) * Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149) * KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155) +* Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172) ### Bug Fixes * KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147) ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 5fcc51bb53..1753140e63 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -88,6 +88,7 @@ public class KNNSettings { public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit"; public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes"; public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; + public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; /** * Default setting values @@ -112,11 +113,31 @@ public class KNNSettings { public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed // 10% of the JVM heap public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; + public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true; /** * Settings Definition */ + /** + * This setting controls whether shard-level re-scoring for KNN disk-based vectors is turned off. + * The setting uses: + *
    + *
  • KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED: The name of the setting.
  • + *
  • KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE: The default value (true or false).
  • + *
  • IndexScope: The setting works at the index level.
  • + *
  • Dynamic: This setting can be changed without restarting the cluster.
  • + *
+ * + * @see Setting#boolSetting(String, boolean, Setting.Property...) + */ + public static final Setting KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING = Setting.boolSetting( + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE, + IndexScope, + Dynamic + ); + // This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default // 1% of the JVM heap public static final Setting KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting( @@ -454,6 +475,10 @@ private Setting getSetting(String key) { return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING; } + if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) { + return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -475,7 +500,8 @@ public List> getSettings() { KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, KNN_FAISS_AVX512_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, - QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING + QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -528,6 +554,14 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } + public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { + return KNNSettings.state().clusterService.state() + .getMetadata() + .index(indexName) + .getSettings() + .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true); + } + public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index 3e1b47db7b..c9a169efca 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -97,32 +97,35 @@ public static boolean isConfigured(CompressionLevel compressionLevel) { /** * Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}. * - *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of - * {@code dimension}: + *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the + * {@code dimension} value: *

    - *
  • If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an - * oversample factor of 5.0f.
  • - *
  • If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with - * the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.
  • + *
  • If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).
  • + *
  • If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).
  • + *
  • If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).
  • *
- * If the {@code mode} is not valid, the method returns {@code null}. + * If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}. * * @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}. * @param dimension The dimensional value that determines the {@link RescoreContext} behavior. - * @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than - * or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode - * is invalid. + * @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode + * is not valid. */ public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) { if (modesForRescore.contains(mode)) { // Adjust RescoreContext based on dimension - if (dimension <= RescoreContext.DIMENSION_THRESHOLD) { - // For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD).build(); + if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) { + // No oversampling for dimensions >= 1000 + return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build(); + } else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) { + // 2x oversampling for dimensions >= 768 but < 1000 + return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build(); } else { - return defaultRescoreContext; + // 3x oversampling for dimensions < 768 + return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build(); } } return null; } + } diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 945da850ac..adb2875d5e 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -20,6 +20,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.query.ExactSearcher; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; @@ -54,7 +55,6 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); List leafReaderContexts = reader.leaves(); - List> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); int finalK = knnQuery.getK(); @@ -63,7 +63,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo } else { int firstPassK = rescoreContext.getFirstPassK(finalK); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - ResultUtil.reduceToTopK(perLeafResults, firstPassK); + if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) { + ResultUtil.reduceToTopK(perLeafResults, firstPassK); + } StopWatch stopWatch = new StopWatch().start(); perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 51d4e491c9..a2563b2a61 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -24,6 +24,15 @@ public final class RescoreContext { public static final int DIMENSION_THRESHOLD = 1000; public static final float OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD = 5.0f; + // Dimension thresholds for adjusting oversample factor + public static final int DIMENSION_THRESHOLD_1000 = 1000; + public static final int DIMENSION_THRESHOLD_768 = 768; + + // Oversample factors based on dimension thresholds + public static final float OVERSAMPLE_FACTOR_1000 = 1.0f; // No oversampling for dimensions >= 1000 + public static final float OVERSAMPLE_FACTOR_768 = 2.0f; // 2x oversampling for dimensions >= 768 and < 1000 + public static final float OVERSAMPLE_FACTOR_BELOW_768 = 3.0f; // 3x oversampling for dimensions < 768 + // Todo:- We will improve this in upcoming releases public static final int MIN_FIRST_PASS_RESULTS = 100; diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index 75eb14713d..fd25699ccd 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -158,6 +158,41 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { assertEquals(userProvidedEfSearch, efSearchValue); } + @SneakyThrows + public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + Node mockNode = createMockNode(Collections.emptyMap()); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet(); + mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); + KNNSettings.state().setClusterService(clusterService); + + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + mockNode.close(); + assertTrue(shardLevelRescoringDisabled); + } + + @SneakyThrows + public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingApplied() { + boolean userDefinedRescoringDisabled = false; + Node mockNode = createMockNode(Collections.emptyMap()); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet(); + mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); + KNNSettings.state().setClusterService(clusterService); + + final Settings rescoringDisabledSetting = Settings.builder() + .put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, userDefinedRescoringDisabled) + .build(); + + mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); + + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + mockNode.close(); + assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); + } + @SneakyThrows public void testGetFaissAVX2DisabledSettingValueFromConfig_enableSetting_thenValidateAndSucceed() { boolean expectedKNNFaissAVX2Disabled = true; diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java index cc70d4c2de..57372b11ee 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java @@ -44,65 +44,84 @@ public void testIsConfigured() { public void testGetDefaultRescoreContext() { // Test rescore context for ON_DISK mode Mode mode = Mode.ON_DISK; - int belowThresholdDimension = 500; // A dimension below the threshold - int aboveThresholdDimension = 1500; // A dimension above the threshold - // x32 with dimension <= 1000 should have an oversample factor of 5.0f + // Test various dimensions based on the updated oversampling logic + int belowThresholdDimension = 500; // A dimension below 768 + int between768and1000Dimension = 800; // A dimension between 768 and 1000 + int above1000Dimension = 1500; // A dimension above 1000 + + // Compression level x32 with dimension < 768 should have an oversample factor of 3.0f RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x32 with dimension > 1000 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension); assertNotNull(rescoreContext); - assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x16 with dimension <= 1000 should have an oversample factor of 5.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); + // Compression level x32 with dimension > 1000 should have no oversampling (1.0f) + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension); assertNotNull(rescoreContext); - assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x16 with dimension > 1000 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x16 with dimension < 768 should have an oversample factor of 3.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x8 with dimension <= 1000 should have an oversample factor of 5.0f + // Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension); + assertNotNull(rescoreContext); + assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + + // Compression level x16 with dimension > 1000 should have no oversampling (1.0f) + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension); + assertNotNull(rescoreContext); + assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); + + // Compression level x8 with dimension < 768 should have an oversample factor of 3.0f rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x8 with dimension > 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension); assertNotNull(rescoreContext); assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext) + // Compression level x8 with dimension > 1000 should have no oversampling (1.0f) + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension); + assertNotNull(rescoreContext); + assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); + + // Compression level x4 with dimension < 768 should return null (no RescoreContext) rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4) - rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension); - assertNull(rescoreContext); - // Other compression levels should behave similarly with respect to dimension + // Compression level x4 with dimension > 1000 should return null (no RescoreContext) + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension); + assertNull(rescoreContext); + // Compression level x2 with dimension < 768 should return null rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - // x2 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x2 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension); assertNull(rescoreContext); + // Compression level x1 with dimension < 768 should return null rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - // x1 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x1 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension); assertNull(rescoreContext); - // NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f + // NOT_CONFIGURED mode should return null for any dimension rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - } + } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 06350f39c7..7fd96c6df4 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -17,11 +17,16 @@ import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.Bits; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.invocation.InvocationOnMock; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.ResultUtil; @@ -35,12 +40,11 @@ import java.util.Map; import java.util.concurrent.Callable; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.times; import static org.mockito.MockitoAnnotations.openMocks; public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @@ -66,6 +70,9 @@ public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @Mock private LeafReader leafReader2; + @Mock + private ClusterService clusterService; + @InjectMocks private NativeEngineKnnVectorQuery objectUnderTest; @@ -91,6 +98,11 @@ public void setUp() throws Exception { }); when(reader.getContext()).thenReturn(indexReaderContext); + + when(clusterService.state()).thenReturn(mock(ClusterState.class)); // Mock ClusterState + + // Set ClusterService in KNNSettings + KNNSettings.state().setClusterService(clusterService); } @SneakyThrows @@ -127,6 +139,49 @@ public void testMultiLeaf() { assertEquals(expected, actual.getQuery()); } + @SneakyThrows + public void testRescoreWhenShardLevelRescoringEnabled() { + // Given + List leaves = List.of(leaf1, leaf2); + when(reader.leaves()).thenReturn(leaves); + + int k = 2; + int firstPassK = 3; + Map initialLeaf1Results = new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f)); + Map initialLeaf2Results = new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f)); + Map rescoredLeaf1Results = new HashMap<>(Map.of(0, 18f, 1, 20f)); + Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); + + when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); + when(knnQuery.getK()).thenReturn(k); + when(knnWeight.getQuery()).thenReturn(knnQuery); + when(knnWeight.searchLeaf(leaf1, firstPassK)).thenReturn(initialLeaf1Results); + when(knnWeight.searchLeaf(leaf2, firstPassK)).thenReturn(initialLeaf2Results); + when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); + when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); + + try ( + MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); + MockedStatic mockedResultUtil = mockStatic(ResultUtil.class) + ) { + + // When shard-level re-scoring is enabled + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); + + // Mock ResultUtil to return valid TopDocs + mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) + .thenReturn(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0])); + mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenCallRealMethod(); + + // When + Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + + // Then + mockedResultUtil.verify(() -> ResultUtil.reduceToTopK(any(), anyInt()), times(2)); + assertNotNull(actual); + } + } + @SneakyThrows public void testSingleLeaf() { // Given @@ -188,7 +243,15 @@ public void testRescore() { when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); - try (MockedStatic mockedResultUtil = mockStatic(ResultUtil.class)) { + + try ( + MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); + MockedStatic mockedResultUtil = mockStatic(ResultUtil.class) + ) { + + // When shard-level re-scoring is enabled + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(true); + mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf2Results), anyInt())).thenAnswer(t -> topDocs2);