Skip to content

Commit fb27240

Browse files
committed
feat: OpenAI Web Search Annotations
This PR adds support for retrieving web search annotations from the OpenAI API, as described in their [web search documentation](https://platform.openai.com/docs/guides/web-search). This allows us to access citation URLs and their context within generated responses when using models like `gpt-4o-search-preview`. **Changes:** * Added `annotations` (with `Annotation` and `UrlCitation` records) to `ChatCompletionMessage` in `OpenAiApi.java`. * Updated `OpenAiChatModel` to populate the `annotations` field (via metadata) for both regular and streaming responses. * Added integration tests (`webSearchAnnotationsTest`, `streamWebSearchAnnotationsTest`) to `OpenAiChatModelIT.java`. * Added `GPT_4_O_SEARCH_PREVIEW` and `GPT_4_O_MINI_SEARCH_PREVIEW` to `OpenAiApi.ChatModel`. * Added `WebSearchOptions` and related records to `OpenAiApi`. * Minor updates to `ChatCompletionRequest` and its `Builder`. Resolves spring-projects#2449 Signed-off-by: Alexandros Pappas <[email protected]>
1 parent ea3fd92 commit fb27240

File tree

7 files changed

+217
-27
lines changed

7 files changed

+217
-27
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
218218
"role", choice.message().role() != null ? choice.message().role().name() : "",
219219
"index", choice.index(),
220220
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
221-
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
221+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
222+
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
222223
return buildGeneration(choice, metadata, request);
223224
}).toList();
224225
// @formatter:on
@@ -316,8 +317,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
316317
"role", roleMap.getOrDefault(id, ""),
317318
"index", choice.index(),
318319
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
319-
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
320-
320+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
321+
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
321322
return buildGeneration(choice, metadata, request);
322323
}).toList();
323324
// @formatter:on
@@ -580,7 +581,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
580581

581582
}
582583
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
583-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
584+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
584585
}
585586
else if (message.getMessageType() == MessageType.TOOL) {
586587
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -590,7 +591,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
590591
return toolMessage.getResponses()
591592
.stream()
592593
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
593-
tr.id(), null, null, null))
594+
tr.id(), null, null, null, null))
594595
.toList();
595596
}
596597
else {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
3838
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
3939
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
40+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.WebSearchOptions;
4041
import org.springframework.ai.openai.api.ResponseFormat;
4142
import org.springframework.ai.tool.ToolCallback;
4243
import org.springframework.lang.Nullable;
@@ -194,6 +195,11 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
194195
*/
195196
private @JsonProperty("reasoning_effort") String reasoningEffort;
196197

198+
/**
199+
* This tool searches the web for relevant results to use in a response.
200+
*/
201+
private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions;
202+
197203
/**
198204
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
199205
*/
@@ -593,6 +599,14 @@ public void setReasoningEffort(String reasoningEffort) {
593599
this.reasoningEffort = reasoningEffort;
594600
}
595601

602+
public WebSearchOptions getWebSearchOptions() {
603+
return this.webSearchOptions;
604+
}
605+
606+
public void setWebSearchOptions(WebSearchOptions webSearchOptions) {
607+
this.webSearchOptions = webSearchOptions;
608+
}
609+
596610
@Override
597611
public OpenAiChatOptions copy() {
598612
return OpenAiChatOptions.fromOptions(this);
@@ -605,7 +619,7 @@ public int hashCode() {
605619
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
606620
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
607621
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
608-
this.store, this.metadata, this.reasoningEffort);
622+
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions);
609623
}
610624

611625
@Override
@@ -637,7 +651,8 @@ public boolean equals(Object o) {
637651
&& Objects.equals(this.outputModalities, other.outputModalities)
638652
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
639653
&& Objects.equals(this.metadata, other.metadata)
640-
&& Objects.equals(this.reasoningEffort, other.reasoningEffort);
654+
&& Objects.equals(this.reasoningEffort, other.reasoningEffort)
655+
&& Objects.equals(this.webSearchOptions, other.webSearchOptions);
641656
}
642657

643658
@Override
@@ -848,6 +863,11 @@ public Builder reasoningEffort(String reasoningEffort) {
848863
return this;
849864
}
850865

