@@ -4,9 +4,9 @@ import dev.langchain4j.data.message.ChatMessage
4
4
import dev.langchain4j.data.message.SystemMessage
5
5
import dev.langchain4j.data.message.UserMessage
6
6
import dev.langchain4j.internal.Utils
7
+ import dev.langchain4j.kotlin.model.chat.chat
7
8
import dev.langchain4j.memory.ChatMemory
8
9
import dev.langchain4j.model.chat.Capability
9
- import dev.langchain4j.model.chat.request.ChatRequest
10
10
import dev.langchain4j.model.chat.request.ChatRequestParameters
11
11
import dev.langchain4j.model.chat.request.ResponseFormat
12
12
import dev.langchain4j.model.chat.request.ResponseFormatType
@@ -30,8 +30,8 @@ import dev.langchain4j.service.memory.ChatMemoryService
30
30
import dev.langchain4j.service.output.ServiceOutputParser
31
31
import dev.langchain4j.service.tool.ToolServiceContext
32
32
import dev.langchain4j.spi.services.TokenStreamAdapter
33
+ import kotlinx.coroutines.asCoroutineDispatcher
33
34
import me.kpavlov.langchain4j.kotlin.ChatMemoryId
34
- import me.kpavlov.langchain4j.kotlin.model.chat.chatAsync
35
35
import me.kpavlov.langchain4j.kotlin.service.ReflectionHelper.validateParameters
36
36
import me.kpavlov.langchain4j.kotlin.service.ReflectionVariableResolver.asString
37
37
import me.kpavlov.langchain4j.kotlin.service.ReflectionVariableResolver.findMemoryId
@@ -53,16 +53,47 @@ import java.util.concurrent.ExecutorService
53
53
import java.util.concurrent.Executors
54
54
import java.util.concurrent.Future
55
55
import java.util.function.Supplier
56
-
56
+ import kotlin.coroutines.CoroutineContext
57
+
58
+
59
+ /* *
60
+ * AiServiceOrchestrator is an internal class responsible for coordinating interactions with AI services.
61
+ * It handles the flow of messages, memory management, augmentation, moderation, and parsing outputs
62
+ * from the AI services. This class orchestrates the functional components to provide responses based
63
+ * on the provided inputs, ensuring the required AI operations are performed.
64
+ *
65
+ * @constructor
66
+ * Initializes the orchestrator with the specified context, parser, adapters, and optional configurations.
67
+ *
68
+ * @param T The type constraint applied to instances handled by this orchestrator.
69
+ * @param context The context for the AI service, providing access to shared configurations and tools.
70
+ * @param serviceOutputParser A parser to process and transform the raw service outputs into desired formats.
71
+ * @param tokenStreamAdapters A collection of adapters to facilitate converting token streams into different formats.
72
+ * @param executor The executor service used for managing asynchronous tasks, defaulted to a cached thread pool.
73
+ * @param coroutineContext An optional coroutine context for managing coroutine-based execution.
74
+ * If not provided, the coroutine dispatcher, created from [executor] will be used.
75
+ * This parameter is experimental and subject to change.
76
+ *
77
+ * @see AiServiceContext
78
+ * @see ServiceOutputParser
79
+ * @see TokenStreamAdapter
80
+ * @see ExecutorService
81
+ * @see CoroutineContext
82
+ *
83
+ * @author Konstantin Pavlov
84
+ *
85
+ * @throws Exception General exception thrown during various processing tasks and execution points
86
+ * depending on the scenario.
87
+ */
57
88
@ApiStatus.Internal
58
89
@Suppress(" TooManyFunctions" , " detekt:all" )
59
- internal class AiServiceOrchestrator <T : Any >(
90
+ internal class AiServiceOrchestrator <T : Any > @JvmOverloads constructor (
60
91
private val context : AiServiceContext ,
61
92
private val serviceOutputParser : ServiceOutputParser ,
62
93
private val tokenStreamAdapters : Collection <TokenStreamAdapter >,
94
+ private val executor : ExecutorService = Executors .newCachedThreadPool(),
95
+ private val coroutineContext : CoroutineContext = executor.asCoroutineDispatcher(),
63
96
) {
64
- private val executor: ExecutorService = Executors .newCachedThreadPool()
65
-
66
97
@Throws(Exception ::class )
67
98
@Suppress(
68
99
" LongMethod" ,
@@ -82,11 +113,12 @@ internal class AiServiceOrchestrator<T : Any>(
82
113
}
83
114
84
115
val chatMemoryService = context.chatMemoryService
85
- if (method.declaringClass == ChatMemoryAccess ::class .java && args.size >= 1 ) {
116
+ if (method.declaringClass == ChatMemoryAccess ::class .java && args.isNotEmpty()) {
117
+ val memoryId = args[0 ]!!
86
118
return when (method.name) {
87
- " getChatMemory" -> chatMemoryService.getChatMemoryAsync(args[ 0 ] !! )
119
+ " getChatMemory" -> chatMemoryService.getChatMemoryAsync(memoryId )
88
120
" evictChatMemory" -> {
89
- chatMemoryService.evictChatMemoryAsync(args[ 0 ] !! ) != null
121
+ chatMemoryService.evictChatMemoryAsync(memoryId ) != null
90
122
}
91
123
92
124
else -> throw UnsupportedOperationException (
@@ -202,10 +234,10 @@ internal class AiServiceOrchestrator<T : Any>(
202
234
@Suppress(" LongParameterList" )
203
235
private suspend fun handleNonStreamingCall (
204
236
returnType : Type ,
205
- messages : MutableList <ChatMessage ? >,
237
+ messages : MutableList <ChatMessage >,
206
238
toolServiceContext : ToolServiceContext ,
207
239
augmentationResult : AugmentationResult ? ,
208
- moderationFuture : Future <Moderation ? >? ,
240
+ moderationFuture : Future <Moderation >? ,
209
241
chatMemory : ChatMemory ? ,
210
242
memoryId : ChatMemoryId ,
211
243
supportsJsonSchema : Boolean ,
@@ -229,14 +261,10 @@ internal class AiServiceOrchestrator<T : Any>(
229
261
.responseFormat(responseFormat)
230
262
.build()
231
263
232
- val chatRequest =
233
- ChatRequest
234
- .builder()
235
- .messages(messages)
236
- .parameters(parameters)
237
- .build()
238
-
239
- var chatResponse = context.chatModel.chatAsync(chatRequest)
264
+ var chatResponse = context.chatModel.chat(coroutineContext) {
265
+ this .messages = messages
266
+ this .parameters = parameters
267
+ }
240
268
241
269
AiServices .verifyModerationIfNeeded(moderationFuture)
242
270
@@ -301,11 +329,11 @@ internal class AiServiceOrchestrator<T : Any>(
301
329
302
330
private fun triggerModerationIfNeeded (
303
331
method : Method ,
304
- messages : MutableList <ChatMessage ? >,
305
- ): Future <Moderation ? >? =
332
+ messages : MutableList <ChatMessage >,
333
+ ): Future <Moderation >? =
306
334
if (method.isAnnotationPresent(Moderate ::class .java)) {
307
335
executor.submit(
308
- Callable < Moderation ?> {
336
+ Callable {
309
337
val messagesToModerate = AiServices .removeToolMessages(messages)
310
338
context.moderationModel
311
339
.moderate(messagesToModerate)
@@ -322,11 +350,11 @@ internal class AiServiceOrchestrator<T : Any>(
322
350
args : Array <Any ?>,
323
351
): SystemMessage ? =
324
352
findSystemMessageTemplate(memoryId, method)
325
- .map< SystemMessage > { systemMessageTemplate: String ->
353
+ .map { systemMessageTemplate: String ->
326
354
PromptTemplate
327
355
.from(systemMessageTemplate)
328
356
.apply (
329
- ReflectionVariableResolver . findTemplateVariables(
357
+ findTemplateVariables(
330
358
systemMessageTemplate,
331
359
method,
332
360
args,
@@ -339,7 +367,7 @@ internal class AiServiceOrchestrator<T : Any>(
339
367
method : Method ,
340
368
): Optional <String > {
341
369
val annotation =
342
- method.getAnnotation< dev.langchain4j.service. SystemMessage > (
370
+ method.getAnnotation(
343
371
dev.langchain4j.service.SystemMessage ::class .java,
344
372
)
345
373
if (annotation != null ) {
@@ -364,9 +392,9 @@ internal class AiServiceOrchestrator<T : Any>(
364
392
value : Array <String >,
365
393
delimiter : String ,
366
394
): String {
367
- var messageTemplate: String =
395
+ val messageTemplate: String =
368
396
if (! resource.trim { it <= ' ' }.isEmpty()) {
369
- val resourceText = getResourceText(method.getDeclaringClass() , resource)
397
+ val resourceText = getResourceText(method.declaringClass , resource)
370
398
if (resourceText == null ) {
371
399
throw IllegalConfigurationException .illegalConfiguration(
372
400
" @%sMessage's resource '%s' not found" ,
@@ -393,7 +421,7 @@ internal class AiServiceOrchestrator<T : Any>(
393
421
): String? {
394
422
var inputStream = clazz.getResourceAsStream(resource)
395
423
if (inputStream == null ) {
396
- inputStream = clazz.getResourceAsStream(" /" + resource)
424
+ inputStream = clazz.getResourceAsStream(" /$ resource" )
397
425
}
398
426
return getText(inputStream)
399
427
}
@@ -418,9 +446,9 @@ internal class AiServiceOrchestrator<T : Any>(
418
446
419
447
val prompt = PromptTemplate .from(template).apply (variables)
420
448
421
- val maybeUserName = findUserName(method.getParameters() , args)
449
+ val maybeUserName = findUserName(method.parameters , args)
422
450
return maybeUserName
423
- .map< UserMessage > { userName: String? ->
451
+ .map { userName: String? ->
424
452
UserMessage .from(
425
453
userName,
426
454
prompt.text(),
@@ -436,48 +464,48 @@ internal class AiServiceOrchestrator<T : Any>(
436
464
findUserMessageTemplateFromMethodAnnotation(method)
437
465
val templateFromParameterAnnotation =
438
466
findUserMessageTemplateFromAnnotatedParameter(
439
- method.getParameters() ,
467
+ method.parameters ,
440
468
args,
441
469
)
442
470
443
- if (templateFromMethodAnnotation.isPresent() &&
444
- templateFromParameterAnnotation.isPresent()
471
+ if (templateFromMethodAnnotation.isPresent &&
472
+ templateFromParameterAnnotation.isPresent
445
473
) {
446
474
throw IllegalConfigurationException .illegalConfiguration(
447
475
" Error: The method '%s' has multiple @UserMessage annotations. Please use only one." ,
448
- method.getName() ,
476
+ method.name ,
449
477
)
450
478
}
451
479
452
- if (templateFromMethodAnnotation.isPresent() ) {
480
+ if (templateFromMethodAnnotation.isPresent) {
453
481
return templateFromMethodAnnotation.get()
454
482
}
455
- if (templateFromParameterAnnotation.isPresent() ) {
483
+ if (templateFromParameterAnnotation.isPresent) {
456
484
return templateFromParameterAnnotation.get()
457
485
}
458
486
459
487
val templateFromTheOnlyArgument =
460
488
findUserMessageTemplateFromTheOnlyArgument(
461
- method.getParameters() ,
489
+ method.parameters ,
462
490
args,
463
491
)
464
- if (templateFromTheOnlyArgument.isPresent() ) {
492
+ if (templateFromTheOnlyArgument.isPresent) {
465
493
return templateFromTheOnlyArgument.get()
466
494
}
467
495
468
496
throw IllegalConfigurationException .illegalConfiguration(
469
497
" Error: The method '%s' does not have a user message defined." ,
470
- method.getName() ,
498
+ method.name ,
471
499
)
472
500
}
473
501
474
502
private fun findUserMessageTemplateFromMethodAnnotation (method : Method ): Optional <String > =
475
503
Optional
476
504
.ofNullable< dev.langchain4j.service.UserMessage > (
477
- method.getAnnotation< dev.langchain4j.service. UserMessage > (
505
+ method.getAnnotation(
478
506
dev.langchain4j.service.UserMessage ::class .java,
479
507
),
480
- ).map< String > { userMessage ->
508
+ ).map { userMessage ->
481
509
getTemplate(
482
510
method,
483
511
" User" ,
0 commit comments