Skip to content

Commit a191cfa

Browse files
authored
Merge pull request #18 from cnblogs/support-function-behavior
feat: support auto invoke functions for non-stream chat
2 parents b2d1e95 + 041318f commit a191cfa

26 files changed

+925
-60
lines changed

src/KernelMemory.DashScope/DashScopeTextEmbeddingGenerator.cs

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using Cnblogs.DashScope.Sdk;
2-
using Cnblogs.DashScope.Sdk.TextEmbedding;
1+
using Cnblogs.DashScope.Core;
32
using Microsoft.KernelMemory;
43
using Microsoft.KernelMemory.AI;
54

@@ -30,7 +29,13 @@ public async Task<Embedding> GenerateEmbeddingAsync(
3029
string text,
3130
CancellationToken cancellationToken = new())
3231
{
33-
var result = await dashScopeClient.GetTextEmbeddingsAsync(modelId, [text], null, cancellationToken);
32+
var result = await dashScopeClient.GetEmbeddingsAsync(
33+
new ModelRequest<TextEmbeddingInput, ITextEmbeddingParameters>
34+
{
35+
Input = new TextEmbeddingInput { Texts = [text] },
36+
Model = modelId
37+
},
38+
cancellationToken);
3439
return result.Output.Embeddings[0].Embedding;
3540
}
3641

src/KernelMemory.DashScope/DashScopeTextGenerator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using System.Runtime.CompilerServices;
2-
using Cnblogs.DashScope.Sdk;
2+
using Cnblogs.DashScope.Core;
33
using Microsoft.Extensions.Logging;
44
using Microsoft.KernelMemory.AI;
55
using Microsoft.KernelMemory.Diagnostics;

src/KernelMemory.DashScope/DependencyInjector.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Cnblogs.DashScope.Sdk;
1+
using Cnblogs.DashScope.Core;
22
using Cnblogs.KernelMemory.AI.DashScope;
33
using Microsoft.Extensions.Configuration;
44
using Microsoft.Extensions.DependencyInjection;

src/KernelMemory.DashScope/KernelMemory.DashScope.csproj

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
<ItemGroup>
2020
<PackageReference Include="Microsoft.DeepDev.TokenizerLib" Version="1.3.3" />
21-
<PackageReference Include="Microsoft.KernelMemory.Abstractions" Version="0.32.240307.1"/>
22-
<PackageReference Include="Cnblogs.DashScope.Sdk" Version="0.0.3"/>
21+
<PackageReference Include="Microsoft.KernelMemory.Abstractions" Version="0.34.240313.1" />
22+
<PackageReference Include="Cnblogs.DashScope.Core" Version="0.2.0" />
2323
</ItemGroup>
2424

2525
<ItemGroup>

src/SemanticKernel.DashScope/DashScopeChatCompletionService.cs

+191-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using System.Runtime.CompilerServices;
2-
using Cnblogs.DashScope.Sdk;
2+
using System.Text.Json;
3+
using Cnblogs.DashScope.Core;
4+
using Microsoft.Extensions.Logging;
35
using Microsoft.SemanticKernel;
46
using Microsoft.SemanticKernel.ChatCompletion;
57
using Microsoft.SemanticKernel.Services;
@@ -15,45 +17,132 @@ public sealed class DashScopeChatCompletionService : IChatCompletionService, ITe
1517
private readonly IDashScopeClient _dashScopeClient;
1618
private readonly Dictionary<string, object?> _attributes = new();
1719
private readonly string _modelId;
20+
private readonly ILogger<DashScopeChatCompletionService> _logger;
1821

