Skip to content

Commit 3e774ed

Browse files
committed
Refactor AsyncAiServices
1 parent cfb76a0 commit 3e774ed

File tree

4 files changed

+97
-118
lines changed

4 files changed

+97
-118
lines changed

langchain4j-kotlin/pom.xml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
<?xml version="1.0" encoding="UTF-8"?>
2-
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
34
<modelVersion>4.0.0</modelVersion>
45
<parent>
56
<groupId>me.kpavlov.langchain4j.kotlin</groupId>
@@ -17,14 +18,19 @@
1718
<artifactId>langchain4j-core</artifactId>
1819
</dependency>
1920
<dependency>
20-
<groupId>org.jetbrains.kotlin</groupId>
21-
<artifactId>kotlin-reflect</artifactId>
21+
<groupId>dev.langchain4j</groupId>
22+
<artifactId>langchain4j</artifactId>
23+
<optional>true</optional>
2224
</dependency>
2325
<dependency>
2426
<groupId>dev.langchain4j</groupId>
25-
<artifactId>langchain4j</artifactId>
27+
<artifactId>langchain4j-kotlin</artifactId>
2628
<optional>true</optional>
2729
</dependency>
30+
<dependency>
31+
<groupId>org.jetbrains.kotlin</groupId>
32+
<artifactId>kotlin-reflect</artifactId>
33+
</dependency>
2834
<dependency>
2935
<groupId>org.jetbrains.kotlinx</groupId>
3036
<artifactId>kotlinx-coroutines-core-jvm</artifactId>

langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/AiServiceOrchestrator.kt

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import dev.langchain4j.data.message.ChatMessage
44
import dev.langchain4j.data.message.SystemMessage
55
import dev.langchain4j.data.message.UserMessage
66
import dev.langchain4j.internal.Utils
7+
import dev.langchain4j.kotlin.model.chat.chat
78
import dev.langchain4j.memory.ChatMemory
89
import dev.langchain4j.model.chat.Capability
9-
import dev.langchain4j.model.chat.request.ChatRequest
1010
import dev.langchain4j.model.chat.request.ChatRequestParameters
1111
import dev.langchain4j.model.chat.request.ResponseFormat
1212
import dev.langchain4j.model.chat.request.ResponseFormatType
@@ -30,8 +30,8 @@ import dev.langchain4j.service.memory.ChatMemoryService
3030
import dev.langchain4j.service.output.ServiceOutputParser
3131
import dev.langchain4j.service.tool.ToolServiceContext
3232
import dev.langchain4j.spi.services.TokenStreamAdapter
33+
import kotlinx.coroutines.asCoroutineDispatcher
3334
import me.kpavlov.langchain4j.kotlin.ChatMemoryId
34-
import me.kpavlov.langchain4j.kotlin.model.chat.chatAsync
3535
import me.kpavlov.langchain4j.kotlin.service.ReflectionHelper.validateParameters
3636
import me.kpavlov.langchain4j.kotlin.service.ReflectionVariableResolver.asString
3737
import me.kpavlov.langchain4j.kotlin.service.ReflectionVariableResolver.findMemoryId
@@ -53,16 +53,47 @@ import java.util.concurrent.ExecutorService
5353
import java.util.concurrent.Executors
5454
import java.util.concurrent.Future
5555
import java.util.function.Supplier
56-
56+
import kotlin.coroutines.CoroutineContext
57+
58+
59+
/**
60+
* AiServiceOrchestrator is an internal class responsible for coordinating interactions with AI services.
61+
* It handles the flow of messages, memory management, augmentation, moderation, and parsing outputs
62+
* from the AI services. This class orchestrates the functional components to provide responses based
63+
* on the provided inputs, ensuring the required AI operations are performed.
64+
*
65+
* @constructor
66+
* Initializes the orchestrator with the specified context, parser, adapters, and optional configurations.
67+
*
68+
* @param T The type constraint applied to instances handled by this orchestrator.
69+
* @param context The context for the AI service, providing access to shared configurations and tools.
70+
* @param serviceOutputParser A parser to process and transform the raw service outputs into desired formats.
71+
* @param tokenStreamAdapters A collection of adapters to facilitate converting token streams into different formats.
72+
* @param executor The executor service used for managing asynchronous tasks, defaulted to a cached thread pool.
73+
* @param coroutineContext An optional coroutine context for managing coroutine-based execution.
74+
* If not provided, the coroutine dispatcher, created from [executor] will be used.
75+
* This parameter is experimental and subject to change.
76+
*
77+
* @see AiServiceContext
78+
* @see ServiceOutputParser
79+
* @see TokenStreamAdapter
80+
* @see ExecutorService
81+
* @see CoroutineContext
82+
*
83+
* @author Konstantin Pavlov
84+
*
85+
* @throws Exception General exception thrown during various processing tasks and execution points
86+
* depending on the scenario.
87+
*/
5788
@ApiStatus.Internal
5889
@Suppress("TooManyFunctions", "detekt:all")
59-
internal class AiServiceOrchestrator<T : Any>(
90+
internal class AiServiceOrchestrator<T : Any> @JvmOverloads constructor(
6091
private val context: AiServiceContext,
6192
private val serviceOutputParser: ServiceOutputParser,
6293
private val tokenStreamAdapters: Collection<TokenStreamAdapter>,
94+
private val executor: ExecutorService = Executors.newCachedThreadPool(),
95+
private val coroutineContext: CoroutineContext = executor.asCoroutineDispatcher(),
6396
) {
64-
private val executor: ExecutorService = Executors.newCachedThreadPool()
65-
6697
@Throws(Exception::class)
6798
@Suppress(
6899
"LongMethod",
@@ -82,11 +113,12 @@ internal class AiServiceOrchestrator<T : Any>(
82113
}
83114

84115
val chatMemoryService = context.chatMemoryService
85-
if (method.declaringClass == ChatMemoryAccess::class.java && args.size >= 1) {
116+
if (method.declaringClass == ChatMemoryAccess::class.java && args.isNotEmpty()) {
117+
val memoryId = args[0]!!
86118
return when (method.name) {
87-
"getChatMemory" -> chatMemoryService.getChatMemoryAsync(args[0]!!)
119+
"getChatMemory" -> chatMemoryService.getChatMemoryAsync(memoryId)
88120
"evictChatMemory" -> {
89-
chatMemoryService.evictChatMemoryAsync(args[0]!!) != null
121+
chatMemoryService.evictChatMemoryAsync(memoryId) != null
90122
}
91123

92124
else -> throw UnsupportedOperationException(
@@ -202,10 +234,10 @@ internal class AiServiceOrchestrator<T : Any>(
202234
@Suppress("LongParameterList")
203235
private suspend fun handleNonStreamingCall(
204236
returnType: Type,
205-
messages: MutableList<ChatMessage?>,
237+
messages: MutableList<ChatMessage>,
206238
toolServiceContext: ToolServiceContext,
207239
augmentationResult: AugmentationResult?,
208-
moderationFuture: Future<Moderation?>?,
240+
moderationFuture: Future<Moderation>?,
209241
chatMemory: ChatMemory?,
210242
memoryId: ChatMemoryId,
211243
supportsJsonSchema: Boolean,
@@ -229,14 +261,10 @@ internal class AiServiceOrchestrator<T : Any>(
229261
.responseFormat(responseFormat)
230262
.build()
231263

232-
val chatRequest =
233-
ChatRequest
234-
.builder()
235-
.messages(messages)
236-
.parameters(parameters)
237-
.build()
238-
239-
var chatResponse = context.chatModel.chatAsync(chatRequest)
264+
var chatResponse = context.chatModel.chat(coroutineContext) {
265+
this.messages = messages
266+
this.parameters = parameters
267+
}
240268

241269
AiServices.verifyModerationIfNeeded(moderationFuture)
242270

@@ -301,11 +329,11 @@ internal class AiServiceOrchestrator<T : Any>(
301329

302330
private fun triggerModerationIfNeeded(
303331
method: Method,
304-
messages: MutableList<ChatMessage?>,
305-
): Future<Moderation?>? =
332+
messages: MutableList<ChatMessage>,
333+
): Future<Moderation>? =
306334
if (method.isAnnotationPresent(Moderate::class.java)) {
307335
executor.submit(
308-
Callable<Moderation?> {
336+
Callable {
309337
val messagesToModerate = AiServices.removeToolMessages(messages)
310338
context.moderationModel
311339
.moderate(messagesToModerate)
@@ -322,11 +350,11 @@ internal class AiServiceOrchestrator<T : Any>(
322350
args: Array<Any?>,
323351
): SystemMessage? =
324352
findSystemMessageTemplate(memoryId, method)
325-
.map<SystemMessage> { systemMessageTemplate: String ->
353+
.map { systemMessageTemplate: String ->
326354
PromptTemplate
327355
.from(systemMessageTemplate)
328356
.apply(
329-
ReflectionVariableResolver.findTemplateVariables(
357+
findTemplateVariables(
330358
systemMessageTemplate,
331359
method,
332360
args,
@@ -339,7 +367,7 @@ internal class AiServiceOrchestrator<T : Any>(
339367
method: Method,
340368
): Optional<String> {
341369
val annotation =
342-
method.getAnnotation<dev.langchain4j.service.SystemMessage>(
370+
method.getAnnotation(
343371
dev.langchain4j.service.SystemMessage::class.java,
344372
)
345373
if (annotation != null) {
@@ -364,9 +392,9 @@ internal class AiServiceOrchestrator<T : Any>(
364392
value: Array<String>,
365393
delimiter: String,
366394
): String {
367-
var messageTemplate: String =
395+
val messageTemplate: String =
368396
if (!resource.trim { it <= ' ' }.isEmpty()) {
369-
val resourceText = getResourceText(method.getDeclaringClass(), resource)
397+
val resourceText = getResourceText(method.declaringClass, resource)
370398
if (resourceText == null) {
371399
throw IllegalConfigurationException.illegalConfiguration(
372400
"@%sMessage's resource '%s' not found",
@@ -393,7 +421,7 @@ internal class AiServiceOrchestrator<T : Any>(
393421
): String? {
394422
var inputStream = clazz.getResourceAsStream(resource)
395423
if (inputStream == null) {
396-
inputStream = clazz.getResourceAsStream("/" + resource)
424+
inputStream = clazz.getResourceAsStream("/$resource")
397425
}
398426
return getText(inputStream)
399427
}
@@ -418,9 +446,9 @@ internal class AiServiceOrchestrator<T : Any>(
418446

419447
val prompt = PromptTemplate.from(template).apply(variables)
420448

421-
val maybeUserName = findUserName(method.getParameters(), args)
449+
val maybeUserName = findUserName(method.parameters, args)
422450
return maybeUserName
423-
.map<UserMessage> { userName: String? ->
451+
.map { userName: String? ->
424452
UserMessage.from(
425453
userName,
426454
prompt.text(),
@@ -436,48 +464,48 @@ internal class AiServiceOrchestrator<T : Any>(
436464
findUserMessageTemplateFromMethodAnnotation(method)
437465
val templateFromParameterAnnotation =
438466
findUserMessageTemplateFromAnnotatedParameter(
439-
method.getParameters(),
467+
method.parameters,
440468
args,
441469
)
442470

443-
if (templateFromMethodAnnotation.isPresent() &&
444-
templateFromParameterAnnotation.isPresent()
471+
if (templateFromMethodAnnotation.isPresent &&
472+
templateFromParameterAnnotation.isPresent
445473
) {
446474
throw IllegalConfigurationException.illegalConfiguration(
447475
"Error: The method '%s' has multiple @UserMessage annotations. Please use only one.",
448-
method.getName(),
476+
method.name,
449477
)
450478
}
451479

452-
if (templateFromMethodAnnotation.isPresent()) {
480+
if (templateFromMethodAnnotation.isPresent) {
453481
return templateFromMethodAnnotation.get()
454482
}
455-
if (templateFromParameterAnnotation.isPresent()) {
483+
if (templateFromParameterAnnotation.isPresent) {
456484
return templateFromParameterAnnotation.get()
457485
}
458486

459487
val templateFromTheOnlyArgument =
460488
findUserMessageTemplateFromTheOnlyArgument(
461-
method.getParameters(),
489+
method.parameters,
462490
args,
463491
)
464-
if (templateFromTheOnlyArgument.isPresent()) {
492+
if (templateFromTheOnlyArgument.isPresent) {
465493
return templateFromTheOnlyArgument.get()
466494
}
467495

468496
throw IllegalConfigurationException.illegalConfiguration(
469497
"Error: The method '%s' does not have a user message defined.",
470-
method.getName(),
498+
method.name,
471499
)
472500
}
473501

474502
private fun findUserMessageTemplateFromMethodAnnotation(method: Method): Optional<String> =
475503
Optional
476504
.ofNullable<dev.langchain4j.service.UserMessage>(
477-
method.getAnnotation<dev.langchain4j.service.UserMessage>(
505+
method.getAnnotation(
478506
dev.langchain4j.service.UserMessage::class.java,
479507
),
480-
).map<String> { userMessage ->
508+
).map { userMessage ->
481509
getTemplate(
482510
method,
483511
"User",

langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/AsyncAiServices.kt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package me.kpavlov.langchain4j.kotlin.service
22

3+
import dev.langchain4j.internal.VirtualThreadUtils
34
import dev.langchain4j.service.AiServiceContext
45
import dev.langchain4j.service.AiServices
56
import dev.langchain4j.service.IllegalConfigurationException.illegalConfiguration
@@ -11,10 +12,14 @@ import dev.langchain4j.service.memory.ChatMemoryAccess
1112
import dev.langchain4j.service.output.ServiceOutputParser
1213
import dev.langchain4j.spi.ServiceHelper
1314
import dev.langchain4j.spi.services.TokenStreamAdapter
15+
import java.util.concurrent.ExecutorService
16+
import java.util.concurrent.Executors
17+
import kotlin.coroutines.CoroutineContext
1418

1519
public class AsyncAiServices<T : Any>(
1620
context: AiServiceContext,
1721
) : AiServices<T>(context) {
22+
private val chatModelCoroutineContext: CoroutineContext? = null
1823
private val serviceOutputParser = ServiceOutputParser()
1924
private val tokenStreamAdapters =
2025
ServiceHelper.loadFactories<TokenStreamAdapter>(TokenStreamAdapter::class.java)
@@ -71,7 +76,16 @@ public class AsyncAiServices<T : Any>(
7176
}
7277
}
7378

74-
val handler = AiServiceOrchestrator<T>(context, serviceOutputParser, tokenStreamAdapters)
79+
val executor: ExecutorService = VirtualThreadUtils.createVirtualThreadExecutor {
80+
Executors.newCachedThreadPool()
81+
}!!
82+
83+
val handler = AiServiceOrchestrator<T>(
84+
context = context,
85+
serviceOutputParser = serviceOutputParser,
86+
tokenStreamAdapters = tokenStreamAdapters,
87+
executor = executor,
88+
)
7589
@Suppress("UNCHECKED_CAST", "unused")
7690
return ReflectionHelper.createSuspendProxy(context.aiServiceClass) { method, args ->
7791
return@createSuspendProxy handler.execute(method, args)

0 commit comments

Comments
 (0)