diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 038d8da1382..6a8d5525faf 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -156,6 +156,8 @@ interface CallResponseSpec { @Nullable T entity(Class type); + ChatClientResponse chatClientResponse(); + @Nullable ChatResponse chatResponse(); @@ -172,6 +174,8 @@ interface CallResponseSpec { interface StreamResponseSpec { + Flux chatClientResponse(); + Flux chatResponse(); Flux content(); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java new file mode 100644 index 00000000000..e87078ce4a2 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +/** + * Common attributes used in {@link ChatClient} context. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public enum ChatClientAttributes { + + //@formatter:off + + @Deprecated // Only for backward compatibility until the next release. + ADVISORS("spring.ai.chat.client.advisors"), + @Deprecated // Only for backward compatibility until the next release. + CHAT_MODEL("spring.ai.chat.client.model"), + @Deprecated // Only for backward compatibility until the next release. + OUTPUT_FORMAT("spring.ai.chat.client.output.format"), + @Deprecated // Only for backward compatibility until the next release. + USER_PARAMS("spring.ai.chat.client.user.params"), + @Deprecated // Only for backward compatibility until the next release. + SYSTEM_PARAMS("spring.ai.chat.client.system.params"); + + //@formatter:on + + private final String key; + + ChatClientAttributes(String key) { + this.key = key; + } + + public String getKey() { + return key; + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java new file mode 100644 index 00000000000..c7d60b00b93 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.util.Assert; + +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a request processed by a {@link ChatClient} that ultimately is used to build + * a {@link Prompt} to be sent to an AI model. + * + * @param prompt The prompt to be sent to the AI model + * @param context The contextual data through the execution chain + * @author Thomas Vitale + * @since 1.0.0 + */ +public record ChatClientRequest(Prompt prompt, Map context) { + + public ChatClientRequest { + Assert.notNull(prompt, "prompt cannot be null"); + Assert.notNull(context, "context cannot be null"); + Assert.noNullElements(context.keySet(), "context keys cannot be null"); + } + + public Builder mutate() { + return new Builder().prompt(this.prompt).context(this.context); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private Prompt prompt; + + private Map context = new HashMap<>(); + + private Builder() { + } + + public Builder prompt(Prompt prompt) { + Assert.notNull(prompt, "prompt cannot be null"); + this.prompt = prompt; + return this; + } + + public Builder context(Map context) { + Assert.notNull(context, "context cannot be null"); + this.context.putAll(context); + return this; + } + + public Builder context(String key, Object value) { + Assert.notNull(key, "key cannot be null"); + this.context.put(key, value); + return this; + } + + public ChatClientRequest build() { + return new ChatClientRequest(prompt, context); + } + + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java new file mode 100644 index 00000000000..eae44a51728 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a response returned by a {@link ChatClient}. + * + * @param chatResponse The response returned by the AI model + * @param context The contextual data propagated through the execution chain + * @author Thomas Vitale + * @since 1.0.0 + */ +public record ChatClientResponse(@Nullable ChatResponse chatResponse, Map context) { + + public ChatClientResponse { + Assert.notNull(context, "context cannot be null"); + Assert.noNullElements(context.keySet(), "context keys cannot be null"); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ChatResponse chatResponse; + + private Map context = new HashMap<>(); + + private Builder() { + } + + public Builder chatResponse(ChatResponse chatResponse) { + this.chatResponse = chatResponse; + return this; + } + + public Builder context(Map context) { + Assert.notNull(context, "context cannot be null"); + this.context.putAll(context); + return this; + } + + public Builder context(String key, Object value) { + Assert.notNull(key, "key cannot be null"); + this.context.put(key, value); + return this; + } + + public ChatClientResponse build() { + return new ChatClientResponse(this.chatResponse, this.context); + } + + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index fb4d3eb35dc..5e4abd074c4 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -33,17 +32,19 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; -import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.ChatModelCallAdvisor; +import org.springframework.ai.chat.client.advisor.ChatModelStreamAdvisor; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.lang.NonNull; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation; @@ -62,10 +63,6 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.ToolCallbackProvider; -import org.springframework.ai.tool.ToolCallbacks; -import org.springframework.core.Ordered; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; @@ -98,14 +95,10 @@ public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) this.defaultChatClientRequest = defaultChatClientRequest; } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, - @Nullable String formatParam) { + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { Assert.notNull(inputRequest, "inputRequest cannot be null"); Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); - if (StringUtils.hasText(formatParam)) { - advisorContext.put("formatParam", formatParam); - } // Process userText, media and messages before creating the AdvisedRequest. String userText = inputRequest.userText; @@ -131,11 +124,12 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu } return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions, - media, inputRequest.functionNames, inputRequest.functionCallbacks, messages, inputRequest.userParams, + media, inputRequest.toolNames, inputRequest.toolCallbacks, messages, inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, inputRequest.toolContext); } + @Deprecated public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { @@ -391,11 +385,25 @@ public Map getParams() { public static class DefaultCallResponseSpec implements CallResponseSpec { - private final DefaultChatClientRequestSpec request; + private final ChatClientRequest request; - public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) { - Assert.notNull(request, "request cannot be null"); - this.request = request; + private final BaseAdvisorChain advisorChain; + + private final ObservationRegistry observationRegistry; + + private final ChatClientObservationConvention observationConvention; + + public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, + ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { + Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); + Assert.notNull(advisorChain, "advisorChain cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(observationConvention, "observationConvention cannot be null"); + + this.request = chatClientRequest; + this.advisorChain = advisorChain; + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; } @Override @@ -419,7 +427,8 @@ public ResponseEntity responseEntity( protected ResponseEntity doResponseEntity(StructuredOutputConverter outputConverter) { Assert.notNull(outputConverter, "structuredOutputConverter cannot be null"); - var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat()); + var chatResponse = doGetObservableChatClientResponse(this.request, outputConverter.getFormat()) + .chatResponse(); var responseContent = getContentFromChatResponse(chatResponse); if (responseContent == null) { return new ResponseEntity<>(chatResponse, null); @@ -452,7 +461,8 @@ public T entity(Class type) { @Nullable private T doSingleWithBeanOutputConverter(StructuredOutputConverter outputConverter) { - var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat()); + var chatResponse = doGetObservableChatClientResponse(this.request, outputConverter.getFormat()) + .chatResponse(); var stringResponse = getContentFromChatResponse(chatResponse); if (stringResponse == null) { return null; @@ -460,38 +470,85 @@ private T doSingleWithBeanOutputConverter(StructuredOutputConverter outpu return outputConverter.convert(stringResponse); } + @Override + public ChatClientResponse chatClientResponse() { + return doGetObservableChatClientResponse(this.request); + } + + @Override @Nullable - private ChatResponse doGetChatResponse() { - return this.doGetObservableChatResponse(this.request, null); + public ChatResponse chatResponse() { + return doGetObservableChatClientResponse(this.request).chatResponse(); } + @Override @Nullable - private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec inputRequest, - @Nullable String formatParam) { + public String content() { + ChatResponse chatResponse = doGetObservableChatClientResponse(this.request).chatResponse(); + return getContentFromChatResponse(chatResponse); + } + + private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest) { + return doGetObservableChatClientResponse(chatClientRequest, null); + } + + private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, + @Nullable String outputFormat) { + ChatClientRequest formattedChatClientRequest = StringUtils.hasText(outputFormat) + ? addFormatInstructionsToPrompt(chatClientRequest, outputFormat) : chatClientRequest; ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(inputRequest) - .withFormat(formatParam) - .withStream(false) + .request(formattedChatClientRequest) + .stream(false) + .withFormat(outputFormat) .build(); - var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( - inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, - () -> observationContext, inputRequest.getObservationRegistry()); - return observation.observe(() -> doGetChatResponse(inputRequest, formatParam, observation)); + var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(observationConvention, + DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry); + var chatClientResponse = observation.observe(() -> { + // Apply the advisor chain that terminates with the ChatModelCallAdvisor. + return advisorChain.nextCall(formattedChatClientRequest); + }); + return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); } - private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, - @Nullable String formatParam, Observation parentObservation) { + @NonNull + private static ChatClientRequest addFormatInstructionsToPrompt(ChatClientRequest chatClientRequest, + String outputFormat) { + List originalMessages = chatClientRequest.prompt().getInstructions(); + + if (CollectionUtils.isEmpty(originalMessages)) { + return chatClientRequest; + } + + // Create a copy of the message list to avoid modifying the original. + List modifiedMessages = new ArrayList<>(originalMessages); - AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam); + // Get the last message (without removing it from original list) + Message lastMessage = modifiedMessages.get(modifiedMessages.size() - 1); - // Apply the around advisor chain that terminates with the last model call - // advisor. - AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build() - .nextAroundCall(advisedRequest); + // If the last message is a UserMessage, replace it with the modified version + if (lastMessage instanceof UserMessage userMessage) { + // Remove last message + modifiedMessages.remove(modifiedMessages.size() - 1); - return advisedResponse.response(); + // Create new user message with format instructions + UserMessage userMessageWithFormat = userMessage.mutate() + .text(userMessage.getText() + System.lineSeparator() + outputFormat) + .build(); + + // Add modified message back + modifiedMessages.add(userMessageWithFormat); + + // Build new ChatClientRequest preserving all properties but with modified + // prompt + return ChatClientRequest.builder() + .prompt(chatClientRequest.prompt().mutate().messages(modifiedMessages).build()) + .context(Map.copyOf(chatClientRequest.context())) + .build(); + } + + return chatClientRequest; } @Nullable @@ -503,53 +560,49 @@ private static String getContentFromChatResponse(@Nullable ChatResponse chatResp .orElse(null); } - @Override - @Nullable - public ChatResponse chatResponse() { - return doGetChatResponse(); - } - - @Override - @Nullable - public String content() { - ChatResponse chatResponse = doGetChatResponse(); - return getContentFromChatResponse(chatResponse); - } - } public static class DefaultStreamResponseSpec implements StreamResponseSpec { - private final DefaultChatClientRequestSpec request; + private final ChatClientRequest request; + + private final BaseAdvisorChain advisorChain; - public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) { - Assert.notNull(request, "request cannot be null"); - this.request = request; + private final ObservationRegistry observationRegistry; + + private final ChatClientObservationConvention observationConvention; + + public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, + ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { + Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); + Assert.notNull(advisorChain, "advisorChain cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(observationConvention, "observationConvention cannot be null"); + + this.request = chatClientRequest; + this.advisorChain = advisorChain; + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; } - private Flux doGetObservableFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { + private Flux doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) { return Flux.deferContextual(contextView -> { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(inputRequest) - .withStream(true) + .request(chatClientRequest) + .stream(true) .build(); Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( - inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, - () -> observationContext, inputRequest.getObservationRegistry()); + observationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, + observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) .start(); - var initialAdvisedRequest = toAdvisedRequest(inputRequest, null); - // @formatter:off - // Apply the around advisor chain that terminates with the last model call advisor. - Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); - - return stream - .map(AdvisedResponse::response) + // Apply the advisor chain that terminates with the ChatModelStreamAdvisor. + return advisorChain.nextStream(chatClientRequest) .doOnError(observation::error) .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); @@ -558,19 +611,29 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ } @Override - public Flux chatResponse() { + public Flux chatClientResponse() { return doGetObservableFluxChatResponse(this.request); } + @Override + public Flux chatResponse() { + return doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse); + } + @Override public Flux content() { - return doGetObservableFluxChatResponse(this.request).map(r -> { - if (r.getResult() == null || r.getResult().getOutput() == null - || r.getResult().getOutput().getText() == null) { - return ""; - } - return r.getResult().getOutput().getText(); - }).filter(StringUtils::hasLength); + // @formatter:off + return doGetObservableFluxChatResponse(this.request) + .mapNotNull(ChatClientResponse::chatResponse) + .map(r -> { + if (r.getResult() == null || r.getResult().getOutput() == null + || r.getResult().getOutput().getText() == null) { + return ""; + } + return r.getResult().getOutput().getText(); + }) + .filter(StringUtils::hasLength); + // @formatter:on } } @@ -579,15 +642,15 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final ObservationRegistry observationRegistry; - private final ChatClientObservationConvention customObservationConvention; + private final ChatClientObservationConvention observationConvention; private final ChatModel chatModel; private final List media = new ArrayList<>(); - private final List functionNames = new ArrayList<>(); + private final List toolNames = new ArrayList<>(); - private final List functionCallbacks = new ArrayList<>(); + private final List toolCallbacks = new ArrayList<>(); private final List messages = new ArrayList<>(); @@ -614,25 +677,24 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { - this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, - ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, - ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext); + this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks, + ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, + ccr.observationRegistry, ccr.observationConvention, ccr.toolContext); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, - List functionCallbacks, List messages, List functionNames, - List media, @Nullable ChatOptions chatOptions, List advisors, - Map advisorParams, ObservationRegistry observationRegistry, - @Nullable ChatClientObservationConvention customObservationConvention, - Map toolContext) { + List toolCallbacks, List messages, List toolNames, List media, + @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, + ObservationRegistry observationRegistry, + @Nullable ChatClientObservationConvention observationConvention, Map toolContext) { Assert.notNull(chatModel, "chatModel cannot be null"); Assert.notNull(userParams, "userParams cannot be null"); Assert.notNull(systemParams, "systemParams cannot be null"); - Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.notNull(messages, "messages cannot be null"); - Assert.notNull(functionNames, "functionNames cannot be null"); + Assert.notNull(toolNames, "toolNames cannot be null"); Assert.notNull(media, "media cannot be null"); Assert.notNull(advisors, "advisors cannot be null"); Assert.notNull(advisorParams, "advisorParams cannot be null"); @@ -648,58 +710,21 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.systemText = systemText; this.systemParams.putAll(systemParams); - this.functionNames.addAll(functionNames); - this.functionCallbacks.addAll(functionCallbacks); + this.toolNames.addAll(toolNames); + this.toolCallbacks.addAll(toolCallbacks); this.messages.addAll(messages); this.media.addAll(media); this.advisors.addAll(advisors); this.advisorParams.putAll(advisorParams); this.observationRegistry = observationRegistry; - this.customObservationConvention = customObservationConvention != null ? customObservationConvention + this.observationConvention = observationConvention != null ? observationConvention : DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION; this.toolContext.putAll(toolContext); - // @formatter:off - // At the stack bottom add the non-streaming and streaming model call advisors. - // They play the role of the last advisor in the around advisor chain. - this.advisors.add(new CallAroundAdvisor() { - - @Override - public String getName() { - return CallAroundAdvisor.class.getSimpleName(); - } - - @Override - public int getOrder() { - return Ordered.LOWEST_PRECEDENCE; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext())); - } - }); - - this.advisors.add(new StreamAroundAdvisor() { - - @Override - public String getName() { - return StreamAroundAdvisor.class.getSimpleName(); - } - - @Override - public int getOrder() { - return Ordered.LOWEST_PRECEDENCE; - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - return chatModel.stream(advisedRequest.toPrompt()) - .map(chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))) - .publishOn(Schedulers.boundedElastic()); // TODO add option to disable. - } - }); - // @formatter:on + // At the stack bottom add the model call advisors. + // They play the role of the last advisors in the advisor chain. + this.advisors.add(new ChatModelCallAdvisor(chatModel)); + this.advisors.add(new ChatModelStreamAdvisor(chatModel)); this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry) .pushAll(this.advisors); @@ -710,7 +735,7 @@ private ObservationRegistry getObservationRegistry() { } private ChatClientObservationConvention getCustomObservationConvention() { - return this.customObservationConvention; + return this.observationConvention; } @Nullable @@ -753,11 +778,11 @@ public List getMedia() { } public List getFunctionNames() { - return this.functionNames; + return this.toolNames; } public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.toolCallbacks; } public Map getToolContext() { @@ -770,8 +795,8 @@ public Map getToolContext() { */ public Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient - .builder(this.chatModel, this.observationRegistry, this.customObservationConvention) - .defaultFunctions(StringUtils.toStringArray(this.functionNames)); + .builder(this.chatModel, this.observationRegistry, this.observationConvention) + .defaultTools(StringUtils.toStringArray(this.toolNames)); if (StringUtils.hasText(this.userText)) { builder.defaultUser( @@ -787,7 +812,7 @@ public Builder mutate() { } builder.addMessages(this.messages); - builder.addToolCallbacks(this.functionCallbacks); + builder.addToolCallbacks(this.toolCallbacks); builder.addToolContext(this.toolContext); return builder; @@ -843,7 +868,7 @@ public ChatClientRequestSpec options(T options) { public ChatClientRequestSpec tools(String... toolNames) { Assert.notNull(toolNames, "toolNames cannot be null"); Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); - this.functionNames.addAll(List.of(toolNames)); + this.toolNames.addAll(List.of(toolNames)); return this; } @@ -851,7 +876,7 @@ public ChatClientRequestSpec tools(String... toolNames) { public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - this.functionCallbacks.addAll(List.of(toolCallbacks)); + this.toolCallbacks.addAll(List.of(toolCallbacks)); return this; } @@ -859,7 +884,7 @@ public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) { public ChatClientRequestSpec tools(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - this.functionCallbacks.addAll(toolCallbacks); + this.toolCallbacks.addAll(toolCallbacks); return this; } @@ -867,7 +892,7 @@ public ChatClientRequestSpec tools(List toolCallbacks) { public ChatClientRequestSpec tools(Object... toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null"); Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); - this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); + this.toolCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); return this; } @@ -876,7 +901,7 @@ public ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProviders Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements"); for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { - this.functionCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); + this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); } return this; } @@ -890,7 +915,7 @@ public ChatClientRequestSpec functions(String... functionBeanNames) { public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) { Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements"); - this.functionCallbacks.addAll(Arrays.asList(functionCallbacks)); + this.toolCallbacks.addAll(Arrays.asList(functionCallbacks)); return this; } @@ -973,17 +998,22 @@ public ChatClientRequestSpec user(Consumer consumer) { } public CallResponseSpec call() { - return new DefaultCallResponseSpec(this); + BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build(); + return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain, + observationRegistry, observationConvention); } public StreamResponseSpec stream() { - return new DefaultStreamResponseSpec(this); + BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build(); + return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain, + observationRegistry, observationConvention); } } // Prompt + @Deprecated // never used, to be removed public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec { private final ChatModel chatModel; @@ -1015,6 +1045,7 @@ private ChatResponse doGetChatResponse(Prompt prompt) { } + @Deprecated // never used, to be removed public static class DefaultStreamPromptResponseSpec implements StreamPromptResponseSpec { private final Prompt prompt; diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java new file mode 100644 index 00000000000..7d33112899d --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.core.Ordered; +import org.springframework.util.Assert; + +import java.util.Map; + +/** + * A {@link CallAdvisor} that uses a {@link ChatModel} to generate a response. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class ChatModelCallAdvisor implements CallAdvisor { + + private final ChatModel chatModel; + + public ChatModelCallAdvisor(ChatModel chatModel) { + this.chatModel = chatModel; + } + + @Override + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) { + Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); + + ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt()); + return ChatClientResponse.builder() + .chatResponse(chatResponse) + .context(Map.copyOf(chatClientRequest.context())) + .build(); + } + + @Override + public String getName() { + return "call"; + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java new file mode 100644 index 00000000000..5743b69c69d --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.*; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.core.Ordered; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import java.util.Map; + +/** + * A {@link StreamAdvisor} that uses a {@link ChatModel} to generate a streaming response. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class ChatModelStreamAdvisor implements StreamAdvisor { + + private final ChatModel chatModel; + + public ChatModelStreamAdvisor(ChatModel chatModel) { + this.chatModel = chatModel; + } + + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) { + Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); + + return chatModel.stream(chatClientRequest.prompt()) + .map(chatResponse -> ChatClientResponse.builder() + .chatResponse(chatResponse) + .context(Map.copyOf(chatClientRequest.context())) + .build()) + .publishOn(Schedulers.boundedElastic()); // TODO add option to disable + } + + @Override + public String getName() { + return "stream"; + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index 084afd001b4..6ee2aedb30b 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,15 +23,18 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation; @@ -41,16 +44,16 @@ import org.springframework.util.CollectionUtils; /** - * Implementation of the {@link CallAroundAdvisorChain} and - * {@link StreamAroundAdvisorChain}. Used by the + * Default implementation for the {@link BaseAdvisorChain}. Used by the * {@link org.springframework.ai.chat.client.ChatClient} to delegate the call to the next - * {@link CallAroundAdvisor} or {@link StreamAroundAdvisor} in the chain. + * {@link CallAdvisor} or {@link StreamAdvisor} in the chain. * * @author Christian Tzolov * @author Dariusz Jedrzejczyk + * @author Thomas Vitale * @since 1.0.0 */ -public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, StreamAroundAdvisorChain { +public class DefaultAroundAdvisorChain implements BaseAdvisorChain { public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); @@ -77,7 +80,42 @@ public static Builder builder(ObservationRegistry observationRegistry) { } @Override + public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) { + Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); + + if (this.callAroundAdvisors.isEmpty()) { + throw new IllegalStateException("No CallAdvisors available to execute"); + } + + var advisor = this.callAroundAdvisors.pop(); + + var observationContext = AdvisorObservationContext.builder() + .advisorName(advisor.getName()) + .chatClientRequest(chatClientRequest) + .order(advisor.getOrder()) + .build(); + + return AdvisorObservationDocumentation.AI_ADVISOR + .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) + .observe(() -> { + // Supports both deprecated and new API. + if (advisor instanceof CallAdvisor callAdvisor) { + return callAdvisor.adviseCall(chatClientRequest, this); + } + AdvisedResponse advisedResponse = advisor.aroundCall(AdvisedRequest.from(chatClientRequest), this); + ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); + observationContext.setChatClientResponse(chatClientResponse); + return chatClientResponse; + }); + } + + /** + * @deprecated Use {@link #nextCall(ChatClientRequest)} instead + */ + @Override + @Deprecated public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { + Assert.notNull(advisedRequest, "the advisedRequest cannot be null"); if (this.callAroundAdvisors.isEmpty()) { throw new IllegalStateException("No AroundAdvisor available to execute"); @@ -87,19 +125,75 @@ public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { var observationContext = AdvisorObservationContext.builder() .advisorName(advisor.getName()) - .advisorType(AdvisorObservationContext.Type.AROUND) - .advisedRequest(advisedRequest) - .advisorRequestContext(advisedRequest.adviseContext()) + .chatClientRequest(advisedRequest.toChatClientRequest()) .order(advisor.getOrder()) .build(); return AdvisorObservationDocumentation.AI_ADVISOR .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> advisor.aroundCall(advisedRequest, this)); + .observe(() -> { + // Supports both deprecated and new API. + if (advisor instanceof CallAdvisor callAdvisor) { + ChatClientResponse chatClientResponse = callAdvisor.adviseCall(advisedRequest.toChatClientRequest(), + this); + return AdvisedResponse.from(chatClientResponse); + } + AdvisedResponse advisedResponse = advisor.aroundCall(advisedRequest, this); + ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); + observationContext.setChatClientResponse(chatClientResponse); + return advisedResponse; + }); } @Override + public Flux nextStream(ChatClientRequest chatClientRequest) { + Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); + + return Flux.deferContextual(contextView -> { + if (this.streamAroundAdvisors.isEmpty()) { + return Flux.error(new IllegalStateException("No StreamAdvisors available to execute")); + } + + var advisor = this.streamAroundAdvisors.pop(); + + AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .advisorName(advisor.getName()) + .chatClientRequest(chatClientRequest) + .order(advisor.getOrder()) + .build(); + + var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null, + DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + // @formatter:off + return Flux.defer(() -> { + // Supports both deprecated and new API. + if (advisor instanceof StreamAdvisor streamAdvisor) { + return streamAdvisor.adviseStream(chatClientRequest, this) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + } + return advisor.aroundStream(AdvisedRequest.from(chatClientRequest), this) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)) + .map(AdvisedResponse::toChatClientResponse); + }); + // @formatter:on + }); + } + + /** + * @deprecated Use {@link #nextStream(ChatClientRequest)} instead. + */ + @Override + @Deprecated public Flux nextAroundStream(AdvisedRequest advisedRequest) { + Assert.notNull(advisedRequest, "the advisedRequest cannot be null"); + return Flux.deferContextual(contextView -> { if (this.streamAroundAdvisors.isEmpty()) { return Flux.error(new IllegalStateException("No AroundAdvisor available to execute")); @@ -109,9 +203,7 @@ public Flux nextAroundStream(AdvisedRequest advisedRequest) { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .advisorName(advisor.getName()) - .advisorType(AdvisorObservationContext.Type.AROUND) - .advisedRequest(advisedRequest) - .advisorRequestContext(advisedRequest.adviseContext()) + .chatClientRequest(advisedRequest.toChatClientRequest()) .order(advisor.getOrder()) .build(); @@ -121,10 +213,21 @@ public Flux nextAroundStream(AdvisedRequest advisedRequest) { observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); // @formatter:off - return Flux.defer(() -> advisor.aroundStream(advisedRequest, this)) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + return Flux.defer(() -> { + // Supports both deprecated and new API. + if (advisor instanceof StreamAdvisor streamAdvisor) { + return streamAdvisor.adviseStream(advisedRequest.toChatClientRequest(), this) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)) + .map(AdvisedResponse::from); + } + + return advisor.aroundStream(advisedRequest, this) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + }); // @formatter:on }); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 6d58b77ed4b..ff0fb8c14ec 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,10 +20,14 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.Objects; +import org.springframework.ai.chat.client.ChatClientAttributes; +import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -34,6 +38,7 @@ import org.springframework.ai.content.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -60,8 +65,10 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @deprecated Use {@link ChatClientRequest} instead. * @since 1.0.0 */ +@Deprecated public record AdvisedRequest( // @formatter:off ChatModel chatModel, @@ -77,6 +84,7 @@ public record AdvisedRequest( Map userParams, Map systemParams, List advisors, + @Deprecated // Not really used. Use "adviseContext" instead. Map advisorParams, Map adviseContext, Map toolContext @@ -139,6 +147,52 @@ public static Builder from(AdvisedRequest from) { return builder; } + @SuppressWarnings("unchecked") + public static AdvisedRequest from(ChatClientRequest from) { + Assert.notNull(from, "ChatClientRequest cannot be null"); + + List messages = new LinkedList<>(from.prompt().getInstructions()); + + Builder builder = new Builder(); + if (from.context().get(ChatClientAttributes.CHAT_MODEL.getKey()) instanceof ChatModel chatModel) { + builder.chatModel = chatModel; + } + + if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof UserMessage userMessage) { + builder.userText = userMessage.getText(); + builder.media = userMessage.getMedia(); + messages.remove(messages.size() - 1); + } + if (from.context().get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map contextUserParams) { + builder.userParams = (Map) contextUserParams; + } + + if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof SystemMessage systemMessage) { + builder.systemText = systemMessage.getText(); + messages.remove(messages.size() - 1); + } + if (from.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map contextSystemParams) { + builder.systemParams = (Map) contextSystemParams; + } + + builder.messages = messages; + + builder.chatOptions = Objects.requireNonNullElse(from.prompt().getOptions(), ChatOptions.builder().build()); + if (from.prompt().getOptions() instanceof ToolCallingChatOptions options) { + builder.functionNames = options.getToolNames().stream().toList(); + builder.functionCallbacks = options.getToolCallbacks(); + builder.toolContext = options.getToolContext(); + } + + if (from.context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List advisors) { + builder.advisors = (List) advisors; + } + builder.advisorParams = Map.of(); + builder.adviseContext = from.context(); + + return builder.build(); + } + public AdvisedRequest updateContext(Function, Map> contextTransform) { Assert.notNull(contextTransform, "contextTransform cannot be null"); return from(this) @@ -146,6 +200,17 @@ public AdvisedRequest updateContext(Function, Map(this.messages()); @@ -157,16 +222,9 @@ public Prompt toPrompt() { messages.add(new SystemMessage(processedSystemText)); } - String formatParam = (String) this.adviseContext().get("formatParam"); - - var processedUserText = StringUtils.hasText(formatParam) - ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); - - if (StringUtils.hasText(processedUserText)) { + if (StringUtils.hasText(this.userText())) { Map userParams = new HashMap<>(this.userParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } + String processedUserText = this.userText(); if (!CollectionUtils.isEmpty(userParams)) { processedUserText = new PromptTemplate(processedUserText, userParams).render(); } @@ -338,7 +396,9 @@ public Builder advisors(List advisors) { * Set the advisor params. * @param advisorParams the advisor params * @return this {@link Builder} instance + * @deprecated in favor of {@link #adviseContext(Map)} */ + @Deprecated public Builder advisorParams(Map advisorParams) { this.advisorParams = advisorParams; return this; diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java index 7174aceacf7..04644c7db8a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.Map; import java.util.function.Function; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -33,8 +34,10 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @deprecated Use {@link ChatClientResponse} instead. * @since 1.0.0 */ +@Deprecated public record AdvisedResponse(@Nullable ChatResponse response, Map adviseContext) { /** @@ -66,6 +69,15 @@ public static Builder from(AdvisedResponse advisedResponse) { return new Builder().response(advisedResponse.response).adviseContext(advisedResponse.adviseContext); } + public static AdvisedResponse from(ChatClientResponse chatClientResponse) { + Assert.notNull(chatClientResponse, "chatClientResponse cannot be null"); + return new AdvisedResponse(chatClientResponse.chatResponse(), chatClientResponse.context()); + } + + public ChatClientResponse toChatClientResponse() { + return new ChatClientResponse(this.response, this.adviseContext); + } + /** * Update the context of the advised response. * @param contextTransform the function to transform the context diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java index d4831d45683..a8fd7bc7b3a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,9 @@ * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @since 1.0.0 - * @see CallAroundAdvisor - * @see StreamAroundAdvisor - * @see CallAroundAdvisorChain + * @see CallAdvisor + * @see StreamAdvisor + * @see BaseAdvisor */ public interface Advisor extends Ordered { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java new file mode 100644 index 00000000000..7957d48e6e2 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +/** + * A base interface for advisor chains that can be used to chain multiple advisors + * together, both for call and stream advisors. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface BaseAdvisorChain extends CallAdvisorChain, StreamAdvisorChain { + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java new file mode 100644 index 00000000000..6478b92f51a --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; + +/** + * Advisor for execution flows ultimately resulting in a call to an AI model + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface CallAdvisor extends CallAroundAdvisor { + + /** + * @deprecated use {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)} + */ + @Deprecated + default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + ChatClientResponse chatClientResponse = adviseCall(advisedRequest.toChatClientRequest(), chain); + return AdvisedResponse.from(chatClientResponse); + } + + ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain); + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java new file mode 100644 index 00000000000..c9de67b990d --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; + +/** + * A chain of {@link CallAdvisor} instances orchestrating the execution of a + * {@link ChatClientRequest} on the next {@link CallAdvisor} in the chain. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface CallAdvisorChain extends CallAroundAdvisorChain { + + /** + * @deprecated use {@link #nextCall(ChatClientRequest)} + */ + @Deprecated + default AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { + ChatClientResponse chatClientResponse = nextCall(advisedRequest.toChatClientRequest()); + return AdvisedResponse.from(chatClientResponse); + } + + ChatClientResponse nextCall(ChatClientRequest chatClientRequest); + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java index 12fcc45a30a..3faaf36a599 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,17 @@ package org.springframework.ai.chat.client.advisor.api; +import org.springframework.ai.chat.client.ChatClientRequest; + /** * Around advisor that wraps the ChatModel#call(Prompt) method. * * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @since 1.0.0 + * @deprecated in favor of {@link CallAdvisor} */ - +@Deprecated public interface CallAroundAdvisor extends Advisor { /** @@ -31,7 +34,10 @@ public interface CallAroundAdvisor extends Advisor { * @param advisedRequest the advised request * @param chain the advisor chain * @return the response + * @deprecated in favor of + * {@link CallAdvisor#adviseCall(ChatClientRequest, CallAroundAdvisorChain)} */ + @Deprecated AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java index 9158a721265..8f4a62825ef 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.ai.chat.client.advisor.api; +import org.springframework.ai.chat.client.ChatClientRequest; + /** * The Call Around Advisor Chain is used to invoke the next Around Advisor in the chain. * Used for non-streaming responses. @@ -23,7 +25,9 @@ * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @since 1.0.0 + * @deprecated in favor of {@link CallAdvisorChain} */ +@Deprecated public interface CallAroundAdvisorChain { /** @@ -32,7 +36,9 @@ public interface CallAroundAdvisorChain { * @param advisedRequest the request containing the data to be processed by the next * advisor in the chain. * @return the response generated by the next advisor in the chain. + * @deprecated in favor of {@link CallAdvisorChain#nextCall(ChatClientRequest)} */ + @Deprecated AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java new file mode 100644 index 00000000000..9ea441d4486 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import reactor.core.publisher.Flux; + +/** + * Advisor for execution flows ultimately resulting in a streaming call to an AI model. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface StreamAdvisor extends StreamAroundAdvisor { + + /** + * @deprecated use {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)} + */ + @Deprecated + default Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + Flux chatClientResponse = adviseStream(advisedRequest.toChatClientRequest(), chain); + return chatClientResponse.map(AdvisedResponse::from); + } + + Flux adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain); + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java new file mode 100644 index 00000000000..7ee12bbbdd7 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import reactor.core.publisher.Flux; + +/** + * A chain of {@link StreamAdvisor} instances orchestrating the execution of a + * {@link ChatClientRequest} on the next {@link StreamAdvisor} in the chain. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface StreamAdvisorChain extends StreamAroundAdvisorChain { + + /** + * @deprecated use {@link #nextStream(ChatClientRequest)} + */ + @Deprecated + default Flux nextAroundStream(AdvisedRequest advisedRequest) { + Flux chatClientResponse = nextStream(advisedRequest.toChatClientRequest()); + return chatClientResponse.map(AdvisedResponse::from); + } + + Flux nextStream(ChatClientRequest chatClientRequest); + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java index 3e06d5df4e0..d7145e14246 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.ai.chat.client.advisor.api; +import org.springframework.ai.chat.client.ChatClientRequest; import reactor.core.publisher.Flux; /** @@ -24,7 +25,9 @@ * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @since 1.0.0 + * @deprecated in favor of {@link StreamAdvisor} */ +@Deprecated public interface StreamAroundAdvisor extends Advisor { /** @@ -32,7 +35,10 @@ public interface StreamAroundAdvisor extends Advisor { * @param advisedRequest the advised request * @param chain the chain of advisors to execute * @return the result of the advised request + * @deprecated in favor of + * {@link StreamAdvisor#adviseStream(ChatClientRequest, StreamAroundAdvisorChain)} */ + @Deprecated Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java index 175ae9e71fa..7ab9631785a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.ai.chat.client.advisor.api; +import org.springframework.ai.chat.client.ChatClientRequest; import reactor.core.publisher.Flux; /** @@ -25,7 +26,9 @@ * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @since 1.0.0 + * @deprecated in favor of {@link StreamAdvisorChain} */ +@Deprecated public interface StreamAroundAdvisorChain { /** @@ -34,7 +37,9 @@ public interface StreamAroundAdvisorChain { * @param advisedRequest the request containing data of the chat client that can be * modified before execution * @return a Flux stream of AdvisedResponse objects + * @deprecated in favor of {@link StreamAdvisorChain#nextStream(ChatClientRequest)} */ + @Deprecated Flux nextAroundStream(AdvisedRequest advisedRequest); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java index 7816d3b8e20..c5005c92471 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,10 +21,13 @@ import io.micrometer.observation.Observation; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Context used to store metadata for chat client advisors. @@ -37,26 +40,12 @@ public class AdvisorObservationContext extends Observation.Context { private final String advisorName; - private final Type advisorType; + private final ChatClientRequest chatClientRequest; - /** - * The order of the advisor in the advisor chain. - */ private final int order; - /** - * The {@link AdvisedRequest} data to be advised. Represents the row - * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. - */ - @Nullable - private AdvisedRequest advisorRequest; - - /** - * The shared data between the advisors in the chain. It is shared between all request - * and response advising points of all advisors in the chain. - */ @Nullable - private Map advisorRequestContext; + private ChatClientResponse chatClientResponse; /** * the shared data between the advisors in the chain. It is shared between all request @@ -73,18 +62,32 @@ public class AdvisorObservationContext extends Observation.Context { * @param advisorRequestContext the shared data between the advisors in the chain * @param advisorResponseContext the shared data between the advisors in the chain * @param order the order of the advisor in the advisor chain + * @deprecated use the builder instead */ + @Deprecated public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest, @Nullable Map advisorRequestContext, @Nullable Map advisorResponseContext, int order) { - Assert.hasText(advisorName, "advisorName must not be null or empty"); - Assert.notNull(advisorType, "advisorType must not be null"); + Assert.hasText(advisorName, "advisorName cannot be null or empty"); this.advisorName = advisorName; - this.advisorType = advisorType; - this.advisorRequest = advisorRequest; - this.advisorRequestContext = advisorRequestContext; - this.advisorResponseContext = advisorResponseContext; + this.chatClientRequest = advisorRequest != null ? advisorRequest.toChatClientRequest() + : ChatClientRequest.builder().prompt(new Prompt()).build(); + if (!CollectionUtils.isEmpty(advisorRequestContext)) { + this.chatClientRequest.context().putAll(advisorRequestContext); + } + if (!CollectionUtils.isEmpty(advisorResponseContext)) { + this.chatClientResponse = ChatClientResponse.builder().context(advisorResponseContext).build(); + } + this.order = order; + } + + AdvisorObservationContext(String advisorName, ChatClientRequest chatClientRequest, int order) { + Assert.hasText(advisorName, "advisorName cannot be null or empty"); + Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); + + this.advisorName = advisorName; + this.chatClientRequest = chatClientRequest; this.order = order; } @@ -96,89 +99,115 @@ public static Builder builder() { return new Builder(); } - /** - * The advisor name. - * @return the advisor name - */ public String getAdvisorName() { return this.advisorName; } + public ChatClientRequest getChatClientRequest() { + return this.chatClientRequest; + } + + public int getOrder() { + return this.order; + } + + @Nullable + public ChatClientResponse getChatClientResponse() { + return this.chatClientResponse; + } + + public void setChatClientResponse(@Nullable ChatClientResponse chatClientResponse) { + this.chatClientResponse = chatClientResponse; + } + /** * The type of the advisor. * @return the type of the advisor + * @deprecated advisors don't have types anymore, they're all "around" */ + @Deprecated public Type getAdvisorType() { - return this.advisorType; + return Type.AROUND; } /** * The order of the advisor in the advisor chain. * @return the order of the advisor in the advisor chain + * @deprecated not used anymore */ - @Nullable + @Deprecated public AdvisedRequest getAdvisedRequest() { - return this.advisorRequest; + return AdvisedRequest.from(this.chatClientRequest); } /** * Set the {@link AdvisedRequest} data to be advised. Represents the row * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. * @param advisedRequest the advised request + * @deprecated immutable object, use the builder instead to create a new instance */ + @Deprecated public void setAdvisedRequest(@Nullable AdvisedRequest advisedRequest) { - this.advisorRequest = advisedRequest; + throw new IllegalStateException( + "The AdvisedRequest is immutable. Build a new AdvisorObservationContext instead."); } /** * Get the shared data between the advisors in the chain. It is shared between all * request and response advising points of all advisors in the chain. * @return the shared data between the advisors in the chain + * @deprecated use {@link #getChatClientRequest()} instead */ - @Nullable + @Deprecated public Map getAdvisorRequestContext() { - return this.advisorRequestContext; + return this.chatClientRequest.context(); } /** * Set the shared data between the advisors in the chain. It is shared between all * request and response advising points of all advisors in the chain. * @param advisorRequestContext the shared data between the advisors in the chain + * @deprecated not supported anymore, use {@link #getChatClientRequest()} instead */ + @Deprecated public void setAdvisorRequestContext(@Nullable Map advisorRequestContext) { - this.advisorRequestContext = advisorRequestContext; + if (!CollectionUtils.isEmpty(advisorRequestContext)) { + this.chatClientRequest.context().putAll(advisorRequestContext); + } } /** * Get the shared data between the advisors in the chain. It is shared between all * request and response advising points of all advisors in the chain. * @return the shared data between the advisors in the chain + * @deprecated use {@link #getChatClientResponse()} instead */ @Nullable + @Deprecated public Map getAdvisorResponseContext() { - return this.advisorResponseContext; + if (this.chatClientResponse != null) { + return this.chatClientResponse.context(); + } + return null; } /** * Set the shared data between the advisors in the chain. It is shared between all * request and response advising points of all advisors in the chain. * @param advisorResponseContext the shared data between the advisors in the chain + * @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead */ + @Deprecated public void setAdvisorResponseContext(@Nullable Map advisorResponseContext) { this.advisorResponseContext = advisorResponseContext; } - /** - * The order of the advisor in the advisor chain. - * @return the order of the advisor in the advisor chain - */ - public int getOrder() { - return this.order; - } - /** * The type of the advisor. + * + * @deprecated advisors don't have types anymore, they're all "around" */ + @Deprecated public enum Type { /** @@ -203,7 +232,9 @@ public static final class Builder { private String advisorName; - private Type advisorType; + private ChatClientRequest chatClientRequest; + + private int order = 0; private AdvisedRequest advisorRequest; @@ -211,28 +242,32 @@ public static final class Builder { private Map advisorResponseContext; - private int order = 0; - private Builder() { } - /** - * Set the advisor name. - * @param advisorName the advisor name - * @return the builder - */ public Builder advisorName(String advisorName) { this.advisorName = advisorName; return this; } + public Builder chatClientRequest(ChatClientRequest chatClientRequest) { + this.chatClientRequest = chatClientRequest; + return this; + } + + public Builder order(int order) { + this.order = order; + return this; + } + /** * Set the advisor type. * @param advisorType the advisor type * @return the builder + * @deprecated advisors don't have types anymore, they're all "around" */ + @Deprecated public Builder advisorType(Type advisorType) { - this.advisorType = advisorType; return this; } @@ -240,7 +275,9 @@ public Builder advisorType(Type advisorType) { * Set the advised request. * @param advisedRequest the advised request * @return the builder + * @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead */ + @Deprecated public Builder advisedRequest(AdvisedRequest advisedRequest) { this.advisorRequest = advisedRequest; return this; @@ -250,7 +287,9 @@ public Builder advisedRequest(AdvisedRequest advisedRequest) { * Set the advisor request context. * @param advisorRequestContext the advisor request context * @return the builder + * @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead */ + @Deprecated public Builder advisorRequestContext(Map advisorRequestContext) { this.advisorRequestContext = advisorRequestContext; return this; @@ -260,29 +299,26 @@ public Builder advisorRequestContext(Map advisorRequestContext) * Set the advisor response context. * @param advisorResponseContext the advisor response context * @return the builder + * @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead */ + @Deprecated public Builder advisorResponseContext(Map advisorResponseContext) { this.advisorResponseContext = advisorResponseContext; return this; } - /** - * Set the order of the advisor in the advisor chain. - * @param order the order of the advisor in the advisor chain - * @return the builder - */ - public Builder order(int order) { - this.order = order; - return this; - } - - /** - * Build the {@link AdvisorObservationContext}. - * @return the {@link AdvisorObservationContext} - */ public AdvisorObservationContext build() { - return new AdvisorObservationContext(this.advisorName, this.advisorType, this.advisorRequest, - this.advisorRequestContext, this.advisorResponseContext, this.order); + if (chatClientRequest != null && advisorRequest != null) { + throw new IllegalArgumentException( + "ChatClientRequest and AdvisedRequest cannot be set at the same time"); + } + else if (chatClientRequest != null) { + return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order); + } + else { + return new AdvisorObservationContext(this.advisorName, Type.AROUND, this.advisorRequest, + this.advisorRequestContext, this.advisorResponseContext, this.order); + } } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java index edf07976d43..cfbb15eff65 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.util.ParsingUtils; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * Default implementation of the {@link AdvisorObservationConvention}. @@ -55,6 +56,7 @@ public String getName() { @Override @Nullable public String getContextualName(AdvisorObservationContext context) { + Assert.notNull(context, "context cannot be null"); return ParsingUtils.reConcatenateCamelCase(context.getAdvisorName(), "_") .replace("_around_advisor", "") .replace("_advisor", ""); @@ -66,6 +68,7 @@ public String getContextualName(AdvisorObservationContext context) { @Override public KeyValues getLowCardinalityKeyValues(AdvisorObservationContext context) { + Assert.notNull(context, "context cannot be null"); return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorType(context), advisorName(context)); } @@ -78,6 +81,7 @@ protected KeyValue aiProvider(AdvisorObservationContext context) { return KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER, AiProvider.SPRING_AI.value()); } + @Deprecated protected KeyValue advisorType(AdvisorObservationContext context) { return KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE, context.getAdvisorType().name()); } @@ -96,6 +100,7 @@ protected KeyValue advisorName(AdvisorObservationContext context) { @Override public KeyValues getHighCardinalityKeyValues(AdvisorObservationContext context) { + Assert.notNull(context, "context cannot be null"); return KeyValues.of(advisorOrder(context)); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java index bd9918d5631..f8a231b5665 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,15 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationFilter; +import org.springframework.ai.chat.client.ChatClientAttributes; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.observation.tracing.TracingHelper; import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; + +import java.util.List; +import java.util.Map; /** * An {@link ObservationFilter} to include the chat prompt content in the observation. @@ -37,7 +43,8 @@ public Observation.Context map(Observation.Context context) { if (!(context instanceof ChatClientObservationContext chatClientObservationContext)) { return context; } - + // TODO: we really want these? Should probably align with same format as chat + // model observation chatClientSystemText(chatClientObservationContext); chatClientSystemParams(chatClientObservationContext); chatClientUserText(chatClientObservationContext); @@ -47,39 +54,65 @@ public Observation.Context map(Observation.Context context) { } protected void chatClientSystemText(ChatClientObservationContext context) { - if (!StringUtils.hasText(context.getRequest().getSystemText())) { + List messages = context.getRequest().prompt().getInstructions(); + if (CollectionUtils.isEmpty(messages)) { + return; + } + + var systemMessage = messages.stream() + .filter(message -> message instanceof SystemMessage) + .reduce((first, second) -> second); + if (systemMessage.isEmpty()) { return; } context.addHighCardinalityKeyValue( KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT, - context.getRequest().getSystemText())); + systemMessage.get().getText())); } + @SuppressWarnings("unchecked") protected void chatClientSystemParams(ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getSystemParams())) { + if (!(context.getRequest() + .context() + .get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map systemParams)) { + return; + } + if (CollectionUtils.isEmpty(systemParams)) { return; } + context.addHighCardinalityKeyValue( KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM, - TracingHelper.concatenateMaps(context.getRequest().getSystemParams()))); + TracingHelper.concatenateMaps((Map) systemParams))); } protected void chatClientUserText(ChatClientObservationContext context) { - if (!StringUtils.hasText(context.getRequest().getUserText())) { + List messages = context.getRequest().prompt().getInstructions(); + if (CollectionUtils.isEmpty(messages)) { + return; + } + + if (!(messages.get(messages.size() - 1) instanceof UserMessage userMessage)) { return; } context.addHighCardinalityKeyValue( KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT, - context.getRequest().getUserText())); + userMessage.getText())); } + @SuppressWarnings("unchecked") protected void chatClientUserParams(ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getUserParams())) { + if (!(context.getRequest() + .context() + .get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map userParams)) { + return; + } + if (CollectionUtils.isEmpty(userParams)) { return; } context.addHighCardinalityKeyValue( KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS, - TracingHelper.concatenateMaps(context.getRequest().getUserParams()))); + TracingHelper.concatenateMaps((Map) userParams))); } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java index 6ad0d244fac..f6400e70842 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,14 @@ import io.micrometer.observation.Observation; -import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.ChatClientAttributes; +import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * Context used to store metadata for chat client workflows. @@ -34,20 +36,16 @@ */ public class ChatClientObservationContext extends Observation.Context { - private final DefaultChatClientRequestSpec request; + private final ChatClientRequest request; private final AiOperationMetadata operationMetadata = new AiOperationMetadata(AiOperationType.FRAMEWORK.value(), AiProvider.SPRING_AI.value()); private final boolean stream; - @Nullable - private String format; - - ChatClientObservationContext(DefaultChatClientRequestSpec requestSpec, String format, boolean isStream) { - Assert.notNull(requestSpec, "requestSpec cannot be null"); - this.request = requestSpec; - this.format = format; + ChatClientObservationContext(ChatClientRequest chatClientRequest, boolean isStream) { + Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); + this.request = chatClientRequest; this.stream = isStream; } @@ -55,7 +53,7 @@ public static Builder builder() { return new Builder(); } - public DefaultChatClientRequestSpec getRequest() { + public ChatClientRequest getRequest() { return this.request; } @@ -67,18 +65,31 @@ public boolean isStream() { return this.stream; } + /** + * @deprecated not used anymore. The format instructions are already included in the + * ChatModelObservationContext. + */ @Nullable + @Deprecated public String getFormat() { - return this.format; + if (this.request.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()) instanceof String format) { + return format; + } + return null; } + /** + * @deprecated not used anymore. The format instructions are already included in the + * ChatModelObservationContext. + */ + @Deprecated public void setFormat(@Nullable String format) { - this.format = format; + this.request.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format); } public static final class Builder { - private DefaultChatClientRequestSpec request; + private ChatClientRequest chatClientRequest; private String format; @@ -87,23 +98,41 @@ public static final class Builder { private Builder() { } - public Builder withRequest(DefaultChatClientRequestSpec request) { - this.request = request; + public Builder request(ChatClientRequest chatClientRequest) { + this.chatClientRequest = chatClientRequest; return this; } + @Deprecated // use request(ChatClientRequest chatClientRequest) + public Builder withRequest(ChatClientRequest chatClientRequest) { + return request(chatClientRequest); + } + + /** + * @deprecated not used anymore. The format instructions are already included in + * the ChatModelObservationContext. + */ + @Deprecated public Builder withFormat(String format) { this.format = format; return this; } - public Builder withStream(boolean isStream) { + public Builder stream(boolean isStream) { this.isStream = isStream; return this; } + @Deprecated // use stream(boolean isStream) + public Builder withStream(boolean isStream) { + return stream(isStream); + } + public ChatClientObservationContext build() { - return new ChatClientObservationContext(this.request, this.format, this.isStream); + if (StringUtils.hasText(format)) { + this.chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format); + } + return new ChatClientObservationContext(this.chatClientRequest, this.isStream); } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java index b51ade63ca2..d52cb3be1a6 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,15 +19,20 @@ import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; +import org.springframework.ai.chat.client.ChatClientAttributes; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.tracing.TracingHelper; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; +import java.util.Arrays; +import java.util.List; + /** * Default conventions to populate observations for chat client workflows. * @@ -88,53 +93,71 @@ protected KeyValue stream(ChatClientObservationContext context) { public KeyValues getHighCardinalityKeyValues(ChatClientObservationContext context) { var keyValues = KeyValues.empty(); keyValues = chatClientAdvisorNames(keyValues, context); + // TODO: rename attribute? any sensitive data here? keyValues = chatClientAdvisorParams(keyValues, context); - keyValues = toolFunctionNames(keyValues, context); - keyValues = toolFunctionCallbacks(keyValues, context); + // TODO: remove this? Already included in chat model observation + keyValues = toolNames(keyValues, context); + // TODO: remove this? Already included in chat model observation + keyValues = toolCallbacks(keyValues, context); return keyValues; } + @SuppressWarnings("unchecked") protected KeyValues chatClientAdvisorNames(KeyValues keyValues, ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getAdvisors())) { + if (!(context.getRequest().context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List advisors)) { return keyValues; } - var advisorNames = context.getRequest().getAdvisors().stream().map(Advisor::getName).toList(); + var advisorNames = ((List) advisors).stream().map(Advisor::getName).toList(); return keyValues.and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), TracingHelper.concatenateStrings(advisorNames)); } protected KeyValues chatClientAdvisorParams(KeyValues keyValues, ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getAdvisorParams())) { + if (CollectionUtils.isEmpty(context.getRequest().context())) { return keyValues; } - var advisorParams = context.getRequest().getAdvisorParams(); + var chatClientContext = context.getRequest().context(); + Arrays.stream(ChatClientAttributes.values()).forEach(attribute -> chatClientContext.remove(attribute.getKey())); return keyValues.and( ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(), - TracingHelper.concatenateMaps(advisorParams)); + TracingHelper.concatenateMaps(chatClientContext)); } - protected KeyValues toolFunctionNames(KeyValues keyValues, ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getFunctionNames())) { + protected KeyValues toolNames(KeyValues keyValues, ChatClientObservationContext context) { + if (context.getRequest().prompt().getOptions() == null) { + return keyValues; + } + if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) { return keyValues; } - var functionNames = context.getRequest().getFunctionNames(); + + var toolNames = options.getToolNames(); + if (CollectionUtils.isEmpty(toolNames)) { + return keyValues; + } + return keyValues.and( ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), - TracingHelper.concatenateStrings(functionNames)); + TracingHelper.concatenateStrings(toolNames.stream().sorted().toList())); } - protected KeyValues toolFunctionCallbacks(KeyValues keyValues, ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getFunctionCallbacks())) { + protected KeyValues toolCallbacks(KeyValues keyValues, ChatClientObservationContext context) { + if (context.getRequest().prompt().getOptions() == null) { return keyValues; } - var functionCallbacks = context.getRequest() - .getFunctionCallbacks() - .stream() - .map(FunctionCallback::getName) - .toList(); + if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) { + return keyValues; + } + + var toolCallbacks = options.getToolCallbacks(); + if (CollectionUtils.isEmpty(toolCallbacks)) { + return keyValues; + } + + var toolCallbackNames = toolCallbacks.stream().map(FunctionCallback::getName).sorted().toList(); return keyValues .and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS - .asString(), TracingHelper.concatenateStrings(functionCallbacks)); + .asString(), TracingHelper.concatenateStrings(toolCallbackNames)); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java new file mode 100644 index 00000000000..071e0772190 --- /dev/null +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ChatClientRequest}. + * + * @author Thomas Vitale + */ +class ChatClientRequestTests { + + @Test + void whenPromptIsNullThenThrow() { + assertThatThrownBy(() -> new ChatClientRequest(null, Map.of())).isInstanceOf(IllegalArgumentException.class) + .hasMessage("prompt cannot be null"); + + assertThatThrownBy(() -> ChatClientRequest.builder().prompt(null).context(Map.of()).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("prompt cannot be null"); + } + + @Test + void whenContextIsNullThenThrow() { + assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + + assertThatThrownBy(() -> ChatClientRequest.builder().prompt(new Prompt()).context(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + void whenContextHasNullKeysThenThrow() { + Map context = new HashMap<>(); + context.put(null, "something"); + assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context keys cannot be null"); + } + +} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java new file mode 100644 index 00000000000..aa9076cd91b --- /dev/null +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ChatClientResponse}. + * + * @author Thomas Vitale + */ +class ChatClientResponseTests { + + @Test + void whenContextIsNullThenThrow() { + assertThatThrownBy(() -> new ChatClientResponse(null, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + + assertThatThrownBy(() -> ChatClientResponse.builder().chatResponse(null).context(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + void whenContextHasNullKeysThenThrow() { + Map context = new HashMap<>(); + context.put(null, "something"); + assertThatThrownBy(() -> new ChatClientResponse(null, context)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("context keys cannot be null"); + } + +} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 048272359a0..78415dd7752 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -29,6 +29,9 @@ import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; +import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; +import org.springframework.ai.tool.ToolCallback; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; @@ -46,7 +49,6 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.tool.ToolCallback; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.ClassPathResource; @@ -598,17 +600,67 @@ void whenAdvisorListThenReturn() { void buildCallResponseSpec() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient - .prompt(); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + .prompt("question"); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); assertThat(spec).isNotNull(); } @Test void buildCallResponseSpecWithNullRequest() { - assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null)) + assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null, mock(BaseAdvisorChain.class), + mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatClientRequest cannot be null"); + } + + @Test + void buildCallResponseSpecWithNullAdvisorChain() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class), null, + mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisorChain cannot be null"); + } + + @Test + void buildCallResponseSpecWithNullObservationRegistry() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class), + mock(BaseAdvisorChain.class), null, mock(ChatClientObservationConvention.class))) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("request cannot be null"); + .hasMessage("observationRegistry cannot be null"); + } + + @Test + void buildCallResponseSpecWithNullObservationConvention() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class), + mock(BaseAdvisorChain.class), mock(ObservationRegistry.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("observationConvention cannot be null"); + } + + @Test + void whenSimplePromptThenChatClientResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); + + ChatClientResponse chatClientResponse = spec.chatClientResponse(); + assertThat(chatClientResponse).isNotNull(); + + ChatResponse chatResponse = chatClientResponse.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(1); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("my question"); } @Test @@ -621,8 +673,8 @@ void whenSimplePromptThenChatResponse() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ChatResponse chatResponse = spec.chatResponse(); assertThat(chatResponse).isNotNull(); @@ -644,8 +696,8 @@ void whenFullPromptThenChatResponse() { Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt(prompt); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ChatResponse chatResponse = spec.chatResponse(); assertThat(chatResponse).isNotNull(); @@ -669,8 +721,8 @@ void whenPromptAndUserTextThenChatResponse() { DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt(prompt) .user("another question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ChatResponse chatResponse = spec.chatResponse(); assertThat(chatResponse).isNotNull(); @@ -696,8 +748,8 @@ void whenUserTextAndMessagesThenChatResponse() { .prompt() .user("another question") .messages(messages); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ChatResponse chatResponse = spec.chatResponse(); assertThat(chatResponse).isNotNull(); @@ -719,8 +771,8 @@ void whenChatResponseIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ChatResponse chatResponse = spec.chatResponse(); assertThat(chatResponse).isNull(); @@ -736,8 +788,8 @@ void whenChatResponseContentIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); String content = spec.content(); assertThat(content).isNull(); @@ -748,10 +800,10 @@ void whenResponseEntityWithParameterizedTypeIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); - assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference) null)) + assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("type cannot be null"); } @@ -766,8 +818,8 @@ void whenResponseEntityWithParameterizedTypeAndChatResponseContentNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ResponseEntity> responseEntity = spec .responseEntity(new ParameterizedTypeReference<>() { @@ -793,8 +845,8 @@ void whenResponseEntityWithParameterizedType() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ResponseEntity> responseEntity = spec .responseEntity(new ParameterizedTypeReference<>() { @@ -808,10 +860,10 @@ void whenResponseEntityWithConverterIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); - assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter) null)) + assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("structuredOutputConverter cannot be null"); } @@ -826,8 +878,8 @@ void whenResponseEntityWithConverterAndChatResponseContentNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ResponseEntity> responseEntity = spec .responseEntity(new ListOutputConverter(new DefaultConversionService())); @@ -847,8 +899,8 @@ void whenResponseEntityWithConverter() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ResponseEntity> responseEntity = spec .responseEntity(new ListOutputConverter(new DefaultConversionService())); @@ -861,8 +913,8 @@ void whenResponseEntityWithTypeIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); assertThatThrownBy(() -> spec.responseEntity((Class) null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("type cannot be null"); @@ -878,8 +930,8 @@ void whenResponseEntityWithTypeAndChatResponseContentNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ResponseEntity responseEntity = spec.responseEntity(String.class); assertThat(responseEntity.response()).isNotNull(); @@ -898,8 +950,8 @@ void whenResponseEntityWithType() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); ResponseEntity responseEntity = spec.responseEntity(Person.class); assertThat(responseEntity.response()).isNotNull(); @@ -912,10 +964,10 @@ void whenEntityWithParameterizedTypeIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); - assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference) null)) + assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("type cannot be null"); } @@ -930,8 +982,8 @@ void whenEntityWithParameterizedTypeAndChatResponseContentNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); List entity = spec.entity(new ParameterizedTypeReference<>() { }); @@ -954,8 +1006,8 @@ void whenEntityWithParameterizedType() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); List entity = spec.entity(new ParameterizedTypeReference<>() { }); @@ -967,10 +1019,10 @@ void whenEntityWithConverterIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); - assertThatThrownBy(() -> spec.entity((StructuredOutputConverter) null)) + assertThatThrownBy(() -> spec.entity((StructuredOutputConverter) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("structuredOutputConverter cannot be null"); } @@ -980,8 +1032,8 @@ void whenEntityWithConverterAndChatResponseContentNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); List entity = spec.entity(new ListOutputConverter(new DefaultConversionService())); assertThat(entity).isNull(); @@ -999,8 +1051,8 @@ void whenEntityWithConverter() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); List entity = spec.entity(new ListOutputConverter(new DefaultConversionService())); assertThat(entity).hasSize(3); @@ -1011,10 +1063,10 @@ void whenEntityWithTypeIsNull() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); - assertThatThrownBy(() -> spec.entity((Class) null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> spec.entity((Class) null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("type cannot be null"); } @@ -1028,8 +1080,8 @@ void whenEntityWithTypeAndChatResponseContentNull() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); String entity = spec.entity(String.class); assertThat(entity).isNull(); @@ -1047,8 +1099,8 @@ void whenEntityWithType() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); Person entity = spec.entity(Person.class); assertThat(entity).isNotNull(); @@ -1061,17 +1113,67 @@ void whenEntityWithType() { void buildStreamResponseSpec() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient - .prompt(); - DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( - chatClientRequestSpec); + .prompt("question"); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); assertThat(spec).isNotNull(); } @Test void buildStreamResponseSpecWithNullRequest() { - assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null)) + assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null, mock(BaseAdvisorChain.class), + mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatClientRequest cannot be null"); + } + + @Test + void buildStreamResponseSpecWithNullAdvisorChain() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class), null, + mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisorChain cannot be null"); + } + + @Test + void buildStreamResponseSpecWithNullObservationRegistry() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class), + mock(BaseAdvisorChain.class), null, mock(ChatClientObservationConvention.class))) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("request cannot be null"); + .hasMessage("observationRegistry cannot be null"); + } + + @Test + void buildStreamResponseSpecWithNullObservationConvention() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class), + mock(BaseAdvisorChain.class), mock(ObservationRegistry.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("observationConvention cannot be null"); + } + + @Test + void whenSimplePromptThenFluxChatClientResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); + + ChatClientResponse chatClientResponse = spec.chatClientResponse().blockLast(); + assertThat(chatClientResponse).isNotNull(); + + ChatResponse chatResponse = chatClientResponse.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(1); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("my question"); } @Test @@ -1084,8 +1186,8 @@ void whenSimplePromptThenFluxChatResponse() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); ChatResponse chatResponse = spec.chatResponse().blockLast(); assertThat(chatResponse).isNotNull(); @@ -1107,8 +1209,8 @@ void whenFullPromptThenFluxChatResponse() { Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt(prompt); - DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); ChatResponse chatResponse = spec.chatResponse().blockLast(); assertThat(chatResponse).isNotNull(); @@ -1132,8 +1234,8 @@ void whenPromptAndUserTextThenFluxChatResponse() { DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt(prompt) .user("another question"); - DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); ChatResponse chatResponse = spec.chatResponse().blockLast(); assertThat(chatResponse).isNotNull(); @@ -1159,8 +1261,9 @@ void whenUserTextAndMessagesThenFluxChatResponse() { .prompt() .user("another question") .messages(messages); - DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( - chatClientRequestSpec); + + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); ChatResponse chatResponse = spec.chatResponse().blockLast(); assertThat(chatResponse).isNotNull(); @@ -1183,8 +1286,8 @@ void whenChatResponseContentIsNullThenReturnFlux() { ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt("my question"); - DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( - chatClientRequestSpec); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); String content = spec.content().blockLast(); assertThat(content).isNull(); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java index 4f2d4415ec2..3e4e8e0c270 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,16 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientAttributes; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -147,4 +156,58 @@ void whenToolContextIsNullThenThrows() { .hasMessage("toolContext cannot be null"); } + @Test + void whenConvertToAndFromChatClientRequest() { + ChatModel chatModel = mock(ChatModel.class); + ChatOptions chatOptions = ToolCallingChatOptions.builder().build(); + List messages = List.of(mock(UserMessage.class)); + SystemMessage systemMessage = new SystemMessage("Instructions {key}"); + UserMessage userMessage = new UserMessage("Question {key}", mock(Media.class)); + Map systemParams = Map.of("key", "value"); + Map userParams = Map.of("key", "value"); + List toolNames = List.of("tool1", "tool2"); + ToolCallback toolCallback = mock(ToolCallback.class); + Map toolContext = Map.of("key", "value"); + List advisors = List.of(mock(Advisor.class)); + Map advisorContext = Map.of("key", "value"); + + AdvisedRequest advisedRequest = AdvisedRequest.builder() + .chatModel(chatModel) + .chatOptions(chatOptions) + .messages(messages) + .systemText(systemMessage.getText()) + .systemParams(systemParams) + .userText(userMessage.getText()) + .userParams(userParams) + .media(userMessage.getMedia()) + .functionNames(toolNames) + .functionCallbacks(List.of(toolCallback)) + .toolContext(toolContext) + .advisors(advisors) + .adviseContext(advisorContext) + .build(); + + ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest(); + + assertThat(chatClientRequest.context().get(ChatClientAttributes.CHAT_MODEL.getKey())).isEqualTo(chatModel); + assertThat(chatClientRequest.prompt().getOptions()).isEqualTo(chatOptions); + assertThat(chatClientRequest.prompt().getInstructions()).hasSize(3); + assertThat(chatClientRequest.prompt().getInstructions().get(0)).isEqualTo(messages.get(0)); + assertThat(chatClientRequest.prompt().getInstructions().get(1).getText()).isEqualTo("Instructions value"); + assertThat(chatClientRequest.prompt().getInstructions().get(2).getText()).isEqualTo("Question value"); + assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolNames()) + .containsAll(toolNames); + assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolCallbacks()) + .contains(toolCallback); + assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolContext()) + .containsAllEntriesOf(toolContext); + assertThat((List) chatClientRequest.context().get(ChatClientAttributes.ADVISORS.getKey())) + .containsAll(advisors); + assertThat(chatClientRequest.context()).containsAllEntriesOf(advisorContext); + + AdvisedRequest convertedAdvisedRequest = AdvisedRequest.from(chatClientRequest); + assertThat(convertedAdvisedRequest.toPrompt()).isEqualTo(chatClientRequest.prompt()); + assertThat(convertedAdvisedRequest.adviseContext()).containsAllEntriesOf(chatClientRequest.context()); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java index ccf3637da82..dd52eea8361 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.model.ChatResponse; import static org.assertj.core.api.Assertions.assertThat; @@ -67,7 +68,8 @@ void whenAdviseContextValuesIsNullThenThrows() { @Test void whenBuildFromNullAdvisedResponseThenThrows() { - assertThatThrownBy(() -> AdvisedResponse.from(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> AdvisedResponse.from((AdvisedResponse) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("advisedResponse cannot be null"); } @@ -85,4 +87,16 @@ void whenUpdateFromNullContextThenThrows() { .hasMessage("contextTransform cannot be null"); } + @Test + void whenConvertToAndFromChatClientResponse() { + ChatResponse chatResponse = mock(ChatResponse.class); + Map context = Map.of("key", "value"); + AdvisedResponse advisedResponse = new AdvisedResponse(chatResponse, context); + + ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); + + AdvisedResponse newAdvisedResponse = AdvisedResponse.from(chatClientResponse); + assertThat(newAdvisedResponse).isEqualTo(advisedResponse); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java index 06b34bf672b..f548346ba64 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,13 @@ package org.springframework.ai.chat.client.advisor.observation; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; /** * Unit tests for {@link AdvisorObservationContext}. @@ -32,8 +36,7 @@ class AdvisorObservationContextTests { @Test void whenMandatoryOptionsThenReturn() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() - .advisorName("MyName") - .advisorType(AdvisorObservationContext.Type.BEFORE) + .advisorName("AdvisorName") .build(); assertThat(observationContext).isNotNull(); @@ -41,17 +44,38 @@ void whenMandatoryOptionsThenReturn() { @Test void missingAdvisorName() { - assertThatThrownBy( - () -> AdvisorObservationContext.builder().advisorType(AdvisorObservationContext.Type.BEFORE).build()) + assertThatThrownBy(() -> AdvisorObservationContext.builder().build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("advisorName must not be null or empty"); + .hasMessageContaining("advisorName cannot be null or empty"); } @Test - void missingAdvisorType() { - assertThatThrownBy(() -> AdvisorObservationContext.builder().advisorName("MyName").build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("advisorType must not be null"); + void whenBuilderWithAdvisedRequestThenReturn() { + AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .advisorName("AdvisorName") + .advisedRequest(mock(AdvisedRequest.class)) + .build(); + + assertThat(observationContext).isNotNull(); + } + + @Test + void whenBuilderWithChatClientRequestThenReturn() { + AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .advisorName("AdvisorName") + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build()) + .build(); + + assertThat(observationContext).isNotNull(); + } + + @Test + void missingBuilderWithBothRequestsThenThrow() { + assertThatThrownBy(() -> AdvisorObservationContext.builder() + .advisedRequest(mock(AdvisedRequest.class)) + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build()) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ChatClientRequest and AdvisedRequest cannot be set at the same time"); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java index 39de4b46e57..c24098ad25a 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,6 @@ void shouldHaveName() { void contextualName() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .advisorName("MyName") - .advisorType(AdvisorObservationContext.Type.AROUND) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("my_name"); } @@ -56,7 +55,6 @@ void contextualName() { void supportsAdvisorObservationContext() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .advisorName("MyName") - .advisorType(AdvisorObservationContext.Type.AROUND) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); @@ -66,11 +64,8 @@ void supportsAdvisorObservationContext() { void shouldHaveLowCardinalityKeyValuesWhenDefined() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .advisorName("MyName") - .advisorType(AdvisorObservationContext.Type.AROUND) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( - KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE.asString(), - AdvisorObservationContext.Type.AROUND.name()), KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.FRAMEWORK.value()), KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.SPRING_AI.value()), KeyValue.of(LowCardinalityKeyNames.ADVISOR_NAME.asString(), "MyName"), @@ -81,7 +76,6 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { void shouldHaveKeyValuesWhenDefinedAndResponse() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .advisorName("MyName") - .advisorType(AdvisorObservationContext.Type.AROUND) .order(678) .build(); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java index 31d017d749e..70be4c4e620 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,20 +16,22 @@ package org.springframework.ai.chat.client.observation; -import java.util.List; import java.util.Map; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.ChatClientAttributes; +import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -57,14 +59,9 @@ void whenNotSupportedObservationContextThenReturnOriginalContext() { @Test void whenEmptyInputContentThenReturnOriginalContext() { - ObservationRegistry observationRegistry = ObservationRegistry.NOOP; - ChatClientObservationConvention customObservationConvention = null; + var request = ChatClientRequest.builder().prompt(new Prompt()).build(); - var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, - Map.of()); - - var expectedContext = ChatClientObservationContext.builder().withRequest(request).build(); + var expectedContext = ChatClientObservationContext.builder().request(request).build(); var actualContext = this.observationFilter.map(expectedContext); @@ -73,14 +70,13 @@ void whenEmptyInputContentThenReturnOriginalContext() { @Test void whenWithTextThenAugmentContext() { - ObservationRegistry observationRegistry = ObservationRegistry.NOOP; - ChatClientObservationConvention customObservationConvention = null; - - var request = new DefaultChatClientRequestSpec(this.chatModel, "sample user text", Map.of("up1", "upv1"), - "sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null, - List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); + var request = ChatClientRequest.builder() + .prompt(new Prompt(new SystemMessage("sample system text"), new UserMessage("sample user text"))) + .context(ChatClientAttributes.USER_PARAMS.getKey(), Map.of("up1", "upv1")) + .context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), Map.of("sp1", "sp1v")) + .build(); - var originalContext = ChatClientObservationContext.builder().withRequest(request).build(); + var originalContext = ChatClientObservationContext.builder().request(request).build(); var augmentedContext = this.observationFilter.map(originalContext); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java index cf8f644248a..810c1abd87a 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,14 @@ package org.springframework.ai.chat.client.observation; -import java.util.List; -import java.util.Map; - -import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -44,11 +41,10 @@ class ChatClientObservationContextTests { @Test void whenMandatoryRequestOptionsThenReturn() { - - var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); - - var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build(); + var observationContext = ChatClientObservationContext.builder() + .request(ChatClientRequest.builder().prompt(new Prompt()).build()) + .stream(true) + .build(); assertThat(observationContext).isNotNull(); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 99fb24c571d..36ba65f0eb6 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,18 +17,17 @@ package org.springframework.ai.chat.client.observation; import java.util.List; -import java.util.Map; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.ChatClientAttributes; +import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -36,7 +35,9 @@ import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -56,7 +57,7 @@ class DefaultChatClientObservationConventionTests { @Mock ChatModel chatModel; - DefaultChatClientRequestSpec request; + ChatClientRequest request; static CallAroundAdvisor dummyAdvisor(String name) { return new CallAroundAdvisor() { @@ -109,8 +110,7 @@ public String call(String functionInput) { @BeforeEach public void beforeEach() { - this.request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), - List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); + this.request = ChatClientRequest.builder().prompt(new Prompt()).build(); } @Test @@ -121,8 +121,8 @@ void shouldHaveName() { @Test void shouldHaveContextualName() { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(this.request) - .withStream(true) + .request(this.request) + .stream(true) .build(); assertThat(this.observationConvention.getContextualName(observationContext)) @@ -132,8 +132,8 @@ void shouldHaveContextualName() { @Test void supportsOnlyChatClientObservationContext() { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(this.request) - .withStream(true) + .request(this.request) + .stream(true) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); @@ -143,8 +143,8 @@ void supportsOnlyChatClientObservationContext() { @Test void shouldHaveRequiredKeyValues() { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(this.request) - .withStream(true) + .request(this.request) + .stream(true) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( @@ -154,27 +154,31 @@ void shouldHaveRequiredKeyValues() { @Test void shouldHaveOptionalKeyValues() { - var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), - List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(), - List.of("function1", "function2"), List.of(), null, - List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"), - ObservationRegistry.NOOP, null, Map.of()); + var request = ChatClientRequest.builder() + .prompt(new Prompt("", + ToolCallingChatOptions.builder() + .toolNames("tool1", "tool2") + .toolCallbacks(dummyFunction("toolCallback1"), dummyFunction("toolCallback2")) + .build())) + .context("advParam1", "advisorParam1Value") + .context(ChatClientAttributes.ADVISORS.getKey(), + List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2"))) + .build(); ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) + .request(request) .withFormat("json") - .withStream(true) + .stream(true) .build(); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( - KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), - "[\"advisor1\", \"advisor2\", \"CallAroundAdvisor\", \"StreamAroundAdvisor\"]"), + KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), "[\"advisor1\", \"advisor2\"]"), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(), "[\"advParam1\":\"advisorParam1Value\"]"), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), - "[\"function1\", \"function2\"]"), + "[\"tool1\", \"tool2\"]"), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(), - "[\"functionCallback1\", \"functionCallback2\"]")); + "[\"toolCallback1\", \"toolCallback2\"]")); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java index 0ef04415c04..5fe295458fc 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -148,4 +148,21 @@ public void testPromptCopy() { assertThat(prompt.getInstructions()).isNotSameAs(copiedPrompt.getInstructions()); } + @Test + public void mutatePrompt() { + String template = "Hello, {name}! Your age is {age}."; + Map model = new HashMap<>(); + model.put("name", "Alice"); + model.put("age", 30); + PromptTemplate promptTemplate = new PromptTemplate(template, model); + ChatOptions chatOptions = ChatOptions.builder().temperature(0.5).maxTokens(100).build(); + + Prompt prompt = promptTemplate.create(model, chatOptions); + + Prompt copiedPrompt = prompt.mutate().build(); + assertThat(prompt).isNotSameAs(copiedPrompt); + assertThat(prompt.getOptions()).isNotSameAs(copiedPrompt.getOptions()); + assertThat(prompt.getInstructions()).isNotSameAs(copiedPrompt.getInstructions()); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 8f3dd228510..6e37fd7548b 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -24,6 +24,7 @@ import java.util.Objects; import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.StreamUtils; @@ -49,6 +50,7 @@ public abstract class AbstractMessage implements Message { /** * The content of the message. */ + @Nullable protected final String textContent; /** @@ -63,11 +65,12 @@ public abstract class AbstractMessage implements Message { * @param textContent the text content * @param metadata the metadata */ - protected AbstractMessage(MessageType messageType, String textContent, Map metadata) { + protected AbstractMessage(MessageType messageType, @Nullable String textContent, Map metadata) { Assert.notNull(messageType, "Message type must not be null"); if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); } + Assert.notNull(metadata, "Metadata must not be null"); this.messageType = messageType; this.textContent = textContent; this.metadata = new HashMap<>(metadata); @@ -81,7 +84,9 @@ protected AbstractMessage(MessageType messageType, String textContent, Map metadata) { + Assert.notNull(messageType, "Message type must not be null"); Assert.notNull(resource, "Resource must not be null"); + Assert.notNull(metadata, "Metadata must not be null"); try (InputStream inputStream = resource.getInputStream()) { this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); } @@ -98,6 +103,7 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map metadata) { + super(MessageType.SYSTEM, textContent, metadata); } @Override + @NonNull public String getText() { return this.textContent; } @@ -68,4 +77,53 @@ public String toString() { + ", metadata=" + this.metadata + '}'; } + public SystemMessage copy() { + return new SystemMessage(getText(), Map.copyOf(this.metadata)); + } + + public Builder mutate() { + return new Builder().text(this.textContent).metadata(this.metadata); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + @Nullable + private String textContent; + + @Nullable + private Resource resource; + + private Map metadata = new HashMap<>(); + + public Builder text(String textContent) { + this.textContent = textContent; + return this; + } + + public Builder text(Resource resource) { + this.resource = resource; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public SystemMessage build() { + if (StringUtils.hasText(textContent) && resource != null) { + throw new IllegalArgumentException("textContent and resource cannot be set at the same time"); + } + else if (resource != null) { + this.textContent = MessageUtils.readResource(resource); + } + return new SystemMessage(this.textContent, this.metadata); + } + + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 7bae70b64ee..52caeddf75e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,13 +19,17 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.springframework.ai.content.Media; import org.springframework.ai.content.MediaContent; import org.springframework.core.io.Resource; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * A message of the type 'user' passed as input Messages with the user role are from the @@ -37,26 +41,45 @@ public class UserMessage extends AbstractMessage implements MediaContent { protected final List media; public UserMessage(String textContent) { - this(MessageType.USER, textContent, new ArrayList<>(), Map.of()); + this(textContent, new ArrayList<>(), Map.of()); } public UserMessage(Resource resource) { - super(MessageType.USER, resource, Map.of()); - this.media = new ArrayList<>(); + this(MessageUtils.readResource(resource)); } + /** + * @deprecated use {@link #builder()} instead. + */ + @Deprecated public UserMessage(String textContent, List media) { this(MessageType.USER, textContent, media, Map.of()); } + /** + * @deprecated use {@link #builder()} instead. + */ + @Deprecated public UserMessage(String textContent, Media... media) { this(textContent, Arrays.asList(media)); } - public UserMessage(String textContent, Collection mediaList, Map metadata) { - this(MessageType.USER, textContent, mediaList, metadata); + /** + * @deprecated use {@link #builder()} instead. Will be made private in the next + * release. + */ + @Deprecated + public UserMessage(String textContent, Collection media, Map metadata) { + super(MessageType.USER, textContent, metadata); + Assert.notNull(media, "media cannot be null"); + Assert.noNullElements(media, "media cannot have null elements"); + this.media = new ArrayList<>(media); } + /** + * @deprecated use {@link #builder()} instead. + */ + @Deprecated public UserMessage(MessageType messageType, String textContent, Collection media, Map metadata) { super(messageType, textContent, metadata); @@ -70,14 +93,78 @@ public String toString() { + this.messageType + '}'; } + @Override + @NonNull + public String getText() { + return this.textContent; + } + @Override public List getMedia() { return this.media; } - @Override - public String getText() { - return this.textContent; + public UserMessage copy() { + return new UserMessage(getText(), List.copyOf(getMedia()), Map.copyOf(getMetadata())); + } + + public Builder mutate() { + return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata())); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + @Nullable + private String textContent; + + @Nullable + private Resource resource; + + private List media = new ArrayList<>(); + + private Map metadata = new HashMap<>(); + + public Builder text(String textContent) { + this.textContent = textContent; + return this; + } + + public Builder text(Resource resource) { + this.resource = resource; + return this; + } + + public Builder media(List media) { + this.media = media; + return this; + } + + public Builder media(@Nullable Media... media) { + if (media != null) { + this.media = Arrays.asList(media); + } + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public UserMessage build() { + if (StringUtils.hasText(textContent) && resource != null) { + throw new IllegalArgumentException("textContent and resource cannot be set at the same time"); + } + else if (resource != null) { + this.textContent = MessageUtils.readResource(resource); + } + return new UserMessage(this.textContent, this.media, this.metadata); + } + } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/package-info.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/package-info.java new file mode 100644 index 00000000000..226bb13428e --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.messages; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index a5620b5e7e8..6993cd29775 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -30,6 +30,8 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.model.ModelRequest; import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * The Prompt class represents a prompt used in AI model requests. A prompt consists of @@ -62,15 +64,15 @@ public Prompt(Message... messages) { this(Arrays.asList(messages), null); } - public Prompt(String contents, ChatOptions chatOptions) { + public Prompt(String contents, @Nullable ChatOptions chatOptions) { this(new UserMessage(contents), chatOptions); } - public Prompt(Message message, ChatOptions chatOptions) { + public Prompt(Message message, @Nullable ChatOptions chatOptions) { this(Collections.singletonList(message), chatOptions); } - public Prompt(List messages, ChatOptions chatOptions) { + public Prompt(List messages, @Nullable ChatOptions chatOptions) { this.messages = messages; this.chatOptions = chatOptions; } @@ -123,10 +125,17 @@ private List instructionsCopy() { List messagesCopy = new ArrayList<>(); this.messages.forEach(message -> { if (message instanceof UserMessage userMessage) { - messagesCopy.add(new UserMessage(userMessage.getText(), userMessage.getMedia(), message.getMetadata())); + messagesCopy.add(UserMessage.builder() + .text(userMessage.getText()) + .media(userMessage.getMedia()) + .metadata(message.getMetadata()) + .build()); } else if (message instanceof SystemMessage systemMessage) { - messagesCopy.add(new SystemMessage(systemMessage.getText())); + messagesCopy.add(SystemMessage.builder() + .text(systemMessage.getText()) + .metadata(systemMessage.getMetadata()) + .build()); } else if (message instanceof AssistantMessage assistantMessage) { messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), @@ -144,4 +153,69 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) { return messagesCopy; } + public Builder mutate() { + Builder builder = new Builder().messages(instructionsCopy()); + if (this.chatOptions != null) { + builder.chatOptions(this.chatOptions.copy()); + } + return builder; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + @Nullable + private String content; + + @Nullable + private List messages = new ArrayList<>(); + + @Nullable + private ChatOptions chatOptions; + + public Builder content(@Nullable String content) { + this.content = content; + return this; + } + + public Builder messages(Message... messages) { + if (messages != null) { + this.messages = Arrays.asList(messages); + } + return this; + } + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder addMessage(Message message) { + if (this.messages == null) { + this.messages = new ArrayList<>(); + } + this.messages.add(message); + return this; + } + + public Builder chatOptions(ChatOptions chatOptions) { + this.chatOptions = chatOptions; + return this; + } + + public Prompt build() { + if (StringUtils.hasText(this.content) && !CollectionUtils.isEmpty(this.messages)) { + throw new IllegalArgumentException("content and messages cannot be set at the same time"); + } + else if (StringUtils.hasText(this.content)) { + this.messages = List.of(new UserMessage(this.content)); + } + return new Prompt(this.messages, this.chatOptions); + } + + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/MessageUtilsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/MessageUtilsTests.java new file mode 100644 index 00000000000..9fee140b12c --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/MessageUtilsTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.messages; + +import org.junit.jupiter.api.Test; +import org.springframework.core.io.ClassPathResource; + +import java.nio.charset.StandardCharsets; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MessageUtils}. + * + * @author Thomas Vitale + */ +class MessageUtilsTests { + + @Test + void readResource() { + String content = MessageUtils.readResource(new ClassPathResource("prompt-user.txt")); + assertThat(content).isEqualTo("Hello, world!"); + } + + @Test + void readResourceWhenNull() { + assertThatThrownBy(() -> MessageUtils.readResource(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("resource cannot be null"); + } + + @Test + void readResourceWithCharset() { + String content = MessageUtils.readResource(new ClassPathResource("prompt-user.txt"), StandardCharsets.UTF_8); + assertThat(content).isEqualTo("Hello, world!"); + } + + @Test + void readResourceWithCharsetWhenNull() { + assertThatThrownBy(() -> MessageUtils.readResource(new ClassPathResource("prompt-user.txt"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("charset cannot be null"); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/SystemMessageTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/SystemMessageTests.java new file mode 100644 index 00000000000..188a9816617 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/SystemMessageTests.java @@ -0,0 +1,111 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.messages; + +import org.junit.jupiter.api.Test; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.*; +import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE; + +/** + * Unit tests for {@link SystemMessage}. + * + * @author Thomas Vitale + */ +class SystemMessageTests { + + @Test + void systemMessageWithNullText() { + assertThrows(IllegalArgumentException.class, () -> new SystemMessage((String) null)); + } + + @Test + void systemMessageWithTextContent() { + String text = "Tell me, did you sail across the sun?"; + SystemMessage message = new SystemMessage(text); + assertEquals(text, message.getText()); + assertEquals(MessageType.SYSTEM, message.getMetadata().get(MESSAGE_TYPE)); + } + + @Test + void systemMessageWithNullResource() { + assertThrows(IllegalArgumentException.class, () -> new SystemMessage((Resource) null)); + } + + @Test + void systemMessageWithResource() { + SystemMessage message = new SystemMessage(new ClassPathResource("prompt-system.txt")); + assertEquals("Tell me, did you sail across the sun?", message.getText()); + assertEquals(MessageType.SYSTEM, message.getMetadata().get(MESSAGE_TYPE)); + } + + @Test + void systemMessageFromBuilderWithText() { + String text = "Tell me, did you sail across the sun?"; + SystemMessage message = SystemMessage.builder().text(text).metadata(Map.of("key", "value")).build(); + assertEquals(text, message.getText()); + assertThat(message.getMetadata()).hasSize(2) + .containsEntry(MESSAGE_TYPE, MessageType.SYSTEM) + .containsEntry("key", "value"); + } + + @Test + void systemMessageFromBuilderWithResource() { + Resource resource = new ClassPathResource("prompt-system.txt"); + SystemMessage message = SystemMessage.builder().text(resource).metadata(Map.of("key", "value")).build(); + assertEquals("Tell me, did you sail across the sun?", message.getText()); + assertThat(message.getMetadata()).hasSize(2) + .containsEntry(MESSAGE_TYPE, MessageType.SYSTEM) + .containsEntry("key", "value"); + } + + @Test + void systemMessageCopy() { + String text1 = "Tell me, did you sail across the sun?"; + Map metadata1 = Map.of("key", "value"); + SystemMessage systemMessage1 = SystemMessage.builder().text(text1).metadata(metadata1).build(); + + SystemMessage systemMessage2 = systemMessage1.copy(); + + assertThat(systemMessage2.getText()).isEqualTo(text1); + assertThat(systemMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1); + } + + @Test + void systemMessageMutate() { + String text1 = "Tell me, did you sail across the sun?"; + Map metadata1 = Map.of("key", "value"); + SystemMessage systemMessage1 = SystemMessage.builder().text(text1).metadata(metadata1).build(); + + SystemMessage systemMessage2 = systemMessage1.mutate().build(); + + assertThat(systemMessage2.getText()).isEqualTo(text1); + assertThat(systemMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1); + + String text3 = "Farewell, Aragog!"; + SystemMessage systemMessage3 = systemMessage2.mutate().text(text3).build(); + + assertThat(systemMessage3.getText()).isEqualTo(text3); + assertThat(systemMessage3.getMetadata()).hasSize(2).isNotSameAs(systemMessage2.getMetadata()); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java new file mode 100644 index 00000000000..26bb59718bd --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.messages; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.content.Media; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeTypeUtils; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE; + +/** + * Unit tests for {@link UserMessage}. + * + * @author Thomas Vitale + */ +class UserMessageTests { + + @Test + void userMessageWithNullText() { + assertThatThrownBy(() -> new UserMessage((String) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Content must not be null for SYSTEM or USER messages"); + ; + } + + @Test + void userMessageWithTextContent() { + String text = "Hello, world!"; + UserMessage message = new UserMessage(text); + assertThat(message.getText()).isEqualTo(text); + assertThat(message.getMedia()).isEmpty(); + assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER); + } + + @Test + void userMessageWithNullResource() { + assertThatThrownBy(() -> new UserMessage((Resource) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("resource cannot be null"); + ; + } + + @Test + void userMessageWithResource() { + UserMessage message = new UserMessage(new ClassPathResource("prompt-user.txt")); + assertThat(message.getText()).isEqualTo("Hello, world!"); + assertThat(message.getMedia()).isEmpty(); + assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER); + } + + @Test + void userMessageFromBuilderWithText() { + String text = "Hello, world!"; + UserMessage message = UserMessage.builder() + .text(text) + .media(new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt"))) + .metadata(Map.of("key", "value")) + .build(); + assertThat(message.getText()).isEqualTo(text); + assertThat(message.getMedia()).hasSize(1); + assertThat(message.getMetadata()).hasSize(2) + .containsEntry(MESSAGE_TYPE, MessageType.USER) + .containsEntry("key", "value"); + } + + @Test + void userMessageFromBuilderWithResource() { + UserMessage message = UserMessage.builder().text(new ClassPathResource("prompt-user.txt")).build(); + assertThat(message.getText()).isEqualTo("Hello, world!"); + assertThat(message.getMedia()).isEmpty(); + assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER); + } + + @Test + void userMessageCopy() { + String text1 = "Hello, world!"; + Media media1 = new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt")); + Map metadata1 = Map.of("key", "value"); + UserMessage userMessage1 = UserMessage.builder().text(text1).media(media1).metadata(metadata1).build(); + + UserMessage userMessage2 = userMessage1.copy(); + + assertThat(userMessage2.getText()).isEqualTo(text1); + assertThat(userMessage2.getMedia()).hasSize(1).isNotSameAs(metadata1); + assertThat(userMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1); + } + + @Test + void userMessageMutate() { + String text1 = "Hello, world!"; + Media media1 = new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt")); + Map metadata1 = Map.of("key", "value"); + UserMessage userMessage1 = UserMessage.builder().text(text1).media(media1).metadata(metadata1).build(); + + UserMessage userMessage2 = userMessage1.mutate().build(); + + assertThat(userMessage2.getText()).isEqualTo(text1); + assertThat(userMessage2.getMedia()).hasSize(1).isNotSameAs(metadata1); + assertThat(userMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1); + + String text3 = "Farewell, Aragog!"; + UserMessage userMessage3 = userMessage2.mutate().text(text3).build(); + + assertThat(userMessage3.getText()).isEqualTo(text3); + assertThat(userMessage3.getMedia()).hasSize(1).isNotSameAs(metadata1); + assertThat(userMessage3.getMetadata()).hasSize(2).isNotSameAs(metadata1); + } + +} diff --git a/spring-ai-model/src/test/resources/prompt-system.txt b/spring-ai-model/src/test/resources/prompt-system.txt new file mode 100644 index 00000000000..b292fd2f45c --- /dev/null +++ b/spring-ai-model/src/test/resources/prompt-system.txt @@ -0,0 +1 @@ +Tell me, did you sail across the sun? \ No newline at end of file diff --git a/spring-ai-model/src/test/resources/prompt-user.txt b/spring-ai-model/src/test/resources/prompt-user.txt new file mode 100644 index 00000000000..5dd01c177f5 --- /dev/null +++ b/spring-ai-model/src/test/resources/prompt-user.txt @@ -0,0 +1 @@ +Hello, world! \ No newline at end of file