Skip to content

Commit

Permalink
split search leaves
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Nov 11, 2024
1 parent 5d98552 commit 5cb48f0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 5 deletions.
18 changes: 15 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
if (isFilteredExactSearchPreferred(cardinality, k)) {
return doExactSearch(context, filterBitSet, k);
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
Expand All @@ -153,6 +153,18 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
return docIdsToScoreMap;
}

public boolean isExactSearchPreferred(LeafReaderContext context, int k) throws IOException {
final BitSet filterBitSet = getFilteredDocsBitSet(context);
int cardinality = filterBitSet.cardinality();
if (isFilteredExactSearchPreferred(cardinality, k)) {
return true;
}
if (isMissingNativeEngineFiles(context)) {
return true;
}
return false;
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
Expand Down Expand Up @@ -398,7 +410,7 @@ public static float normalizeScore(float score) {
return -score + 1;
}

private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
private boolean isFilteredExactSearchPreferred(final int filterIdsCount, int k) {
if (filterWeight == null) {
return false;
}
Expand All @@ -409,7 +421,7 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
);
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) {
if (knnQuery.getRadius() == null && filterIdsCount <= k) {
return true;
}
// See user has defined Exact Search filtered threshold. if yes, then use that setting.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;

/**
Expand Down Expand Up @@ -64,13 +66,24 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
List<LeafReaderContext> leafReaderContextForExactSearch = new ArrayList<>();
List<LeafReaderContext> leafReaderContextForApproxSearch = new ArrayList<>();
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
if (knnWeight.isExactSearchPreferred(leafReaderContext, firstPassK)) {
leafReaderContextForExactSearch.add(leafReaderContext);
} else {
leafReaderContextForApproxSearch.add(leafReaderContext);
}
}
perLeafResults = doSearch(indexSearcher, leafReaderContextForApproxSearch, knnWeight, firstPassK);
if (isShardLevelRescoringEnabled == true) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
perLeafResults = doRescore(indexSearcher, leafReaderContextForApproxSearch, knnWeight, perLeafResults, finalK);
perLeafResults.addAll(score(indexSearcher, leafReaderContextForExactSearch, knnWeight, finalK));

long rescoreTime = stopWatch.stop().totalTime().millis();
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size());
}
Expand All @@ -87,6 +100,20 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
}

private Set<LeafReaderContext> filterLeafReaderContextsForExactSearch(
final List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
int k
) throws IOException {
Set<LeafReaderContext> filteredLeafReaderContexts = new HashSet<>();
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
if (knnWeight.isExactSearchPreferred(leafReaderContext, k)) {
filteredLeafReaderContexts.add(leafReaderContext);
}
}
return filteredLeafReaderContexts;
}

private List<Map<Integer, Float>> doSearch(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
Expand Down Expand Up @@ -127,6 +154,30 @@ private List<Map<Integer, Float>> doRescore(
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
}

private List<Map<Integer, Float>> score(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> tasks = new ArrayList<>(leafReaderContexts.size());
for (int i = 0; i < leafReaderContexts.size(); i++) {
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
tasks.add(() -> {
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
// setting to false because in re-scoring we want to do exact search on full precision vectors
.useQuantizedVectorsForSearch(false)
.k(k)
.isParentHits(false)
.knnQuery(knnQuery)
.build();
return knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
});
}
return indexSearcher.getTaskExecutor().invokeAll(tasks);
}

private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
Expand Down

0 comments on commit 5cb48f0

Please sign in to comment.