2
2
3
3
import static org .hamcrest .MatcherAssert .assertThat ;
4
4
import static org .junit .Assert .*;
5
+ import static org .junit .Assume .assumeTrue ;
5
6
import static redis .clients .jedis .util .AssertUtil .assertOK ;
7
+ import static redis .clients .jedis .util .RedisConditions .ModuleVersion .SEARCH_MOD_VER_80M3 ;
6
8
7
9
import java .util .*;
8
10
import java .util .stream .Collectors ;
9
11
10
12
import io .redis .test .annotations .SinceRedisVersion ;
11
13
import io .redis .test .utils .RedisVersion ;
12
14
import org .hamcrest .Matchers ;
15
+ import org .junit .After ;
13
16
import org .junit .BeforeClass ;
14
17
import org .junit .Test ;
15
18
import org .junit .runner .RunWith ;
33
36
import redis .clients .jedis .search .schemafields .GeoShapeField .CoordinateSystem ;
34
37
import redis .clients .jedis .search .schemafields .VectorField .VectorAlgorithm ;
35
38
import redis .clients .jedis .modules .RedisModuleCommandsTestBase ;
39
+ import redis .clients .jedis .util .RedisConditions ;
36
40
import redis .clients .jedis .util .RedisVersionUtil ;
37
41
38
42
@ RunWith (Parameterized .class )
@@ -44,11 +48,13 @@ public class SearchWithParamsTest extends RedisModuleCommandsTestBase {
44
48
public static void prepare () {
45
49
RedisModuleCommandsTestBase .prepare ();
46
50
}
47
- //
48
- // @AfterClass
49
- // public static void tearDown() {
50
- //// RedisModuleCommandsTestBase.tearDown();
51
- // }
51
+
52
+ @ After
53
+ public void cleanUp () {
54
+ if (client .ftList ().contains (index )) {
55
+ client .ftDropIndex (index );
56
+ }
57
+ }
52
58
53
59
public SearchWithParamsTest (RedisProtocol protocol ) {
54
60
super (protocol );
@@ -1248,6 +1254,32 @@ public void testFlatVectorSimilarity() {
1248
1254
assertEquals ("0" , doc1 .get ("__v_score" ));
1249
1255
}
1250
1256
1257
+ @ Test
1258
+ public void testFlatVectorSimilarityInt8 () {
1259
+ assumeTrue ("INT8" ,
1260
+ RedisConditions .of (client ).moduleVersionIsGreaterThanOrEqual (SEARCH_MOD_VER_80M3 ));
1261
+ assertOK (client .ftCreate (index ,
1262
+ VectorField .builder ().fieldName ("v" ).algorithm (VectorAlgorithm .FLAT )
1263
+ .addAttribute ("TYPE" , "INT8" ).addAttribute ("DIM" , 2 )
1264
+ .addAttribute ("DISTANCE_METRIC" , "L2" ).build ()));
1265
+
1266
+ byte [] a = { 127 , 1 };
1267
+ byte [] b = { 127 , 10 };
1268
+ byte [] c = { 127 , 100 };
1269
+
1270
+ client .hset ("a" .getBytes (), "v" .getBytes (), a );
1271
+ client .hset ("b" .getBytes (), "v" .getBytes (), b );
1272
+ client .hset ("c" .getBytes (), "v" .getBytes (), c );
1273
+
1274
+ FTSearchParams searchParams = FTSearchParams .searchParams ().addParam ("vec" , a )
1275
+ .sortBy ("__v_score" , SortingOrder .ASC ).returnFields ("__v_score" );
1276
+
1277
+ Document doc1 = client .ftSearch (index , "*=>[KNN 2 @v $vec]" , searchParams ).getDocuments ()
1278
+ .get (0 );
1279
+ assertEquals ("a" , doc1 .getId ());
1280
+ assertEquals ("0" , doc1 .get ("__v_score" ));
1281
+ }
1282
+
1251
1283
@ Test
1252
1284
@ SinceRedisVersion (value = "7.4.0" , message = "no optional params before 7.4.0" )
1253
1285
public void vectorFieldParams () {
@@ -1286,6 +1318,26 @@ public void bfloat16StorageType() {
1286
1318
.build ()));
1287
1319
}
1288
1320
1321
+ @ Test
1322
+ public void int8StorageType () {
1323
+ assumeTrue ("INT8" ,
1324
+ RedisConditions .of (client ).moduleVersionIsGreaterThanOrEqual (SEARCH_MOD_VER_80M3 ));
1325
+ assertOK (client .ftCreate (index ,
1326
+ VectorField .builder ().fieldName ("v" ).algorithm (VectorAlgorithm .HNSW )
1327
+ .addAttribute ("TYPE" , "INT8" ).addAttribute ("DIM" , 4 )
1328
+ .addAttribute ("DISTANCE_METRIC" , "L2" ).build ()));
1329
+ }
1330
+
1331
+ @ Test
1332
+ public void uint8StorageType () {
1333
+ assumeTrue ("UINT8" ,
1334
+ RedisConditions .of (client ).moduleVersionIsGreaterThanOrEqual (SEARCH_MOD_VER_80M3 ));
1335
+ assertOK (client .ftCreate (index ,
1336
+ VectorField .builder ().fieldName ("v" ).algorithm (VectorAlgorithm .HNSW )
1337
+ .addAttribute ("TYPE" , "UINT8" ).addAttribute ("DIM" , 4 )
1338
+ .addAttribute ("DISTANCE_METRIC" , "L2" ).build ()));
1339
+ }
1340
+
1289
1341
@ Test
1290
1342
public void searchProfile () {
1291
1343
assertOK (client .ftCreate (index , TextField .of ("t1" ), TextField .of ("t2" )));
0 commit comments