Skip to content

Commit 1982afa

Browse files
committed
Fixing Scoring Issue with Binary Quantization
1 parent b0d82b7 commit 1982afa

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

src/main/java/org/opensearch/knn/index/query/KNNWeight.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,10 @@ private Map<Integer, Float> doANNSearch(
373373
log.debug("[KNN] Query yielded 0 results");
374374
return null;
375375
}
376-
376+
if (quantizedVector != null) {
377+
return Arrays.stream(results)
378+
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), SpaceType.HAMMING)));
379+
}
377380
return Arrays.stream(results)
378381
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
379382
}

src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,111 @@ public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() {
515515
validateANNWithFilterQuery_whenDoingANN_thenSuccess(true);
516516
}
517517

518+
@SneakyThrows
519+
public void testScorerWithQuantizedVector() {
520+
// Given
521+
int k = 3;
522+
byte[] quantizedVector = new byte[] { 1, 2, 3 }; // Mocked quantized vector
523+
float[] queryVector = new float[] { 0.1f, 0.3f };
524+
525+
// Mock the JNI service to return KNNQueryResults
526+
KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {
527+
new KNNQueryResult(1, 10.0f), // Mock result with id 1 and score 10
528+
new KNNQueryResult(2, 20.0f) // Mock result with id 2 and score 20
529+
};
530+
jniServiceMockedStatic.when(
531+
() -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any())
532+
).thenReturn(knnQueryResults);
533+
534+
KNNEngine knnEngine = mock(KNNEngine.class);
535+
when(knnEngine.score(anyFloat(), eq(SpaceType.HAMMING))).thenAnswer(invocation -> {
536+
Float score = invocation.getArgument(0);
537+
return 1 / (1 + score);
538+
});
539+
540+
// Build the KNNQuery object
541+
final KNNQuery query = KNNQuery.builder()
542+
.field(FIELD_NAME)
543+
.queryVector(queryVector)
544+
.k(k)
545+
.indexName(INDEX_NAME)
546+
.vectorDataType(VectorDataType.BINARY) // Simulate binary vector type for quantization
547+
.build();
548+
549+
final float boost = 1.0F;
550+
final KNNWeight knnWeight = new KNNWeight(query, boost);
551+
552+
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
553+
final SegmentReader reader = mock(SegmentReader.class);
554+
when(leafReaderContext.reader()).thenReturn(reader);
555+
556+
final FieldInfos fieldInfos = mock(FieldInfos.class);
557+
final FieldInfo fieldInfo = mock(FieldInfo.class);
558+
when(reader.getFieldInfos()).thenReturn(fieldInfos);
559+
when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo);
560+
561+
when(fieldInfo.attributes()).thenReturn(Map.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.HAMMING.getValue()));
562+
563+
FSDirectory directory = mock(FSDirectory.class);
564+
when(reader.directory()).thenReturn(directory);
565+
Path path = mock(Path.class);
566+
when(directory.getDirectory()).thenReturn(path);
567+
when(path.toString()).thenReturn("/fake/directory");
568+
569+
SegmentInfo segmentInfo = new SegmentInfo(
570+
directory, // The directory where the segment is stored
571+
Version.LATEST, // Lucene version
572+
Version.LATEST, // Version of the segment info
573+
"0", // Segment name
574+
100, // Max document count for this segment
575+
false, // Is this a compound file segment
576+
false, // Is this a merged segment
577+
KNNCodecVersion.current().getDefaultCodecDelegate(), // Codec delegate for KNN
578+
Map.of(), // Diagnostics map
579+
new byte[StringHelper.ID_LENGTH], // Segment ID
580+
Map.of(), // Attributes
581+
Sort.RELEVANCE // Default sort order
582+
);
583+
584+
final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);
585+
586+
when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);
587+
588+
try (MockedStatic<KNNCodecUtil> knnCodecUtilMockedStatic = mockStatic(KNNCodecUtil.class)) {
589+
List<String> engineFiles = List.of("_0_1_target_field.faiss");
590+
knnCodecUtilMockedStatic.when(() -> KNNCodecUtil.getEngineFiles(anyString(), anyString(), eq(segmentInfo)))
591+
.thenReturn(engineFiles);
592+
593+
try (MockedStatic<SegmentLevelQuantizationUtil> quantizationUtilMockedStatic = mockStatic(SegmentLevelQuantizationUtil.class)) {
594+
quantizationUtilMockedStatic.when(() -> SegmentLevelQuantizationUtil.quantizeVector(any(), any()))
595+
.thenReturn(quantizedVector);
596+
597+
// When: Call the scorer method
598+
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
599+
600+
// Then: Ensure scorer is not null
601+
assertNotNull(knnScorer);
602+
603+
// Verify that JNIService.queryBinaryIndex is called with the quantized vector
604+
jniServiceMockedStatic.verify(
605+
() -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()),
606+
times(1)
607+
);
608+
609+
// Iterate over the results and ensure they are scored with SpaceType.HAMMING
610+
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
611+
assertNotNull(docIdSetIterator);
612+
while (docIdSetIterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
613+
int docId = docIdSetIterator.docID();
614+
float expectedScore = knnEngine.score(knnQueryResults[docId - 1].getScore(), SpaceType.HAMMING);
615+
float actualScore = knnScorer.score();
616+
// Check if the score is calculated using HAMMING
617+
assertEquals(expectedScore, actualScore, 0.01f); // Tolerance for floating-point comparison
618+
}
619+
}
620+
}
621+
}
622+
518623
public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean isBinary) throws IOException {
519624
// Given
520625
int k = 3;

0 commit comments

Comments
 (0)