Skip to content

Commit 88fc955

Browse files
Initial support for AOT generated VectorSearch
1 parent 09835d8 commit 88fc955

File tree

11 files changed

+709
-48
lines changed

11 files changed

+709
-48
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.data.mongodb.repository.aot;
1717

18+
import java.lang.reflect.Method;
1819
import java.util.ArrayList;
1920
import java.util.Iterator;
2021
import java.util.List;
@@ -46,6 +47,7 @@
4647
import org.springframework.data.mongodb.core.query.Query;
4748
import org.springframework.data.mongodb.core.query.TextCriteria;
4849
import org.springframework.data.mongodb.core.query.UpdateDefinition;
50+
import org.springframework.data.mongodb.repository.VectorSearch;
4951
import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor;
5052
import org.springframework.data.mongodb.repository.query.MongoParameterAccessor;
5153
import org.springframework.data.mongodb.repository.query.MongoQueryCreator;
@@ -79,14 +81,16 @@ public AotQueryCreator() {
7981
}
8082

8183
@SuppressWarnings("NullAway")
82-
StringQuery createQuery(PartTree partTree, QueryMethod queryMethod) {
83-
84+
StringQuery createQuery(PartTree partTree, QueryMethod queryMethod, Method source) {
8485

8586
boolean geoNear = queryMethod instanceof MongoQueryMethod mqm ? mqm.isGeoNearQuery() : false;
87+
boolean searchQuery = queryMethod instanceof MongoQueryMethod mqm
88+
? mqm.isSearchQuery() || source.isAnnotationPresent(VectorSearch.class)
89+
: source.isAnnotationPresent(VectorSearch.class);
8690

8791
Query query = new MongoQueryCreator(partTree,
88-
new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, geoNear, queryMethod.isSearchQuery())
89-
.createQuery();
92+
new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext,
93+
geoNear, searchQuery).createQuery();
9094

9195
if (partTree.isLimiting()) {
9296
query.limit(partTree.getMaxResults());
@@ -141,8 +145,7 @@ public PlaceholderParameterAccessor(QueryMethod queryMethod) {
141145
for (Parameter parameter : parameters.toList()) {
142146
if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) {
143147
placeholders.add(parameter.getIndex(), new GeoJsonPlaceholder(parameter.getIndex(), ""));
144-
}
145-
else if (ClassUtils.isAssignable(Point.class, parameter.getType())) {
148+
} else if (ClassUtils.isAssignable(Point.class, parameter.getType())) {
146149
placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex()));
147150
} else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) {
148151
placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex()));
@@ -152,8 +155,7 @@ else if (ClassUtils.isAssignable(Point.class, parameter.getType())) {
152155
placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex()));
153156
} else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) {
154157
placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex()));
155-
}
156-
else {
158+
} else {
157159
placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex()));
158160
}
159161
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@
1616
package org.springframework.data.mongodb.repository.aot;
1717

1818
import java.util.ArrayList;
19+
import java.util.LinkedHashMap;
1920
import java.util.List;
2021
import java.util.Locale;
2122
import java.util.Map;
23+
import java.util.function.Consumer;
2224

2325
import org.bson.Document;
2426
import org.jspecify.annotations.Nullable;
27+
import org.springframework.data.domain.Range;
28+
import org.springframework.data.domain.Score;
29+
import org.springframework.data.domain.ScoringFunction;
2530
import org.springframework.data.expression.ValueEvaluationContext;
2631
import org.springframework.data.expression.ValueExpression;
2732
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
@@ -33,6 +38,7 @@
3338
import org.springframework.data.mongodb.core.mapping.FieldName;
3439
import org.springframework.data.mongodb.core.query.BasicQuery;
3540
import org.springframework.data.mongodb.core.query.Collation;
41+
import org.springframework.data.mongodb.core.query.Criteria;
3642
import org.springframework.data.mongodb.repository.query.MongoParameters;
3743
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
3844
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
@@ -42,7 +48,9 @@
4248
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
4349
import org.springframework.data.repository.query.ValueExpressionDelegate;
4450
import org.springframework.expression.EvaluationContext;
51+
import org.springframework.util.Assert;
4552
import org.springframework.util.ClassUtils;
53+
import org.springframework.util.CollectionUtils;
4654
import org.springframework.util.ObjectUtils;
4755

