|
5 | 5 | import java.nio.charset.Charset;
|
6 | 6 | import java.util.ArrayList;
|
7 | 7 | import java.util.Arrays;
|
8 |
| -import java.util.Collections; |
9 | 8 | import java.util.HashMap;
|
10 | 9 | import java.util.List;
|
11 | 10 | import java.util.Map;
|
|
19 | 18 | import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
|
20 | 19 | import org.springframework.ai.chat.client.DefaultChatClient.DefaultStreamResponseSpec;
|
21 | 20 | import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
|
22 |
| -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; |
23 |
| -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; |
24 | 21 | import org.springframework.ai.chat.client.advisor.api.Advisor;
|
25 |
| -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; |
26 |
| -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; |
27 |
| -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; |
28 |
| -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; |
29 | 22 | import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
|
30 | 23 | import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention;
|
31 | 24 | import org.springframework.ai.chat.messages.Message;
|
|
34 | 27 | import org.springframework.ai.chat.prompt.Prompt;
|
35 | 28 | import org.springframework.ai.model.Media;
|
36 | 29 | import org.springframework.ai.model.function.FunctionCallback;
|
37 |
| -import org.springframework.core.Ordered; |
38 | 30 | import org.springframework.core.io.Resource;
|
39 | 31 | import org.springframework.lang.Nullable;
|
40 | 32 | import org.springframework.util.Assert;
|
41 | 33 | import org.springframework.util.MimeType;
|
42 | 34 | import org.springframework.util.StringUtils;
|
43 | 35 |
|
44 |
| -import reactor.core.publisher.Flux; |
45 |
| -import reactor.core.scheduler.Schedulers; |
46 |
| - |
| 36 | +import io.arconia.ai.core.client.advisor.CallAdvisor; |
| 37 | +import io.arconia.ai.core.client.advisor.StreamAdvisor; |
47 | 38 | import io.arconia.ai.core.tools.ToolCallback;
|
48 | 39 | import io.arconia.ai.core.tools.ToolCallbacks;
|
49 | 40 |
|
|
52 | 43 | */
|
53 | 44 | public class DefaultArconiaChatClient implements ArconiaChatClient {
|
54 | 45 |
|
55 |
| - private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION = new DefaultChatClientObservationConvention(); |
| 46 | + private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION |
| 47 | + = new DefaultChatClientObservationConvention(); |
56 | 48 |
|
57 | 49 | private final DefaultArconiaChatClientRequestSpec defaultArconiaChatClientRequest;
|
58 | 50 |
|
@@ -142,8 +134,7 @@ public ArconiaPromptUserSpec text(Resource text, Charset charset) {
|
142 | 134 | Assert.notNull(charset, "charset cannot be null");
|
143 | 135 | try {
|
144 | 136 | this.text(text.getContentAsString(charset));
|
145 |
| - } |
146 |
| - catch (IOException e) { |
| 137 | + } catch (IOException e) { |
147 | 138 | throw new RuntimeException(e);
|
148 | 139 | }
|
149 | 140 | return this;
|
@@ -208,8 +199,7 @@ public ArconiaPromptSystemSpec text(Resource text, Charset charset) {
|
208 | 199 | Assert.notNull(charset, "charset cannot be null");
|
209 | 200 | try {
|
210 | 201 | this.text(text.getContentAsString(charset));
|
211 |
| - } |
212 |
| - catch (IOException e) { |
| 202 | + } catch (IOException e) { |
213 | 203 | throw new RuntimeException(e);
|
214 | 204 | }
|
215 | 205 | return this;
|
@@ -393,59 +383,23 @@ public DefaultArconiaChatClientRequestSpec(ChatModel chatModel, @Nullable String
|
393 | 383 | this.userParams.putAll(userParams);
|
394 | 384 | this.systemText = systemText;
|
395 | 385 | this.systemParams.putAll(systemParams);
|
| 386 | + this.messages.addAll(messages); |
| 387 | + this.media.addAll(media); |
396 | 388 |
|
397 | 389 | this.toolNames.addAll(toolNames);
|
398 | 390 | this.toolCallbacks.addAll(toolCallbacks);
|
399 |
| - this.messages.addAll(messages); |
400 |
| - this.media.addAll(media); |
401 |
| - this.advisors.addAll(advisors); |
402 |
| - this.advisorParams.putAll(advisorParams); |
| 391 | + this.toolContext.putAll(toolContext); |
| 392 | + |
403 | 393 | this.observationRegistry = observationRegistry;
|
404 | 394 | this.customObservationConvention = customObservationConvention != null ? customObservationConvention
|
405 | 395 | : DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
|
406 |
| - this.toolContext.putAll(toolContext); |
407 | 396 |
|
408 |
| - this.advisors.add(new CallAroundAdvisor() { |
409 |
| - @Override |
410 |
| - public String getName() { |
411 |
| - return CallAroundAdvisor.class.getSimpleName(); |
412 |
| - } |
413 |
| - |
414 |
| - @Override |
415 |
| - public int getOrder() { |
416 |
| - return Ordered.LOWEST_PRECEDENCE; |
417 |
| - } |
418 |
| - |
419 |
| - @Override |
420 |
| - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { |
421 |
| - return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), |
422 |
| - Collections.unmodifiableMap(advisedRequest.adviseContext())); |
423 |
| - } |
424 |
| - }); |
425 |
| - |
426 |
| - this.advisors.add(new StreamAroundAdvisor() { |
427 |
| - @Override |
428 |
| - public String getName() { |
429 |
| - return StreamAroundAdvisor.class.getSimpleName(); |
430 |
| - } |
431 |
| - |
432 |
| - @Override |
433 |
| - public int getOrder() { |
434 |
| - return Ordered.LOWEST_PRECEDENCE; |
435 |
| - } |
436 |
| - |
437 |
| - @Override |
438 |
| - public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, |
439 |
| - StreamAroundAdvisorChain chain) { |
440 |
| - return chatModel.stream(advisedRequest.toPrompt()) |
441 |
| - .map(chatResponse -> new AdvisedResponse(chatResponse, |
442 |
| - Collections.unmodifiableMap(advisedRequest.adviseContext()))) |
443 |
| - .publishOn(Schedulers.boundedElastic()); |
444 |
| - } |
445 |
| - }); |
446 |
| - |
447 |
| - this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry) |
448 |
| - .pushAll(this.advisors); |
| 397 | + this.advisors.addAll(advisors); |
| 398 | + this.advisorParams.putAll(advisorParams); |
| 399 | + this.advisors.add(new CallAdvisor(chatModel)); |
| 400 | + this.advisors.add(new StreamAdvisor(chatModel)); |
| 401 | + |
| 402 | + this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry).pushAll(this.advisors); |
449 | 403 | }
|
450 | 404 |
|
451 | 405 | private ObservationRegistry getObservationRegistry() {
|
@@ -516,12 +470,16 @@ public ArconiaBuilder mutate() {
|
516 | 470 | .defaultTools(StringUtils.toStringArray(this.toolNames));
|
517 | 471 |
|
518 | 472 | if (StringUtils.hasText(this.userText)) {
|
519 |
| - builder.defaultUser( |
520 |
| - u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0]))); |
| 473 | + builder.defaultUser(u -> u |
| 474 | + .text(this.userText) |
| 475 | + .params(this.userParams) |
| 476 | + .media(this.media.toArray(new Media[0]))); |
521 | 477 | }
|
522 | 478 |
|
523 | 479 | if (StringUtils.hasText(this.systemText)) {
|
524 |
| - builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams)); |
| 480 | + builder.defaultSystem(s -> s |
| 481 | + .text(this.systemText) |
| 482 | + .params(this.systemParams)); |
525 | 483 | }
|
526 | 484 |
|
527 | 485 | if (this.chatOptions != null) {
|
@@ -657,10 +615,10 @@ public ArconiaChatClientRequestSpec system(Resource text, Charset charset) {
|
657 | 615 |
|
658 | 616 | try {
|
659 | 617 | this.systemText = text.getContentAsString(charset);
|
660 |
| - } |
661 |
| - catch (IOException e) { |
| 618 | + } catch (IOException e) { |
662 | 619 | throw new RuntimeException(e);
|
663 | 620 | }
|
| 621 | + |
664 | 622 | return this;
|
665 | 623 | }
|
666 | 624 |
|
@@ -696,8 +654,7 @@ public ArconiaChatClientRequestSpec user(Resource text, Charset charset) {
|
696 | 654 |
|
697 | 655 | try {
|
698 | 656 | this.userText = text.getContentAsString(charset);
|
699 |
| - } |
700 |
| - catch (IOException e) { |
| 657 | + } catch (IOException e) { |
701 | 658 | throw new RuntimeException(e);
|
702 | 659 | }
|
703 | 660 | return this;
|
|
0 commit comments