Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure system role is first in advised request messages #2541

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
* @author Christian Tzolov
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @author Alexandros Pappas
* @since 1.0.0
*/
public record AdvisedRequest(
Expand Down Expand Up @@ -147,18 +148,24 @@ public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Ob
}

public Prompt toPrompt() {
var messages = new ArrayList<>(this.messages());
List<Message> promptMessages = new ArrayList<>();

// 1. Start with existing conversation messages, if present
if (!CollectionUtils.isEmpty(this.messages())) {
promptMessages.addAll(this.messages());
}

// 2. Process the new SystemMessage, if present
String processedSystemText = this.systemText();
if (StringUtils.hasText(processedSystemText)) {
if (!CollectionUtils.isEmpty(this.systemParams())) {
processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render();
}
messages.add(new SystemMessage(processedSystemText));
promptMessages.add(new SystemMessage(processedSystemText));
}

// 3. Process the new UserMessage, if present
String formatParam = (String) this.adviseContext().get("formatParam");

var processedUserText = StringUtils.hasText(formatParam)
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();

Expand All @@ -170,9 +177,10 @@ public Prompt toPrompt() {
if (!CollectionUtils.isEmpty(userParams)) {
processedUserText = new PromptTemplate(processedUserText, userParams).render();
}
messages.add(new UserMessage(processedUserText, this.media()));
promptMessages.add(new UserMessage(processedUserText, this.media()));
}

// 4. Configure function-calling options, if applicable
if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
if (!this.functionNames().isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(this.functionNames()));
Expand All @@ -185,7 +193,7 @@ public Prompt toPrompt() {
}
}

return new Prompt(messages, this.chatOptions());
return new Prompt(promptMessages, this.chatOptions());
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package org.springframework.ai.integration.tests.client.advisor;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for {@link MessageChatMemoryAdvisor}.
*
* @author Alexandros Pappas
*/
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
class MessageChatMemoryAdvisorIT {

@Autowired
OpenAiChatModel openAiChatModel;

@Test
void chatMemoryStoresAndRecallsConversation() {
var chatMemory = new InMemoryChatMemory();
var conversationId = "test-conversation";

var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build();

var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build();

// First interaction
ChatResponse response1 = chatClient.prompt().user("Hello, my name is John.").call().chatResponse();

assertThat(response1).isNotNull();
String assistantReply1 = response1.getResult().getOutput().getText();
System.out.println("Assistant reply 1: " + assistantReply1);

// Second interaction - Verify memory recall
ChatResponse response2 = chatClient.prompt().user("What is my name?").call().chatResponse();

assertThat(response2).isNotNull();
String assistantReply2 = response2.getResult().getOutput().getText();
System.out.println("Assistant reply 2: " + assistantReply2);

assertThat(assistantReply2.toLowerCase()).contains("john");
}

@Test
void separateConversationsDoNotMixMemory() {
var chatMemory = new InMemoryChatMemory();

var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).build();

var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build();

// First conversation
chatClient.prompt()
.user("Remember my secret code is blue.")
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-1"))
.call();

// Second conversation
ChatResponse response = chatClient.prompt()
.user("Do you remember my secret code?")
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-2"))
.call()
.chatResponse();

assertThat(response).isNotNull();
String assistantReply = response.getResult().getOutput().getText();
System.out.println("Assistant reply: " + assistantReply);

assertThat(assistantReply.toLowerCase()).doesNotContain("blue");
}

}