From dd89a8e4af7e128f49640e4563ac9fb2688f6709 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Wed, 15 Jan 2025 15:42:38 +0100 Subject: [PATCH] Add token usage tracking (#947) ## Motivation and Context (Why the change? What's the scenario?) Adds a new TokenUsage property to MemoryAnswer to hold information about token usage. ## High level description (Approach, Design) * Include token count provided by the internal tokenizer * Include token count provided by the service, if available * Support streaming and multiple services, if needed --------- Co-authored-by: Devis Lucato --- examples/001-dotnet-WebClient/Program.cs | 43 ++++++++-- examples/002-dotnet-Serverless/Program.cs | 42 ++++++++-- examples/104-dotnet-custom-LLM/Program.cs | 2 +- .../Anthropic/AnthropicTextGeneration.cs | 2 +- .../AzureOpenAI/AzureOpenAITextGenerator.cs | 37 ++++++++- .../LlamaSharpTextGeneratorTest.cs | 5 +- .../LlamaSharp/LlamaSharpTextGenerator.cs | 6 +- .../OnnxTextGeneratorTest.cs | 2 +- extensions/ONNX/Onnx/OnnxTextGenerator.cs | 2 +- .../Ollama/Ollama/OllamaTextGenerator.cs | 2 +- .../OpenAI/OpenAI/OpenAITextGenerator.cs | 45 +++++++--- service/Abstractions/AI/ITextGenerator.cs | 2 +- service/Abstractions/Constants.cs | 6 ++ .../Models/GeneratedTextContent.cs | 32 ++++++++ service/Abstractions/Models/MemoryAnswer.cs | 9 ++ service/Abstractions/Models/TokenUsage.cs | 82 +++++++++++++++++++ service/Core/AI/NoTextGenerator.cs | 2 +- service/Core/Handlers/SummarizationHandler.cs | 2 +- .../Handlers/SummarizationParallelHandler.cs | 2 +- service/Core/Search/AnswerGenerator.cs | 56 +++++++++---- service/Core/Search/SearchClient.cs | 4 +- service/Core/Search/SearchClientResult.cs | 17 ++-- .../SemanticKernelTextGenerator.cs | 2 +- .../AIClients/AzureOpenAITextGeneratorTest.cs | 4 +- .../AIClients/OpenAITextGeneratorTest.cs | 10 +-- 25 files changed, 338 insertions(+), 80 deletions(-) create mode 100644 service/Abstractions/Models/GeneratedTextContent.cs create mode 100644 service/Abstractions/Models/TokenUsage.cs diff --git a/examples/001-dotnet-WebClient/Program.cs b/examples/001-dotnet-WebClient/Program.cs index 74dfcb310..f269fb402 100644 --- a/examples/001-dotnet-WebClient/Program.cs +++ b/examples/001-dotnet-WebClient/Program.cs @@ -253,31 +253,58 @@ private static async Task AskSimpleQuestionStreamingTheAnswer() { var question = "What's E = m*c^2?"; Console.WriteLine($"Question: {question}"); - Console.WriteLine($"Expected result: formula explanation using the information loaded"); + Console.WriteLine("Expected result: formula explanation using the information loaded"); Console.Write("\nAnswer: "); + var tokenUsage = new List(); var answerStream = s_memory.AskStreamingAsync(question, options: new SearchOptions { Stream = true }); await foreach (var answer in answerStream) { // Print token received by LLM Console.Write(answer.Result); + + // Collect token usage + if (answer.TokenUsage?.Count > 0) + { + tokenUsage = tokenUsage.Union(answer.TokenUsage).ToList(); + } + // Slow down the stream for demo purpose await Task.Delay(25); } + Console.WriteLine("\n\nToken usage report:"); + foreach (var report in tokenUsage) + { + Console.WriteLine($"{report.ServiceType}: {report.ModelName} [{report.ModelType}]"); + Console.WriteLine($"- Input : {report.TokenizerTokensIn} tokens (measured by KM tokenizer)"); + Console.WriteLine($"- Input : {report.ServiceTokensIn} tokens (measured by remote service)"); + Console.WriteLine($"- Output: {report.ServiceTokensOut} tokens (measured by remote service)"); + Console.WriteLine($"- Output: {report.TokenizerTokensOut} tokens (measured by KM tokenizer)"); + Console.WriteLine(); + } + Console.WriteLine("\n\n====================================\n"); /* OUTPUT Question: What's E = m*c^2? - - Answer: E = m*c^2 is the formula representing the principle of mass-energy equivalence, which was introduced by Albert Einstein. In this equation, - E stands for energy, m represents mass, and c is the speed of light in a vacuum, which is approximately 299,792,458 meters per second (m/s). - The equation states that the energy (E) of a system in its rest frame is equal to its mass (m) multiplied by the square of the speed of light (c^2). - This implies that mass and energy are interchangeable; a small amount of mass can be converted into a large amount of energy and vice versa, - due to the speed of light being a very large number when squared. This concept is a fundamental principle in physics and has important implications - in various fields, including nuclear physics and cosmology. + Expected result: formula explanation using the information loaded + + Answer: E = m*c^2 is a formula derived by the physicist Albert Einstein, which describes the principle of + mass–energy equivalence. In this equation, E represents energy, m represents mass, and c represents the + speed of light in a vacuum (approximately 3 x 10^8 meters per second). The formula indicates that mass and + energy are interchangeable; they are different forms of the same thing and can be converted into each other. + This principle is fundamental in physics and has significant implications in various fields, including nuclear + physics and cosmology. + + Token usage report: + Azure OpenAI: gpt-4o [TextGeneration] + - Input : 15657 tokens (measured by KM tokenizer) + - Input : 15664 tokens (measured by remote service) + - Output: 110 tokens (measured by remote service) + - Output: 110 tokens (measured by KM tokenizer) */ } diff --git a/examples/002-dotnet-Serverless/Program.cs b/examples/002-dotnet-Serverless/Program.cs index 77b77bc5a..c38c980af 100644 --- a/examples/002-dotnet-Serverless/Program.cs +++ b/examples/002-dotnet-Serverless/Program.cs @@ -311,31 +311,57 @@ private static async Task AskSimpleQuestionStreamingTheAnswer() { var question = "What's E = m*c^2?"; Console.WriteLine($"Question: {question}"); - Console.WriteLine($"Expected result: formula explanation using the information loaded"); + Console.WriteLine("Expected result: formula explanation using the information loaded"); Console.Write("\nAnswer: "); + var tokenUsage = new List(); var answerStream = s_memory.AskStreamingAsync(question, options: new SearchOptions { Stream = true }); await foreach (var answer in answerStream) { // Print token received by LLM Console.Write(answer.Result); + + // Collect token usage + if (answer.TokenUsage?.Count > 0) + { + tokenUsage = tokenUsage.Union(answer.TokenUsage).ToList(); + } + // Slow down the stream for demo purpose await Task.Delay(25); } + Console.WriteLine("\n\nToken usage report:"); + foreach (var report in tokenUsage) + { + Console.WriteLine($"{report.ServiceType}: {report.ModelName} [{report.ModelType}]"); + Console.WriteLine($"- Input : {report.TokenizerTokensIn} tokens (measured by KM tokenizer)"); + Console.WriteLine($"- Input : {report.ServiceTokensIn} tokens (measured by remote service)"); + Console.WriteLine($"- Output: {report.ServiceTokensOut} tokens (measured by remote service)"); + Console.WriteLine($"- Output: {report.TokenizerTokensOut} tokens (measured by KM tokenizer)"); + Console.WriteLine(); + } + Console.WriteLine("\n\n====================================\n"); /* OUTPUT Question: What's E = m*c^2? - - Answer: E = m*c^2 is the formula representing the principle of mass-energy equivalence, which was introduced by Albert Einstein. In this equation, - E stands for energy, m represents mass, and c is the speed of light in a vacuum, which is approximately 299,792,458 meters per second (m/s). - The equation states that the energy (E) of a system in its rest frame is equal to its mass (m) multiplied by the square of the speed of light (c^2). - This implies that mass and energy are interchangeable; a small amount of mass can be converted into a large amount of energy and vice versa, - due to the speed of light being a very large number when squared. This concept is a fundamental principle in physics and has important implications - in various fields, including nuclear physics and cosmology. + Expected result: formula explanation using the information loaded + + Answer: E = m*c^2 is a formula derived by physicist Albert Einstein, which expresses the principle of + mass–energy equivalence. In this equation, E represents energy, m represents mass, and c represents the + speed of light in a vacuum (approximately 3 x 10^8 meters per second). The formula indicates that mass and + energy are interchangeable; a small amount of mass can be converted into a large amount of energy, and vice + versa, differing only by a multiplicative constant (c^2). + + Token usage report: + Azure OpenAI: gpt-4o [TextGeneration] + - Input : 24349 tokens (measured by KM tokenizer) + - Input : 24356 tokens (measured by remote service) + - Output: 103 tokens (measured by remote service) + - Output: 103 tokens (measured by KM tokenizer) */ } diff --git a/examples/104-dotnet-custom-LLM/Program.cs b/examples/104-dotnet-custom-LLM/Program.cs index c6f9db6b3..e036d208c 100644 --- a/examples/104-dotnet-custom-LLM/Program.cs +++ b/examples/104-dotnet-custom-LLM/Program.cs @@ -68,7 +68,7 @@ public IReadOnlyList GetTokens(string text) } /// - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/extensions/Anthropic/AnthropicTextGeneration.cs b/extensions/Anthropic/AnthropicTextGeneration.cs index 9257853bd..4571b4a32 100644 --- a/extensions/Anthropic/AnthropicTextGeneration.cs +++ b/extensions/Anthropic/AnthropicTextGeneration.cs @@ -97,7 +97,7 @@ public IReadOnlyList GetTokens(string text) } /// - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs index 94375aa2b..fb8be581e 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Net.Http; @@ -12,6 +13,7 @@ using Microsoft.KernelMemory.Diagnostics; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using OpenAI.Chat; namespace Microsoft.KernelMemory.AI.AzureOpenAI; @@ -28,6 +30,8 @@ public sealed class AzureOpenAITextGenerator : ITextGenerator private readonly ITextTokenizer _textTokenizer; private readonly ILogger _log; + private readonly string _deployment; + /// public int MaxTokenTotal { get; } @@ -87,6 +91,7 @@ public AzureOpenAITextGenerator( { this._client = skClient; this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); + this._deployment = config.Deployment; this.MaxTokenTotal = config.MaxTokenTotal; textTokenizer ??= TokenizerFactory.GetTokenizerForEncoding(config.Tokenizer); @@ -114,7 +119,7 @@ public IReadOnlyList GetTokens(string text) } /// - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -153,9 +158,33 @@ public async IAsyncEnumerable GenerateTextAsync( await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) { - if (x.Text == null) { continue; } - - yield return x.Text; + TokenUsage? tokenUsage = null; + + // The last message includes tokens usage metadata. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options + if (x.Metadata?["Usage"] is ChatTokenUsage usage) + { + this._log.LogTrace("Usage report: input tokens: {InputTokenCount}, output tokens: {OutputTokenCount}, output reasoning tokens: {ReasoningTokenCount}", + usage.InputTokenCount, usage.OutputTokenCount, usage.OutputTokenDetails?.ReasoningTokenCount ?? 0); + + tokenUsage = new TokenUsage + { + Timestamp = (DateTimeOffset?)x.Metadata["CreatedAt"] ?? DateTimeOffset.UtcNow, + ServiceType = "Azure OpenAI", + ModelType = Constants.ModelType.TextGeneration, + ModelName = this._deployment, + ServiceTokensIn = usage.InputTokenCount, + ServiceTokensOut = usage.OutputTokenCount, + ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount + }; + } + + // NOTE: as stated at https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices, + // the Choice can also be empty for the last chunk if we set stream_options: { "include_usage": true} to get token counts, so it is possible that + // x.Text is null, but tokenUsage is not (token usage statistics for the entire request are included in the last chunk). + if (x.Text is null && tokenUsage is null) { continue; } + + yield return new(x.Text ?? string.Empty, tokenUsage); } } } diff --git a/extensions/LlamaSharp/LlamaSharp.FunctionalTests/LlamaSharpTextGeneratorTest.cs b/extensions/LlamaSharp/LlamaSharp.FunctionalTests/LlamaSharpTextGeneratorTest.cs index 285ff1425..55796dd52 100644 --- a/extensions/LlamaSharp/LlamaSharp.FunctionalTests/LlamaSharpTextGeneratorTest.cs +++ b/extensions/LlamaSharp/LlamaSharp.FunctionalTests/LlamaSharpTextGeneratorTest.cs @@ -40,7 +40,7 @@ public void ItCountsTokens() // Assert Console.WriteLine("Phi3 token count: " + tokenCount); - Console.WriteLine("GPT4 token count: " + (new CL100KTokenizer()).CountTokens(text)); + Console.WriteLine("GPT4 token count: " + new CL100KTokenizer().CountTokens(text)); Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs"); // Expected result with Phi-3-mini-4k-instruct-q4.gguf, without BoS (https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf) @@ -90,9 +90,8 @@ public async Task ItGeneratesText() this._timer.Restart(); var tokens = this._target.GenerateTextAsync(prompt, options); var result = new StringBuilder(); - await foreach (string token in tokens) + await foreach (var token in tokens) { - // Console.WriteLine(token); result.Append(token); } diff --git a/extensions/LlamaSharp/LlamaSharp/LlamaSharpTextGenerator.cs b/extensions/LlamaSharp/LlamaSharp/LlamaSharpTextGenerator.cs index 0290de95f..278ab4e39 100644 --- a/extensions/LlamaSharp/LlamaSharp/LlamaSharpTextGenerator.cs +++ b/extensions/LlamaSharp/LlamaSharp/LlamaSharpTextGenerator.cs @@ -74,7 +74,7 @@ public IReadOnlyList GetTokens(string text) } /// - public IAsyncEnumerable GenerateTextAsync( + public IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default) @@ -85,7 +85,7 @@ public IAsyncEnumerable GenerateTextAsync( ? options.TokenSelectionBiases.ToDictionary(pair => (LLamaToken)pair.Key, pair => pair.Value) : []; - var samplingPipeline = new DefaultSamplingPipeline() + var samplingPipeline = new DefaultSamplingPipeline { Temperature = (float)options.Temperature, TopP = (float)options.NucleusSampling, @@ -103,7 +103,7 @@ public IAsyncEnumerable GenerateTextAsync( }; this._log.LogTrace("Generating text, temperature {0}, max tokens {1}", samplingPipeline.Temperature, settings.MaxTokens); - return executor.InferAsync(prompt, settings, cancellationToken); + return executor.InferAsync(prompt, settings, cancellationToken).Select(x => new GeneratedTextContent(x)); } /// diff --git a/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs index a2e491b6d..acf38fda9 100644 --- a/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs +++ b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs @@ -45,7 +45,7 @@ public async Task ItGeneratesText() this._timer.Restart(); var tokens = this._target.GenerateTextAsync(prompt, options); var result = new StringBuilder(); - await foreach (string token in tokens) + await foreach (var token in tokens) { result.Append(token); } diff --git a/extensions/ONNX/Onnx/OnnxTextGenerator.cs b/extensions/ONNX/Onnx/OnnxTextGenerator.cs index 7f31e49a7..6d7aebcb6 100644 --- a/extensions/ONNX/Onnx/OnnxTextGenerator.cs +++ b/extensions/ONNX/Onnx/OnnxTextGenerator.cs @@ -85,7 +85,7 @@ public OnnxTextGenerator( } /// - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/extensions/Ollama/Ollama/OllamaTextGenerator.cs b/extensions/Ollama/Ollama/OllamaTextGenerator.cs index 34900713c..ec14cf9c5 100644 --- a/extensions/Ollama/Ollama/OllamaTextGenerator.cs +++ b/extensions/Ollama/Ollama/OllamaTextGenerator.cs @@ -91,7 +91,7 @@ public IReadOnlyList GetTokens(string text) return this._textTokenizer.GetTokens(text); } - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index dbc9cb857..f36812dde 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Net.Http; @@ -12,6 +13,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; using OpenAI; +using OpenAI.Chat; namespace Microsoft.KernelMemory.AI.OpenAI; @@ -29,6 +31,8 @@ public sealed class OpenAITextGenerator : ITextGenerator private readonly ITextTokenizer _textTokenizer; private readonly ILogger _log; + private readonly string _textModel; + /// public int MaxTokenTotal { get; } @@ -87,6 +91,7 @@ public OpenAITextGenerator( { this._client = skClient; this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); + this._textModel = config.TextModel; this.MaxTokenTotal = config.TextModelMaxTokenTotal; if (textTokenizer == null && !string.IsNullOrEmpty(config.TextModelTokenizer)) @@ -119,7 +124,7 @@ public IReadOnlyList GetTokens(string text) } /// - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -159,17 +164,33 @@ public async IAsyncEnumerable GenerateTextAsync( await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) { - // TODO: try catch - // if (x.Metadata?["Usage"] is not null) - // { - // var usage = x.Metadata["Usage"] as ChatTokenUsage; - // this._log.LogTrace("Usage report: input tokens {0}, output tokens {1}, output reasoning tokens {2}", - // usage?.InputTokenCount, usage?.OutputTokenCount, usage?.OutputTokenDetails.ReasoningTokenCount); - // } - - if (x.Text == null) { continue; } - - yield return x.Text; + TokenUsage? tokenUsage = null; + + // The last message in the chunk has the usage metadata. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options + if (x.Metadata?["Usage"] is ChatTokenUsage { } usage) + { + this._log.LogTrace("Usage report: input tokens {0}, output tokens {1}, output reasoning tokens {2}", + usage.InputTokenCount, usage.OutputTokenCount, usage.OutputTokenDetails?.ReasoningTokenCount ?? 0); + + tokenUsage = new TokenUsage + { + Timestamp = (DateTimeOffset?)x.Metadata["CreatedAt"] ?? DateTimeOffset.UtcNow, + ServiceType = "OpenAI", + ModelType = Constants.ModelType.TextGeneration, + ModelName = this._textModel, + ServiceTokensIn = usage!.InputTokenCount, + ServiceTokensOut = usage.OutputTokenCount, + ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount + }; + } + + // NOTE: as stated at https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices, + // The Choice can also be empty for the last chunk if we set stream_options: { "include_usage": true} to get token counts, so it is possible that + // x.Text is null, but tokenUsage is not (token usage statistics for the entire request are included in the last chunk). + if (x.Text is null && tokenUsage is null) { continue; } + + yield return new(x.Text ?? string.Empty, tokenUsage); } } } diff --git a/service/Abstractions/AI/ITextGenerator.cs b/service/Abstractions/AI/ITextGenerator.cs index fdd1aef8d..5b9f9b457 100644 --- a/service/Abstractions/AI/ITextGenerator.cs +++ b/service/Abstractions/AI/ITextGenerator.cs @@ -19,7 +19,7 @@ public interface ITextGenerator : ITextTokenizer /// Options for the LLM request /// Async task cancellation token /// Text generated, returned as a stream of strings/tokens - public IAsyncEnumerable GenerateTextAsync( + public IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default); diff --git a/service/Abstractions/Constants.cs b/service/Abstractions/Constants.cs index 40987040c..2cb087917 100644 --- a/service/Abstractions/Constants.cs +++ b/service/Abstractions/Constants.cs @@ -25,6 +25,12 @@ public static class WebService public const string ArgsField = "args"; } + public static class ModelType + { + public const string EmbeddingGeneration = "EmbeddingGeneration"; + public const string TextGeneration = "TextGeneration"; + } + public static class CustomContext { public static class Partitioning diff --git a/service/Abstractions/Models/GeneratedTextContent.cs b/service/Abstractions/Models/GeneratedTextContent.cs new file mode 100644 index 000000000..de7db53a5 --- /dev/null +++ b/service/Abstractions/Models/GeneratedTextContent.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.KernelMemory; + +#pragma warning disable CA2225 +public class GeneratedTextContent +{ + public string Text { get; set; } + + public TokenUsage? TokenUsage { get; set; } + + public GeneratedTextContent(string text, TokenUsage? tokenUsage = null) + { + this.Text = text; + this.TokenUsage = tokenUsage; + } + + /// + public override string ToString() + { + return this.Text; + } + + /// + /// Convert a string to an instance of GeneratedTextContent + /// + /// Text content + public static implicit operator GeneratedTextContent(string text) + { + return new GeneratedTextContent(text); + } +} diff --git a/service/Abstractions/Models/MemoryAnswer.cs b/service/Abstractions/Models/MemoryAnswer.cs index c78e695c6..cde6f7177 100644 --- a/service/Abstractions/Models/MemoryAnswer.cs +++ b/service/Abstractions/Models/MemoryAnswer.cs @@ -46,6 +46,15 @@ public class MemoryAnswer [JsonPropertyOrder(10)] public string Result { get; set; } = string.Empty; + /// + /// The token used by the model to generate the answer. + /// + /// Not all the models and text generators return token usage information. + [JsonPropertyName("tokenUsage")] + [JsonPropertyOrder(11)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? TokenUsage { get; set; } + /// /// List of the relevant sources used to produce the answer. /// Key = Document ID diff --git a/service/Abstractions/Models/TokenUsage.cs b/service/Abstractions/Models/TokenUsage.cs new file mode 100644 index 000000000..2a350f9b9 --- /dev/null +++ b/service/Abstractions/Models/TokenUsage.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json.Serialization; + +namespace Microsoft.KernelMemory; + +/// +/// Represents the usage of tokens in a request and response cycle. +/// +public class TokenUsage +{ + [JsonPropertyName("timestamp")] + public DateTimeOffset Timestamp { get; set; } + + [JsonPropertyName("serviceType")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ServiceType { get; set; } + + [JsonPropertyName("modelType")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ModelType { get; set; } + + [JsonPropertyName("modelName")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ModelName { get; set; } + + /// + /// The number of tokens in the request message input, spanning all message content items, measured by the tokenizer. + /// + [JsonPropertyName("tokenizerTokensIn")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TokenizerTokensIn { get; set; } + + /// + /// The combined number of output tokens in the generated completion, measured by the tokenizer. + /// + [JsonPropertyName("tokenizerTokensOut")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TokenizerTokensOut { get; set; } + + /// + /// The number of tokens in the request message input, spanning all message content items, measured by the service. + /// + [JsonPropertyName("serviceTokensIn")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? ServiceTokensIn { get; set; } + + /// + /// The combined number of output tokens in the generated completion, as consumed by the model. + /// + [JsonPropertyName("serviceTokensOut")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? ServiceTokensOut { get; set; } + + [JsonPropertyName("serviceReasoningTokens")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? ServiceReasoningTokens { get; set; } + + public TokenUsage() + { + } + + public void Merge(TokenUsage? input) + { + if (input == null) + { + return; + } + + this.Timestamp = input.Timestamp; + this.ServiceType = input.ServiceType; + this.ModelType = input.ModelType; + this.ModelName = input.ModelName; + + this.TokenizerTokensIn = (this.TokenizerTokensIn ?? 0) + (input.TokenizerTokensIn ?? 0); + this.TokenizerTokensOut = (this.TokenizerTokensOut ?? 0) + (input.TokenizerTokensOut ?? 0); + this.ServiceTokensIn = (this.ServiceTokensIn ?? 0) + (input.ServiceTokensIn ?? 0); + this.ServiceTokensOut = (this.ServiceTokensOut ?? 0) + (input.ServiceTokensOut ?? 0); + this.ServiceReasoningTokens = (this.ServiceReasoningTokens ?? 0) + (input.ServiceReasoningTokens ?? 0); + } +} diff --git a/service/Core/AI/NoTextGenerator.cs b/service/Core/AI/NoTextGenerator.cs index 6dae9122c..f3cbbb311 100644 --- a/service/Core/AI/NoTextGenerator.cs +++ b/service/Core/AI/NoTextGenerator.cs @@ -37,7 +37,7 @@ public IReadOnlyList GetTokens(string text) } /// - public IAsyncEnumerable GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default) + public IAsyncEnumerable GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default) { throw this.Error(); } diff --git a/service/Core/Handlers/SummarizationHandler.cs b/service/Core/Handlers/SummarizationHandler.cs index 1620108dd..b3999d07b 100644 --- a/service/Core/Handlers/SummarizationHandler.cs +++ b/service/Core/Handlers/SummarizationHandler.cs @@ -224,7 +224,7 @@ public SummarizationHandler( this._log.LogTrace("Summarizing paragraph {0}", index); var filledPrompt = summarizationPrompt.Replace("{{$input}}", paragraph, StringComparison.OrdinalIgnoreCase); - await foreach (string token in textGenerator.GenerateTextAsync(filledPrompt, new TextGenerationOptions()).ConfigureAwait(false)) + await foreach (var token in textGenerator.GenerateTextAsync(filledPrompt, new TextGenerationOptions()).ConfigureAwait(false)) { newContent.Append(token); } diff --git a/service/Core/Handlers/SummarizationParallelHandler.cs b/service/Core/Handlers/SummarizationParallelHandler.cs index 118e4f962..3e103305d 100644 --- a/service/Core/Handlers/SummarizationParallelHandler.cs +++ b/service/Core/Handlers/SummarizationParallelHandler.cs @@ -183,7 +183,7 @@ await Parallel.ForEachAsync(uploadedFile.GeneratedFiles, options, async (generat this._log.LogTrace("Summarizing paragraph {0}", index); var filledPrompt = this._summarizationPrompt.Replace("{{$input}}", paragraph, StringComparison.OrdinalIgnoreCase); - await foreach (string token in textGenerator.GenerateTextAsync(filledPrompt, new TextGenerationOptions()).ConfigureAwait(false)) + await foreach (var token in textGenerator.GenerateTextAsync(filledPrompt, new TextGenerationOptions()).ConfigureAwait(false)) { newContent.Append(token); } diff --git a/service/Core/Search/AnswerGenerator.cs b/service/Core/Search/AnswerGenerator.cs index 352160851..b8dcaac79 100644 --- a/service/Core/Search/AnswerGenerator.cs +++ b/service/Core/Search/AnswerGenerator.cs @@ -55,6 +55,18 @@ internal async IAsyncEnumerable GenerateAnswerAsync( string question, SearchClientResult result, IContext? context, [EnumeratorCancellation] CancellationToken cancellationToken) { + var prompt = this.PreparePrompt(question, result.Facts.ToString(), context); + var promptSize = this._textGenerator.CountTokens(prompt); + this._log.LogInformation("RAG prompt ({0} tokens): {1}", promptSize, prompt); + + var tokenUsage = new TokenUsage + { + Timestamp = DateTimeOffset.UtcNow, + ModelType = Constants.ModelType.TextGeneration, + TokenizerTokensIn = promptSize, + }; + result.AddTokenUsageToStaticResults(tokenUsage); + if (result.FactsAvailableCount > 0 && result.FactsUsedCount == 0) { this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); @@ -69,42 +81,48 @@ internal async IAsyncEnumerable GenerateAnswerAsync( yield break; } - var completeAnswer = new StringBuilder(); - await foreach (var answerToken in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false)) + var completeAnswerTokens = new StringBuilder(); + await foreach (GeneratedTextContent answerToken in this.GenerateAnswerTokensAsync(prompt, context, cancellationToken).ConfigureAwait(false)) { - completeAnswer.Append(answerToken); - result.AskResult.Result = answerToken; + completeAnswerTokens.Append(answerToken.Text); + tokenUsage.Merge(answerToken.TokenUsage); + result.AskResult.Result = answerToken.Text; + yield return result.AskResult; } - // Finalize the answer, checking if it's empty - result.AskResult.Result = completeAnswer.ToString(); - if (string.IsNullOrWhiteSpace(result.AskResult.Result) - || ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer)) + // Check if the complete answer is empty + string completeAnswer = completeAnswerTokens.ToString(); + if (string.IsNullOrWhiteSpace(completeAnswer) || ValueIsEquivalentTo(completeAnswer, this._config.EmptyAnswer)) { this._log.LogInformation("No relevant memories found, returning empty answer."); yield return result.NoFactsResult; yield break; } - this._log.LogSensitive("Answer: {0}", result.AskResult.Result); + this._log.LogSensitive("Answer: {0}", completeAnswer); + // Check if the complete answer is safe if (this._config.UseContentModeration && this._contentModeration != null - && !await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false)) + && !await this._contentModeration.IsSafeAsync(completeAnswer, cancellationToken).ConfigureAwait(false)) { this._log.LogWarning("Unsafe answer detected. Returning error message instead."); yield return result.UnsafeAnswerResult; + yield break; } + + // Add token usage report at the end + result.AskResult.Result = string.Empty; + tokenUsage.TokenizerTokensOut = this._textGenerator.CountTokens(completeAnswer); + result.AskResult.TokenUsage = [tokenUsage]; + yield return result.AskResult; } - private IAsyncEnumerable GenerateAnswerTokensAsync(string question, string facts, IContext? context, CancellationToken cancellationToken) + private string PreparePrompt(string question, string facts, IContext? context) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); - int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); - double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature); - double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); question = question.Trim(); question = question.EndsWith('?') ? question : $"{question}?"; @@ -112,7 +130,15 @@ private IAsyncEnumerable GenerateAnswerTokensAsync(string question, stri prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$notFound}}", emptyAnswer, StringComparison.OrdinalIgnoreCase); - this._log.LogInformation("New prompt: {0}", prompt); + + return prompt; + } + + private IAsyncEnumerable GenerateAnswerTokensAsync(string prompt, IContext? context, CancellationToken cancellationToken) + { + int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); + double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature); + double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); var options = new TextGenerationOptions { diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 8216e8a83..f942acf27 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -135,6 +135,8 @@ public async Task AskAsync( { if (done) { break; } + result.TokenUsage = part.TokenUsage; + switch (part.StreamState) { case StreamStates.Error: @@ -253,7 +255,7 @@ public async IAsyncEnumerable AskStreamingAsync( this._log.LogTrace("{Count} records processed", result.RecordCount); var first = true; - await foreach (var answer in this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false)) + await foreach (MemoryAnswer answer in this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false)) { yield return answer; diff --git a/service/Core/Search/SearchClientResult.cs b/service/Core/Search/SearchClientResult.cs index a795d5cce..791c22694 100644 --- a/service/Core/Search/SearchClientResult.cs +++ b/service/Core/Search/SearchClientResult.cs @@ -33,7 +33,6 @@ internal class SearchClientResult public MemoryAnswer NoQuestionResult { get; private init; } = new(); public MemoryAnswer UnsafeAnswerResult { get; private init; } = new(); public MemoryAnswer InsufficientTokensResult { get; private init; } = new(); - public MemoryAnswer ErrorResult { get; private init; } = new(); // Use by Ask mode public SearchResult SearchResult { get; private init; } = new(); @@ -92,13 +91,6 @@ public static SearchClientResult AskResultInstance( NoResult = true, NoResultReason = "Content moderation", Result = moderatedAnswer - }, - ErrorResult = new MemoryAnswer - { - StreamState = StreamStates.Error, - Question = question, - NoResult = true, - NoResultReason = "An error occurred" } }; } @@ -112,7 +104,14 @@ public void AddSource(Citation citation) this.AskResult.RelevantSources?.Add(citation); this.InsufficientTokensResult.RelevantSources?.Add(citation); this.UnsafeAnswerResult.RelevantSources?.Add(citation); - this.ErrorResult.RelevantSources?.Add(citation); + } + + public void AddTokenUsageToStaticResults(TokenUsage tokenUsage) + { + // Add report only to non-streamed results + this.InsufficientTokensResult.TokenUsage = [tokenUsage]; + this.UnsafeAnswerResult.TokenUsage = [tokenUsage]; + this.NoFactsResult.TokenUsage = [tokenUsage]; } /// diff --git a/service/Core/SemanticKernel/SemanticKernelTextGenerator.cs b/service/Core/SemanticKernel/SemanticKernelTextGenerator.cs index 06c4844d3..5f2110b03 100644 --- a/service/Core/SemanticKernel/SemanticKernelTextGenerator.cs +++ b/service/Core/SemanticKernel/SemanticKernelTextGenerator.cs @@ -52,7 +52,7 @@ public SemanticKernelTextGenerator( } /// - public async IAsyncEnumerable GenerateTextAsync( + public async IAsyncEnumerable GenerateTextAsync( string prompt, TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/service/tests/Core.FunctionalTests/ServerLess/AIClients/AzureOpenAITextGeneratorTest.cs b/service/tests/Core.FunctionalTests/ServerLess/AIClients/AzureOpenAITextGeneratorTest.cs index abaee1d6e..e5d636da1 100644 --- a/service/tests/Core.FunctionalTests/ServerLess/AIClients/AzureOpenAITextGeneratorTest.cs +++ b/service/tests/Core.FunctionalTests/ServerLess/AIClients/AzureOpenAITextGeneratorTest.cs @@ -26,12 +26,12 @@ public async Task ItStreamsFromChatModel() var client = new AzureOpenAITextGenerator(this._config, loggerFactory: null); // Act - IAsyncEnumerable text = client.GenerateTextAsync( + IAsyncEnumerable text = client.GenerateTextAsync( "write 100 words about the Earth", new TextGenerationOptions()); // Assert var count = 0; - await foreach (string word in text) + await foreach (var word in text) { Console.Write(word); if (count++ > 10) { break; } diff --git a/service/tests/Core.FunctionalTests/ServerLess/AIClients/OpenAITextGeneratorTest.cs b/service/tests/Core.FunctionalTests/ServerLess/AIClients/OpenAITextGeneratorTest.cs index 6458661f2..49e9d975d 100644 --- a/service/tests/Core.FunctionalTests/ServerLess/AIClients/OpenAITextGeneratorTest.cs +++ b/service/tests/Core.FunctionalTests/ServerLess/AIClients/OpenAITextGeneratorTest.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; using Microsoft.KernelMemory.AI; @@ -26,12 +26,12 @@ public async Task ItStreamsFromChatModel() var client = new OpenAITextGenerator(this._config); // Act - IAsyncEnumerable text = client.GenerateTextAsync( + IAsyncEnumerable text = client.GenerateTextAsync( "write 100 words about the Earth", new TextGenerationOptions()); // Assert var count = 0; - await foreach (string word in text) + await foreach (var word in text) { Console.Write(word); if (count++ > 10) { break; } @@ -49,12 +49,12 @@ public async Task ItStreamsFromTextModel() var client = new OpenAITextGenerator(this._config); // Act - IAsyncEnumerable text = client.GenerateTextAsync( + IAsyncEnumerable text = client.GenerateTextAsync( "write 100 words about the Earth", new TextGenerationOptions()); // Assert var count = 0; - await foreach (string word in text) + await foreach (var word in text) { Console.Write(word); if (count++ > 10) { break; }