Skip to content

Commit a24f152

Browse files
committed
refactor(ai): Improve maintainability and apply code conventions
1 parent db2d463 commit a24f152

File tree

13 files changed

+145
-118
lines changed

13 files changed

+145
-118
lines changed

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/client/ArconiaChatClient.java

-4
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ static ArconiaBuilder builder(ChatModel chatModel, ObservationRegistry observati
4040
return new DefaultArconiaChatClientBuilder(chatModel, observationRegistry, customObservationConvention);
4141
}
4242

43-
// @formatter:off
44-
4543
ArconiaChatClientRequestSpec prompt();
4644

4745
ArconiaChatClientRequestSpec prompt(String content);
@@ -184,6 +182,4 @@ default <I, O> Builder defaultFunction(String name, String description,
184182
ArconiaChatClient build();
185183
}
186184

187-
// @formatter:on
188-
189185
}

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/client/DefaultArconiaChatClient.java

+26-69
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import java.nio.charset.Charset;
66
import java.util.ArrayList;
77
import java.util.Arrays;
8-
import java.util.Collections;
98
import java.util.HashMap;
109
import java.util.List;
1110
import java.util.Map;
@@ -19,13 +18,7 @@
1918
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
2019
import org.springframework.ai.chat.client.DefaultChatClient.DefaultStreamResponseSpec;
2120
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
22-
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
23-
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
2421
import org.springframework.ai.chat.client.advisor.api.Advisor;
25-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
26-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
27-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
28-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
2922
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
3023
import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention;
3124
import org.springframework.ai.chat.messages.Message;
@@ -34,16 +27,14 @@
3427
import org.springframework.ai.chat.prompt.Prompt;
3528
import org.springframework.ai.model.Media;
3629
import org.springframework.ai.model.function.FunctionCallback;
37-
import org.springframework.core.Ordered;
3830
import org.springframework.core.io.Resource;
3931
import org.springframework.lang.Nullable;
4032
import org.springframework.util.Assert;
4133
import org.springframework.util.MimeType;
4234
import org.springframework.util.StringUtils;
4335

44-
import reactor.core.publisher.Flux;
45-
import reactor.core.scheduler.Schedulers;
46-
36+
import io.arconia.ai.core.client.advisor.CallAdvisor;
37+
import io.arconia.ai.core.client.advisor.StreamAdvisor;
4738
import io.arconia.ai.core.tools.ToolCallback;
4839
import io.arconia.ai.core.tools.ToolCallbacks;
4940

@@ -52,7 +43,8 @@
5243
*/
5344
public class DefaultArconiaChatClient implements ArconiaChatClient {
5445

55-
private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION = new DefaultChatClientObservationConvention();
46+
private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION
47+
= new DefaultChatClientObservationConvention();
5648

5749
private final DefaultArconiaChatClientRequestSpec defaultArconiaChatClientRequest;
5850

@@ -142,8 +134,7 @@ public ArconiaPromptUserSpec text(Resource text, Charset charset) {
142134
Assert.notNull(charset, "charset cannot be null");
143135
try {
144136
this.text(text.getContentAsString(charset));
145-
}
146-
catch (IOException e) {
137+
} catch (IOException e) {
147138
throw new RuntimeException(e);
148139
}
149140
return this;
@@ -208,8 +199,7 @@ public ArconiaPromptSystemSpec text(Resource text, Charset charset) {
208199
Assert.notNull(charset, "charset cannot be null");
209200
try {
210201
this.text(text.getContentAsString(charset));
211-
}
212-
catch (IOException e) {
202+
} catch (IOException e) {
213203
throw new RuntimeException(e);
214204
}
215205
return this;
@@ -393,59 +383,23 @@ public DefaultArconiaChatClientRequestSpec(ChatModel chatModel, @Nullable String
393383
this.userParams.putAll(userParams);
394384
this.systemText = systemText;
395385
this.systemParams.putAll(systemParams);
386+
this.messages.addAll(messages);
387+
this.media.addAll(media);
396388

397389
this.toolNames.addAll(toolNames);
398390
this.toolCallbacks.addAll(toolCallbacks);
399-
this.messages.addAll(messages);
400-
this.media.addAll(media);
401-
this.advisors.addAll(advisors);
402-
this.advisorParams.putAll(advisorParams);
391+
this.toolContext.putAll(toolContext);
392+
403393
this.observationRegistry = observationRegistry;
404394
this.customObservationConvention = customObservationConvention != null ? customObservationConvention
405395
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
406-
this.toolContext.putAll(toolContext);
407396

408-
this.advisors.add(new CallAroundAdvisor() {
409-
@Override
410-
public String getName() {
411-
return CallAroundAdvisor.class.getSimpleName();
412-
}
413-
414-
@Override
415-
public int getOrder() {
416-
return Ordered.LOWEST_PRECEDENCE;
417-
}
418-
419-
@Override
420-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
421-
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()),
422-
Collections.unmodifiableMap(advisedRequest.adviseContext()));
423-
}
424-
});
425-
426-
this.advisors.add(new StreamAroundAdvisor() {
427-
@Override
428-
public String getName() {
429-
return StreamAroundAdvisor.class.getSimpleName();
430-
}
431-
432-
@Override
433-
public int getOrder() {
434-
return Ordered.LOWEST_PRECEDENCE;
435-
}
436-
437-
@Override
438-
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest,
439-
StreamAroundAdvisorChain chain) {
440-
return chatModel.stream(advisedRequest.toPrompt())
441-
.map(chatResponse -> new AdvisedResponse(chatResponse,
442-
Collections.unmodifiableMap(advisedRequest.adviseContext())))
443-
.publishOn(Schedulers.boundedElastic());
444-
}
445-
});
446-
447-
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry)
448-
.pushAll(this.advisors);
397+
this.advisors.addAll(advisors);
398+
this.advisorParams.putAll(advisorParams);
399+
this.advisors.add(new CallAdvisor(chatModel));
400+
this.advisors.add(new StreamAdvisor(chatModel));
401+
402+
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry).pushAll(this.advisors);
449403
}
450404

