Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
VijayanB committed Feb 14, 2025
1 parent 36e3128 commit f2239d8
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.iterators.KNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

@Builder
public class KnnExactSearchCollector implements KnnCollector {

int k;
@Getter String field;
@Getter Boolean useQuantization;
@Getter ExactSearcher.ExactSearcherContext context;
@Getter
LeafReaderContext leafReaderContext;

@Getter @Setter
KNNIterator knnIterator;


@Override
public boolean earlyTerminated() {
return false;
}

@Override
public void incVisitedCount(int i) {

}

@Override
public long visitedCount() {
return 0;
}

@Override
public long visitLimit() {
return 0;
}

@Override
public int k() {
return k;
}

@Override
public boolean collect(int i, float v) {
return false;
}

@Override
public float minCompetitiveSimilarity() {
return 0;
}

@Override
public TopDocs topDocs() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,29 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.opensearch.common.UUIDs;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
import org.opensearch.knn.index.query.iterators.*;
import org.opensearch.knn.index.vectorvalues.*;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;
Expand All @@ -42,6 +48,7 @@
* Vectors reader class for reading the flat vectors for native engines. The class provides methods for iterating
* over the vectors and retrieving their values.
*/
@Log4j2
public class NativeEngines990KnnVectorsReader extends KnnVectorsReader {

private final FlatVectorsReader flatVectorsReader;
Expand Down Expand Up @@ -135,6 +142,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState);
return;
}

if(knnCollector instanceof KnnExactSearchCollector){
((KnnExactSearchCollector) knnCollector).setKnnIterator(getKNNIterator(((KnnExactSearchCollector) knnCollector).getContext(), ((KnnExactSearchCollector) knnCollector).getLeafReaderContext()));
return;
}
throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader");
}

Expand Down Expand Up @@ -221,4 +233,85 @@ private static List<String> getVectorCacheKeysFromSegmentReaderState(SegmentRead

return cacheKeys;
}

private KNNIterator getKNNIterator(ExactSearcher.ExactSearcherContext exactSearcherContext, LeafReaderContext leafReaderContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator();
final FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(exactSearcherContext.getKnnQuery().getField());
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
if (fieldInfo == null) {
log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(),segmentReadState.segmentInfo.name);
return null;
}
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;

if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, flatVectorsReader);
if (isNestedRequired) {
return new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}

if (VectorDataType.BYTE == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
// Build Segment Level Quantization info.
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// Quantize the Query Vector Once.
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo);
} else {
segmentLevelQuantizationInfo = null;
quantizedQueryVector = null;
}

final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}
return new VectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// should skip graph building only for non quantization use case and if threshold is met
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush",
fieldInfo.name,
Expand Down Expand Up @@ -144,7 +144,7 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
// should skip graph building only for non quantization use case and if threshold is met
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge",
fieldInfo.name,
Expand Down
13 changes: 8 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN990Codec.KnnExactSearchCollector;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
Expand All @@ -37,10 +40,7 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.*;
import java.util.function.Predicate;

@Log4j2
Expand All @@ -59,7 +59,9 @@ public class ExactSearcher {
*/
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
final KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
KnnExactSearchCollector collector = KnnExactSearchCollector.builder().context(exactSearcherContext).leafReaderContext(leafReaderContext).build();
leafReaderContext.reader().searchNearestVectors(exactSearcherContext.getKnnQuery().getField(), exactSearcherContext.getKnnQuery().getQueryVector(),collector, null);
final KNNIterator iterator = collector.getKnnIterator();
// because of any reason if we are not able to get KNNIterator, return an empty map
if (iterator == null) {
return Collections.emptyMap();
Expand Down Expand Up @@ -245,6 +247,7 @@ public static class ExactSearcherContext {
boolean useQuantizedVectorsForSearch;
int k;
DocIdSetIterator matchedDocsIterator;
Bits acceptedDocs;
long numberOfMatchedDocs;
KNNQuery knnQuery;
/**
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
Map<Integer, Float> result = doExactSearch(context, filterBitSet, cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}

Expand All @@ -184,7 +184,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) {
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, cardinality) : null;
final BitSetIterator docs = filterWeight != null ? filterBitSet : null;
Map<Integer, Float> result = doExactSearch(context, docs, cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
Expand Down Expand Up @@ -253,7 +253,7 @@ private int[] bitSetToIntArray(final BitSet bitSet) {

private Map<Integer, Float> doExactSearch(
final LeafReaderContext context,
final DocIdSetIterator acceptedDocs,
final Bits acceptedDocs,
final long numberOfAcceptedDocs,
final int k
) throws IOException {
Expand All @@ -264,7 +264,7 @@ private Map<Integer, Float> doExactSearch(
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(knnQuery)
.matchedDocsIterator(acceptedDocs)
.acceptedDocs(acceptedDocs)
.numberOfMatchedDocs(numberOfAcceptedDocs);
return exactSearch(context, exactSearcherContextBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ public static <T> KNNVectorValues<T> getVectorValues(final FieldInfo fieldInfo,
}
}

public static <T> KNNVectorValues<T> getVectorValues(final FieldInfo fieldInfo, final KnnVectorsReader leafReader) throws IOException {
if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) {
return getVectorValues(
FieldInfoExtractor.extractVectorDataType(fieldInfo),
new KNNVectorValuesIterator.DocIdsIteratorValues(leafReader.getByteVectorValues(fieldInfo.getName()))
);
} else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
return getVectorValues(
FieldInfoExtractor.extractVectorDataType(fieldInfo),
new KNNVectorValuesIterator.DocIdsIteratorValues(leafReader.getFloatVectorValues(fieldInfo.getName()))
);
} else {
throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues");
}
}

/**
* Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader}
*
Expand Down

0 comments on commit f2239d8

Please sign in to comment.