Skip to content

Commit 4a0adc1

Browse files
authored
[WIP] Add asynchronous AI service infrastructure for Kotlin (#128)
# [WIP] Add asynchronous AI service infrastructure for Kotlin Implemented `AsyncAiServices` and supporting components to enable asynchronous AI service handling in Kotlin. Added tests, a service factory, and _InvocationHandler_ to integrate and validate the functionality.
1 parent da2319a commit 4a0adc1

File tree

12 files changed

+1021
-1
lines changed

12 files changed

+1021
-1
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ apidocs:
88

99
lint:prepare
1010
ktlint && \
11-
mvn spotless:check
11+
mvn spotless:check detekt:check
1212

1313
# https://docs.openrewrite.org/recipes/maven/bestpractices
1414
format:prepare
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package dev.langchain4j.service
2+
3+
import me.kpavlov.langchain4j.kotlin.ChatMemoryId
4+
import java.lang.reflect.Method
5+
import java.util.Optional
6+
7+
/**
8+
* This is a hack to access package-private methods in [DefaultAiServices].
9+
* It is not supposed to be used directly.
10+
*/
11+
internal object DefaultAiServicesOpener {
12+
/**
13+
* This class is used to open package-private methods in [DefaultAiServices].
14+
* It is not supposed to be used directly.
15+
*/
16+
@Suppress("UNCHECKED_CAST")
17+
internal fun findMemoryId(
18+
method: Method,
19+
args: Array<Any?>,
20+
): Optional<ChatMemoryId> {
21+
val findMemoryId =
22+
DefaultAiServices::class.java.getDeclaredMethod(
23+
"findMemoryId",
24+
Method::class.java,
25+
Array<Any?>::class.java,
26+
)
27+
findMemoryId.isAccessible = true
28+
@Suppress("UNCHECKED_CAST")
29+
return findMemoryId.invoke(null, method, args) as Optional<Any>
30+
}
31+
32+
@Suppress("TooGenericExceptionCaught")
33+
internal fun validateParameters(method: Method) {
34+
val validateParameters =
35+
DefaultAiServices::class.java.getDeclaredMethod(
36+
"validateParameters",
37+
Method::class.java,
38+
)
39+
validateParameters.isAccessible = true
40+
try {
41+
validateParameters.invoke(null, method)
42+
} catch (e: Exception) {
43+
if (e.cause is dev.langchain4j.service.IllegalConfigurationException) {
44+
val illegalConfigurationException =
45+
e.cause as dev.langchain4j.service.IllegalConfigurationException
46+
illegalConfigurationException.printStackTrace()
47+
return
48+
}
49+
throw e
50+
}
51+
}
52+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package me.kpavlov.langchain4j.kotlin.rag
2+
3+
import dev.langchain4j.rag.AugmentationRequest
4+
import dev.langchain4j.rag.AugmentationResult
5+
import dev.langchain4j.rag.RetrievalAugmentor
6+
import kotlinx.coroutines.coroutineScope
7+
8+
public suspend fun RetrievalAugmentor.augmentAsync(
9+
request: AugmentationRequest,
10+
): AugmentationResult = coroutineScope { this@augmentAsync.augment(request) }
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package me.kpavlov.langchain4j.kotlin.service
2+
3+
import dev.langchain4j.service.AiServiceContext
4+
import dev.langchain4j.service.AiServices
5+
import dev.langchain4j.spi.services.AiServicesFactory
6+
7+
/**
8+
* Creates an [AiServices] instance using the provided [AiServicesFactory].
9+
*/
10+
public fun <T> createAiService(
11+
serviceClass: Class<T>,
12+
factory: AiServicesFactory,
13+
): AiServices<T> = AiServiceContext(serviceClass).let { context -> factory.create<T>(context) }
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package me.kpavlov.langchain4j.kotlin.service
2+
3+
import dev.langchain4j.service.AiServiceContext
4+
import dev.langchain4j.service.AiServices
5+
import dev.langchain4j.service.ChatMemoryAccess
6+
import dev.langchain4j.service.IllegalConfigurationException.illegalConfiguration
7+
import dev.langchain4j.service.MemoryId
8+
import dev.langchain4j.service.Moderate
9+
import dev.langchain4j.service.Result
10+
import dev.langchain4j.service.TypeUtils
11+
import dev.langchain4j.service.output.ServiceOutputParser
12+
import dev.langchain4j.spi.ServiceHelper
13+
import dev.langchain4j.spi.services.TokenStreamAdapter
14+
15+
public class AsyncAiServices<T : Any>(
16+
context: AiServiceContext,
17+
) : AiServices<T>(context) {
18+
private val serviceOutputParser = ServiceOutputParser()
19+
private val tokenStreamAdapters =
20+
ServiceHelper.loadFactories<TokenStreamAdapter>(TokenStreamAdapter::class.java)
21+
22+
@Suppress("NestedBlockDepth")
23+
override fun build(): T {
24+
performBasicValidation()
25+
26+
if (!context.hasChatMemory() &&
27+
ChatMemoryAccess::class.java.isAssignableFrom(context.aiServiceClass)
28+
) {
29+
throw illegalConfiguration(
30+
"In order to have a service implementing ChatMemoryAccess, " +
31+
"please configure the ChatMemoryProvider on the '%s'.",
32+
context.aiServiceClass.name,
33+
)
34+
}
35+
36+
for (method in context.aiServiceClass.methods) {
37+
if (method.isAnnotationPresent(Moderate::class.java) &&
38+
context.moderationModel == null
39+
) {
40+
throw illegalConfiguration(
41+
"The @Moderate annotation is present, but the moderationModel is not set up. " +
42+
"Please ensure a valid moderationModel is configured " +
43+
"before using the @Moderate annotation.",
44+
)
45+
}
46+
if (method.returnType in
47+
arrayOf(
48+
// supported collection types
49+
Result::class.java,
50+
MutableList::class.java,
51+
MutableSet::class.java,
52+
)
53+
) {
54+
TypeUtils.validateReturnTypesAreProperlyParametrized(
55+
method.name,
56+
method.genericReturnType,
57+
)
58+
}
59+
60+
if (!context.hasChatMemory()) {
61+
for (parameter in method.parameters) {
62+
if (parameter.isAnnotationPresent(MemoryId::class.java)) {
63+
throw illegalConfiguration(
64+
"In order to use @MemoryId, please configure " +
65+
"ChatMemoryProvider on the '%s'.",
66+
context.aiServiceClass.name,
67+
)
68+
}
69+
}
70+
}
71+
}
72+
73+
val handler = ServiceInvocationHandler<T>(context, serviceOutputParser, tokenStreamAdapters)
74+
@Suppress("UNCHECKED_CAST", "unused")
75+
return ReflectionHelper.createSuspendProxy(context.aiServiceClass) { method, args ->
76+
return@createSuspendProxy handler.invoke(method, args)
77+
} as T
78+
}
79+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package me.kpavlov.langchain4j.kotlin.service
2+
3+
import dev.langchain4j.service.AiServiceContext
4+
import dev.langchain4j.service.AiServices
5+
import dev.langchain4j.spi.services.AiServicesFactory
6+
7+
public class AsyncAiServicesFactory : AiServicesFactory {
8+
override fun <T : Any> create(context: AiServiceContext): AiServices<T> =
9+
AsyncAiServices(context)
10+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package me.kpavlov.langchain4j.kotlin.service
2+
3+
import kotlinx.coroutines.DelicateCoroutinesApi
4+
import kotlinx.coroutines.GlobalScope
5+
import kotlinx.coroutines.asCoroutineDispatcher
6+
import kotlinx.coroutines.launch
7+
import kotlinx.coroutines.runBlocking
8+
import java.lang.reflect.InvocationHandler
9+
import java.lang.reflect.Method
10+
import java.lang.reflect.ParameterizedType
11+
import java.lang.reflect.Proxy
12+
import java.lang.reflect.Type
13+
import java.lang.reflect.WildcardType
14+
import java.util.concurrent.Executors
15+
import kotlin.coroutines.Continuation
16+
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
17+
import kotlin.coroutines.resume
18+
import kotlin.coroutines.resumeWithException
19+
20+
@OptIn(DelicateCoroutinesApi::class)
21+
internal object ReflectionHelper {
22+
private val vtDispatcher = Executors.newVirtualThreadPerTaskExecutor().asCoroutineDispatcher()
23+
24+
@Throws(kotlin.IllegalStateException::class)
25+
private fun getReturnType(method: Method): Type {
26+
val continuationParam = method.parameterTypes.findLast { it.kotlin is Continuation<*> }
27+
if (continuationParam != null) {
28+
@Suppress("UseCheckOrError")
29+
return continuationParam.genericInterfaces[0]
30+
?: throw IllegalStateException(
31+
"Can't find generic interface of continuation parameter",
32+
)
33+
}
34+
return method.getGenericReturnType()
35+
}
36+
37+
fun getSuspendReturnType(method: Method): java.lang.reflect.Type {
38+
val parameters = method.genericParameterTypes
39+
if (parameters.isEmpty()) return getReturnType(method)
40+
val lastParameter = parameters.last()
41+
// Check if the last parameter is Continuation<T>
42+
return (
43+
if (lastParameter is ParameterizedType &&
44+
(lastParameter.rawType as? Class<*>)?.name == Continuation::class.java.name
45+
) {
46+
// T is the first (and only) type argument
47+
val type = lastParameter.actualTypeArguments.first()
48+
if (type is WildcardType) {
49+
type.lowerBounds.first()
50+
} else {
51+
type
52+
}
53+
} else {
54+
getReturnType(method) // Not a suspend function, or not detectably so
55+
}
56+
)
57+
}
58+
59+
@Suppress("UNCHECKED_CAST")
60+
fun <T : Any> createSuspendProxy(
61+
iface: Class<T>,
62+
handler: suspend (method: java.lang.reflect.Method, args: Array<Any?>) -> Any?,
63+
): T {
64+
return Proxy.newProxyInstance(
65+
iface.classLoader,
66+
arrayOf(iface),
67+
InvocationHandler { _, method, args ->
68+
// If not a suspend method, optionally fall back
69+
val cont =
70+
args.lastOrNull() as? Continuation<Any?>
71+
?: return@InvocationHandler method.invoke(this, args)
72+
73+
// Remove Continuation for our handler
74+
val argsForSuspend = args.dropLast(1).toTypedArray()
75+
76+
// Launch coroutine for the suspend implementation
77+
// (here, for demonstration, using a helper)
78+
// Use coroutine machinery to start the suspend block
79+
// If using Kotlin 1.3+, this is the correct way
80+
// Uses GlobalScope (be sure that's okay for your use-case!)
81+
GlobalScope.launch(vtDispatcher) {
82+
@Suppress("TooGenericExceptionCaught")
83+
try {
84+
val result = handler(method, argsForSuspend)
85+
cont.resume(result)
86+
} catch (e: Throwable) {
87+
cont.resumeWithException(e)
88+
}
89+
}
90+
COROUTINE_SUSPENDED
91+
},
92+
) as T
93+
}
94+
95+
@FunctionalInterface
96+
interface MyApi {
97+
suspend fun greet(name: String): String
98+
}
99+
100+
internal fun dropContinuationArg(args: Array<Any?>): Array<Any?> =
101+
args
102+
.dropLastWhile {
103+
it is Continuation<*>
104+
}.toTypedArray()
105+
}
106+
107+
public fun main() {
108+
val proxy =
109+
ReflectionHelper.createSuspendProxy(ReflectionHelper.MyApi::class.java) { method, args ->
110+
"${Thread.currentThread()}: Hello, $method(${args[0]} )"
111+
}
112+
113+
runBlocking {
114+
println(proxy.greet("world")) // Prints: Hello, world
115+
}
116+
}

0 commit comments

Comments
 (0)