451405
private ObservationRegistry getObservationRegistry() {
@@ -516,12 +470,16 @@ public ArconiaBuilder mutate() {
516470
.defaultTools(StringUtils.toStringArray(this.toolNames));
517471

518472
if (StringUtils.hasText(this.userText)) {
519-
builder.defaultUser(
520-
u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0])));
473+
builder.defaultUser(u -> u
474+
.text(this.userText)
475+
.params(this.userParams)
476+
.media(this.media.toArray(new Media[0])));
521477
}
522478

523479
if (StringUtils.hasText(this.systemText)) {
524-
builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams));
480+
builder.defaultSystem(s -> s
481+
.text(this.systemText)
482+
.params(this.systemParams));
525483
}
526484

527485
if (this.chatOptions != null) {
@@ -657,10 +615,10 @@ public ArconiaChatClientRequestSpec system(Resource text, Charset charset) {
657615

658616
try {
659617
this.systemText = text.getContentAsString(charset);
660-
}
661-
catch (IOException e) {
618+
} catch (IOException e) {
662619
throw new RuntimeException(e);
663620
}
621+
664622
return this;
665623
}
666624

@@ -696,8 +654,7 @@ public ArconiaChatClientRequestSpec user(Resource text, Charset charset) {
696654

697655
try {
698656
this.userText = text.getContentAsString(charset);
699-
}
700-
catch (IOException e) {
657+
} catch (IOException e) {
701658
throw new RuntimeException(e);
702659
}
703660
return this;

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/client/DefaultArconiaChatClientBuilder.java

+5-8
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import io.arconia.ai.core.tools.ToolCallbacks;
2727

2828
/**
29-
* Default implementation of {@link ArconiaChatClient.ArconiaBuilder} based on
30-
* {@link DefaultChatClientBuilder}.
29+
* Default implementation of {@link ArconiaChatClient.ArconiaBuilder}
30+
* based on {@link DefaultChatClientBuilder}.
3131
*/
3232
public class DefaultArconiaChatClientBuilder implements ArconiaChatClient.ArconiaBuilder {
3333

@@ -81,8 +81,7 @@ public ArconiaChatClient.ArconiaBuilder defaultUser(Resource text, Charset chars
8181
Assert.notNull(charset, "charset cannot be null");
8282
try {
8383
this.arconiaRequest.user(text.getContentAsString(charset));
84-
}
85-
catch (IOException e) {
84+
} catch (IOException e) {
8685
throw new RuntimeException(e);
8786
}
8887
return this;
@@ -112,8 +111,7 @@ public ArconiaChatClient.ArconiaBuilder defaultSystem(Resource text, Charset cha
112111
Assert.notNull(charset, "charset cannot be null");
113112
try {
114113
this.arconiaRequest.system(text.getContentAsString(charset));
115-
}
116-
catch (IOException e) {
114+
} catch (IOException e) {
117115
throw new RuntimeException(e);
118116
}
119117
return this;
@@ -161,8 +159,7 @@ public ArconiaChatClient.ArconiaBuilder defaultFunctions(String... functionNames
161159

162160
@Override
163161
public ArconiaChatClient.ArconiaBuilder defaultFunctions(FunctionCallback... functionCallbacks) {
164-
return defaultToolCallbacks(
165-
Stream.of(functionCallbacks).map(f -> (ToolCallback) f).toArray(ToolCallback[]::new));
162+
return defaultToolCallbacks(Stream.of(functionCallbacks).map(f -> (ToolCallback) f).toArray(ToolCallback[]::new));
166163
}
167164

168165
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.arconia.ai.core.client.advisor;
2+
3+
import java.util.Map;
4+
5+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
6+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
7+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
8+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
9+
import org.springframework.ai.chat.model.ChatModel;
10+
import org.springframework.ai.chat.model.ChatResponse;
11+
import org.springframework.core.Ordered;
12+
13+
/**
14+
* An advisor that calls a {@link ChatModel}.
15+
*/
16+
public class CallAdvisor implements CallAroundAdvisor {
17+
18+
private final ChatModel chatModel;
19+
20+
public CallAdvisor(ChatModel chatModel) {
21+
this.chatModel = chatModel;
22+
}
23+
24+
@Override
25+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
26+
ChatResponse chatResponse = chatModel.call(advisedRequest.toPrompt());
27+
return new AdvisedResponse(chatResponse, Map.copyOf(advisedRequest.adviseContext()));
28+
}
29+
30+
@Override
31+
public String getName() {
32+
return CallAdvisor.class.getSimpleName();
33+
}
34+
35+
@Override
36+
public int getOrder() {
37+
return Ordered.LOWEST_PRECEDENCE;
38+
}
39+
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package io.arconia.ai.core.client.advisor;
2+
3+
import java.util.Map;
4+
5+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
6+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
7+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
8+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
9+
import org.springframework.ai.chat.model.ChatModel;
10+
import org.springframework.core.Ordered;
11+
12+
import reactor.core.publisher.Flux;
13+
import reactor.core.scheduler.Schedulers;
14+
15+
/**
16+
* An advisor that calls a {@link ChatModel} in stream mode.
17+
*/
18+
public class StreamAdvisor implements StreamAroundAdvisor {
19+
20+
private final ChatModel chatModel;
21+
22+
public StreamAdvisor(ChatModel chatModel) {
23+
this.chatModel = chatModel;
24+
}
25+
26+
@Override
27+
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
28+
return chatModel.stream(advisedRequest.toPrompt())
29+
.map(chatResponse -> new AdvisedResponse(chatResponse, Map.copyOf(advisedRequest.adviseContext())))
30+
.publishOn(Schedulers.boundedElastic());
31+
}
32+
33+
@Override
34+
public String getName() {
35+
return StreamAdvisor.class.getSimpleName();
36+
}
37+
38+
@Override
39+
public int getOrder() {
40+
return Ordered.LOWEST_PRECEDENCE;
41+
}
42+
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@NonNullApi
2+
@NonNullFields
3+
package io.arconia.ai.core.client.advisor;
4+
5+
import org.springframework.lang.NonNullApi;
6+
import org.springframework.lang.NonNullFields;

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallback.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,4 @@
55
/**
66
* Wrapper for {@link FunctionCallback} to identify tools in Spring AI.
77
*/
8-
public interface ToolCallback extends FunctionCallback {
9-
10-
}
8+
public interface ToolCallback extends FunctionCallback {}

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallbacks.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
*/
88
public final class ToolCallbacks {
99

10-
private ToolCallbacks() {
11-
}
10+
private ToolCallbacks() {}
1211

1312
public static ToolCallback[] from(Object... sources) {
1413
return MethodToolCallbackProvider.builder().toolObjects(sources).build().getToolCallbacks();

0 commit comments

Comments
 (0)