Skip to content

Commit 10ba08d

Browse files
refactor a tiny bit
1 parent e951a3b commit 10ba08d

File tree

7 files changed

+1097
-840
lines changed

7 files changed

+1097
-840
lines changed
Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.repository.aot;
17+
18+
import java.util.ArrayList;
19+
import java.util.List;
20+
import java.util.stream.Collectors;
21+
import java.util.stream.Stream;
22+
23+
import org.bson.Document;
24+
import org.jspecify.annotations.NullUnmarked;
25+
import org.springframework.core.annotation.MergedAnnotation;
26+
import org.springframework.data.domain.SliceImpl;
27+
import org.springframework.data.domain.Sort.Order;
28+
import org.springframework.data.mongodb.core.MongoOperations;
29+
import org.springframework.data.mongodb.core.aggregation.Aggregation;
30+
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
31+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
32+
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
33+
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
34+
import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes;
35+
import org.springframework.data.mongodb.core.query.Collation;
36+
import org.springframework.data.mongodb.repository.Hint;
37+
import org.springframework.data.mongodb.repository.ReadPreference;
38+
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
39+
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
40+
import org.springframework.data.util.ReflectionUtils;
41+
import org.springframework.javapoet.CodeBlock;
42+
import org.springframework.javapoet.CodeBlock.Builder;
43+
import org.springframework.util.ClassUtils;
44+
import org.springframework.util.CollectionUtils;
45+
import org.springframework.util.StringUtils;
46+
47+
/**
48+
* @author Christoph Strobl
49+
* @since 5.0
50+
*/
51+
class AggregationBlocks {
52+
53+
@NullUnmarked
54+
static class AggregationExecutionCodeBlockBuilder {
55+
56+
private final AotQueryMethodGenerationContext context;
57+
private final MongoQueryMethod queryMethod;
58+
private String aggregationVariableName;
59+
60+
AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
61+
62+
this.context = context;
63+
this.queryMethod = queryMethod;
64+
}
65+
66+
AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) {
67+
68+
this.aggregationVariableName = aggregationVariableName;
69+
return this;
70+
}
71+
72+
CodeBlock build() {
73+
74+
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
75+
Builder builder = CodeBlock.builder();
76+
77+
builder.add("\n");
78+
79+
Class<?> outputType = queryMethod.getReturnedObjectType();
80+
if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) {
81+
outputType = Document.class;
82+
} else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) {
83+
outputType = queryMethod.getReturnType().getComponentType().getType();
84+
}
85+
86+
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
87+
builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
88+
return builder.build();
89+
}
90+
91+
if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) {
92+
builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
93+
return builder.build();
94+
}
95+
96+
if (outputType == Document.class) {
97+
98+
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
99+
100+
if (queryMethod.isStreamQuery()) {
101+
102+
builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class,
103+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
104+
105+
builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))",
106+
context.localVariable("results"), returnType);
107+
} else {
108+
109+
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
110+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
111+
112+
if (!queryMethod.isCollectionQuery()) {
113+
builder.addStatement(
114+
"return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))",
115+
CollectionUtils.class, returnType, context.localVariable("results"));
116+
} else {
117+
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
118+
context.localVariable("results"));
119+
}
120+
}
121+
} else {
122+
if (queryMethod.isSliceQuery()) {
123+
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
124+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
125+
builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()",
126+
context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName());
127+
builder.addStatement(
128+
"return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)",
129+
SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"),
130+
context.getPageableParameterName());
131+
} else {
132+
133+
if (queryMethod.isStreamQuery()) {
134+
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
135+
outputType);
136+
} else {
137+
138+
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
139+
aggregationVariableName, outputType);
140+
}
141+
}
142+
}
143+
144+
return builder.build();
145+
}
146+
}
147+
148+
@NullUnmarked
149+
static class AggregationCodeBlockBuilder {
150+
151+
private final AotQueryMethodGenerationContext context;
152+
private final MongoQueryMethod queryMethod;
153+
private final List<CodeBlock> arguments;
154+
155+
private AggregationInteraction source;
156+
157+
private String aggregationVariableName;
158+
private boolean pipelineOnly;
159+
160+
AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
161+
162+
this.context = context;
163+
this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList());
164+
this.queryMethod = queryMethod;
165+
}
166+
167+
AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) {
168+
169+
this.source = aggregation;
170+
return this;
171+
}
172+
173+
AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) {
174+
175+
this.aggregationVariableName = aggregationVariableName;
176+
return this;
177+
}
178+
179+
AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) {
180+
181+
this.pipelineOnly = pipelineOnly;
182+
return this;
183+
}
184+
185+
CodeBlock build() {
186+
187+
Builder builder = CodeBlock.builder();
188+
builder.add("\n");
189+
190+
String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline"));
191+
builder.add(pipeline(pipelineName));
192+
193+
if (!pipelineOnly) {
194+
195+
builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())",
196+
TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName,
197+
Aggregation.class, pipelineName);
198+
199+
builder.add(aggregationOptions(aggregationVariableName));
200+
}
201+
202+
return builder.build();
203+
}
204+
205+
private CodeBlock pipeline(String pipelineVariableName) {
206+
207+
String sortParameter = context.getSortParameterName();
208+
String limitParameter = context.getLimitParameterName();
209+
String pageableParameter = context.getPageableParameterName();
210+
211+
boolean mightBeSorted = StringUtils.hasText(sortParameter);
212+
boolean mightBeLimited = StringUtils.hasText(limitParameter);
213+
boolean mightBePaged = StringUtils.hasText(pageableParameter);
214+
215+
int stageCount = source.stages().size();
216+
if (mightBeSorted) {
217+
stageCount++;
218+
}
219+
if (mightBeLimited) {
220+
stageCount++;
221+
}
222+
if (mightBePaged) {
223+
stageCount += 3;
224+
}
225+
226+
Builder builder = CodeBlock.builder();
227+
builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments));
228+
229+
if (mightBeSorted) {
230+
builder.add(sortingStage(sortParameter));
231+
}
232+
233+
if (mightBeLimited) {
234+
builder.add(limitingStage(limitParameter));
235+
}
236+
237+
if (mightBePaged) {
238+
builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery()));
239+
}
240+
241+
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
242+
context.localVariable("stages"));
243+
return builder.build();
244+
}
245+
246+
private CodeBlock aggregationOptions(String aggregationVariableName) {
247+
248+
Builder builder = CodeBlock.builder();
249+
List<CodeBlock> options = new ArrayList<>(5);
250+
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
251+
options.add(CodeBlock.of(".skipOutput()"));
252+
}
253+
254+
MergedAnnotation<Hint> hintAnnotation = context.getAnnotation(Hint.class);
255+
String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null;
256+
if (StringUtils.hasText(hint)) {
257+
options.add(CodeBlock.of(".hint($S)", hint));
258+
}
259+
260+
MergedAnnotation<ReadPreference> readPreferenceAnnotation = context.getAnnotation(ReadPreference.class);
261+
String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null;
262+
if (StringUtils.hasText(readPreference)) {
263+
options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference));
264+
}
265+
266+
if (queryMethod.hasAnnotatedCollation()) {
267+
options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation()));
268+
}
269+
270+
if (!options.isEmpty()) {
271+
272+
Builder optionsBuilder = CodeBlock.builder();
273+
optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class,
274+
context.localVariable("aggregationOptions"));
275+
optionsBuilder.indent();
276+
for (CodeBlock optionBlock : options) {
277+
optionsBuilder.add(optionBlock);
278+
optionsBuilder.add("\n");
279+
}
280+
optionsBuilder.add(".build();\n");
281+
optionsBuilder.unindent();
282+
builder.add(optionsBuilder.build());
283+
284+
builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName,
285+
context.localVariable("aggregationOptions"));
286+
}
287+
return builder.build();
288+
}
289+
290+
private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
291+
List<CodeBlock> arguments) {
292+
293+
Builder builder = CodeBlock.builder();
294+
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
295+
stageCount);
296+
int stageCounter = 0;
297+
298+
for (String stage : stages) {
299+
String stageName = context.localVariable("stage_%s".formatted(stageCounter++));
300+
builder.add(MongoCodeBlocks.renderExpressionToDocument(stage, stageName, arguments));
301+
builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName);
302+
}
303+
304+
return builder.build();
305+
}
306+
307+
private CodeBlock sortingStage(String sortProvider) {
308+
309+
Builder builder = CodeBlock.builder();
310+
311+
builder.beginControlFlow("if ($L.isSorted())", sortProvider);
312+
builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument"));
313+
builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider);
314+
builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);",
315+
context.localVariable("sortDocument"), context.localVariable("order"));
316+
builder.endControlFlow();
317+
builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort",
318+
context.localVariable("sortDocument"));
319+
builder.endControlFlow();
320+
321+
return builder.build();
322+
}
323+
324+
private CodeBlock pagingStage(String pageableProvider, boolean slice) {
325+
326+
Builder builder = CodeBlock.builder();
327+
328+
builder.add(sortingStage(pageableProvider + ".getSort()"));
329+
330+
builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
331+
builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);
332+
builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class,
333+
pageableProvider);
334+
builder.endControlFlow();
335+
if (slice) {
336+
builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"),
337+
Aggregation.class, pageableProvider);
338+
} else {
339+
builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class,
340+
pageableProvider);
341+
}
342+
builder.endControlFlow();
343+
344+
return builder.build();
345+
}
346+
347+
private CodeBlock limitingStage(String limitProvider) {
348+
349+
Builder builder = CodeBlock.builder();
350+
351+
builder.beginControlFlow("if ($L.isLimited())", limitProvider);
352+
builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class,
353+
limitProvider);
354+
builder.endControlFlow();
355+
356+
return builder.build();
357+
}
358+
359+
}
360+
}

0 commit comments

Comments
 (0)