4856
/**
@@ -108,7 +116,27 @@ protected Document bindParameters(String source, Map<String, Object> parameters)
108116
return new ParameterBindingDocumentCodec().decode(source, bindingContext);
109117
}
110118

111-
protected Object evaluate(String source, Map<String, Object> parameters) {
119+
protected Object[] arguments(Object... arguments) {
120+
return arguments;
121+
}
122+
123+
protected Map<String, Object> argumentMap(Object... parameters) {
124+
125+
Assert.state(parameters.length % 2 == 0, "even number of args required");
126+
127+
LinkedHashMap<String, Object> argumentMap = CollectionUtils.newLinkedHashMap(parameters.length / 2);
128+
for (int i = 0; i < parameters.length; i += 2) {
129+
130+
if (!(parameters[i] instanceof String key)) {
131+
throw new IllegalArgumentException("key must be a String");
132+
}
133+
argumentMap.put(key, parameters[i + 1]);
134+
}
135+
136+
return argumentMap;
137+
}
138+
139+
protected @Nullable Object evaluate(String source, Map<String, Object> parameters) {
112140

113141
ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor()
114142
.create(new NoMongoParameters()).getEvaluationContext(parameters.values());
@@ -120,9 +148,63 @@ protected Object evaluate(String source, Map<String, Object> parameters) {
120148
return parse.evaluate(valueEvaluationContext);
121149
}
122150

151+
protected Consumer<Criteria> scoreBetween(Range.Bound<? extends Score> lower, Range.Bound<? extends Score> upper) {
152+
153+
return criteria -> {
154+
if (lower.isBounded()) {
155+
double value = lower.getValue().get().getValue();
156+
if (lower.isInclusive()) {
157+
criteria.gte(value);
158+
} else {
159+
criteria.gt(value);
160+
}
161+
}
162+
163+
if (upper.isBounded()) {
164+
165+
double value = upper.getValue().get().getValue();
166+
if (upper.isInclusive()) {
167+
criteria.lte(value);
168+
} else {
169+
criteria.lt(value);
170+
}
171+
}
172+
173+
};
174+
}
175+
176+
protected ScoringFunction scoringFunction(Range<? extends Score> scoreRange) {
177+
178+
if (scoreRange != null) {
179+
if (scoreRange.getUpperBound().isBounded()) {
180+
return scoreRange.getUpperBound().getValue().get().getFunction();
181+
}
182+
183+
if (scoreRange.getLowerBound().isBounded()) {
184+
return scoreRange.getLowerBound().getValue().get().getFunction();
185+
}
186+
}
187+
188+
return ScoringFunction.unspecified();
189+
}
190+
191+
// Range<Score> scoreRange = accessor.getScoreRange();
192+
//
193+
// if (scoreRange != null) {
194+
// if (scoreRange.getUpperBound().isBounded()) {
195+
// return scoreRange.getUpperBound().getValue().get().getFunction();
196+
// }
197+
//
198+
// if (scoreRange.getLowerBound().isBounded()) {
199+
// return scoreRange.getLowerBound().getValue().get().getFunction();
200+
// }
201+
// }
202+
//
203+
// return ScoringFunction.unspecified();
204+
123205
protected Collation collationOf(@Nullable Object source) {
124206

125-
if(source == null) {
207+
if (source == null) {
126208
return Collation.simple();
127209
}
128210
if (source instanceof String) {

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
3838
import org.springframework.javapoet.CodeBlock;
3939
import org.springframework.javapoet.CodeBlock.Builder;
40+
import org.springframework.util.NumberUtils;
4041
import org.springframework.util.StringUtils;
4142

4243
/**
@@ -49,6 +50,7 @@ class MongoCodeBlocks {
4950

5051
private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)");
5152
private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}");
53+
private static final Pattern VALUE_EXPRESSION_PATTERN = Pattern.compile("^#\\{.*}$");
5254

5355
/**
5456
* Builder for generating query parsing {@link CodeBlock}.
@@ -179,7 +181,7 @@ static CodeBlock renderExpressionToDocument(@Nullable String source, String vari
179181
} else {
180182
builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source);
181183
if (containsNamedPlaceholder(source)) {
182-
renderArgumentMap(arguments);
184+
builder.add(renderArgumentMap(arguments));
183185
} else {
184186
builder.add(renderArgumentArray(arguments));
185187
}
@@ -191,7 +193,7 @@ static CodeBlock renderExpressionToDocument(@Nullable String source, String vari
191193
static CodeBlock renderArgumentMap(Map<String, CodeBlock> arguments) {
192194

193195
Builder builder = CodeBlock.builder();
194-
builder.add("$T.of(", Map.class);
196+
builder.add("argumentMap(");
195197
Iterator<Entry<String, CodeBlock>> iterator = arguments.entrySet().iterator();
196198
while (iterator.hasNext()) {
197199
Entry<String, CodeBlock> next = iterator.next();
@@ -208,24 +210,41 @@ static CodeBlock renderArgumentMap(Map<String, CodeBlock> arguments) {
208210
static CodeBlock renderArgumentArray(Map<String, CodeBlock> arguments) {
209211

210212
Builder builder = CodeBlock.builder();
211-
builder.add("new $T[]{ ", Object.class);
213+
builder.add("arguments(");
212214
Iterator<CodeBlock> iterator = arguments.values().iterator();
213215
while (iterator.hasNext()) {
214216
builder.add(iterator.next());
215217
if (iterator.hasNext()) {
216218
builder.add(", ");
217-
} else {
218-
builder.add(" ");
219219
}
220220
}
221-
builder.add("}");
221+
builder.add(")");
222222
return builder.build();
223223
}
224224

225+
static CodeBlock evaluateNumberPotentially(String value, Class<? extends Number> targetType,
226+
Map<String, CodeBlock> arguments) {
227+
try {
228+
Number number = NumberUtils.parseNumber(value, targetType);
229+
return CodeBlock.of("$L", number);
230+
} catch (IllegalArgumentException e) {
231+
232+
Builder builder = CodeBlock.builder();
233+
builder.add("($T) evaluate($S, ", targetType, value);
234+
builder.add(MongoCodeBlocks.renderArgumentMap(arguments));
235+
builder.add(")");
236+
return builder.build();
237+
}
238+
}
239+
225240
static boolean containsPlaceholder(String source) {
226241
return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source);
227242
}
228243

244+
static boolean containsExpression(String source) {
245+
return VALUE_EXPRESSION_PATTERN.matcher(source).find();
246+
}
247+
229248
static boolean containsNamedPlaceholder(String source) {
230249
return EXPRESSION_BINDING_PATTERN.matcher(source).find();
231250
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
3838
import org.springframework.data.mongodb.repository.Query;
3939
import org.springframework.data.mongodb.repository.Update;
40+
import org.springframework.data.mongodb.repository.VectorSearch;
4041
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
4142
import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder;
4243
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
@@ -107,7 +108,11 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB
107108
}
108109

109110
QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod,
110-
AnnotatedElementUtils.findMergedAnnotation(method, Query.class));
111+
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method);
112+
113+
if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) {
114+
return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery()));
115+
}
111116

112117
if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1
113118
&& queryMethod.getReturnType().isCollectionLike())) {
@@ -126,8 +131,8 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB
126131

127132
UpdateInteraction update = new UpdateInteraction(query, null, updateIndex);
128133
return updateMethodContributor(queryMethod, update);
129-
130134
} else {
135+
131136
Update updateSource = queryMethod.getUpdateSource();
132137
if (StringUtils.hasText(updateSource.value())) {
133138
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null);
@@ -146,7 +151,7 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB
146151

147152
@SuppressWarnings("NullAway")
148153
private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod,
149-
@Nullable Query queryAnnotation) {
154+
@Nullable Query queryAnnotation, Method source) {
150155

151156
QueryInteraction query;
152157
if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) {
@@ -155,8 +160,8 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor
155160
} else {
156161

157162
PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType());
158-
query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), partTree.isCountProjection(),
159-
partTree.isDelete(), partTree.isExistsProjection());
163+
query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod, source),
164+
partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection());
160165
}
161166

162167
if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) {
@@ -172,7 +177,7 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor
172177
private static boolean backoff(MongoQueryMethod method) {
173178

174179
// TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays.
175-
boolean skip = method.isSearchQuery() || method.getReturnType().getType().isArray();
180+
boolean skip = method.getReturnType().getType().isArray();
176181

177182
if (skip && logger.isDebugEnabled()) {
178183
logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query"
@@ -220,6 +225,21 @@ static MethodContributor<MongoQueryMethod> aggregationMethodContributor(MongoQue
220225
});
221226
}
222227

228+
static MethodContributor<MongoQueryMethod> searchMethodContributor(MongoQueryMethod queryMethod,
229+
SearchInteraction interaction) {
230+
return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> {
231+
232+
CodeBlock.Builder builder = CodeBlock.builder();
233+
234+
String variableName = "search";
235+
236+
builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod)
237+
.usingVariableName(variableName).withFilter(interaction.getFilter()).build());
238+
239+
return builder.build();
240+
});
241+
}
242+
223243
static MethodContributor<MongoQueryMethod> updateMethodContributor(MongoQueryMethod queryMethod,
224244
UpdateInteraction update) {
225245

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ CodeBlock build() {
211211

212212
Builder builder = CodeBlock.builder();
213213

214-
builder.add("\n");
215-
builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName));
214+
builder.add(buildJustTheQuery());
216215

217216
if (StringUtils.hasText(source.getQuery().getFieldsString())) {
218217

@@ -289,6 +288,14 @@ CodeBlock build() {
289288
return builder.build();
290289
}
291290

291+
CodeBlock buildJustTheQuery() {
292+
293+
Builder builder = CodeBlock.builder();
294+
builder.add("\n");
295+
builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName));
296+
return builder.build();
297+
}
298+
292299
private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) {
293300

294301
Builder builder = CodeBlock.builder();

0 commit comments

Comments
 (0)