Skip to content

Commit e86b011

Browse files
committed
MLE-17146 Added support for annTopK
I removed the restriction in BaseTypeImpl that requires every value to be a subclass of `BaseArgImpl`. That puts a big burden on the server in that it must accept "wrapped" values for primitive values. The server function `annTopK` does not support "wrapped" values for `k` or `queryTolerance`, and thus that restriction caused things to break. Not sure we'll keep this change yet.
1 parent 8048ee2 commit e86b011

File tree

5 files changed

+74
-21
lines changed

5 files changed

+74
-21
lines changed

Jenkinsfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def setupDockerMarkLogic(String image){
2121
docker compose down -v || true
2222
docker volume prune -f
2323
echo "Using image: "'''+image+'''
24+
docker pull '''+image+'''
2425
MARKLOGIC_IMAGE='''+image+''' MARKLOGIC_LOGS_VOLUME=marklogicLogs docker compose up -d --build
2526
echo "mlPassword=admin" > gradle-local.properties
2627
echo "Waiting for MarkLogic server to initialize."

marklogic-client-api/src/main/java/com/marklogic/client/expression/PlanBuilder.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,21 @@ public interface ModifyPlan extends PreparePlan, PlanBuilderBase.ModifyPlanBase
14981498
* @return a ModifyPlan object
14991499
*/
15001500
public abstract ModifyPlan bindAs(PlanColumn column, ServerExpression expression);
1501+
1502+
/**
1503+
* Facilitates Approximate Nearest Neighbor (ann) vector search. Given a query vector, it searches for K nearest
1504+
* neighbor vector embeddings that are stored in the database.
1505+
*
1506+
* @param k This positive integer k is the top-K rows to return as a result of the index lookup.
1507+
* @param vectorColumn The column representing the vector ann-indexed column to perform the index lookup against.
1508+
* @param queryVector Specifies the query vector to perform the index lookup with.
1509+
* @param distanceColumn Optional output column that captures the values of the distance metric of the vectors retrieved from the index associated with vectorColumn and the queryVector.
1510+
* @param queryTolerance Specifies the query tolerance to help balance recall and search time. The value is between 0.0 and 1.0. At 0.0, the recall will be highest. At 1.0 the recall will likely see a large degradation, but queries will be quick. The default value is 0.0.
1511+
* @return
1512+
* @since 7.1.0
1513+
*/
1514+
ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, float queryTolerance);
1515+
15011516
/**
15021517
* This method restricts the left row set to rows where a row with the same columns and values doesn't exist in the right row set.
15031518
* @param right The row set from the right view.

marklogic-client-api/src/main/java/com/marklogic/client/impl/BaseTypeImpl.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,11 @@ static <T extends BaseArgImpl> T[] convertList(Object[] items, Class<T> as) {
409409
Arrays.stream(items)
410410
.map(item -> {
411411
if (item != null && !as.isInstance(item)) {
412-
throw new IllegalArgumentException("expected "+as.getName()+" argument instead of "+item.getClass().getName());
412+
// Prior to 7.1.0, this threw an exception, as it was requiring every item to be an instance of the given
413+
// class. This meant that a primitive value could never be passed. But that forces the server to support
414+
// both a primitive value and a "wrapped" value (e.g. with ns=xs, fn=float, args=value) for every
415+
// argument. This instead assumes that it can just write the item as-is and the server will accept it.
416+
return (BaseArgImpl) serializedPlanBuilder -> serializedPlanBuilder.append(item);
413417
}
414418
return (T) item;
415419
})

marklogic-client-api/src/main/java/com/marklogic/client/impl/PlanBuilderSubImpl.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,13 @@ static class ModifyPlanSubImpl
986986
super(prior, fnPrefix, fnName, fnArgs);
987987
}
988988

989+
@Override
990+
public ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, float queryTolerance) {
991+
return new PlanBuilderSubImpl.ModifyPlanSubImpl(this, "op", "annTopK", new Object[]{
992+
k, vectorColumn, queryVector, distanceColumn, queryTolerance
993+
});
994+
}
995+
989996
@Override
990997
public ModifyPlan patch(String docColumn, PatchBuilder patchDef) {
991998
return new PlanBuilderSubImpl.ModifyPlanSubImpl(this, "op", "patch", new Object[]{ this.col(docColumn), patchDef });

marklogic-client-api/src/test/java/com/marklogic/client/test/rows/VectorTest.java

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ void vectorFunctionsHappyPath() {
3838
PlanBuilder.ModifyPlan plan =
3939
op.fromView("vectors", "persons")
4040
.bind(op.as("sampleVector", op.vec.vector(sampleVector)))
41-
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"),op.col("sampleVector"))))
42-
.bind(op.as("dotProduct", op.vec.dotProduct(op.col("embedding"),op.col("sampleVector"))))
43-
.bind(op.as("euclideanDistance", op.vec.euclideanDistance(op.col("embedding"),op.col("sampleVector"))))
41+
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
42+
.bind(op.as("dotProduct", op.vec.dotProduct(op.col("embedding"), op.col("sampleVector"))))
43+
.bind(op.as("euclideanDistance", op.vec.euclideanDistance(op.col("embedding"), op.col("sampleVector"))))
4444
.bind(op.as("dimension", op.vec.dimension(op.col("sampleVector"))))
4545
.bind(op.as("normalize", op.vec.normalize(op.col("sampleVector"))))
4646
.bind(op.as("magnitude", op.vec.magnitude(op.col("sampleVector"))))
4747
.bind(op.as("get", op.vec.get(op.col("sampleVector"), op.xs.integer(2))))
48-
.bind(op.as("add", op.vec.add(op.col("embedding"),op.col("sampleVector"))))
49-
.bind(op.as("subtract", op.vec.subtract(op.col("embedding"),op.col("sampleVector"))))
48+
.bind(op.as("add", op.vec.add(op.col("embedding"), op.col("sampleVector"))))
49+
.bind(op.as("subtract", op.vec.subtract(op.col("embedding"), op.col("sampleVector"))))
5050
.bind(op.as("base64Encode", op.vec.base64Encode(op.col("sampleVector"))))
5151
.bind(op.as("base64Decode", op.vec.base64Decode(op.col("base64Encode"))))
5252
.bind(op.as("subVector", op.vec.subvector(op.col("sampleVector"), op.xs.integer(1), op.xs.integer(1))))
@@ -64,15 +64,15 @@ void vectorFunctionsHappyPath() {
6464
rows.forEach(row -> {
6565
// Simple a sanity checks to verify that the functions ran. Very little concern about the actual return values.
6666
double cosineSimilarity = row.getDouble("cosineSimilarity");
67-
assertTrue((cosineSimilarity > 0) && (cosineSimilarity < 1),"Unexpected value: " + cosineSimilarity);
67+
assertTrue((cosineSimilarity > 0) && (cosineSimilarity < 1), "Unexpected value: " + cosineSimilarity);
6868
double dotProduct = row.getDouble("dotProduct");
6969
Assertions.assertTrue(dotProduct > 0, "Unexpected value: " + dotProduct);
7070
double euclideanDistance = row.getDouble("euclideanDistance");
7171
Assertions.assertTrue(euclideanDistance > 0, "Unexpected value: " + euclideanDistance);
7272
assertEquals(3, row.getInt("dimension"));
7373
assertEquals(3, ((ArrayNode) row.get("normalize")).size());
7474
double magnitude = row.getDouble("magnitude");
75-
assertTrue( magnitude > 0, "Unexpected value: " + magnitude);
75+
assertTrue(magnitude > 0, "Unexpected value: " + magnitude);
7676
assertEquals(3, ((ArrayNode) row.get("add")).size());
7777
assertEquals(3, ((ArrayNode) row.get("subtract")).size());
7878
assertFalse(row.getString("base64Encode").isEmpty());
@@ -89,7 +89,7 @@ void cosineSimilarity_DimensionMismatch() {
8989
PlanBuilder.ModifyPlan plan =
9090
op.fromView("vectors", "persons")
9191
.bind(op.as("sampleVector", op.vec.vector(twoDimensionalVector)))
92-
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"),op.col("sampleVector"))))
92+
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
9393
.select(op.col("name"), op.col("summary"), op.col("cosineSimilarity"));
9494
Exception exception = assertThrows(FailedRequestException.class, () -> resultRows(plan));
9595
String actualMessage = exception.getMessage();
@@ -102,7 +102,7 @@ void cosineSimilarity_InvalidVector() {
102102
PlanBuilder.ModifyPlan plan =
103103
op.fromView("vectors", "persons")
104104
.bind(op.as("sampleVector", invalidVector))
105-
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"),op.col("sampleVector"))))
105+
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
106106
.select(op.col("name"), op.col("summary"), op.col("cosineSimilarity"));
107107
Exception exception = assertThrows(FailedRequestException.class, () -> resultRows(plan));
108108
String actualMessage = exception.getMessage();
@@ -111,20 +111,20 @@ void cosineSimilarity_InvalidVector() {
111111
}
112112

113113
@Test
114-
// As of 07/26/24, this test will fail with the ML12 develop branch.
115-
// However, it will succeed with the 12ea1 build.
116-
// See https://progresssoftware.atlassian.net/browse/MLE-15707
114+
// As of 07/26/24, this test will fail with the ML12 develop branch.
115+
// However, it will succeed with the 12ea1 build.
116+
// See https://progresssoftware.atlassian.net/browse/MLE-15707
117117
void bindVectorFromDocs() {
118118
PlanBuilder.ModifyPlan plan =
119119
op.fromSearchDocs(
120-
op.cts.andQuery(
121-
op.cts.documentQuery("/optic/vectors/alice.json"),
122-
op.cts.elementQuery(
123-
"person",
124-
op.cts.trueQuery()
125-
)
126-
))
127-
.bind(op.as("embedding", op.vec.vector(op.xpath("doc", "/person/embedding"))));
120+
op.cts.andQuery(
121+
op.cts.documentQuery("/optic/vectors/alice.json"),
122+
op.cts.elementQuery(
123+
"person",
124+
op.cts.trueQuery()
125+
)
126+
))
127+
.bind(op.as("embedding", op.vec.vector(op.xpath("doc", "/person/embedding"))));
128128
List<RowRecord> rows = resultRows(plan);
129129
assertEquals(1, rows.size());
130130
}
@@ -138,4 +138,30 @@ void vecVectorWithCol() {
138138
List<RowRecord> rows = resultRows(plan);
139139
assertEquals(2, rows.size());
140140
}
141+
142+
@Test
143+
void annTopK() {
144+
PlanBuilder.ModifyPlan plan = op.fromView("vectors", "persons")
145+
.annTopK(10, op.col("embedding"), op.vec.vector(sampleVector), op.col("distance"), 0.5f);
146+
147+
List<RowRecord> rows = resultRows(plan);
148+
assertEquals(2, rows.size(), "Verifying that annTopK worked and returned both rows from the view.");
149+
150+
rows.forEach(row -> {
151+
float distance = row.getFloat("distance");
152+
assertTrue(distance > 0, "Just verifying that annTopK both worked and put a valid value into the 'distance' column.");
153+
});
154+
}
155+
156+
@Test
157+
void dslAnnTopK() {
158+
String query = "const qualityVector = vec.vector([ 1.1, 2.2, 3.3 ]);\n" +
159+
"op.fromView('vectors', 'persons')\n" +
160+
" .bind(op.as('myVector', op.vec.vector(op.col('embedding'))))\n" +
161+
" .annTopK(2, op.col('myVector'), qualityVector, op.col('distance'), 0.5)";
162+
163+
RawQueryDSLPlan plan = rowManager.newRawQueryDSLPlan(new StringHandle(query));
164+
List<RowRecord> rows = resultRows(plan);
165+
assertEquals(2, rows.size(), "Just verifying that 'annTopK' works via the DSL and v1/rows.");
166+
}
141167
}

0 commit comments

Comments
 (0)