Skip to content

Commit f573c4f

Browse files
committed
fix: ensure system role is first in advised request messages
Start with existing conversation messages, if present Signed-off-by: Alexandros Pappas <[email protected]>
1 parent ef0a202 commit f573c4f

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java

+13-5
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
* @author Christian Tzolov
6161
* @author Thomas Vitale
6262
* @author Ilayaperumal Gopinathan
63+
* @author Alexandros Pappas
6364
* @since 1.0.0
6465
*/
6566
public record AdvisedRequest(
@@ -147,18 +148,24 @@ public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Ob
147148
}
148149

149150
public Prompt toPrompt() {
150-
var messages = new ArrayList<>(this.messages());
151+
List<Message> promptMessages = new ArrayList<>();
151152

153+
// 1. Start with existing conversation messages, if present
154+
if (!CollectionUtils.isEmpty(this.messages())) {
155+
promptMessages.addAll(this.messages());
156+
}
157+
158+
// 2. Process the new SystemMessage, if present
152159
String processedSystemText = this.systemText();
153160
if (StringUtils.hasText(processedSystemText)) {
154161
if (!CollectionUtils.isEmpty(this.systemParams())) {
155162
processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render();
156163
}
157-
messages.add(new SystemMessage(processedSystemText));
164+
promptMessages.add(new SystemMessage(processedSystemText));
158165
}
159166

167+
// 3. Process the new UserMessage, if present
160168
String formatParam = (String) this.adviseContext().get("formatParam");
161-
162169
var processedUserText = StringUtils.hasText(formatParam)
163170
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();
164171

@@ -170,9 +177,10 @@ public Prompt toPrompt() {
170177
if (!CollectionUtils.isEmpty(userParams)) {
171178
processedUserText = new PromptTemplate(processedUserText, userParams).render();
172179
}
173-
messages.add(new UserMessage(processedUserText, this.media()));
180+
promptMessages.add(new UserMessage(processedUserText, this.media()));
174181
}
175182

183+
// 4. Configure function-calling options, if applicable
176184
if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
177185
if (!this.functionNames().isEmpty()) {
178186
functionCallingOptions.setFunctions(new HashSet<>(this.functionNames()));
@@ -185,7 +193,7 @@ public Prompt toPrompt() {
185193
}
186194
}
187195

188-
return new Prompt(messages, this.chatOptions());
196+
return new Prompt(promptMessages, chatOptions);
189197
}
190198

191199
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package org.springframework.ai.integration.tests.client.advisor;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
5+
import org.springframework.ai.chat.client.ChatClient;
6+
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
7+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
8+
import org.springframework.ai.chat.memory.InMemoryChatMemory;
9+
import org.springframework.ai.chat.model.ChatResponse;
10+
import org.springframework.ai.openai.OpenAiChatModel;
11+
import org.springframework.beans.factory.annotation.Autowired;
12+
import org.springframework.boot.test.context.SpringBootTest;
13+
14+
import static org.assertj.core.api.Assertions.assertThat;
15+
16+
/**
17+
* Integration tests for {@link MessageChatMemoryAdvisor}.
18+
*
19+
* @author Alexandros Pappas
20+
*/
21+
@SpringBootTest
22+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
23+
class MessageChatMemoryAdvisorIT {
24+
25+
@Autowired
26+
OpenAiChatModel openAiChatModel;
27+
28+
@Test
29+
void chatMemoryStoresAndRecallsConversation() {
30+
var chatMemory = new InMemoryChatMemory();
31+
var conversationId = "test-conversation";
32+
33+
var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build();
34+
35+
var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build();
36+
37+
// First interaction
38+
ChatResponse response1 = chatClient.prompt().user("Hello, my name is John.").call().chatResponse();
39+
40+
assertThat(response1).isNotNull();
41+
String assistantReply1 = response1.getResult().getOutput().getText();
42+
System.out.println("Assistant reply 1: " + assistantReply1);
43+
44+
// Second interaction - Verify memory recall
45+
ChatResponse response2 = chatClient.prompt().user("What is my name?").call().chatResponse();
46+
47+
assertThat(response2).isNotNull();
48+
String assistantReply2 = response2.getResult().getOutput().getText();
49+
System.out.println("Assistant reply 2: " + assistantReply2);
50+
51+
assertThat(assistantReply2.toLowerCase()).contains("john");
52+
}
53+
54+
@Test
55+
void separateConversationsDoNotMixMemory() {
56+
var chatMemory = new InMemoryChatMemory();
57+
58+
var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
59+
60+
var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build();
61+
62+
// First conversation
63+
chatClient.prompt()
64+
.user("Remember my secret code is blue.")
65+
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-1"))
66+
.call();
67+
68+
// Second conversation
69+
ChatResponse response = chatClient.prompt()
70+
.user("Do you remember my secret code?")
71+
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-2"))
72+
.call()
73+
.chatResponse();
74+
75+
assertThat(response).isNotNull();
76+
String assistantReply = response.getResult().getOutput().getText();
77+
System.out.println("Assistant reply: " + assistantReply);
78+
79+
assertThat(assistantReply.toLowerCase()).doesNotContain("blue");
80+
}
81+
82+
}

0 commit comments

Comments
 (0)