Skip to content

Commit 5b0666d

Browse files
authored
Add tests for vector search INT8/UINT8 types (#4091)
1 parent a2a0601 commit 5b0666d

File tree

4 files changed

+101
-22
lines changed

4 files changed

+101
-22
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ start: cleanup compile-module
473473
echo "$$REDIS_UDS" | redis-server -
474474
echo "$$REDIS_UNAVAILABLE_CONF" | redis-server -
475475
redis-cli -a cluster --cluster create 127.0.0.1:7479 127.0.0.1:7480 127.0.0.1:7481 --cluster-yes
476-
docker run -p 6479:6379 --name jedis-stack -d redis/redis-stack-server:edge
476+
docker run -p 6479:6379 --name jedis-stack -e PORT=6379 -d redislabs/client-libs-test:8.0-M04-pre
477477

478478
cleanup:
479479
- rm -vf /tmp/redis_cluster_node*.conf 2>/dev/null

src/test/java/redis/clients/jedis/modules/search/AggregationTest.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static org.junit.Assert.assertNull;
66
import static org.junit.Assert.assertThat;
77
import static org.junit.Assert.fail;
8+
import static redis.clients.jedis.util.RedisConditions.ModuleVersion.SEARCH_MOD_VER_80M3;
89

910
import io.redis.test.annotations.SinceRedisVersion;
1011
import io.redis.test.utils.RedisVersion;
@@ -39,11 +40,6 @@ public class AggregationTest extends RedisModuleCommandsTestBase {
3940
public static void prepare() {
4041
RedisModuleCommandsTestBase.prepare();
4142
}
42-
//
43-
// @AfterClass
44-
// public static void tearDown() {
45-
//// RedisModuleCommandsTestBase.tearDown();
46-
// }
4743

4844
public AggregationTest(RedisProtocol redisProtocol) {
4945
super(redisProtocol);
@@ -205,7 +201,7 @@ public void testAggregationBuilderAddScores() {
205201
.apply("@__score * 100", "normalized_score").dialect(3);
206202

207203
AggregationResult res = client.ftAggregate(index, r);
208-
if (RedisConditions.of(client).moduleVersionIsGreatherThan("SEARCH", 79900)) {
204+
if (RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3)) {
209205
// Default scorer is BM25
210206
assertEquals(0.6931, res.getRow(0).getDouble("__score"), 0.0001);
211207
assertEquals(69.31, res.getRow(0).getDouble("normalized_score"), 0.01);

src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java

+57-5
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import static org.hamcrest.MatcherAssert.assertThat;
44
import static org.junit.Assert.*;
5+
import static org.junit.Assume.assumeTrue;
56
import static redis.clients.jedis.util.AssertUtil.assertOK;
7+
import static redis.clients.jedis.util.RedisConditions.ModuleVersion.SEARCH_MOD_VER_80M3;
68

79
import java.util.*;
810
import java.util.stream.Collectors;
911

1012
import io.redis.test.annotations.SinceRedisVersion;
1113
import io.redis.test.utils.RedisVersion;
1214
import org.hamcrest.Matchers;
15+
import org.junit.After;
1316
import org.junit.BeforeClass;
1417
import org.junit.Test;
1518
import org.junit.runner.RunWith;
@@ -33,6 +36,7 @@
3336
import redis.clients.jedis.search.schemafields.GeoShapeField.CoordinateSystem;
3437
import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
3538
import redis.clients.jedis.modules.RedisModuleCommandsTestBase;
39+
import redis.clients.jedis.util.RedisConditions;
3640
import redis.clients.jedis.util.RedisVersionUtil;
3741

3842
@RunWith(Parameterized.class)
@@ -44,11 +48,13 @@ public class SearchWithParamsTest extends RedisModuleCommandsTestBase {
4448
public static void prepare() {
4549
RedisModuleCommandsTestBase.prepare();
4650
}
47-
//
48-
// @AfterClass
49-
// public static void tearDown() {
50-
//// RedisModuleCommandsTestBase.tearDown();
51-
// }
51+
52+
@After
53+
public void cleanUp() {
54+
if (client.ftList().contains(index)) {
55+
client.ftDropIndex(index);
56+
}
57+
}
5258

5359
public SearchWithParamsTest(RedisProtocol protocol) {
5460
super(protocol);
@@ -1248,6 +1254,32 @@ public void testFlatVectorSimilarity() {
12481254
assertEquals("0", doc1.get("__v_score"));
12491255
}
12501256

1257+
@Test
1258+
public void testFlatVectorSimilarityInt8() {
1259+
assumeTrue("INT8",
1260+
RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3));
1261+
assertOK(client.ftCreate(index,
1262+
VectorField.builder().fieldName("v").algorithm(VectorAlgorithm.FLAT)
1263+
.addAttribute("TYPE", "INT8").addAttribute("DIM", 2)
1264+
.addAttribute("DISTANCE_METRIC", "L2").build()));
1265+
1266+
byte[] a = { 127, 1 };
1267+
byte[] b = { 127, 10 };
1268+
byte[] c = { 127, 100 };
1269+
1270+
client.hset("a".getBytes(), "v".getBytes(), a);
1271+
client.hset("b".getBytes(), "v".getBytes(), b);
1272+
client.hset("c".getBytes(), "v".getBytes(), c);
1273+
1274+
FTSearchParams searchParams = FTSearchParams.searchParams().addParam("vec", a)
1275+
.sortBy("__v_score", SortingOrder.ASC).returnFields("__v_score");
1276+
1277+
Document doc1 = client.ftSearch(index, "*=>[KNN 2 @v $vec]", searchParams).getDocuments()
1278+
.get(0);
1279+
assertEquals("a", doc1.getId());
1280+
assertEquals("0", doc1.get("__v_score"));
1281+
}
1282+
12511283
@Test
12521284
@SinceRedisVersion(value = "7.4.0", message = "no optional params before 7.4.0")
12531285
public void vectorFieldParams() {
@@ -1286,6 +1318,26 @@ public void bfloat16StorageType() {
12861318
.build()));
12871319
}
12881320

1321+
@Test
1322+
public void int8StorageType() {
1323+
assumeTrue("INT8",
1324+
RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3));
1325+
assertOK(client.ftCreate(index,
1326+
VectorField.builder().fieldName("v").algorithm(VectorAlgorithm.HNSW)
1327+
.addAttribute("TYPE", "INT8").addAttribute("DIM", 4)
1328+
.addAttribute("DISTANCE_METRIC", "L2").build()));
1329+
}
1330+
1331+
@Test
1332+
public void uint8StorageType() {
1333+
assumeTrue("UINT8",
1334+
RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3));
1335+
assertOK(client.ftCreate(index,
1336+
VectorField.builder().fieldName("v").algorithm(VectorAlgorithm.HNSW)
1337+
.addAttribute("TYPE", "UINT8").addAttribute("DIM", 4)
1338+
.addAttribute("DISTANCE_METRIC", "L2").build()));
1339+
}
1340+
12891341
@Test
12901342
public void searchProfile() {
12911343
assertOK(client.ftCreate(index, TextField.of("t1"), TextField.of("t2")));

src/test/java/redis/clients/jedis/util/RedisConditions.java

+41-10
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,33 @@
1818

1919
public class RedisConditions {
2020

21+
public enum ModuleVersion {
22+
23+
SEARCH_MOD_VER_80M3("SEARCH", 79903);
24+
25+
private final String moduleName;
26+
private final int version;
27+
28+
ModuleVersion(String moduleName, int version) {
29+
this.moduleName = moduleName;
30+
this.version = version;
31+
}
32+
33+
public String getModuleName() {
34+
return moduleName;
35+
}
36+
37+
public int getVersion() {
38+
return version;
39+
}
40+
}
41+
2142
private final RedisVersion version;
2243
private final Map<String, Integer> modules;
2344
private final Map<String, CommandInfo> commands;
2445

25-
private RedisConditions(RedisVersion version, Map< String, CommandInfo> commands, Map<String, Integer> modules) {
46+
private RedisConditions(RedisVersion version, Map<String, CommandInfo> commands,
47+
Map<String, Integer> modules) {
2648
this.version = version;
2749
this.commands = commands;
2850
this.modules = modules;
@@ -31,15 +53,14 @@ private RedisConditions(RedisVersion version, Map< String, CommandInfo> commands
3153
public static RedisConditions of(UnifiedJedis jedis) {
3254
RedisVersion version = RedisVersionUtil.getRedisVersion(jedis);
3355

34-
CommandObject<Map<String, CommandInfo>> commandInfoCmd
35-
= new CommandObject<>(new CommandArguments(COMMAND), CommandInfo.COMMAND_INFO_RESPONSE);
56+
CommandObject<Map<String, CommandInfo>> commandInfoCmd = new CommandObject<>(
57+
new CommandArguments(COMMAND), CommandInfo.COMMAND_INFO_RESPONSE);
3658
Map<String, CommandInfo> commands = jedis.executeCommand(commandInfoCmd);
3759

38-
CommandObject<List<Module>> moduleListCmd
39-
= new CommandObject<>(new CommandArguments(MODULE).add(LIST), MODULE_LIST);
60+
CommandObject<List<Module>> moduleListCmd = new CommandObject<>(
61+
new CommandArguments(MODULE).add(LIST), MODULE_LIST);
4062

41-
Map<String, Integer> modules = jedis.executeCommand(moduleListCmd)
42-
.stream()
63+
Map<String, Integer> modules = jedis.executeCommand(moduleListCmd).stream()
4364
.collect(Collectors.toMap((m) -> m.getName().toUpperCase(), Module::getVersion));
4465

4566
return new RedisConditions(version, commands, modules);
@@ -68,10 +89,20 @@ public boolean hasModule(String module) {
6889
/**
6990
* @param module
7091
* @param version
71-
* @return {@code true} if the module is present.
92+
* @return {@code true} if the module with the requested minimum version is present.
7293
*/
73-
public boolean moduleVersionIsGreatherThan(String module, int version) {
94+
public boolean moduleVersionIsGreaterThanOrEqual(String module, int version) {
7495
Integer moduleVersion = modules.get(module.toUpperCase());
75-
return moduleVersion != null && moduleVersion > version;
96+
return moduleVersion != null && moduleVersion >= version;
97+
}
98+
99+
/**
100+
* @param moduleVersion
101+
* @return {@code true} if the module version is greater than or equal to the specified version.
102+
*/
103+
public boolean moduleVersionIsGreaterThanOrEqual(ModuleVersion moduleVersion) {
104+
return moduleVersionIsGreaterThanOrEqual(moduleVersion.getModuleName(),
105+
moduleVersion.getVersion());
76106
}
107+
77108
}

0 commit comments

Comments
 (0)