Skip to content

Commit aa891d0

Browse files
committed
Introduce TemplateRenderer for prompt templating
- Introduce new TemplateRenderer API providing the logic for rendering an input template. - Update the PromptTemplate API to accept a TemplateRenderer object at construction time. - Move ST logic to StTemplateRenderer implementation, used by default in PromptTemplate. - Extend ChatClient API to support passing a custom TemplateRenderer. - Add integration tests showing how to customize prompts in QuestionAnswerAdvisor and RetrievalAugmentationAdvisor. - Support PromptTemplate instead of String in QuestionAnswerAdvisor. Relates to spring-projectsgh-2655 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 67cbcf5 commit aa891d0

File tree

21 files changed

+1321
-190
lines changed

21 files changed

+1321
-190
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java

+45-25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@
2121
import java.util.Map;
2222
import java.util.stream.Collectors;
2323

24+
import org.springframework.lang.Nullable;
2425
import reactor.core.publisher.Flux;
2526
import reactor.core.publisher.Mono;
2627
import reactor.core.scheduler.Schedulers;
@@ -49,6 +50,7 @@
4950
* @author Christian Tzolov
5051
* @author Timo Salm
5152
* @author Ilayaperumal Gopinathan
53+
* @author Thomas Vitale
5254
* @since 1.0.0
5355
*/
5456
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
@@ -57,7 +59,7 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
5759

5860
public static final String FILTER_EXPRESSION = "qa_filter_expression";
5961

