Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to use a LangChain4j CustomMessage with Ollama #1255

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static String extractDialogue(List<ChatMessage> chatMessages, String userPrefix,
joiner.add("%s%s".formatted(assistantPrefix, aiMessage.text()));
}
case USER -> joiner.add("%s%s".formatted(userPrefix, chatMessage.text()));
case SYSTEM, TOOL_EXECUTION_RESULT -> {
case SYSTEM, TOOL_EXECUTION_RESULT, CUSTOM -> {
continue;
}
default -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.Map;

import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.CustomMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.test.QuarkusUnitTest;

public class OllamaCustomMessageTest extends WiremockAware {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false")
.overrideConfigKey("quarkus.langchain4j.ollama.chat-model.model-name", "granite3-guardian")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true");

@Inject
ChatLanguageModel chatLanguageModel;

@Test
void extract() {
wiremock().register(
post(urlEqualTo("/api/chat"))
.withRequestBody(equalToJson(
"""
{
"model": "granite3-guardian",
"messages": [
{
"role": "system",
"content": "context_relevance"
},
{
"role": "user",
"content": "What is the history of treaty making?"
},
{
"role": "context",
"content": "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding."
}
],
"options": {
"temperature": 0.8,
"top_k": 40,
"top_p": 0.9
},
"stream": false
}"""))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"model": "granite3-guardian",
"created_at": "2025-01-28T15:21:23.422542932Z",
"message": {
"role": "assistant",
"content": "Yes"
},
"done_reason": "stop",
"done": true,
"total_duration": 8125806496,
"load_duration": 4223887064,
"prompt_eval_count": 31,
"prompt_eval_duration": 1331000000,
"eval_count": 2,
"eval_duration": 2569000000
}""")));

String retrievedContext = "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding.";

List<ChatMessage> messages = List.of(
SystemMessage.from("context_relevance"),
UserMessage.from("What is the history of treaty making?"),
CustomMessage.from(Map.of("role", "context", "content", retrievedContext)));

ChatResponse chatResponse = chatLanguageModel.chat(ChatRequest.builder().messages(messages).build());
assertThat(chatResponse.aiMessage().text()).isEqualTo("Yes");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ public record ChatResponse(String model, String createdAt, Message message, Bool
Integer evalCount) {

public static ChatResponse emptyNotDone() {
return new ChatResponse(null, null, new Message(Role.ASSISTANT, "", Collections.emptyList(), Collections.emptyList()),
return new ChatResponse(null, null,
new Message(Role.ASSISTANT, "", Collections.emptyList(), Collections.emptyList(), null),
true, null, null);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
package io.quarkiverse.langchain4j.ollama;

import java.util.List;
import java.util.Map;

public record Message(Role role, String content, List<ToolCall> toolCalls, List<String> images) {
import com.fasterxml.jackson.annotation.JsonAnyGetter;
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonIgnore;

public record Message(Role role, String content, List<ToolCall> toolCalls, List<String> images,
@JsonIgnore Map<String, Object> additionalFields) {

@JsonAnyGetter
public Map<String, Object> getAdditionalFields() {
return additionalFields;
}

public static Builder builder() {
return new Builder();
Expand All @@ -13,6 +24,7 @@ public static class Builder {
private String content;
private List<ToolCall> toolCalls;
private List<String> images;
private Map<String, Object> additionalFields;

public Builder role(Role role) {
this.role = role;
Expand All @@ -34,8 +46,14 @@ public Builder images(List<String> images) {
return this;
}

@JsonAnySetter
public Builder additionalFields(Map<String, Object> additionalFields) {
this.additionalFields = additionalFields;
return this;
}

public Message build() {
return new Message(role, content, toolCalls, images);
return new Message(role, content, toolCalls, images, additionalFields);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.CustomMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
Expand Down Expand Up @@ -114,6 +115,12 @@ private static Message otherMessages(ChatMessage message) {
.build();
}

if (message instanceof CustomMessage customMessage) {
return Message.builder()
.additionalFields(customMessage.attributes())
.build();
}

return Message.builder()
.role(toOllamaRole(message.type()))
.content(message.text())
Expand Down
Loading