From 5c2d188d59ce6a3a182097b944b943e3b4f647ce Mon Sep 17 00:00:00 2001 From: Alexandros Pappas Date: Fri, 21 Mar 2025 11:23:36 +0100 Subject: [PATCH] fix: ensure system role is first in advised request messages Start with existing conversation messages, if present Signed-off-by: Alexandros Pappas --- .../client/advisor/api/AdvisedRequest.java | 18 ++-- .../advisor/MessageChatMemoryAdvisorIT.java | 82 +++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/MessageChatMemoryAdvisorIT.java diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 6d58b77ed4b..f8d328b53de 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -60,6 +60,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0.0 */ public record AdvisedRequest( @@ -147,18 +148,24 @@ public AdvisedRequest updateContext(Function, Map(this.messages()); + List promptMessages = new ArrayList<>(); + // 1. Start with existing conversation messages, if present + if (!CollectionUtils.isEmpty(this.messages())) { + promptMessages.addAll(this.messages()); + } + + // 2. Process the new SystemMessage, if present String processedSystemText = this.systemText(); if (StringUtils.hasText(processedSystemText)) { if (!CollectionUtils.isEmpty(this.systemParams())) { processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); } - messages.add(new SystemMessage(processedSystemText)); + promptMessages.add(new SystemMessage(processedSystemText)); } + // 3. Process the new UserMessage, if present String formatParam = (String) this.adviseContext().get("formatParam"); - var processedUserText = StringUtils.hasText(formatParam) ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); @@ -170,9 +177,10 @@ public Prompt toPrompt() { if (!CollectionUtils.isEmpty(userParams)) { processedUserText = new PromptTemplate(processedUserText, userParams).render(); } - messages.add(new UserMessage(processedUserText, this.media())); + promptMessages.add(new UserMessage(processedUserText, this.media())); } + // 4. Configure function-calling options, if applicable if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { if (!this.functionNames().isEmpty()) { functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); @@ -185,7 +193,7 @@ public Prompt toPrompt() { } } - return new Prompt(messages, this.chatOptions()); + return new Prompt(promptMessages, this.chatOptions()); } /** diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/MessageChatMemoryAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/MessageChatMemoryAdvisorIT.java new file mode 100644 index 00000000000..e7b127c9ec1 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/MessageChatMemoryAdvisorIT.java @@ -0,0 +1,82 @@ +package org.springframework.ai.integration.tests.client.advisor; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link MessageChatMemoryAdvisor}. + * + * @author Alexandros Pappas + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class MessageChatMemoryAdvisorIT { + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void chatMemoryStoresAndRecallsConversation() { + var chatMemory = new InMemoryChatMemory(); + var conversationId = "test-conversation"; + + var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build(); + + var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build(); + + // First interaction + ChatResponse response1 = chatClient.prompt().user("Hello, my name is John.").call().chatResponse(); + + assertThat(response1).isNotNull(); + String assistantReply1 = response1.getResult().getOutput().getText(); + System.out.println("Assistant reply 1: " + assistantReply1); + + // Second interaction - Verify memory recall + ChatResponse response2 = chatClient.prompt().user("What is my name?").call().chatResponse(); + + assertThat(response2).isNotNull(); + String assistantReply2 = response2.getResult().getOutput().getText(); + System.out.println("Assistant reply 2: " + assistantReply2); + + assertThat(assistantReply2.toLowerCase()).contains("john"); + } + + @Test + void separateConversationsDoNotMixMemory() { + var chatMemory = new InMemoryChatMemory(); + + var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).build(); + + var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build(); + + // First conversation + chatClient.prompt() + .user("Remember my secret code is blue.") + .advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-1")) + .call(); + + // Second conversation + ChatResponse response = chatClient.prompt() + .user("Do you remember my secret code?") + .advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-2")) + .call() + .chatResponse(); + + assertThat(response).isNotNull(); + String assistantReply = response.getResult().getOutput().getText(); + System.out.println("Assistant reply: " + assistantReply); + + assertThat(assistantReply.toLowerCase()).doesNotContain("blue"); + } + +}