Skip to content

Commit bc375ab

Browse files
apappascstzolov
authored andcommitted
This commit enhances AzureOpenAiChatOptions by:
- Adding `equals` and `hashCode` methods for proper object comparison. - Implementing a deep `copy()` method, creating new instances of mutable collections (List, Set, Map, Metadata) to prevent shared state. - Adding `AzureOpenAiChatOptionsTests` to verify `copy()`, builders, setters, and default values. Signed-off-by: Alexandros Pappas <[email protected]>
1 parent beb81ef commit bc375ab

File tree

2 files changed

+226
-6
lines changed

2 files changed

+226
-6
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashSet;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.Objects;
2526
import java.util.Set;
2627

2728
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
@@ -46,6 +47,7 @@
4647
* @author Thomas Vitale
4748
* @author Soby Chacko
4849
* @author Ilayaperumal Gopinathan
50+
* @author Alexandros Pappas
4951
*/
5052
@JsonInclude(Include.NON_NULL)
5153
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
@@ -250,22 +252,24 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
250252
.maxTokens(fromOptions.getMaxTokens())
251253
.N(fromOptions.getN())
252254
.presencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null)
253-
.stop(fromOptions.getStop())
255+
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
254256
.temperature(fromOptions.getTemperature())
255257
.topP(fromOptions.getTopP())
256258
.user(fromOptions.getUser())
257-
.functionCallbacks(fromOptions.getFunctionCallbacks())
258-
.functions(fromOptions.getFunctions())
259+
.functionCallbacks(fromOptions.getFunctionCallbacks() != null
260+
? new ArrayList<>(fromOptions.getFunctionCallbacks()) : null)
261+
.functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null)
259262
.responseFormat(fromOptions.getResponseFormat())
260263
.seed(fromOptions.getSeed())
261264
.logprobs(fromOptions.isLogprobs())
262265
.topLogprobs(fromOptions.getTopLogProbs())
263266
.enhancements(fromOptions.getEnhancements())
264-
.toolContext(fromOptions.getToolContext())
267+
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
265268
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
266269
.streamOptions(fromOptions.getStreamOptions())
267-
.toolCallbacks(fromOptions.getToolCallbacks())
268-
.toolNames(fromOptions.getToolNames())
270+
.toolCallbacks(
271+
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
272+
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
269273
.build();
270274
}
271275

@@ -479,10 +483,44 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) {
479483
}
480484

481485
@Override
486+
@SuppressWarnings("unchecked")
482487
public AzureOpenAiChatOptions copy() {
483488
return fromOptions(this);
484489
}
485490

