diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java new file mode 100644 index 00000000000..07168c35b62 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java @@ -0,0 +1,18 @@ +package org.springframework.ai.chat.client.advisor; + +import java.util.Map; + +/** + * A {@link UserTextProcessor} that returns the user text as is. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class NoOpUserTextProcessor implements UserTextProcessor { + + @Override + public String process(String userText, Map userParams) { + return userText; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java new file mode 100644 index 00000000000..9764f044cf1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.util.Assert; + +import java.util.Map; + +/** + * Processes the advised user text with the given user parameters using a + * {@link PromptTemplate}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class PromptTemplateUserTextProcessor implements UserTextProcessor { + + @Override + public String process(String userText, Map userParams) { + Assert.hasText(userText, "userText cannot be null or empty"); + Assert.notNull(userParams, "userParams cannot be null"); + Assert.noNullElements(userParams.keySet(), "userParams keys cannot be null"); + + return new PromptTemplate(userText, userParams).render(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index f2dfbde24aa..eab6d8507f5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -29,7 +29,6 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.rag.Query; import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; @@ -61,6 +60,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { public static final String DOCUMENT_CONTEXT = "rag_document_context"; + private final UserTextProcessor userTextProcessor; + private final List queryTransformers; @Nullable @@ -78,10 +79,12 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { private final int order; - public RetrievalAugmentationAdvisor(@Nullable List queryTransformers, - @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, - @Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter, - @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) { + public RetrievalAugmentationAdvisor(@Nullable UserTextProcessor userTextProcessor, + @Nullable List queryTransformers, @Nullable QueryExpander queryExpander, + DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner, + @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, + @Nullable Integer order) { + this.userTextProcessor = userTextProcessor != null ? userTextProcessor : new PromptTemplateUserTextProcessor(); Assert.notNull(documentRetriever, "documentRetriever cannot be null"); Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); this.queryTransformers = queryTransformers != null ? queryTransformers : List.of(); @@ -102,9 +105,11 @@ public static Builder builder() { public AdvisedRequest before(AdvisedRequest request) { Map context = new HashMap<>(request.adviseContext()); + String processedUserText = this.userTextProcessor.apply(request.userText(), request.userParams()); + // 0. Create a query from the user text, parameters, and conversation history. Query originalQuery = Query.builder() - .text(new PromptTemplate(request.userText(), request.userParams()).render()) + .text(processedUserText) .history(request.messages()) .context(context) .build(); @@ -183,6 +188,8 @@ private static TaskExecutor buildDefaultTaskExecutor() { public static final class Builder { + private UserTextProcessor userTextProcessor; + private List queryTransformers; private QueryExpander queryExpander; @@ -202,6 +209,11 @@ public static final class Builder { private Builder() { } + public Builder userTextProcessor(UserTextProcessor userTextProcessor) { + this.userTextProcessor = userTextProcessor; + return this; + } + public Builder queryTransformers(List queryTransformers) { this.queryTransformers = queryTransformers; return this; @@ -248,8 +260,9 @@ public Builder order(Integer order) { } public RetrievalAugmentationAdvisor build() { - return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, - this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order); + return new RetrievalAugmentationAdvisor(this.userTextProcessor, this.queryTransformers, this.queryExpander, + this.documentRetriever, this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, + this.order); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java new file mode 100644 index 00000000000..92d5fba303a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Processes the advised user text with the given user parameters. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@FunctionalInterface +public interface UserTextProcessor extends BiFunction, String> { + + String process(String userText, Map userParams); + + @Override + default String apply(String userText, Map userParams) { + return process(userText, userParams); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java new file mode 100644 index 00000000000..4eb6100f2c9 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java @@ -0,0 +1,22 @@ +package org.springframework.ai.chat.client.advisor; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Unit tests for {@link NoOpUserTextProcessor}. + * + * @author Thomas Vitale + */ +class NoOpUserTextProcessorTests { + + @Test + void process() { + NoOpUserTextProcessor processor = new NoOpUserTextProcessor(); + String userText = "Hello, {World}!"; + String processedText = processor.process(userText, null); + assertEquals(userText, processedText); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java new file mode 100644 index 00000000000..161b14be649 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Unit tests for {@link PromptTemplateUserTextProcessor}. + * + * @author Thomas Vitale + */ +class PromptTemplateUserTextProcessorTests { + + @ParameterizedTest + @NullAndEmptySource + void processWithNullOrEmptyUserText(String userText) { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + Map userParams = Map.of("name", "William"); + assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams)) + .withMessage("userText cannot be null or empty"); + } + + @Test + void processWithNullUserParams() { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + String userText = "Hello, {name}!"; + Map userParams = null; + assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams)) + .withMessage("userParams cannot be null"); + } + + @Test + void processWithNullUserParamsKeys() { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + String userText = "Hello, {name}!"; + Map userParams = new HashMap<>(); + userParams.put(null, "William"); + assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams)) + .withMessage("userParams keys cannot be null"); + } + + @Test + void process() { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + String userText = "Hello, {name}!"; + Map userParams = Map.of("name", "William"); + String processedText = processor.process(userText, userParams); + assertThat(processedText).isEqualTo("Hello, William!"); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc index 44d10874283..d874f157117 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc @@ -138,6 +138,24 @@ String answer = chatClient.prompt() See xref:api/retrieval-augmented-generation.adoc#_vectorstoredocumentretriever for more information. +By default, the `RetrievalAugmentationAdvisor` process the input user text with a `PromptTemplate`, ensuring that any template placeholder is correctly rendered before using the text for the retrieval process. +If you want to customize the processing logic, you can provide a custom `UserTextProcessor` to the advisor, either as a lambda or a class. +For example, in case you want to skip the rendering step, you can provide a `NoOpUserTextProcessor`. That is useful if you're planning to use the templating special characters in the user text for other purposes. + +[source,java] +---- +Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore().build()) + .userTextProcessor(new NoOpUserTextProcessor()) + .build(); + +String answer = chatClient.prompt() + .advisors(retrievalAugmentationAdvisor) + .user(question) + .call() + .content(); +---- + ===== Advanced RAG [source,java] diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java index e7170c68b75..db090141710 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java @@ -16,16 +16,14 @@ package org.springframework.ai.integration.tests.client.advisor; -import java.util.List; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.NoOpUserTextProcessor; import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.ai.chat.model.ChatResponse; @@ -49,6 +47,8 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -131,6 +131,31 @@ void ragWithRequestFilter() { .isNull(); } + @Test + void ragWithCustomUserTextProcessor() { + String question = "Where does the adventure of {Anacletus} and {Birba} take place?"; + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build()) + .userTextProcessor(new NoOpUserTextProcessor()) + .build(); + + ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel) + .build() + .prompt(question) + .advisors(ragAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + + String response = chatResponse.getResult().getOutput().getText(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Highlands"); + + evaluateRelevancy(question, chatResponse); + } + @Test void ragWithCompression() { MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();