Skip to content

Support userText rendering strategy in RetrievalAugmentationAdvisor #2468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.springframework.ai.chat.client.advisor;

import java.util.Map;

/**
* A {@link UserTextProcessor} that returns the user text as is.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public class NoOpUserTextProcessor implements UserTextProcessor {

@Override
public String process(String userText, Map<String, Object> userParams) {
return userText;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor;

import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.util.Assert;

import java.util.Map;

/**
* Processes the advised user text with the given user parameters using a
* {@link PromptTemplate}.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public class PromptTemplateUserTextProcessor implements UserTextProcessor {

@Override
public String process(String userText, Map<String, Object> userParams) {
Assert.hasText(userText, "userText cannot be null or empty");
Assert.notNull(userParams, "userParams cannot be null");
Assert.noNullElements(userParams.keySet(), "userParams keys cannot be null");

return new PromptTemplate(userText, userParams).render();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
Expand Down Expand Up @@ -61,6 +60,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

public static final String DOCUMENT_CONTEXT = "rag_document_context";

private final UserTextProcessor userTextProcessor;

private final List<QueryTransformer> queryTransformers;

@Nullable
Expand All @@ -78,10 +79,12 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

private final int order;

public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
public RetrievalAugmentationAdvisor(@Nullable UserTextProcessor userTextProcessor,
@Nullable List<QueryTransformer> queryTransformers, @Nullable QueryExpander queryExpander,
DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner,
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
@Nullable Integer order) {
this.userTextProcessor = userTextProcessor != null ? userTextProcessor : new PromptTemplateUserTextProcessor();
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
Expand All @@ -102,9 +105,11 @@ public static Builder builder() {
public AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());

String processedUserText = this.userTextProcessor.apply(request.userText(), request.userParams());

// 0. Create a query from the user text, parameters, and conversation history.
Query originalQuery = Query.builder()
.text(new PromptTemplate(request.userText(), request.userParams()).render())
.text(processedUserText)
.history(request.messages())
.context(context)
.build();
Expand Down Expand Up @@ -183,6 +188,8 @@ private static TaskExecutor buildDefaultTaskExecutor() {

public static final class Builder {

private UserTextProcessor userTextProcessor;

private List<QueryTransformer> queryTransformers;

private QueryExpander queryExpander;
Expand All @@ -202,6 +209,11 @@ public static final class Builder {
private Builder() {
}

public Builder userTextProcessor(UserTextProcessor userTextProcessor) {
this.userTextProcessor = userTextProcessor;
return this;
}

public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
this.queryTransformers = queryTransformers;
return this;
Expand Down Expand Up @@ -248,8 +260,9 @@ public Builder order(Integer order) {
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever,
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
return new RetrievalAugmentationAdvisor(this.userTextProcessor, this.queryTransformers, this.queryExpander,
this.documentRetriever, this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler,
this.order);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor;

import java.util.Map;
import java.util.function.BiFunction;

/**
* Processes the advised user text with the given user parameters.
*
* @author Thomas Vitale
* @since 1.0.0
*/
@FunctionalInterface
public interface UserTextProcessor extends BiFunction<String, Map<String, Object>, String> {

String process(String userText, Map<String, Object> userParams);

@Override
default String apply(String userText, Map<String, Object> userParams) {
return process(userText, userParams);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.springframework.ai.chat.client.advisor;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

/**
* Unit tests for {@link NoOpUserTextProcessor}.
*
* @author Thomas Vitale
*/
class NoOpUserTextProcessorTests {

@Test
void process() {
NoOpUserTextProcessor processor = new NoOpUserTextProcessor();
String userText = "Hello, {World}!";
String processedText = processor.process(userText, null);
assertEquals(userText, processedText);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;

import java.util.HashMap;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;

/**
* Unit tests for {@link PromptTemplateUserTextProcessor}.
*
* @author Thomas Vitale
*/
class PromptTemplateUserTextProcessorTests {

@ParameterizedTest
@NullAndEmptySource
void processWithNullOrEmptyUserText(String userText) {
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
Map<String, Object> userParams = Map.of("name", "William");
assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams))
.withMessage("userText cannot be null or empty");
}

@Test
void processWithNullUserParams() {
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
String userText = "Hello, {name}!";
Map<String, Object> userParams = null;
assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams))
.withMessage("userParams cannot be null");
}

@Test
void processWithNullUserParamsKeys() {
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
String userText = "Hello, {name}!";
Map<String, Object> userParams = new HashMap<>();
userParams.put(null, "William");
assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams))
.withMessage("userParams keys cannot be null");
}

@Test
void process() {
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
String userText = "Hello, {name}!";
Map<String, Object> userParams = Map.of("name", "William");
String processedText = processor.process(userText, userParams);
assertThat(processedText).isEqualTo("Hello, William!");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ String answer = chatClient.prompt()

See xref:api/retrieval-augmented-generation.adoc#_vectorstoredocumentretriever for more information.

By default, the `RetrievalAugmentationAdvisor` process the input user text with a `PromptTemplate`, ensuring that any template placeholder is correctly rendered before using the text for the retrieval process.
If you want to customize the processing logic, you can provide a custom `UserTextProcessor` to the advisor, either as a lambda or a class.
For example, in case you want to skip the rendering step, you can provide a `NoOpUserTextProcessor`. That is useful if you're planning to use the templating special characters in the user text for other purposes.

[source,java]
----
Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore().build())
.userTextProcessor(new NoOpUserTextProcessor())
.build();

String answer = chatClient.prompt()
.advisors(retrievalAugmentationAdvisor)
.user(question)
.call()
.content();
----

===== Advanced RAG

[source,java]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@

package org.springframework.ai.integration.tests.client.advisor;

import java.util.List;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.NoOpUserTextProcessor;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatResponse;
Expand All @@ -49,6 +47,8 @@
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.Resource;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

/**
Expand Down Expand Up @@ -131,6 +131,31 @@ void ragWithRequestFilter() {
.isNull();
}

@Test
void ragWithCustomUserTextProcessor() {
String question = "Where does the adventure of {Anacletus} and {Birba} take place?";

RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
.userTextProcessor(new NoOpUserTextProcessor())
.build();

ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
.build()
.prompt(question)
.advisors(ragAdvisor)
.call()
.chatResponse();

assertThat(chatResponse).isNotNull();

String response = chatResponse.getResult().getOutput().getText();
System.out.println(response);
assertThat(response).containsIgnoringCase("Highlands");

evaluateRelevancy(question, chatResponse);
}

@Test
void ragWithCompression() {
MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();
Expand Down