491+
@Override
492+
public boolean equals(Object o) {
493+
if (this == o) {
494+
return true;
495+
}
496+
if (!(o instanceof AzureOpenAiChatOptions that)) {
497+
return false;
498+
}
499+
return Objects.equals(this.logitBias, that.logitBias) && Objects.equals(this.user, that.user)
500+
&& Objects.equals(this.n, that.n) && Objects.equals(this.stop, that.stop)
501+
&& Objects.equals(this.deploymentName, that.deploymentName)
502+
&& Objects.equals(this.responseFormat, that.responseFormat)
503+
504+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
505+
&& Objects.equals(this.toolNames, that.toolNames)
506+
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
507+
&& Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs)
508+
&& Objects.equals(this.enhancements, that.enhancements)
509+
&& Objects.equals(this.streamOptions, that.streamOptions)
510+
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens)
511+
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
512+
&& Objects.equals(this.presencePenalty, that.presencePenalty)
513+
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP);
514+
}
515+
516+
@Override
517+
public int hashCode() {
518+
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
519+
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
520+
this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens,
521+
this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
522+
}
523+
486524
public static class Builder {
487525

488526
protected AzureOpenAiChatOptions options;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.azure.openai;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
23+
import com.azure.ai.openai.models.AzureChatGroundingEnhancementConfiguration;
24+
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
25+
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
26+
import org.junit.jupiter.api.Test;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
30+
/**
31+
* Tests for {@link AzureOpenAiChatOptions}.
32+
*
33+
* @author Alexandros Pappas
34+
*/
35+
class AzureOpenAiChatOptionsTests {
36+
37+
@Test
38+
void testBuilderWithAllFields() {
39+
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT;
40+
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
41+
streamOptions.setIncludeUsage(true);
42+
43+
AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration();
44+
enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true));
45+
enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true));
46+
47+
AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder()
48+
.deploymentName("test-deployment")
49+
.frequencyPenalty(0.5)
50+
.logitBias(Map.of("token1", 1, "token2", -1))
51+
.maxTokens(200)
52+
.N(2)
53+
.presencePenalty(0.8)
54+
.stop(List.of("stop1", "stop2"))
55+
.temperature(0.7)
56+
.topP(0.9)
57+
.user("test-user")
58+
.responseFormat(responseFormat)
59+
.seed(12345L)
60+
.logprobs(true)
61+
.topLogprobs(5)
62+
.enhancements(enhancements)
63+
.streamOptions(streamOptions)
64+
.build();
65+
66+
assertThat(options)
67+
.extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop",
68+
"temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements",
69+
"streamOptions")
70+
.containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8,
71+
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements,
72+
streamOptions);
73+
}
74+
75+
@Test
76+
void testCopy() {
77+
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT;
78+
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
79+
streamOptions.setIncludeUsage(true);
80+
81+
AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration();
82+
enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true));
83+
enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true));
84+
85+
AzureOpenAiChatOptions originalOptions = AzureOpenAiChatOptions.builder()
86+
.deploymentName("test-deployment")
87+
.frequencyPenalty(0.5)
88+
.logitBias(Map.of("token1", 1, "token2", -1))
89+
.maxTokens(200)
90+
.N(2)
91+
.presencePenalty(0.8)
92+
.stop(List.of("stop1", "stop2"))
93+
.temperature(0.7)
94+
.topP(0.9)
95+
.user("test-user")
96+
.responseFormat(responseFormat)
97+
.seed(12345L)
98+
.logprobs(true)
99+
.topLogprobs(5)
100+
.enhancements(enhancements)
101+
.streamOptions(streamOptions)
102+
.build();
103+
104+
AzureOpenAiChatOptions copiedOptions = originalOptions.copy();
105+
106+
assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions);
107+
// Ensure deep copy
108+
assertThat(copiedOptions.getStop()).isNotSameAs(originalOptions.getStop());
109+
assertThat(copiedOptions.getToolContext()).isNotSameAs(originalOptions.getToolContext());
110+
}
111+
112+
@Test
113+
void testSetters() {
114+
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT;
115+
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
116+
streamOptions.setIncludeUsage(true);
117+
AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration();
118+
119+
AzureOpenAiChatOptions options = new AzureOpenAiChatOptions();
120+
options.setDeploymentName("test-deployment");
121+
options.setFrequencyPenalty(0.5);
122+
options.setLogitBias(Map.of("token1", 1, "token2", -1));
123+
options.setMaxTokens(200);
124+
options.setN(2);
125+
options.setPresencePenalty(0.8);
126+
options.setStop(List.of("stop1", "stop2"));
127+
options.setTemperature(0.7);
128+
options.setTopP(0.9);
129+
options.setUser("test-user");
130+
options.setResponseFormat(responseFormat);
131+
options.setSeed(12345L);
132+
options.setLogprobs(true);
133+
options.setTopLogProbs(5);
134+
options.setEnhancements(enhancements);
135+
options.setStreamOptions(streamOptions);
136+
137+
assertThat(options.getDeploymentName()).isEqualTo("test-deployment");
138+
options.setModel("test-model");
139+
assertThat(options.getDeploymentName()).isEqualTo("test-model");
140+
141+
assertThat(options.getFrequencyPenalty()).isEqualTo(0.5);
142+
assertThat(options.getLogitBias()).isEqualTo(Map.of("token1", 1, "token2", -1));
143+
assertThat(options.getMaxTokens()).isEqualTo(200);
144+
assertThat(options.getN()).isEqualTo(2);
145+
assertThat(options.getPresencePenalty()).isEqualTo(0.8);
146+
assertThat(options.getStop()).isEqualTo(List.of("stop1", "stop2"));
147+
assertThat(options.getTemperature()).isEqualTo(0.7);
148+
assertThat(options.getTopP()).isEqualTo(0.9);
149+
assertThat(options.getUser()).isEqualTo("test-user");
150+
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
151+
assertThat(options.getSeed()).isEqualTo(12345L);
152+
assertThat(options.isLogprobs()).isTrue();
153+
assertThat(options.getTopLogProbs()).isEqualTo(5);
154+
assertThat(options.getEnhancements()).isEqualTo(enhancements);
155+
assertThat(options.getStreamOptions()).isEqualTo(streamOptions);
156+
assertThat(options.getModel()).isEqualTo("test-model");
157+
}
158+
159+
@Test
160+
void testDefaultValues() {
161+
AzureOpenAiChatOptions options = new AzureOpenAiChatOptions();
162+
163+
assertThat(options.getDeploymentName()).isNull();
164+
assertThat(options.getFrequencyPenalty()).isNull();
165+
assertThat(options.getLogitBias()).isNull();
166+
assertThat(options.getMaxTokens()).isNull();
167+
assertThat(options.getN()).isNull();
168+
assertThat(options.getPresencePenalty()).isNull();
169+
assertThat(options.getStop()).isNull();
170+
assertThat(options.getTemperature()).isNull();
171+
assertThat(options.getTopP()).isNull();
172+
assertThat(options.getUser()).isNull();
173+
assertThat(options.getResponseFormat()).isNull();
174+
assertThat(options.getSeed()).isNull();
175+
assertThat(options.isLogprobs()).isNull();
176+
assertThat(options.getTopLogProbs()).isNull();
177+
assertThat(options.getEnhancements()).isNull();
178+
assertThat(options.getStreamOptions()).isNull();
179+
assertThat(options.getModel()).isNull();
180+
}
181+
182+
}

0 commit comments

Comments
 (0)