Skip to content

Commit 4f062c0

Browse files
authored
Refactor async aiservices, Upgrade coroutines to 1.9.0 (#147)
- Refactor async aiservices - Upgrade coroutines to 1.9.0 (max version compatible with Kotlin 1.9)
1 parent cfb76a0 commit 4f062c0

File tree

5 files changed

+86
-133
lines changed

5 files changed

+86
-133
lines changed

langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/AiServiceOrchestrator.kt

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ internal class AiServiceOrchestrator<T : Any>(
175175

176176
private fun handleStreamingCall(
177177
returnType: Type,
178-
messages: MutableList<ChatMessage?>,
178+
messages: MutableList<ChatMessage>,
179179
toolServiceContext: ToolServiceContext,
180180
augmentationResult: AugmentationResult?,
181181
memoryId: Any,
@@ -202,10 +202,10 @@ internal class AiServiceOrchestrator<T : Any>(
202202
@Suppress("LongParameterList")
203203
private suspend fun handleNonStreamingCall(
204204
returnType: Type,
205-
messages: MutableList<ChatMessage?>,
205+
messages: MutableList<ChatMessage>,
206206
toolServiceContext: ToolServiceContext,
207207
augmentationResult: AugmentationResult?,
208-
moderationFuture: Future<Moderation?>?,
208+
moderationFuture: Future<Moderation>?,
209209
chatMemory: ChatMemory?,
210210
memoryId: ChatMemoryId,
211211
supportsJsonSchema: Boolean,
@@ -302,10 +302,10 @@ internal class AiServiceOrchestrator<T : Any>(
302302
private fun triggerModerationIfNeeded(
303303
method: Method,
304304
messages: MutableList<ChatMessage?>,
305-
): Future<Moderation?>? =
305+
): Future<Moderation>? =
306306
if (method.isAnnotationPresent(Moderate::class.java)) {
307307
executor.submit(
308-
Callable<Moderation?> {
308+
Callable {
309309
val messagesToModerate = AiServices.removeToolMessages(messages)
310310
context.moderationModel
311311
.moderate(messagesToModerate)
@@ -322,11 +322,11 @@ internal class AiServiceOrchestrator<T : Any>(
322322
args: Array<Any?>,
323323
): SystemMessage? =
324324
findSystemMessageTemplate(memoryId, method)
325-
.map<SystemMessage> { systemMessageTemplate: String ->
325+
.map { systemMessageTemplate: String ->
326326
PromptTemplate
327327
.from(systemMessageTemplate)
328328
.apply(
329-
ReflectionVariableResolver.findTemplateVariables(
329+
findTemplateVariables(
330330
systemMessageTemplate,
331331
method,
332332
args,
@@ -339,7 +339,7 @@ internal class AiServiceOrchestrator<T : Any>(
339339
method: Method,
340340
): Optional<String> {
341341
val annotation =
342-
method.getAnnotation<dev.langchain4j.service.SystemMessage>(
342+
method.getAnnotation(
343343
dev.langchain4j.service.SystemMessage::class.java,
344344
)
345345
if (annotation != null) {
@@ -364,9 +364,9 @@ internal class AiServiceOrchestrator<T : Any>(
364364
value: Array<String>,
365365
delimiter: String,
366366
): String {
367-
var messageTemplate: String =
367+
val messageTemplate: String =
368368
if (!resource.trim { it <= ' ' }.isEmpty()) {
369-
val resourceText = getResourceText(method.getDeclaringClass(), resource)
369+
val resourceText = getResourceText(method.declaringClass, resource)
370370
if (resourceText == null) {
371371
throw IllegalConfigurationException.illegalConfiguration(
372372
"@%sMessage's resource '%s' not found",
@@ -393,7 +393,7 @@ internal class AiServiceOrchestrator<T : Any>(
393393
): String? {
394394
var inputStream = clazz.getResourceAsStream(resource)
395395
if (inputStream == null) {
396-
inputStream = clazz.getResourceAsStream("/" + resource)
396+
inputStream = clazz.getResourceAsStream("/$resource")
397397
}
398398
return getText(inputStream)
399399
}
@@ -418,9 +418,9 @@ internal class AiServiceOrchestrator<T : Any>(
418418

419419
val prompt = PromptTemplate.from(template).apply(variables)
420420

421-
val maybeUserName = findUserName(method.getParameters(), args)
421+
val maybeUserName = findUserName(method.parameters, args)
422422
return maybeUserName
423-
.map<UserMessage> { userName: String? ->
423+
.map { userName: String? ->
424424
UserMessage.from(
425425
userName,
426426
prompt.text(),
@@ -436,48 +436,48 @@ internal class AiServiceOrchestrator<T : Any>(
436436
findUserMessageTemplateFromMethodAnnotation(method)
437437
val templateFromParameterAnnotation =
438438
findUserMessageTemplateFromAnnotatedParameter(
439-
method.getParameters(),
439+
method.parameters,
440440
args,
441441
)
442442

443-
if (templateFromMethodAnnotation.isPresent() &&
444-
templateFromParameterAnnotation.isPresent()
443+
if (templateFromMethodAnnotation.isPresent &&
444+
templateFromParameterAnnotation.isPresent
445445
) {
446446
throw IllegalConfigurationException.illegalConfiguration(
447447
"Error: The method '%s' has multiple @UserMessage annotations. Please use only one.",
448-
method.getName(),
448+
method.name,
449449
)
450450
}
451451

452-
if (templateFromMethodAnnotation.isPresent()) {
452+
if (templateFromMethodAnnotation.isPresent) {
453453
return templateFromMethodAnnotation.get()
454454
}
455-
if (templateFromParameterAnnotation.isPresent()) {
455+
if (templateFromParameterAnnotation.isPresent) {
456456
return templateFromParameterAnnotation.get()
457457
}
458458

459459
val templateFromTheOnlyArgument =
460460
findUserMessageTemplateFromTheOnlyArgument(
461-
method.getParameters(),
461+
method.parameters,
462462
args,
463463
)
464-
if (templateFromTheOnlyArgument.isPresent()) {
464+
if (templateFromTheOnlyArgument.isPresent) {
465465
return templateFromTheOnlyArgument.get()
466466
}
467467

468468
throw IllegalConfigurationException.illegalConfiguration(
469469
"Error: The method '%s' does not have a user message defined.",
470-
method.getName(),
470+
method.name,
471471
)
472472
}
473473

474474
private fun findUserMessageTemplateFromMethodAnnotation(method: Method): Optional<String> =
475475
Optional
476476
.ofNullable<dev.langchain4j.service.UserMessage>(
477-
method.getAnnotation<dev.langchain4j.service.UserMessage>(
477+
method.getAnnotation(
478478
dev.langchain4j.service.UserMessage::class.java,
479479
),
480-
).map<String> { userMessage ->
480+
).map { userMessage ->
481481
getTemplate(
482482
method,
483483
"User",

langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/ReflectionVariableResolver.kt

Lines changed: 43 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@ package me.kpavlov.langchain4j.kotlin.service
33
import dev.langchain4j.internal.Exceptions
44
import dev.langchain4j.model.input.structured.StructuredPrompt
55
import dev.langchain4j.model.input.structured.StructuredPromptProcessor
6-
import dev.langchain4j.service.IllegalConfigurationException
6+
import dev.langchain4j.service.InternalReflectionVariableResolver
77
import dev.langchain4j.service.MemoryId
8-
import dev.langchain4j.service.UserMessage
98
import dev.langchain4j.service.UserName
10-
import dev.langchain4j.service.V
119
import me.kpavlov.langchain4j.kotlin.ChatMemoryId
1210
import java.lang.reflect.Method
1311
import java.lang.reflect.Parameter
@@ -23,95 +21,24 @@ import java.util.Optional
2321
*
2422
* @see https://github.com/langchain4j/langchain4j/pull/2951
2523
*/
26-
@Suppress("detekt:all")
2724
internal object ReflectionVariableResolver {
2825
public fun findTemplateVariables(
2926
template: String,
3027
method: Method,
3128
args: Array<Any?>?,
32-
): MutableMap<String?, Any?> {
33-
if (args == null) {
34-
return mutableMapOf<String?, Any?>()
35-
}
36-
val parameters = method.getParameters()
37-
38-
val variables: MutableMap<String?, Any?> = HashMap<String?, Any?>()
39-
for (i in args.indices) {
40-
val variableName = getVariableName(parameters[i])
41-
val variableValue = args[i]
42-
variables.put(variableName, variableValue)
43-
}
44-
45-
if (template.contains("{{it}}") && !variables.containsKey("it")) {
46-
val itValue = getValueOfVariableIt(parameters, args)
47-
variables.put("it", itValue)
48-
}
49-
50-
return variables
51-
}
52-
53-
private fun getVariableName(parameter: Parameter): String? {
54-
val annotation = parameter.getAnnotation<V?>(V::class.java)
55-
if (annotation != null) {
56-
return annotation.value
57-
} else {
58-
return parameter.getName()
59-
}
60-
}
61-
62-
private fun getValueOfVariableIt(
63-
parameters: Array<Parameter>,
64-
args: Array<Any?>?,
65-
): String? {
66-
if (args != null) {
67-
if (args.size == 1) {
68-
val parameter = parameters[0]
69-
if (!parameter.isAnnotationPresent(MemoryId::class.java) &&
70-
!parameter.isAnnotationPresent(
71-
UserMessage::class.java,
72-
) &&
73-
!parameter.isAnnotationPresent(
74-
UserName::class.java,
75-
) &&
76-
(
77-
!parameter.isAnnotationPresent(V::class.java) ||
78-
isAnnotatedWithIt(
79-
parameter,
80-
)
81-
)
82-
) {
83-
return asString(args[0])
84-
}
85-
}
86-
87-
for (i in args.indices) {
88-
if (isAnnotatedWithIt(parameters[i])) {
89-
return asString(args[i])
90-
}
91-
}
92-
}
29+
): MutableMap<String?, Any?> =
30+
InternalReflectionVariableResolver.findTemplateVariables(template, method, args)
9331

94-
throw IllegalConfigurationException.illegalConfiguration(
95-
"Error: cannot find the value of the prompt template variable \"{{it}}\".",
96-
)
97-
}
98-
99-
private fun isAnnotatedWithIt(parameter: Parameter): Boolean {
100-
val annotation = parameter.getAnnotation<V?>(V::class.java)
101-
return annotation != null && "it" == annotation.value
102-
}
103-
104-
public fun asString(arg: Any?): String? {
32+
public fun asString(arg: Any?): String =
10533
if (arg == null) {
106-
return "null"
34+
"null"
10735
} else if (arg is Array<*>?) {
108-
return arrayAsString(arg)
36+
arrayAsString(arg)
10937
} else if (arg.javaClass.isAnnotationPresent(StructuredPrompt::class.java)) {
110-
return StructuredPromptProcessor.toPrompt(arg).text()
38+
StructuredPromptProcessor.toPrompt(arg).text()
11139
} else {
112-
return arg.toString()
40+
arg.toString()
11341
}
114-
}
11542

11643
private fun arrayAsString(arg: Array<*>?): String =
11744
if (arg == null) {
@@ -132,50 +59,65 @@ internal object ReflectionVariableResolver {
13259
fun findUserMessageTemplateFromTheOnlyArgument(
13360
parameters: Array<Parameter>?,
13461
args: Array<Any?>,
135-
): Optional<String> {
136-
if (parameters != null &&
62+
): Optional<String> =
63+
if (
64+
parameters != null &&
13765
parameters.size == 1 &&
13866
parameters[0].getAnnotations().size == 0
13967
) {
140-
return Optional.ofNullable<String>(asString(args[0]))
68+
Optional.ofNullable<String>(asString(args[0]))
69+
} else {
70+
Optional.empty()
14171
}
142-
return Optional.empty()
143-
}
72+
14473

14574
fun findUserName(
14675
parameters: Array<Parameter>,
14776
args: Array<Any?>,
14877
): Optional<String> {
149-
for (i in parameters.indices) {
78+
var result = Optional.empty<String>()
79+
for (i in args.indices) {
15080
if (parameters[i].isAnnotationPresent(UserName::class.java)) {
151-
return Optional.of<String>(args[i].toString())
81+
result = Optional.of(args[i].toString())
82+
break
15283
}
15384
}
154-
return Optional.empty<String>()
85+
return result
15586
}
15687

88+
@Suppress("ReturnCount")
15789
fun findMemoryId(
15890
method: Method,
15991
args: Array<Any?>?,
16092
): Optional<ChatMemoryId> {
16193
if (args == null) {
162-
return Optional.empty<ChatMemoryId>()
94+
return Optional.empty()
16395
}
96+
97+
val memoryIdParam = findMemoryIdParameter(method, args)
98+
if (memoryIdParam != null) {
99+
val (parameter, memoryId) = memoryIdParam
100+
if (memoryId is ChatMemoryId) {
101+
return Optional.of(memoryId)
102+
} else {
103+
throw Exceptions.illegalArgument(
104+
"The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null",
105+
parameter.getName(),
106+
method.getName(),
107+
)
108+
}
109+
}
110+
111+
return Optional.empty()
112+
}
113+
114+
private fun findMemoryIdParameter(method: Method, args: Array<Any?>): Pair<Parameter, Any?>? {
164115
for (i in args.indices) {
165116
val parameter = method.parameters[i]
166117
if (parameter.isAnnotationPresent(MemoryId::class.java)) {
167-
val memoryId = args[i]
168-
if (memoryId is ChatMemoryId) {
169-
return Optional.of(memoryId)
170-
} else {
171-
throw Exceptions.illegalArgument(
172-
"The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null",
173-
parameter.getName(),
174-
method.getName(),
175-
)
176-
}
118+
return Pair(parameter, args[i])
177119
}
178120
}
179-
return Optional.empty()
121+
return null
180122
}
181123
}

langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/memory/ChatMemoryServiceExtensions.kt

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,26 @@ package me.kpavlov.langchain4j.kotlin.service.memory
22

33
import dev.langchain4j.memory.ChatMemory
44
import dev.langchain4j.service.memory.ChatMemoryService
5-
import kotlinx.coroutines.coroutineScope
5+
import kotlinx.coroutines.Dispatchers
6+
import kotlinx.coroutines.withContext
67
import me.kpavlov.langchain4j.kotlin.ChatMemoryId
8+
import kotlin.coroutines.CoroutineContext
79

810
public suspend fun ChatMemoryService.getOrCreateChatMemoryAsync(
911
memoryId: ChatMemoryId,
10-
): ChatMemory = coroutineScope { this@getOrCreateChatMemoryAsync.getOrCreateChatMemory(memoryId) }
12+
context: CoroutineContext = Dispatchers.IO,
13+
): ChatMemory =
14+
withContext(context) { this@getOrCreateChatMemoryAsync.getOrCreateChatMemory(memoryId) }
1115

12-
public suspend fun ChatMemoryService.getChatMemoryAsync(memoryId: ChatMemoryId): ChatMemory? =
13-
coroutineScope { this@getChatMemoryAsync.getChatMemory(memoryId) }
16+
public suspend fun ChatMemoryService.getChatMemoryAsync(
17+
memoryId: ChatMemoryId,
18+
context: CoroutineContext = Dispatchers.IO,
19+
): ChatMemory? =
20+
withContext(context)
21+
{ this@getChatMemoryAsync.getChatMemory(memoryId) }
1422

15-
public suspend fun ChatMemoryService.evictChatMemoryAsync(memoryId: ChatMemoryId): ChatMemory? =
16-
coroutineScope { this@evictChatMemoryAsync.evictChatMemory(memoryId) }
23+
public suspend fun ChatMemoryService.evictChatMemoryAsync(
24+
memoryId: ChatMemoryId,
25+
context: CoroutineContext = Dispatchers.IO,
26+
): ChatMemory? =
27+
withContext(context) { this@evictChatMemoryAsync.evictChatMemory(memoryId) }

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
<awaitility.version>4.2.2</awaitility.version>
5858
<finchly.version>0.1.1</finchly.version>
5959
<junit.version>5.12.2</junit.version>
60-
<kotlinx-coroutines.version>1.8.1</kotlinx-coroutines.version>
60+
<kotlinx-coroutines.version>1.9.0</kotlinx-coroutines.version>
6161
<langchain4j.version>1.0.0</langchain4j.version>
6262
<mockito-kotlin.version>5.4.0</mockito-kotlin.version>
6363
<mockito.version>5.17.0</mockito.version>

0 commit comments

Comments
 (0)