1922
/// <summary>
2023
/// Creates a new DashScope chat completion service.
2124
/// </summary>
2225
/// <param name="modelId"></param>
2326
/// <param name="dashScopeClient"></param>
24-
public DashScopeChatCompletionService(string modelId, IDashScopeClient dashScopeClient)
27+
/// <param name="logger"></param>
28+
public DashScopeChatCompletionService(
29+
string modelId,
30+
IDashScopeClient dashScopeClient,
31+
ILogger<DashScopeChatCompletionService> logger)
2532
{
2633
_dashScopeClient = dashScopeClient;
2734
_modelId = modelId;
35+
_logger = logger;
2836
_attributes.Add(AIServiceExtensions.ModelIdKey, _modelId);
2937
}
3038

3139
/// <inheritdoc />
3240
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(
33-
ChatHistory chatHistory,
41+
ChatHistory chat,
3442
PromptExecutionSettings? executionSettings = null,
3543
Kernel? kernel = null,
3644
CancellationToken cancellationToken = default)
3745
{
38-
var chatMessages = chatHistory.ToChatMessages();
3946
var chatParameters = DashScopePromptExecutionSettings.FromPromptExecutionSettings(executionSettings);
4047
chatParameters ??= new DashScopePromptExecutionSettings();
4148
chatParameters.IncrementalOutput = false;
4249
chatParameters.ResultFormat = ResultFormats.Message;
43-
var response = await _dashScopeClient.GetTextCompletionAsync(
44-
new ModelRequest<TextGenerationInput, ITextGenerationParameters>
50+
chatParameters.ToolCallBehavior?.ConfigureOptions(kernel, chatParameters);
51+
52+
var autoInvoke = kernel is not null && chatParameters.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0;
53+
for (var it = 1;; it++)
54+
{
55+
var response = await _dashScopeClient.GetTextCompletionAsync(
56+
new ModelRequest<TextGenerationInput, ITextGenerationParameters>
57+
{
58+
Input = new TextGenerationInput { Messages = chat.ToChatMessages() },
59+
Model = string.IsNullOrEmpty(chatParameters.ModelId) ? _modelId : chatParameters.ModelId,
60+
Parameters = chatParameters
61+
},
62+
cancellationToken);
63+
CaptureTokenUsage(response.Usage);
64+
EnsureChoiceExists(response.Output.Choices);
65+
var message = response.Output.Choices![0].Message;
66+
var chatMessageContent = new DashScopeChatMessageContent(
67+
new AuthorRole(message.Role),
68+
message.Content,
69+
name: null,
70+
toolCalls: message.ToolCalls,
71+
metadata: response.ToMetaData());
72+
if (autoInvoke == false || message.ToolCalls is null)
4573
{
46-
Input = new TextGenerationInput { Messages = chatMessages },
47-
Model = string.IsNullOrEmpty(chatParameters.ModelId) ? _modelId : chatParameters.ModelId,
48-
Parameters = chatParameters
49-
},
50-
cancellationToken);
51-
var message = response.Output.Choices![0].Message;
52-
var chatMessageContent = new ChatMessageContent(
53-
new AuthorRole(message.Role),
54-
message.Content,
55-
metadata: response.ToMetaData());
56-
return [chatMessageContent];
74+
// no needs to invoke tool
75+
return [chatMessageContent];
76+
}
77+
78+
LogToolCalls(message.ToolCalls);
79+
chat.Add(chatMessageContent);
80+
81+
foreach (var call in message.ToolCalls)
82+
{
83+
if (call.Type is not ToolTypes.Function || call.Function is null)
84+
{
85+
AddResponseMessage(chat, null, "Error: Tool call was not a function call.", call.Id);
86+
continue;
87+
}
88+
89+
// ensure not calling function that was not included in request list.
90+
if (chatParameters.Tools?.Any(
91+
x => string.Equals(x.Function?.Name, call.Function.Name, StringComparison.OrdinalIgnoreCase))
92+
!= true)
93+
{
94+
AddResponseMessage(
95+
chat,
96+
null,
97+
"Error: Function call requests for a function that wasn't defined.",
98+
call.Id);
99+
continue;
100+
}
101+
102+
object? callResult;
103+
try
104+
{
105+
if (kernel!.Plugins.TryGetKernelFunctionAndArguments(
106+
call.Function,
107+
out var kernelFunction,
108+
out var kernelArguments)
109+
== false)
110+
{
111+
AddResponseMessage(chat, null, "Error: Requested function could not be found.", call.Id);
112+
continue;
113+
}
114+
115+
var functionResult = await kernelFunction.InvokeAsync(kernel, kernelArguments, cancellationToken);
116+
callResult = functionResult.GetValue<object>() ?? string.Empty;
117+
}
118+
catch (JsonException)
119+
{
120+
AddResponseMessage(chat, null, "Error: Function call arguments were invalid JSON.", call.Id);
121+
continue;
122+
}
123+
catch (Exception)
124+
{
125+
AddResponseMessage(chat, null, "Error: Exception while invoking function. {e.Message}", call.Id);
126+
continue;
127+
}
128+
129+
var stringResult = ProcessFunctionResult(callResult, chatParameters.ToolCallBehavior);
130+
AddResponseMessage(chat, stringResult, null, call.Id);
131+
}
132+
133+
chatParameters.Tools?.Clear();
134+
chatParameters.ToolCallBehavior?.ConfigureOptions(kernel, chatParameters);
135+
if (it >= chatParameters.ToolCallBehavior!.MaximumAutoInvokeAttempts)
136+
{
137+
autoInvoke = false;
138+
if (_logger.IsEnabled(LogLevel.Debug))
139+
{
140+
_logger.LogDebug(
141+
"Maximum auto-invoke ({MaximumAutoInvoke}) reached",
142+
chatParameters.ToolCallBehavior!.MaximumAutoInvokeAttempts);
143+
}
144+
}
145+
}
57146
}
58147