60-
private static final String DEFAULT_USER_TEXT_ADVISE = """
62+
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
6163
6264
Context information is below, surrounded by ---------------------
6365
@@ -68,13 +70,13 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
6870
Given the context and provided history information and not prior knowledge,
6971
reply to the user comment. If the answer is not in the context, inform
7072
the user that you can't answer the question.
71-
""";
73+
""");
7274

7375
private static final int DEFAULT_ORDER = 0;
7476

7577
private final VectorStore vectorStore;
7678

77-
private final String userTextAdvise;
79+
private final PromptTemplate promptTemplate;
7880

7981
private final SearchRequest searchRequest;
8082

@@ -88,7 +90,7 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
8890
* @param vectorStore The vector store to use
8991
*/
9092
public QuestionAnswerAdvisor(VectorStore vectorStore) {
91-
this(vectorStore, SearchRequest.builder().build(), DEFAULT_USER_TEXT_ADVISE);
93+
this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
9294
}
9395

9496
/**
@@ -97,9 +99,11 @@ public QuestionAnswerAdvisor(VectorStore vectorStore) {
9799
* @param vectorStore The vector store to use
98100
* @param searchRequest The search request defined using the portable filter
99101
* expression syntax
102+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
100103
*/
104+
@Deprecated
101105
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
102-
this(vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
106+
this(vectorStore, searchRequest, DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
103107
}
104108

105109
/**
@@ -110,9 +114,12 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
110114
* expression syntax
111115
* @param userTextAdvise The user text to append to the existing user prompt. The text
112116
* should contain a placeholder named "question_answer_context".
117+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
113118
*/
119+
@Deprecated
114120
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
115-
this(vectorStore, searchRequest, userTextAdvise, true);
121+
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), true,
122+
DEFAULT_ORDER);
116123
}
117124

118125
/**
@@ -127,10 +134,13 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
127134
* blocking threads. If false the advisor will not protect the execution from blocking
128135
* threads. This is useful when the advisor is used in a non-blocking environment. It
129136
* is true by default.
137+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
130138
*/
139+
@Deprecated
131140
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
132141
boolean protectFromBlocking) {
133-
this(vectorStore, searchRequest, userTextAdvise, protectFromBlocking, DEFAULT_ORDER);
142+
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
143+
DEFAULT_ORDER);
134144
}
135145

136146
/**
@@ -146,17 +156,23 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
146156
* threads. This is useful when the advisor is used in a non-blocking environment. It
147157
* is true by default.
148158
* @param order The order of the advisor.
159+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
149160
*/
161+
@Deprecated
150162
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
151163
boolean protectFromBlocking, int order) {
164+
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
165+
order);
166+
}
152167

153-
Assert.notNull(vectorStore, "The vectorStore must not be null!");
154-
Assert.notNull(searchRequest, "The searchRequest must not be null!");
155-
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
168+
QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, @Nullable PromptTemplate promptTemplate,
169+
boolean protectFromBlocking, int order) {
170+
Assert.notNull(vectorStore, "vectorStore cannot be null");
171+
Assert.notNull(searchRequest, "searchRequest cannot be null");
156172

157173
this.vectorStore = vectorStore;
158174
this.searchRequest = searchRequest;
159-
this.userTextAdvise = userTextAdvise;
175+
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
160176
this.protectFromBlocking = protectFromBlocking;
161177
this.order = order;
162178
}
@@ -212,10 +228,7 @@ private AdvisedRequest before(AdvisedRequest request) {
212228

213229
var context = new HashMap<>(request.adviseContext());
214230

215-
// 1. Advise the system text.
216-
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
217-
218-
// 2. Search for similar documents in the vector store.
231+
// 1. Search for similar documents in the vector store.
219232
String query = new PromptTemplate(request.userText(), request.userParams()).render();
220233
var searchRequestToUse = SearchRequest.from(this.searchRequest)
221234
.query(query)
@@ -224,20 +237,21 @@ private AdvisedRequest before(AdvisedRequest request) {
224237

225238
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
226239

227-
// 3. Create the context from the documents.
240+
// 2. Create the context from the documents.
228241
context.put(RETRIEVED_DOCUMENTS, documents);
229242

230243
String documentContext = documents.stream()
231244
.map(Document::getText)
232245
.collect(Collectors.joining(System.lineSeparator()));
233246

234-
// 4. Advise the user parameters.
235-
Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
236-
advisedUserParams.put("question_answer_context", documentContext);
247+
// 3. Augment the user prompt with the document context.
248+
String augmentedUserText = this.promptTemplate.mutate()
249+
.template(request.userText() + System.lineSeparator() + this.promptTemplate.getTemplate())
250+
.build()
251+
.render(Map.of("question_answer_context", documentContext));
237252

238253
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
239-
.userText(advisedUserText)
240-
.userParams(advisedUserParams)
254+
.userText(augmentedUserText)
241255
.adviseContext(context)
242256
.build();
243257

@@ -266,7 +280,7 @@ public static final class Builder {
266280

267281
private SearchRequest searchRequest = SearchRequest.builder().build();
268282

269-
private String userTextAdvise = DEFAULT_USER_TEXT_ADVISE;
283+
private PromptTemplate promptTemplate;
270284

271285
private boolean protectFromBlocking = true;
272286

@@ -283,9 +297,15 @@ public Builder searchRequest(SearchRequest searchRequest) {
283297
return this;
284298
}
285299

300+
public Builder promptTemplate(PromptTemplate promptTemplate) {
301+
Assert.notNull(promptTemplate, "promptTemplate cannot be null");
302+
this.promptTemplate = promptTemplate;
303+
return this;
304+
}
305+
286306
public Builder userTextAdvise(String userTextAdvise) {
287307
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
288-
this.userTextAdvise = userTextAdvise;
308+
this.promptTemplate = PromptTemplate.builder().template(userTextAdvise).build();
289309
return this;
290310
}
291311

@@ -300,7 +320,7 @@ public Builder order(int order) {
300320
}
301321

302322
public QuestionAnswerAdvisor build() {
303-
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise,
323+
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate,
304324
this.protectFromBlocking, this.order);
305325
}
306326

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.ai.chat.model.ChatResponse;
3333
import org.springframework.ai.chat.prompt.ChatOptions;
3434
import org.springframework.ai.chat.prompt.Prompt;
35+
import org.springframework.ai.chat.prompt.template.TemplateRenderer;
3536
import org.springframework.ai.content.Media;
3637
import org.springframework.ai.converter.StructuredOutputConverter;
3738
import org.springframework.ai.model.function.FunctionCallback;
@@ -254,6 +255,8 @@ interface ChatClientRequestSpec {
254255

255256
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
256257

258+
ChatClientRequestSpec promptTemplateRenderer(TemplateRenderer templateRenderer);
259+
257260
CallResponseSpec call();
258261

259262
StreamResponseSpec stream();
@@ -313,6 +316,8 @@ interface Builder {
313316

314317
Builder defaultToolContext(Map<String, Object> toolContext);
315318

319+
Builder defaultPromptTemplateRenderer(TemplateRenderer templateRenderer);
320+
316321
Builder clone();
317322

318323
ChatClient build();

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

+26-6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
import org.springframework.ai.chat.model.StreamingChatModel;
6060
import org.springframework.ai.chat.prompt.ChatOptions;
6161
import org.springframework.ai.chat.prompt.Prompt;
62+
import org.springframework.ai.chat.prompt.template.TemplateRenderer;
63+
import org.springframework.ai.chat.prompt.template.st.StTemplateRenderer;
6264
import org.springframework.ai.content.Media;
6365
import org.springframework.ai.converter.BeanOutputConverter;
6466
import org.springframework.ai.converter.StructuredOutputConverter;
@@ -88,6 +90,8 @@ public class DefaultChatClient implements ChatClient {
8890

8991
private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION = new DefaultChatClientObservationConvention();
9092

93+
private static final TemplateRenderer DEFAULT_PROMPT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build();
94+
9195
private final DefaultChatClientRequestSpec defaultChatClientRequest;
9296

9397
public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) {
@@ -138,7 +142,7 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise
138142
advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(),
139143
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(),
140144
advisedRequest.advisorParams(), observationRegistry, customObservationConvention,
141-
advisedRequest.toolContext());
145+
advisedRequest.toolContext(), null);
142146
}
143147

144148
@Override
@@ -666,6 +670,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
666670

667671
private final Map<String, Object> toolContext = new HashMap<>();
668672

673+
private TemplateRenderer promptTemplateRenderer;
674+
669675
@Nullable
670676
private String userText;
671677

@@ -679,15 +685,16 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
679685
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
680686
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
681687
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
682-
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext);
688+
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.promptTemplateRenderer);
683689
}
684690

685691
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
686692
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
687693
List<FunctionCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
688694
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
689695
ObservationRegistry observationRegistry,
690-
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext) {
696+
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
697+
@Nullable TemplateRenderer promptTemplateRenderer) {
691698

692699
Assert.notNull(chatModel, "chatModel cannot be null");
693700
Assert.notNull(userParams, "userParams cannot be null");
@@ -720,6 +727,8 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
720727
this.observationConvention = observationConvention != null ? observationConvention
721728
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
722729
this.toolContext.putAll(toolContext);
730+
this.promptTemplateRenderer = promptTemplateRenderer != null ? promptTemplateRenderer
731+
: DEFAULT_PROMPT_TEMPLATE_RENDERER;
723732

724733
// At the stack bottom add the model call advisors.
725734
// They play the role of the last advisors in the advisor chain.
@@ -789,6 +798,10 @@ public Map<String, Object> getToolContext() {
789798
return this.toolContext;
790799
}
791800

801+
public TemplateRenderer getPromptTemplateRenderer() {
802+
return this.promptTemplateRenderer;
803+
}
804+
792805
/**
793806
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
794807
* settings are replicated from this {@code ChatClientRequest}.
@@ -997,15 +1010,22 @@ public ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer) {
9971010
return this;
9981011
}
9991012

1013+
public ChatClientRequestSpec promptTemplateRenderer(TemplateRenderer templateRenderer) {
1014+
Assert.notNull(templateRenderer, "templateRenderer cannot be null");
1015+
this.promptTemplateRenderer = templateRenderer;
1016+
return this;
1017+
}
1018+
10001019
public CallResponseSpec call() {
10011020
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
1002-
return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
1003-
observationRegistry, observationConvention);
1021+
return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(this.promptTemplateRenderer),
1022+
advisorChain, observationRegistry, observationConvention);
10041023
}
10051024

10061025
public StreamResponseSpec stream() {
10071026
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
1008-
return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
1027+
return new DefaultStreamResponseSpec(
1028+
toAdvisedRequest(this).toChatClientRequest(this.promptTemplateRenderer), advisorChain,
10091029
observationRegistry, observationConvention);
10101030
}
10111031

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.chat.model.ChatModel;
3535
import org.springframework.ai.chat.model.ToolContext;
3636
import org.springframework.ai.chat.prompt.ChatOptions;
37+
import org.springframework.ai.chat.prompt.template.TemplateRenderer;
3738
import org.springframework.ai.model.function.FunctionCallback;
3839
import org.springframework.ai.tool.ToolCallback;
3940
import org.springframework.ai.tool.ToolCallbackProvider;
@@ -67,7 +68,7 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
6768
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
6869
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(),
6970
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
70-
customObservationConvention, Map.of());
71+
customObservationConvention, Map.of(), null);
7172
}
7273

7374
public ChatClient build() {
@@ -212,6 +213,12 @@ public Builder defaultToolContext(Map<String, Object> toolContext) {
212213
return this;
213214
}
214215

216+
public Builder defaultPromptTemplateRenderer(TemplateRenderer templateRenderer) {
217+
Assert.notNull(templateRenderer, "templateRenderer cannot be null");
218+
this.defaultRequest.promptTemplateRenderer(templateRenderer);
219+
return this;
220+
}
221+
215222
void addMessages(List<Message> messages) {
216223
this.defaultRequest.messages(messages);
217224
}

0 commit comments

Comments
 (0)