866+
public Builder webSearchOptions(WebSearchOptions webSearchOptions) {
867+
this.options.webSearchOptions = webSearchOptions;
868+
return this;
869+
}
870+
851871
public OpenAiChatOptions build() {
852872
return this.options;
853873
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 116 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,21 @@ public enum ChatModel implements ChatModelDescription {
475475
* Context window: 4,096 tokens. Max output tokens: 4,096 tokens. Knowledge
476476
* cutoff: September, 2021.
477477
*/
478-
GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct");
478+
GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct"),
479+
480+
/**
481+
* <b>GPT-4o Search Preview</b> is a specialized model for web search in Chat
482+
* Completions. It is trained to understand and execute web search queries. See
483+
* the web search guide for more information.
484+
*/
485+
GPT_4_O_SEARCH_PREVIEW("gpt-4o-search-preview"),
486+
487+
/**
488+
* <b>GPT-4o mini Search Preview</b> is a specialized model for web search in Chat
489+
* Completions. It is trained to understand and execute web search queries. See
490+
* the web search guide for more information.
491+
*/
492+
GPT_4_O_MINI_SEARCH_PREVIEW("gpt-4o-mini-search-preview");
479493

480494
public final String value;
481495

@@ -835,6 +849,10 @@ public enum OutputModality {
835849
* @param parallelToolCalls If set to true, the model will call all functions in the
836850
* tools list in parallel. Otherwise, the model will call the functions in the tools
837851
* list in the order they are provided.
852+
* @param reasoningEffort Constrains effort on reasoning for reasoning models.
853+
* Currently supported values are low, medium, and high. Reducing reasoning effort can
854+
* result in faster responses and fewer tokens used on reasoning in a response.
855+
* @param webSearchOptions Options for web search.
838856
*/
839857
@JsonInclude(Include.NON_NULL)
840858
public record ChatCompletionRequest(// @formatter:off
@@ -864,7 +882,8 @@ public record ChatCompletionRequest(// @formatter:off
864882
@JsonProperty("tool_choice") Object toolChoice,
865883
@JsonProperty("parallel_tool_calls") Boolean parallelToolCalls,
866884
@JsonProperty("user") String user,
867-
@JsonProperty("reasoning_effort") String reasoningEffort) {
885+
@JsonProperty("reasoning_effort") String reasoningEffort,
886+
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions) {
868887

869888
/**
870889
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
@@ -876,7 +895,7 @@ public record ChatCompletionRequest(// @formatter:off
876895
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
877896
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
878897
null, null, null, false, null, temperature, null,
879-
null, null, null, null, null);
898+
null, null, null, null, null, null);
880899
}
881900

882901
/**
@@ -890,7 +909,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
890909
this(messages, model, null, null, null, null, null, null,
891910
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
892911
null, null, null, stream, null, null, null,
893-
null, null, null, null, null);
912+
null, null, null, null, null, null);
894913
}
895914

896915
/**
@@ -905,7 +924,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
905924
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
906925
this(messages, model, null, null, null, null, null, null, null, null, null,
907926
null, null, null, null, null, null, null, stream, null, temperature, null,
908-
null, null, null, null, null);
927+
null, null, null, null, null, null);
909928
}
910929

911930
/**
@@ -921,7 +940,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
921940
List<FunctionTool> tools, Object toolChoice) {
922941
this(messages, model, null, null, null, null, null, null, null, null, null,
923942
null, null, null, null, null, null, null, false, null, 0.8, null,
924-
tools, toolChoice, null, null, null);
943+
tools, toolChoice, null, null, null, null);
925944
}
926945

927946
/**
@@ -934,7 +953,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
934953
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
935954
this(messages, null, null, null, null, null, null, null, null, null, null,
936955
null, null, null, null, null, null, null, stream, null, null, null,
937-
null, null, null, null, null);
956+
null, null, null, null, null, null);
938957
}
939958

940959
/**
@@ -947,7 +966,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
947966
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
948967
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
949968
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
950-
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort);
969+
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions);
951970
}
952971

953972
/**
@@ -1029,6 +1048,61 @@ public record StreamOptions(
10291048

10301049
public static StreamOptions INCLUDE_USAGE = new StreamOptions(true);
10311050
}
1051+
1052+
/**
1053+
* This tool searches the web for relevant results to use in a response.
1054+
*
1055+
* @param searchContextSize
1056+
* @param userLocation
1057+
*/
1058+
@JsonInclude(Include.NON_NULL)
1059+
public record WebSearchOptions(@JsonProperty("search_context_size") SearchContextSize searchContextSize,
1060+
@JsonProperty("user_location") UserLocation userLocation) {
1061+
1062+
/**
1063+
* High level guidance for the amount of context window space to use for the
1064+
* search. One of low, medium, or high. medium is the default.
1065+
*/
1066+
public enum SearchContextSize {
1067+
1068+
/**
1069+
* Low context size.
1070+
*/
1071+
@JsonProperty("low")
1072+
LOW,
1073+
1074+
/**
1075+
* Medium context size. This is the default.
1076+
*/
1077+
@JsonProperty("medium")
1078+
MEDIUM,
1079+
1080+
/**
1081+
* High context size.
1082+
*/
1083+
@JsonProperty("high")
1084+
HIGH
1085+
1086+
}
1087+
1088+
/**
1089+
* Approximate location parameters for the search.
1090+
*
1091+
* @param type The type of location approximation. Always "approximate".
1092+
* @param approximate The approximate location details.
1093+
*/
1094+
@JsonInclude(Include.NON_NULL)
1095+
public record UserLocation(@JsonProperty("type") String type,
1096+
@JsonProperty("approximate") Approximate approximate) {
1097+
1098+
@JsonInclude(Include.NON_NULL)
1099+
public record Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country,
1100+
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
1101+
}
1102+
}
1103+
1104+
}
1105+
10321106
} // @formatter:on
10331107

10341108
/**
@@ -1047,19 +1121,22 @@ public record StreamOptions(
10471121
* Applicable only for {@link Role#ASSISTANT} role and null otherwise.
10481122
* @param refusal The refusal message by the assistant. Applicable only for
10491123
* {@link Role#ASSISTANT} role and null otherwise.
1050-
* @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI -
1051-
* Support audio input modality)
1124+
* @param audioOutput Audio response from the model.
1125+
* @param annotations Annotations for the message, when applicable, as when using the
1126+
* web search tool.
10521127
*/
1053-
@JsonInclude(Include.NON_NULL)
1054-
public record ChatCompletionMessage(// @formatter:off
1128+
@JsonInclude(JsonInclude.Include.NON_NULL)
1129+
public record ChatCompletionMessage(
1130+
// @formatter:off
10551131
@JsonProperty("content") Object rawContent,
10561132
@JsonProperty("role") Role role,
10571133
@JsonProperty("name") String name,
10581134
@JsonProperty("tool_call_id") String toolCallId,
1059-
@JsonProperty("tool_calls")
1060-
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
1135+
@JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
10611136
@JsonProperty("refusal") String refusal,
1062-
@JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on
1137+
@JsonProperty("audio") AudioOutput audioOutput,
1138+
@JsonProperty("annotations") List<Annotation> annotations
1139+
) { // @formatter:on
10631140

10641141
/**
10651142
* Create a chat completion message with the given content and role. All other
@@ -1068,8 +1145,7 @@ public record ChatCompletionMessage(// @formatter:off
10681145
* @param role The role of the author of this message.
10691146
*/
10701147
public ChatCompletionMessage(Object content, Role role) {
1071-
this(content, role, null, null, null, null, null);
1072-
1148+
this(content, role, null, null, null, null, null, null);
10731149
}
10741150

10751151
/**
@@ -1246,6 +1322,29 @@ public record AudioOutput(// @formatter:off
12461322
@JsonProperty("transcript") String transcript
12471323
) { // @formatter:on
12481324
}
1325+
1326+
/**
1327+
* Represents an annotation within a message, specifically for URL citations.
1328+
*/
1329+
@JsonInclude(JsonInclude.Include.NON_NULL)
1330+
public record Annotation(@JsonProperty("type") String type,
1331+
@JsonProperty("url_citation") UrlCitation urlCitation) {
1332+
/**
1333+
* A URL citation when using web search.
1334+
*
1335+
* @param endIndex The index of the last character of the URL citation in the
1336+
* message.
1337+
* @param startIndex The index of the first character of the URL citation in
1338+
* the message.
1339+
* @param title The title of the web resource.
1340+
* @param url The URL of the web resource.
1341+
*/
1342+
@JsonInclude(JsonInclude.Include.NON_NULL)
1343+
public record UrlCitation(@JsonProperty("end_index") Integer endIndex,
1344+
@JsonProperty("start_index") Integer startIndex, @JsonProperty("title") String title,
1345+
@JsonProperty("url") String url) {
1346+
}
1347+
}
12491348
}
12501349

12511350
/**

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
*
4141
* @author Christian Tzolov
4242
* @author Thomas Vitale
43+
* @author Alexandros Pappas
4344
* @since 0.8.1
4445
*/
4546
public class OpenAiStreamFunctionCallingHelper {
@@ -98,6 +99,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
9899
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
99100
ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput()
100101
: previous.audioOutput());
102+
List<ChatCompletionMessage.Annotation> annotations = (current.annotations() != null ? current.annotations()
103+
: previous.annotations());
101104

102105
List<ToolCall> toolCalls = new ArrayList<>();
103106
ToolCall lastPreviousTooCall = null;
@@ -127,7 +130,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
127130
toolCalls.add(lastPreviousTooCall);
128131
}
129132
}
130-
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput);
133+
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations);
131134
}
132135

133136
private ToolCall merge(ToolCall previous, ToolCall current) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void validateReasoningTokens() {
7575
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
7676
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
7777
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
78-
null, null, null, "low");
78+
null, null, null, "low", null);
7979
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);
8080

8181
assertThat(response).isNotNull();

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
*
4545
* @author Christian Tzolov
4646
* @author Thomas Vitale
47+
* @author Alexandros Pappas
4748
*/
4849
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
4950
public class OpenAiApiToolFunctionCallIT {
@@ -129,7 +130,7 @@ public void toolFunctionCall() {
129130

130131
// extend conversation with function response.
131132
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
132-
functionName, toolCall.id(), null, null, null));
133+
functionName, toolCall.id(), null, null, null, null));
133134
}
134135
}
135136

0 commit comments

Comments
 (0)