59148
/// <inheritdoc />
@@ -68,6 +157,7 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
68157
var parameters = DashScopePromptExecutionSettings.FromPromptExecutionSettings(executionSettings);
69158
parameters.IncrementalOutput = true;
70159
parameters.ResultFormat = ResultFormats.Message;
160+
parameters.ToolCallBehavior?.ConfigureOptions(kernel, parameters);
71161
var responses = _dashScopeClient.GetTextCompletionStreamAsync(
72162
new ModelRequest<TextGenerationInput, ITextGenerationParameters>
73163
{
@@ -141,4 +231,88 @@ public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsyn
141231
metadata: response.ToMetaData());
142232
}
143233
}
234+
235+
private void CaptureTokenUsage(TextGenerationTokenUsage? usage)
236+
{
237+
if (usage is null)
238+
{
239+
if (_logger.IsEnabled(LogLevel.Debug))
240+
{
241+
_logger.LogDebug("Usage info is not available");
242+
}
243+
244+
return;
245+
}
246+
247+
if (_logger.IsEnabled(LogLevel.Information))
248+
{
249+
_logger.LogInformation(
250+
"Input tokens: {InputTokens}. Output tokens: {CompletionTokens}. Total tokens: {TotalTokens}",
251+
usage.InputTokens,
252+
usage.OutputTokens,
253+
usage.TotalTokens);
254+
}
255+
}
256+
257+
private void LogToolCalls(IReadOnlyCollection<ToolCall>? calls)
258+
{
259+
if (calls is null)
260+
{
261+
return;
262+
}
263+
264+
if (_logger.IsEnabled(LogLevel.Debug))
265+
{
266+
_logger.LogDebug("Tool requests: {Requests}", calls.Count);
267+
}
268+
269+
if (_logger.IsEnabled(LogLevel.Trace))
270+
{
271+
_logger.LogTrace(
272+
"Function call requests: {Requests}",
273+
string.Join(", ", calls.Select(ftc => $"{ftc.Function?.Name}({ftc.Function?.Arguments})")));
274+
}
275+
}
276+
277+
private void AddResponseMessage(ChatHistory chat, string? result, string? errorMessage, string? toolId)
278+
{
279+
// Log any error
280+
if (errorMessage is not null && _logger.IsEnabled(LogLevel.Debug))
281+
{
282+
_logger.LogDebug("Failed to handle tool request ({ToolId}). {Error}", toolId, errorMessage);
283+
}
284+
285+
// Add the tool response message to both the chat options and to the chat history.
286+
result ??= errorMessage ?? string.Empty;
287+
chat.Add(new DashScopeChatMessageContent(AuthorRole.Tool, result, name: toolId));
288+
}
289+
290+
private static void EnsureChoiceExists(List<TextGenerationChoice>? choices)
291+
{
292+
if (choices is null || choices.Count == 0)
293+
{
294+
throw new KernelException("No choice was returned from model");
295+
}
296+
}
297+
298+
private static string ProcessFunctionResult(object functionResult, ToolCallBehavior? toolCallBehavior)
299+
{
300+
if (functionResult is string stringResult)
301+
{
302+
return stringResult;
303+
}
304+
305+
// This is an optimization to use ChatMessageContent content directly
306+
// without unnecessary serialization of the whole message content class.
307+
if (functionResult is ChatMessageContent chatMessageContent)
308+
{
309+
return chatMessageContent.ToString();
310+
}
311+
312+
// For polymorphic serialization of unknown in advance child classes of the KernelContent class,
313+
// a corresponding JsonTypeInfoResolver should be provided via the JsonSerializerOptions.TypeInfoResolver property.
314+
// For more details about the polymorphic serialization, see the article at:
315+
// https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0
316+
return JsonSerializer.Serialize(functionResult, toolCallBehavior?.ToolCallResultSerializerOptions);
317+
}
144318
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using Cnblogs.DashScope.Core;
2+
using Microsoft.SemanticKernel;
3+
using Microsoft.SemanticKernel.ChatCompletion;
4+
5+
namespace Cnblogs.SemanticKernel.Connectors.DashScope;
6+
7+
/// <summary>
8+
/// DashScope specialized message content
9+
/// </summary>
10+
public class DashScopeChatMessageContent(
11+
AuthorRole role,
12+
string content,
13+
Dictionary<string, object?>? metadata = null,
14+
string? name = null,
15+
List<ToolCall>? toolCalls = null)
16+
: ChatMessageContent(role, content, metadata: metadata)
17+
{
18+
/// <summary>
19+
/// The name of tool if role is tool.
20+
/// </summary>
21+
public string? Name { get; } = name;
22+
23+
/// <summary>
24+
/// Optional tool calls.
25+
/// </summary>
26+
public List<ToolCall>? ToolCalls { get; } = toolCalls;
27+
}

src/SemanticKernel.DashScope/DashScopeMapper.cs

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Cnblogs.DashScope.Sdk;
1+
using Cnblogs.DashScope.Core;
22
using Microsoft.SemanticKernel.ChatCompletion;
33

44
namespace Cnblogs.SemanticKernel.Connectors.DashScope;
@@ -7,7 +7,16 @@ internal static class DashScopeMapper
77
{
88
public static List<ChatMessage> ToChatMessages(this ChatHistory history)
99
{
10-
return history.Select(x => new ChatMessage(x.Role.Label, x.Content ?? string.Empty)).ToList();
10+
return history.Select(
11+
x =>
12+
{
13+
if (x is DashScopeChatMessageContent d)
14+
{
15+
return new ChatMessage(x.Role.Label, x.Content ?? string.Empty, d.Name, ToolCalls: d.ToolCalls);
16+
}
17+
18+
return new ChatMessage(x.Role.Label, x.Content ?? string.Empty);
19+
}).ToList();
1120
}
1221

1322
public static Dictionary<string, object?>? ToMetaData<TOutput, TUsage>(

0 commit comments

Comments
 (0)