1+ using Microsoft . Extensions . AI ;
2+ using System ;
3+ using System . Collections . Generic ;
4+ using System . Runtime . CompilerServices ;
5+ using System . Text ;
6+ using System . Threading ;
7+ using System . Threading . Tasks ;
8+
9+ namespace Microsoft . ML . OnnxRuntimeGenAI ;
10+
11+ /// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12+ public sealed partial class ChatClient : IChatClient
13+ {
14+ /// <summary>The options used to configure the instance.</summary>
15+ private readonly ChatClientConfiguration _config ;
16+ /// <summary>The wrapped <see cref="Model"/>.</summary>
17+ private readonly Model _model ;
18+ /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19+ private readonly Tokenizer _tokenizer ;
20+ /// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21+ private readonly bool _ownsModel ;
22+
23+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
24+ /// <param name="configuration">Options used to configure the client instance.</param>
25+ /// <param name="modelPath">The file path to the model to load.</param>
26+ /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
27+ public ChatClient ( ChatClientConfiguration configuration , string modelPath )
28+ {
29+ if ( configuration is null )
30+ {
31+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
32+ }
33+
34+ if ( modelPath is null )
35+ {
36+ throw new ArgumentNullException ( nameof ( modelPath ) ) ;
37+ }
38+
39+ _config = configuration ;
40+
41+ _ownsModel = true ;
42+ _model = new Model ( modelPath ) ;
43+ _tokenizer = new Tokenizer ( _model ) ;
44+
45+ Metadata = new ( "onnxruntime-genai" , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
46+ }
47+
48+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
49+ /// <param name="configuration">Options used to configure the client instance.</param>
50+ /// <param name="model">The model to employ.</param>
51+ /// <param name="ownsModel">
52+ /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
53+ /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
54+ /// The default is <see langword="true"/>.
55+ /// </param>
56+ /// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
57+ public ChatClient ( ChatClientConfiguration configuration , Model model , bool ownsModel = true )
58+ {
59+ if ( configuration is null )
60+ {
61+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
62+ }
63+
64+ if ( model is null )
65+ {
66+ throw new ArgumentNullException ( nameof ( model ) ) ;
67+ }
68+
69+ _config = configuration ;
70+
71+ _ownsModel = ownsModel ;
72+ _model = model ;
73+ _tokenizer = new Tokenizer ( _model ) ;
74+
75+ Metadata = new ( "onnxruntime-genai" ) ;
76+ }
77+
78+ /// <inheritdoc/>
79+ public ChatClientMetadata Metadata { get ; }
80+
81+ /// <inheritdoc/>
82+ public void Dispose ( )
83+ {
84+ _tokenizer . Dispose ( ) ;
85+
86+ if ( _ownsModel )
87+ {
88+ _model . Dispose ( ) ;
89+ }
90+ }
91+
92+ /// <inheritdoc/>
93+ public async Task < ChatCompletion > CompleteAsync ( IList < ChatMessage > chatMessages , ChatOptions options = null , CancellationToken cancellationToken = default )
94+ {
95+ if ( chatMessages is null )
96+ {
97+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
98+ }
99+
100+ StringBuilder text = new ( ) ;
101+ await Task . Run ( ( ) =>
102+ {
103+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
104+ using GeneratorParams generatorParams = new ( _model ) ;
105+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
106+
107+ using Generator generator = new ( _model , generatorParams ) ;
108+ generator . AppendTokenSequences ( tokens ) ;
109+
110+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
111+
112+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
113+ while ( ! generator . IsDone ( ) )
114+ {
115+ cancellationToken . ThrowIfCancellationRequested ( ) ;
116+
117+ generator . GenerateNextToken ( ) ;
118+
119+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
120+ string next = tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
121+
122+ if ( IsStop ( next , options ) )
123+ {
124+ break ;
125+ }
126+
127+ text . Append ( next ) ;
128+ }
129+ } , cancellationToken ) ;
130+
131+ return new ChatCompletion ( new ChatMessage ( ChatRole . Assistant , text . ToString ( ) ) )
132+ {
133+ CompletionId = Guid . NewGuid ( ) . ToString ( ) ,
134+ CreatedAt = DateTimeOffset . UtcNow ,
135+ ModelId = Metadata . ModelId ,
136+ } ;
137+ }
138+
139+ /// <inheritdoc/>
140+ public async IAsyncEnumerable < StreamingChatCompletionUpdate > CompleteStreamingAsync (
141+ IList < ChatMessage > chatMessages , ChatOptions options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
142+ {
143+ if ( chatMessages is null )
144+ {
145+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
146+ }
147+
148+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
149+ using GeneratorParams generatorParams = new ( _model ) ;
150+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
151+
152+ using Generator generator = new ( _model , generatorParams ) ;
153+ generator . AppendTokenSequences ( tokens ) ;
154+
155+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
156+
157+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
158+ while ( ! generator . IsDone ( ) )
159+ {
160+ string next = await Task . Run ( ( ) =>
161+ {
162+ generator . GenerateNextToken ( ) ;
163+
164+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
165+ return tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
166+ } , cancellationToken ) ;
167+
168+ if ( IsStop ( next , options ) )
169+ {
170+ break ;
171+ }
172+
173+ yield return new StreamingChatCompletionUpdate
174+ {
175+ CompletionId = completionId ,
176+ CreatedAt = DateTimeOffset . UtcNow ,
177+ Role = ChatRole . Assistant ,
178+ Text = next ,
179+ } ;
180+ }
181+ }
182+
183+ /// <inheritdoc/>
184+ public object GetService ( Type serviceType , object key = null ) =>
185+ key is not null ? null :
186+ serviceType == typeof ( Model ) ? _model :
187+ serviceType == typeof ( Tokenizer ) ? _tokenizer :
188+ serviceType ? . IsInstanceOfType ( this ) is true ? this :
189+ null ;
190+
191+ /// <summary>Gets whether the specified token is a stop sequence.</summary>
192+ private bool IsStop ( string token , ChatOptions options ) =>
193+ options ? . StopSequences ? . Contains ( token ) is true ||
194+ Array . IndexOf ( _config . StopSequences , token ) >= 0 ;
195+
196+ /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
197+ private static void UpdateGeneratorParamsFromOptions ( int numInputTokens , GeneratorParams generatorParams , ChatOptions options )
198+ {
199+ if ( options is null )
200+ {
201+ return ;
202+ }
203+
204+ if ( options . MaxOutputTokens . HasValue )
205+ {
206+ generatorParams . SetSearchOption ( "max_length" , numInputTokens + options . MaxOutputTokens . Value ) ;
207+ }
208+
209+ if ( options . Temperature . HasValue )
210+ {
211+ generatorParams . SetSearchOption ( "temperature" , options . Temperature . Value ) ;
212+ }
213+
214+ if ( options . TopP . HasValue || options . TopK . HasValue )
215+ {
216+ if ( options . TopP . HasValue )
217+ {
218+ generatorParams . SetSearchOption ( "top_p" , options . TopP . Value ) ;
219+ }
220+
221+ if ( options . TopK . HasValue )
222+ {
223+ generatorParams . SetSearchOption ( "top_k" , options . TopK . Value ) ;
224+ }
225+ }
226+
227+ if ( options . Seed . HasValue )
228+ {
229+ generatorParams . SetSearchOption ( "random_seed" , options . Seed . Value ) ;
230+ }
231+
232+ if ( options . AdditionalProperties is { } props )
233+ {
234+ foreach ( var entry in props )
235+ {
236+ switch ( entry . Value )
237+ {
238+ case int i : generatorParams . SetSearchOption ( entry . Key , i ) ; break ;
239+ case long l : generatorParams . SetSearchOption ( entry . Key , l ) ; break ;
240+ case float f : generatorParams . SetSearchOption ( entry . Key , f ) ; break ;
241+ case double d : generatorParams . SetSearchOption ( entry . Key , d ) ; break ;
242+ case bool b : generatorParams . SetSearchOption ( entry . Key , b ) ; break ;
243+ }
244+ }
245+ }
246+ }
247+ }
0 commit comments