|
30 | 30 |
|
31 | 31 | import org.springframework.core.MethodParameter;
|
32 | 32 | import org.springframework.core.annotation.MergedAnnotation;
|
| 33 | +import org.springframework.dao.IncorrectResultSizeDataAccessException; |
33 | 34 | import org.springframework.data.domain.Pageable;
|
34 | 35 | import org.springframework.data.domain.Score;
|
35 | 36 | import org.springframework.data.domain.SliceImpl;
|
|
46 | 47 | import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
|
47 | 48 | import org.springframework.data.repository.query.ReturnedType;
|
48 | 49 | import org.springframework.data.support.PageableExecutionUtils;
|
| 50 | +import org.springframework.data.util.ReflectionUtils; |
49 | 51 | import org.springframework.javapoet.CodeBlock;
|
50 | 52 | import org.springframework.javapoet.CodeBlock.Builder;
|
51 | 53 | import org.springframework.javapoet.TypeName;
|
@@ -662,13 +664,14 @@ public CodeBlock build() {
|
662 | 664 | : TypeName.get(context.getDomainType());
|
663 | 665 | builder.add("\n");
|
664 | 666 |
|
| 667 | + Class<?> methodReturnType = context.getMethod().getReturnType(); |
665 | 668 | if (modifying.isPresent()) {
|
666 | 669 |
|
667 | 670 | if (modifying.getBoolean("flushAutomatically")) {
|
668 | 671 | builder.addStatement("this.$L.flush()", context.fieldNameOf(EntityManager.class));
|
669 | 672 | }
|
670 | 673 |
|
671 |
| - Class<?> returnType = context.getMethod().getReturnType(); |
| 674 | + Class<?> returnType = methodReturnType; |
672 | 675 |
|
673 | 676 | if (returnsModifying(returnType)) {
|
674 | 677 | builder.addStatement("int $L = $L.executeUpdate()", context.localVariable("result"), queryVariableName);
|
@@ -697,19 +700,42 @@ public CodeBlock build() {
|
697 | 700 |
|
698 | 701 | builder.addStatement("$T $L = $L.getResultList()", List.class,
|
699 | 702 | context.localVariable("resultList"), queryVariableName);
|
| 703 | + |
| 704 | + boolean returnCount = ClassUtils.isAssignable(Number.class, methodReturnType); |
| 705 | + boolean simpleBatch = returnCount || ReflectionUtils.isVoid(methodReturnType); |
| 706 | + boolean collectionQuery = queryMethod.isCollectionQuery(); |
| 707 | + |
| 708 | + if (!simpleBatch && !collectionQuery) { |
| 709 | + |
| 710 | + builder.beginControlFlow("if ($L.size() > 1)", context.localVariable("resultList")); |
| 711 | + builder.addStatement("throw new $1T($2S + $3L.size(), 1, $3L.size())", |
| 712 | + IncorrectResultSizeDataAccessException.class, |
| 713 | + "Delete query returned more than one element: expected 1, actual ", context.localVariable("resultList")); |
| 714 | + builder.endControlFlow(); |
| 715 | + |
| 716 | + builder.addStatement("$L.forEach($L::remove)", context.localVariable("resultList"), |
| 717 | + context.fieldNameOf(EntityManager.class)); |
| 718 | + } |
| 719 | + |
700 | 720 | builder.addStatement("$L.forEach($L::remove)", context.localVariable("resultList"),
|
701 | 721 | context.fieldNameOf(EntityManager.class));
|
702 |
| - if (!Collection.class.isAssignableFrom(context.getReturnType().toClass())) { |
703 |
| - if (ClassUtils.isAssignable(Number.class, context.getMethod().getReturnType())) { |
704 |
| - builder.addStatement("return $T.valueOf($L.size())", context.getMethod().getReturnType(), |
| 722 | + |
| 723 | + if (collectionQuery) { |
| 724 | + builder.addStatement("return ($T) $L", List.class, context.localVariable("resultList")); |
| 725 | + |
| 726 | + } else if (returnCount) { |
| 727 | + builder.addStatement("return $T.valueOf($L.size())", methodReturnType, |
705 | 728 | context.localVariable("resultList"));
|
706 | 729 | } else {
|
707 |
| - builder.addStatement("return ($T) ($L.isEmpty() ? null : $L.iterator().next())", actualReturnType, |
708 |
| - context.localVariable("resultList"), context.localVariable("resultList")); |
| 730 | + |
| 731 | + if (Optional.class.isAssignableFrom(methodReturnType)) { |
| 732 | + builder.addStatement("return ($1T) $1T.ofNullable($2L.isEmpty() ? null : $2L.iterator().next())", |
| 733 | + Optional.class, context.localVariable("resultList")); |
| 734 | + } else { |
| 735 | + builder.addStatement("return ($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, |
| 736 | + context.localVariable("resultList")); |
| 737 | + } |
709 | 738 | }
|
710 |
| - } else { |
711 |
| - builder.addStatement("return ($T) $L", List.class, context.localVariable("resultList")); |
712 |
| - } |
713 | 739 | } else if (aotQuery != null && aotQuery.isExists()) {
|
714 | 740 | builder.addStatement("return !$L.getResultList().isEmpty()", queryVariableName);
|
715 | 741 | } else if (aotQuery != null) {
|
|
0 commit comments