diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java index 56038587fa6..801de92d1fb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java @@ -17,6 +17,7 @@ package org.springframework.ai.rag.retrieval.join; import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -35,6 +36,7 @@ * documents, the first occurrence is kept. The score of each document is kept as is. * * @author Thomas Vitale + * @author ghdcksgml1 * @since 1.0.0 */ public class ConcatenationDocumentJoiner implements DocumentJoiner { @@ -54,7 +56,10 @@ public List join(Map>> documentsForQuery) { .flatMap(List::stream) .flatMap(List::stream) .collect(Collectors.toMap(Document::getId, Function.identity(), (existing, duplicate) -> existing)) - .values()); + .values() + .stream() + .sorted(Comparator.comparing((Document d1) -> (d1.getScore() != null) ? d1.getScore() : 0.0).reversed()) + .toList()); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java index 39a588555e6..f955ea68e21 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java @@ -32,6 +32,7 @@ * Unit tests for {@link ConcatenationDocumentJoiner}. * * @author Thomas Vitale + * @author ghdcksgml1 */ class ConcatenationDocumentJoinerTests { @@ -92,4 +93,22 @@ void whenDuplicatedDocumentsThenOnlyFirstOccurrenceIsKept() { assertThat(result).extracting(Document::getText).containsOnlyOnce("Content 2"); } + @Test + void whenSeveralQueryExistsInMapThenDocumentsAreJoinedInDescendingScoreOrder() { + DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner(); + var documentsForQuery = new HashMap>>(); + documentsForQuery.put(new Query("query1"), + List.of(List.of(Document.builder().id("1").text("Content 1").score(0.9).build(), + Document.builder().id("4").text("Content 4").score(0.6).build()), + List.of(Document.builder().id("2").text("Content 2").score(0.8).build()))); + documentsForQuery.put(new Query("query2"), + List.of(List.of(Document.builder().id("3").text("Content 3").score(0.7).build()))); + + List result = documentJoiner.join(documentsForQuery); + + assertThat(result).hasSize(4); + assertThat(result).extracting(Document::getId).containsExactly("1", "2", "3", "4"); + assertThat(result).extracting(Document::getScore).containsExactly(0.9, 0.8, 0.7, 0.6); + } + }