diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 04c2ce587..27486b69c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -139,7 +139,7 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. * This improves the recall. */ - if (isFilteredExactSearchPreferred(cardinality)) { + if (isFilteredExactSearchPreferred(cardinality, k)) { return doExactSearch(context, filterBitSet, k); } Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); @@ -153,6 +153,18 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I return docIdsToScoreMap; } + public boolean isExactSearchPreferred(LeafReaderContext context, int k) throws IOException { + final BitSet filterBitSet = getFilteredDocsBitSet(context); + int cardinality = filterBitSet.cardinality(); + if (isFilteredExactSearchPreferred(cardinality, k)) { + return true; + } + if (isMissingNativeEngineFiles(context)) { + return true; + } + return false; + } + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { if (this.filterWeight == null) { return new FixedBitSet(0); @@ -398,7 +410,7 @@ public static float normalizeScore(float score) { return -score + 1; } - private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { + private boolean isFilteredExactSearchPreferred(final int filterIdsCount, int k) { if (filterWeight == null) { return false; } @@ -409,7 +421,7 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { ); int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { + if (knnQuery.getRadius() == null && filterIdsCount <= k) { return true; } // See user has defined Exact Search filtered threshold. if yes, then use that setting. diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 8b861b430..f067b0b42 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -31,9 +31,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.Callable; /** @@ -64,13 +66,24 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); - perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); + List leafReaderContextForExactSearch = new ArrayList<>(); + List leafReaderContextForApproxSearch = new ArrayList<>(); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + if (knnWeight.isExactSearchPreferred(leafReaderContext, firstPassK)) { + leafReaderContextForExactSearch.add(leafReaderContext); + } else { + leafReaderContextForApproxSearch.add(leafReaderContext); + } + } + perLeafResults = doSearch(indexSearcher, leafReaderContextForApproxSearch, knnWeight, firstPassK); if (isShardLevelRescoringEnabled == true) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } StopWatch stopWatch = new StopWatch().start(); - perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); + perLeafResults = doRescore(indexSearcher, leafReaderContextForApproxSearch, knnWeight, perLeafResults, finalK); + perLeafResults.addAll(score(indexSearcher, leafReaderContextForExactSearch, knnWeight, finalK)); + long rescoreTime = stopWatch.stop().totalTime().millis(); log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); } @@ -87,6 +100,20 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); } + private Set filterLeafReaderContextsForExactSearch( + final List leafReaderContexts, + KNNWeight knnWeight, + int k + ) throws IOException { + Set filteredLeafReaderContexts = new HashSet<>(); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + if (knnWeight.isExactSearchPreferred(leafReaderContext, k)) { + filteredLeafReaderContexts.add(leafReaderContext); + } + } + return filteredLeafReaderContexts; + } + private List> doSearch( final IndexSearcher indexSearcher, List leafReaderContexts, @@ -127,6 +154,30 @@ private List> doRescore( return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); } + private List> score( + final IndexSearcher indexSearcher, + List leafReaderContexts, + KNNWeight knnWeight, + int k + ) throws IOException { + List>> tasks = new ArrayList<>(leafReaderContexts.size()); + for (int i = 0; i < leafReaderContexts.size(); i++) { + LeafReaderContext leafReaderContext = leafReaderContexts.get(i); + int finalI = i; + tasks.add(() -> { + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + // setting to false because in re-scoring we want to do exact search on full precision vectors + .useQuantizedVectorsForSearch(false) + .k(k) + .isParentHits(false) + .knnQuery(knnQuery) + .build(); + return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + }); + } + return indexSearcher.getTaskExecutor().invokeAll(tasks); + } + private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));