@@ -515,6 +515,111 @@ public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() {
515
515
validateANNWithFilterQuery_whenDoingANN_thenSuccess (true );
516
516
}
517
517
518
+ @ SneakyThrows
519
+ public void testScorerWithQuantizedVector () {
520
+ // Given
521
+ int k = 3 ;
522
+ byte [] quantizedVector = new byte [] { 1 , 2 , 3 }; // Mocked quantized vector
523
+ float [] queryVector = new float [] { 0.1f , 0.3f };
524
+
525
+ // Mock the JNI service to return KNNQueryResults
526
+ KNNQueryResult [] knnQueryResults = new KNNQueryResult [] {
527
+ new KNNQueryResult (1 , 10.0f ), // Mock result with id 1 and score 10
528
+ new KNNQueryResult (2 , 20.0f ) // Mock result with id 2 and score 20
529
+ };
530
+ jniServiceMockedStatic .when (
531
+ () -> JNIService .queryBinaryIndex (anyLong (), eq (quantizedVector ), eq (k ), any (), any (), any (), anyInt (), any ())
532
+ ).thenReturn (knnQueryResults );
533
+
534
+ KNNEngine knnEngine = mock (KNNEngine .class );
535
+ when (knnEngine .score (anyFloat (), eq (SpaceType .HAMMING ))).thenAnswer (invocation -> {
536
+ Float score = invocation .getArgument (0 );
537
+ return 1 / (1 + score );
538
+ });
539
+
540
+ // Build the KNNQuery object
541
+ final KNNQuery query = KNNQuery .builder ()
542
+ .field (FIELD_NAME )
543
+ .queryVector (queryVector )
544
+ .k (k )
545
+ .indexName (INDEX_NAME )
546
+ .vectorDataType (VectorDataType .BINARY ) // Simulate binary vector type for quantization
547
+ .build ();
548
+
549
+ final float boost = 1.0F ;
550
+ final KNNWeight knnWeight = new KNNWeight (query , boost );
551
+
552
+ final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
553
+ final SegmentReader reader = mock (SegmentReader .class );
554
+ when (leafReaderContext .reader ()).thenReturn (reader );
555
+
556
+ final FieldInfos fieldInfos = mock (FieldInfos .class );
557
+ final FieldInfo fieldInfo = mock (FieldInfo .class );
558
+ when (reader .getFieldInfos ()).thenReturn (fieldInfos );
559
+ when (fieldInfos .fieldInfo (FIELD_NAME )).thenReturn (fieldInfo );
560
+
561
+ when (fieldInfo .attributes ()).thenReturn (Map .of (KNN_ENGINE , KNNEngine .FAISS .getName (), SPACE_TYPE , SpaceType .HAMMING .getValue ()));
562
+
563
+ FSDirectory directory = mock (FSDirectory .class );
564
+ when (reader .directory ()).thenReturn (directory );
565
+ Path path = mock (Path .class );
566
+ when (directory .getDirectory ()).thenReturn (path );
567
+ when (path .toString ()).thenReturn ("/fake/directory" );
568
+
569
+ SegmentInfo segmentInfo = new SegmentInfo (
570
+ directory , // The directory where the segment is stored
571
+ Version .LATEST , // Lucene version
572
+ Version .LATEST , // Version of the segment info
573
+ "0" , // Segment name
574
+ 100 , // Max document count for this segment
575
+ false , // Is this a compound file segment
576
+ false , // Is this a merged segment
577
+ KNNCodecVersion .current ().getDefaultCodecDelegate (), // Codec delegate for KNN
578
+ Map .of (), // Diagnostics map
579
+ new byte [StringHelper .ID_LENGTH ], // Segment ID
580
+ Map .of (), // Attributes
581
+ Sort .RELEVANCE // Default sort order
582
+ );
583
+
584
+ final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo (segmentInfo , 0 , 0 , 0 , 0 , 0 , new byte [StringHelper .ID_LENGTH ]);
585
+
586
+ when (reader .getSegmentInfo ()).thenReturn (segmentCommitInfo );
587
+
588
+ try (MockedStatic <KNNCodecUtil > knnCodecUtilMockedStatic = mockStatic (KNNCodecUtil .class )) {
589
+ List <String > engineFiles = List .of ("_0_1_target_field.faiss" );
590
+ knnCodecUtilMockedStatic .when (() -> KNNCodecUtil .getEngineFiles (anyString (), anyString (), eq (segmentInfo )))
591
+ .thenReturn (engineFiles );
592
+
593
+ try (MockedStatic <SegmentLevelQuantizationUtil > quantizationUtilMockedStatic = mockStatic (SegmentLevelQuantizationUtil .class )) {
594
+ quantizationUtilMockedStatic .when (() -> SegmentLevelQuantizationUtil .quantizeVector (any (), any ()))
595
+ .thenReturn (quantizedVector );
596
+
597
+ // When: Call the scorer method
598
+ final KNNScorer knnScorer = (KNNScorer ) knnWeight .scorer (leafReaderContext );
599
+
600
+ // Then: Ensure scorer is not null
601
+ assertNotNull (knnScorer );
602
+
603
+ // Verify that JNIService.queryBinaryIndex is called with the quantized vector
604
+ jniServiceMockedStatic .verify (
605
+ () -> JNIService .queryBinaryIndex (anyLong (), eq (quantizedVector ), eq (k ), any (), any (), any (), anyInt (), any ()),
606
+ times (1 )
607
+ );
608
+
609
+ // Iterate over the results and ensure they are scored with SpaceType.HAMMING
610
+ final DocIdSetIterator docIdSetIterator = knnScorer .iterator ();
611
+ assertNotNull (docIdSetIterator );
612
+ while (docIdSetIterator .nextDoc () != DocIdSetIterator .NO_MORE_DOCS ) {
613
+ int docId = docIdSetIterator .docID ();
614
+ float expectedScore = knnEngine .score (knnQueryResults [docId - 1 ].getScore (), SpaceType .HAMMING );
615
+ float actualScore = knnScorer .score ();
616
+ // Check if the score is calculated using HAMMING
617
+ assertEquals (expectedScore , actualScore , 0.01f ); // Tolerance for floating-point comparison
618
+ }
619
+ }
620
+ }
621
+ }
622
+
518
623
public void validateANNWithFilterQuery_whenDoingANN_thenSuccess (final boolean isBinary ) throws IOException {
519
624
// Given
520
625
int k = 3 ;
0 commit comments