Skip to content

Commit a62ad66

Browse files
committed
feat(ai): More robust tool support for @Tool-annotated methods
1 parent fe7cba5 commit a62ad66

File tree

9 files changed

+77
-171
lines changed

9 files changed

+77
-171
lines changed

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/client/ArconiaChatClient.java

-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.springframework.util.Assert;
2424

2525
import io.arconia.ai.core.tools.ToolCallback;
26-
import io.arconia.ai.core.tools.ToolCallbackProvider;
2726

2827
/**
2928
* A {@link ChatClient} enhanced for more advanced features.
@@ -94,8 +93,6 @@ interface ArconiaChatClientRequestSpec extends ChatClientRequestSpec {
9493

9594
ArconiaChatClientRequestSpec toolCallbacks(ToolCallback... toolCallbacks);
9695

97-
ArconiaChatClientRequestSpec toolCallbackProviders(ToolCallbackProvider... toolCallbackProviders);
98-
9996
ArconiaChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
10097

10198
ArconiaChatClientRequestSpec functions(String... functionBeanNames);
@@ -166,8 +163,6 @@ interface ArconiaBuilder extends Builder {
166163

167164
ArconiaBuilder defaultToolCallbacks(ToolCallback... toolCallbacks);
168165

169-
ArconiaBuilder defaultToolCallbackProviders(ToolCallbackProvider... toolCallbackProviders);
170-
171166
ArconiaBuilder defaultFunctions(String... functionNames);
172167

173168
ArconiaBuilder defaultFunctions(FunctionCallback... functionCallbacks);

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/client/DefaultArconiaChatClient.java

+4-16
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
import reactor.core.scheduler.Schedulers;
4646

4747
import io.arconia.ai.core.tools.ToolCallback;
48-
import io.arconia.ai.core.tools.ToolCallbackProvider;
49-
import io.arconia.ai.core.tools.method.MethodToolCallbackProvider;
48+
import io.arconia.ai.core.tools.ToolCallbacks;
5049

5150
/**
5251
* Default implementation of {@link ArconiaChatClient} based on {@link DefaultChatClient}.
@@ -606,18 +605,15 @@ public ArconiaChatClientRequestSpec tools(String... toolNames) {
606605

607606
@Override
608607
public ArconiaChatClientRequestSpec tools(Class<?>... toolBoxes) {
609-
Assert.notNull(toolBoxes, "toolBoxes cannot be null");
610-
Assert.noNullElements(toolBoxes, "toolBoxes cannot contain null elements");
611-
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
612-
return toolCallbackProviders(toolCallbackProvider);
608+
throw new UnsupportedOperationException("Not yet supported");
613609
}
614610

615611
@Override
616612
public ArconiaChatClientRequestSpec tools(Object... toolBoxes) {
617613
Assert.notNull(toolBoxes, "toolBoxes cannot be null");
618614
Assert.noNullElements(toolBoxes, "toolBoxes cannot contain null elements");
619-
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
620-
return toolCallbackProviders(toolCallbackProvider);
615+
this.toolCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolBoxes)));
616+
return this;
621617
}
622618

623619
@Override
@@ -628,14 +624,6 @@ public ArconiaChatClientRequestSpec toolCallbacks(ToolCallback... toolCallbacks)
628624
return this;
629625
}
630626

631-
@Override
632-
public ArconiaChatClientRequestSpec toolCallbackProviders(ToolCallbackProvider... toolCallbackProviders) {
633-
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
634-
this.toolCallbacks.addAll(Arrays.asList(toolCallbackProvider.getToolCallbacks()));
635-
}
636-
return this;
637-
}
638-
639627
@Override
640628
public ArconiaChatClientRequestSpec functions(String... functionBeanNames) {
641629
return tools(functionBeanNames);

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/client/DefaultArconiaChatClientBuilder.java

+4-15
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323

2424
import io.arconia.ai.core.client.DefaultArconiaChatClient.DefaultArconiaChatClientRequestSpec;
2525
import io.arconia.ai.core.tools.ToolCallback;
26-
import io.arconia.ai.core.tools.ToolCallbackProvider;
27-
import io.arconia.ai.core.tools.method.MethodToolCallbackProvider;
26+
import io.arconia.ai.core.tools.ToolCallbacks;
2827

2928
/**
3029
* Default implementation of {@link ArconiaChatClient.ArconiaBuilder} based on
@@ -140,14 +139,13 @@ public ArconiaChatClient.ArconiaBuilder defaultTools(String... toolNames) {
140139

141140
@Override
142141
public ArconiaChatClient.ArconiaBuilder defaultTools(Class<?>... toolBoxes) {
143-
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
144-
return defaultToolCallbackProviders(toolCallbackProvider);
142+
throw new UnsupportedOperationException("Not yet supported");
145143
}
146144

147145
@Override
148146
public ArconiaChatClient.ArconiaBuilder defaultTools(Object... toolBoxes) {
149-
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
150-
return defaultToolCallbackProviders(toolCallbackProvider);
147+
this.arconiaRequest.functions(ToolCallbacks.from(toolBoxes));
148+
return this;
151149
}
152150

153151
@Override
@@ -156,15 +154,6 @@ public ArconiaChatClient.ArconiaBuilder defaultToolCallbacks(ToolCallback... too
156154
return this;
157155
}
158156

159-
@Override
160-
public ArconiaChatClient.ArconiaBuilder defaultToolCallbackProviders(
161-
ToolCallbackProvider... toolCallbackProviders) {
162-
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
163-
this.arconiaRequest.functions(toolCallbackProvider.getToolCallbacks());
164-
}
165-
return this;
166-
}
167-
168157
@Override
169158
public ArconiaChatClient.ArconiaBuilder defaultFunctions(String... functionNames) {
170159
return defaultTools(functionNames);

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallbacks.java

+1-5
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,8 @@ public final class ToolCallbacks {
1010
private ToolCallbacks() {
1111
}
1212

13-
public static ToolCallback[] from(Class<?>... sources) {
14-
return MethodToolCallbackProvider.builder().sources(sources).build().getToolCallbacks();
15-
}
16-
1713
public static ToolCallback[] from(Object... sources) {
18-
return MethodToolCallbackProvider.builder().sources(sources).build().getToolCallbacks();
14+
return MethodToolCallbackProvider.builder().toolObjects(sources).build().getToolCallbacks();
1915
}
2016

2117
}

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallbackProvider.java

+33-78
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package io.arconia.ai.core.tools.method;
22

33
import java.lang.reflect.Method;
4-
import java.lang.reflect.Modifier;
54
import java.util.Arrays;
65
import java.util.List;
76
import java.util.function.Consumer;
@@ -12,7 +11,6 @@
1211

1312
import org.slf4j.Logger;
1413
import org.slf4j.LoggerFactory;
15-
import org.springframework.lang.Nullable;
1614
import org.springframework.util.Assert;
1715
import org.springframework.util.ClassUtils;
1816
import org.springframework.util.ReflectionUtils;
@@ -31,115 +29,72 @@ public class MethodToolCallbackProvider implements ToolCallbackProvider {
3129

3230
private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class);
3331

34-
@Nullable
35-
private final List<Object> sourceObjects;
32+
private final List<Object> toolObjects;
3633

37-
@Nullable
38-
private final List<Class<?>> sourceTypes;
39-
40-
private MethodToolCallbackProvider(@Nullable List<Object> sourceObjects, @Nullable List<Class<?>> sourceTypes) {
41-
Assert.isTrue(sourceObjects != null || sourceTypes != null, "sourceObjects or sourceTypes cannot be null");
42-
if (sourceObjects != null) {
43-
Assert.noNullElements(sourceObjects, "sourceObjects cannot contain null elements");
44-
}
45-
if (sourceTypes != null) {
46-
Assert.noNullElements(sourceTypes, "sourceTypes cannot contain null elements");
47-
}
48-
this.sourceObjects = sourceObjects;
49-
this.sourceTypes = sourceTypes;
34+
private MethodToolCallbackProvider(List<Object> toolObjects) {
35+
Assert.notNull(toolObjects, "toolObjects cannot be null");
36+
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
37+
this.toolObjects = toolObjects;
5038
}
5139

5240
@Override
5341
public ToolCallback[] getToolCallbacks() {
54-
if (sourceObjects != null) {
55-
return getToolCallbacksFromObjects();
56-
}
57-
return getToolCallbacksFromTypes();
58-
}
59-
60-
private ToolCallback[] getToolCallbacksFromObjects() {
61-
var toolCallbacks = sourceObjects.stream()
62-
.map(sourceObject -> getDeclaredMethodsWithToolAnnotation(sourceObject.getClass())
63-
.map(method -> MethodToolCallback.builder()
64-
.toolMetadata(ToolMetadata.from(method))
65-
.toolMethod(method)
66-
.toolObject(sourceObject)
67-
.build())
68-
.toArray(ToolCallback[]::new))
69-
.flatMap(Stream::of)
70-
.toArray(ToolCallback[]::new);
71-
72-
if (ToolUtils.hasDuplicateToolNames(toolCallbacks)) {
73-
throw new IllegalStateException("Multiple tools with the same name found in sources: "
74-
+ sourceObjects.stream().map(o -> o.getClass().getName()).collect(Collectors.joining(", ")));
75-
}
76-
77-
return toolCallbacks;
78-
}
79-
80-
private ToolCallback[] getToolCallbacksFromTypes() {
81-
var toolCallbacks = sourceTypes.stream()
82-
.map(sourceType -> getDeclaredMethodsWithToolAnnotation(sourceType)
83-
.filter(method -> Modifier.isStatic(method.getModifiers()))
84-
.map(method -> MethodToolCallback.builder()
85-
.toolMetadata(ToolMetadata.from(method))
86-
.toolMethod(method)
42+
var toolCallbacks = toolObjects.stream()
43+
.map(toolObject -> Stream.of(ReflectionUtils.getDeclaredMethods(toolObject.getClass()))
44+
.filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class))
45+
.filter(toolMethod -> !isFunctionalType(toolMethod))
46+
.map(toolMethod -> MethodToolCallback.builder()
47+
.toolMetadata(ToolMetadata.from(toolMethod))
48+
.toolMethod(toolMethod)
49+
.toolObject(toolObject)
8750
.build())
8851
.toArray(ToolCallback[]::new))
8952
.flatMap(Stream::of)
9053
.toArray(ToolCallback[]::new);
9154

92-
if (ToolUtils.hasDuplicateToolNames(toolCallbacks)) {
93-
throw new IllegalStateException("Multiple tools with the same name found in sources: "
94-
+ sourceTypes.stream().map(Class::getName).collect(Collectors.joining(", ")));
95-
}
55+
validateToolCallbacks(toolCallbacks);
9656

9757
return toolCallbacks;
9858
}
9959

100-
private Stream<Method> getDeclaredMethodsWithToolAnnotation(Class<?> sourceType) {
101-
return Stream.of(ReflectionUtils.getDeclaredMethods(sourceType))
102-
.filter(method -> method.isAnnotationPresent(Tool.class))
103-
.filter(method -> !isFunctionalType(method));
104-
}
105-
106-
private static boolean isFunctionalType(Method method) {
107-
var isFunction = ClassUtils.isAssignable(method.getReturnType(), Function.class)
108-
|| ClassUtils.isAssignable(method.getReturnType(), Supplier.class)
109-
|| ClassUtils.isAssignable(method.getReturnType(), Consumer.class);
60+
private static boolean isFunctionalType(Method toolMethod) {
61+
var isFunction = ClassUtils.isAssignable(toolMethod.getReturnType(), Function.class)
62+
|| ClassUtils.isAssignable(toolMethod.getReturnType(), Supplier.class)
63+
|| ClassUtils.isAssignable(toolMethod.getReturnType(), Consumer.class);
11064

11165
if (isFunction) {
11266
logger.warn("Method {} is annotated with @Tool but returns a functional type. "
113-
+ "This is not supported and the method will be ignored.", method.getName());
67+
+ "This is not supported and the method will be ignored.", toolMethod.getName());
11468
}
11569

11670
return isFunction;
11771
}
11872

73+
private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
74+
List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
75+
if (!duplicateToolNames.isEmpty()) {
76+
throw new IllegalStateException("Multiple tools with the same name (%s) found in sources: %s".formatted(
77+
String.join(", ", duplicateToolNames),
78+
toolObjects.stream().map(o -> o.getClass().getName()).collect(Collectors.joining(", "))));
79+
}
80+
}
81+
11982
public static Builder builder() {
12083
return new Builder();
12184
}
12285

12386
public static class Builder {
12487

125-
private List<Object> sourceObjects;
126-
127-
private List<Class<?>> sourceTypes;
128-
129-
public Builder sources(Object... sourceObjects) {
130-
Assert.isNull(this.sourceTypes, "only one of sourceObjects or sourceTypes can be set");
131-
this.sourceObjects = Arrays.asList(sourceObjects);
132-
return this;
133-
}
88+
private List<Object> toolObjects;
13489

135-
public Builder sources(Class<?>... sourceTypes) {
136-
Assert.isNull(this.sourceObjects, "only one of sourceObjects or sourceTypes can be set");
137-
this.sourceTypes = Arrays.asList(sourceTypes);
90+
public Builder toolObjects(Object... toolObjects) {
91+
Assert.notNull(toolObjects, "toolObjects cannot be null");
92+
this.toolObjects = Arrays.asList(toolObjects);
13893
return this;
13994
}
14095

14196
public MethodToolCallbackProvider build() {
142-
return new MethodToolCallbackProvider(sourceObjects, sourceTypes);
97+
return new MethodToolCallbackProvider(toolObjects);
14398
}
14499

145100
}

Diff for: arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/util/ToolUtils.java

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package io.arconia.ai.core.tools.util;
22

3+
import java.util.List;
4+
import java.util.Map;
5+
import java.util.stream.Collectors;
36
import java.util.stream.Stream;
47

58
import org.springframework.ai.model.function.FunctionCallback;
@@ -35,11 +38,14 @@ public static FunctionCallback.SchemaType getToolSchemaType(@Nullable Tool tool)
3538
return tool.schemaType();
3639
}
3740

38-
public static boolean hasDuplicateToolNames(FunctionCallback... functionCallbacks) {
41+
public static List<String> getDuplicateToolNames(FunctionCallback... functionCallbacks) {
3942
return Stream.of(functionCallbacks)
40-
.map(FunctionCallback::getName)
41-
.distinct()
42-
.count() != functionCallbacks.length;
43+
.collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting()))
44+
.entrySet()
45+
.stream()
46+
.filter(entry -> entry.getValue() > 1)
47+
.map(Map.Entry::getKey)
48+
.collect(Collectors.toList());
4349
}
4450

4551
}

Diff for: arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.arconia.ai.core.tools;
22

3-
import io.arconia.ai.core.tools.util.ToolUtils;
3+
import java.util.List;
44

55
import org.junit.jupiter.api.Test;
66

7+
import io.arconia.ai.core.tools.util.ToolUtils;
8+
79
import static org.assertj.core.api.Assertions.assertThat;
810

911
/**
@@ -16,15 +18,22 @@ void shouldDetectDuplicateToolNames() {
1618
ToolCallback callback1 = new TestToolCallback("tool_a");
1719
ToolCallback callback2 = new TestToolCallback("tool_a");
1820
ToolCallback callback3 = new TestToolCallback("tool_b");
19-
assertThat(ToolUtils.hasDuplicateToolNames(callback1, callback2, callback3)).isTrue();
21+
22+
List<String> duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3);
23+
24+
assertThat(duplicates).isNotEmpty();
25+
assertThat(duplicates).contains("tool_a");
2026
}
2127

2228
@Test
2329
void shouldNotDetectDuplicateToolNames() {
2430
ToolCallback callback1 = new TestToolCallback("tool_a");
2531
ToolCallback callback2 = new TestToolCallback("tool_b");
2632
ToolCallback callback3 = new TestToolCallback("tool_c");
27-
assertThat(ToolUtils.hasDuplicateToolNames(callback1, callback2, callback3)).isFalse();
33+
34+
List<String> duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3);
35+
36+
assertThat(duplicates).isEmpty();
2837
}
2938

3039
static class TestToolCallback implements ToolCallback {

0 commit comments

Comments
 (0)