diff --git a/langchain4j-kotlin/pom.xml b/langchain4j-kotlin/pom.xml
index bb99bac..f98daad 100644
--- a/langchain4j-kotlin/pom.xml
+++ b/langchain4j-kotlin/pom.xml
@@ -56,6 +56,12 @@
mockito-junit-jupiter
test
+
+
+ me.kpavlov.aimocks
+ ai-mocks-openai
+ test
+
diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt
index 8e65690..6840fb3 100644
--- a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt
+++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt
@@ -1,7 +1,10 @@
package me.kpavlov.langchain4j.kotlin
+import me.kpavlov.aimocks.openai.MockOpenai
+
object TestEnvironment : me.kpavlov.finchly.BaseTestEnvironment(
dotEnvFileDir = "../",
) {
- val openaiApiKey = TestEnvironment.get("OPENAI_API_KEY", "demo")
+ val openaiApiKey = get("OPENAI_API_KEY", "demo")
+ val mockOpenAi = MockOpenai()
}
diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt
index 6dc2db6..67302bd 100644
--- a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt
+++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt
@@ -7,6 +7,7 @@ import assertk.assertions.hasMessage
import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.data.message.UserMessage.userMessage
import dev.langchain4j.model.chat.StreamingChatLanguageModel
+import dev.langchain4j.model.chat.request.ChatRequest
import dev.langchain4j.model.chat.response.ChatResponse
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler
import kotlinx.coroutines.flow.toList
@@ -40,7 +41,7 @@ internal class StreamingChatLanguageModelExtensionsKtTest {
handler.onPartialResponse(partialToken1)
handler.onPartialResponse(partialToken2)
handler.onCompleteResponse(completeResponse)
- }.whenever(mockModel).chat(any(), any())
+ }.whenever(mockModel).chat(any(), any())
val flow =
mockModel.chatFlow {
@@ -58,7 +59,7 @@ internal class StreamingChatLanguageModelExtensionsKtTest {
)
// Verify interactions
- verify(mockModel).chat(any(), any())
+ verify(mockModel).chat(any(), any())
}
@Test
@@ -70,7 +71,7 @@ internal class StreamingChatLanguageModelExtensionsKtTest {
doAnswer {
val handler = it.arguments[1] as StreamingChatResponseHandler
handler.onError(error)
- }.whenever(mockModel).chat(any(), any())
+ }.whenever(mockModel).chat(any(), any())
val flow =
mockModel.chatFlow {
diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt
index 9a0509e..e2d207c 100644
--- a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt
+++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt
@@ -8,33 +8,31 @@ import dev.langchain4j.data.message.SystemMessage.systemMessage
import dev.langchain4j.data.message.UserMessage.userMessage
import dev.langchain4j.model.chat.StreamingChatLanguageModel
import dev.langchain4j.model.chat.response.ChatResponse
-import dev.langchain4j.model.openai.OpenAiStreamingChatModel
+import kotlinx.coroutines.delay
+import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
+import kotlinx.coroutines.yield
import me.kpavlov.langchain4j.kotlin.TestEnvironment
+import me.kpavlov.langchain4j.kotlin.TestEnvironment.mockOpenAi
import me.kpavlov.langchain4j.kotlin.loadDocument
import me.kpavlov.langchain4j.kotlin.model.chat.StreamingChatLanguageModelReply.CompleteResponse
import me.kpavlov.langchain4j.kotlin.model.chat.StreamingChatLanguageModelReply.PartialResponse
+import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.fail
import org.junit.jupiter.api.Test
-import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable
import org.slf4j.LoggerFactory
+import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicReference
-@EnabledIfEnvironmentVariable(
- named = "OPENAI_API_KEY",
- matches = ".+",
-)
-class StreamingChatLanguageModelIT {
+internal class StreamingChatLanguageModelIT {
private val logger = LoggerFactory.getLogger(javaClass)
- private val model: StreamingChatLanguageModel =
- OpenAiStreamingChatModel
- .builder()
- .apiKey(TestEnvironment.openaiApiKey)
- .modelName("gpt-4o-mini")
- .temperature(0.0)
- .maxTokens(100)
- .build()
+ private val model: StreamingChatLanguageModel = createOpenAiStreamingModel()
+
+ @AfterEach
+ fun afterEach() {
+ mockOpenAi.verifyNoUnmatchedRequests()
+ }
@Test
fun `StreamingChatLanguageModel should generateFlow`() =
@@ -54,9 +52,15 @@ class StreamingChatLanguageModelIT {
""".trimIndent(),
)
+ setupMockResponseIfNecessary(
+ systemMessage.text(),
+ "What does Blumblefang love",
+ "Blumblefang loves to help and cookies",
+ )
+
val responseRef = AtomicReference()
- val collectedTokens = mutableListOf()
+ val collectedTokens = ConcurrentLinkedQueue()
model
.chatFlow {
@@ -66,10 +70,11 @@ class StreamingChatLanguageModelIT {
when (it) {
is PartialResponse -> {
println("Token: '${it.token}'")
- collectedTokens.add(it.token)
+ collectedTokens += it.token
}
is CompleteResponse -> responseRef.set(it.response)
+ is StreamingChatLanguageModelReply.Error -> fail("Error", it.cause)
else -> fail("Unsupported event: $it")
}
}
@@ -82,4 +87,31 @@ class StreamingChatLanguageModelIT {
assertThat(collectedTokens.joinToString("")).isEqualTo(textContent)
assertThat(textContent).contains("Blumblefang loves to help")
}
+
+ fun setupMockResponseIfNecessary(
+ expectedSystemMessage: String,
+ expectedUserMessage: String,
+ expectedAnswer: String,
+ ) {
+ if (TestEnvironment["OPENAI_API_KEY"] != null) {
+ logger.error("Running with real OpenAI API")
+ return
+ }
+ logger.error("Running with Mock OpenAI API (Ai-Mocks/Mokksy)")
+
+ mockOpenAi.completion {
+ requestBodyContains(expectedSystemMessage)
+ requestBodyContains(expectedUserMessage)
+ } respondsStream {
+ responseFlow =
+ flow {
+ expectedAnswer.split(" ").forEach { token ->
+ emit("$token ")
+ yield()
+ delay(42)
+ }
+ }
+ sendDone = true
+ }
+ }
}
diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/TestSetup.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/TestSetup.kt
new file mode 100644
index 0000000..407cdfa
--- /dev/null
+++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/TestSetup.kt
@@ -0,0 +1,29 @@
+package me.kpavlov.langchain4j.kotlin.model.chat
+
+import dev.langchain4j.model.chat.StreamingChatLanguageModel
+import dev.langchain4j.model.openai.OpenAiStreamingChatModel
+import dev.langchain4j.model.openai.OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder
+import me.kpavlov.langchain4j.kotlin.TestEnvironment
+
+internal fun createOpenAiStreamingModel(
+ configurer: OpenAiStreamingChatModelBuilder.() -> Unit = {},
+): StreamingChatLanguageModel {
+ val modelBuilder =
+ OpenAiStreamingChatModel
+ .builder()
+ .modelName("gpt-4o-mini")
+ .temperature(0.1)
+ .maxTokens(100)
+
+ val apiKey = TestEnvironment["OPENAI_API_KEY"]
+ if (apiKey != null) {
+ modelBuilder.apiKey(apiKey)
+ } else {
+ modelBuilder
+ .apiKey("my-key")
+ .baseUrl("http://localhost:${TestEnvironment.mockOpenAi.port()}/v1")
+ }
+ configurer.invoke(modelBuilder)
+
+ return modelBuilder.build()
+}
diff --git a/pom.xml b/pom.xml
index 7eaa3d1..89f703c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -41,7 +41,7 @@
-
+
UTF-8
official
17
@@ -53,11 +53,12 @@
${java.version}
${java.version}
+ 0.1.1
4.2.2
0.1.1
5.11.4
1.10.1
- 1.0.0-alpha1
+ 1.0.0-beta1
5.4.0
5.15.2
2.0.16
@@ -111,6 +112,13 @@
pom
import
+
+ me.kpavlov.aimocks
+ bom
+ ${ai-mocks.version}
+ pom
+ import
+
org.mockito.kotlin
mockito-kotlin
@@ -135,6 +143,7 @@
${finchly.version}
test
+
@@ -148,7 +157,6 @@
org.awaitility
awaitility-kotlin
- ${awaitility.version}
test
@@ -198,6 +206,10 @@
org.jetbrains.dokka
dokka-maven-plugin
2.0.0
+
+ 1.9
+ false
+
com.github.ozsie
@@ -239,17 +251,17 @@
origin/main
-
+
-
+
**/*.md
-
+
diff --git a/samples/pom.xml b/samples/pom.xml
index 8de50d7..30e69dc 100644
--- a/samples/pom.xml
+++ b/samples/pom.xml
@@ -22,7 +22,7 @@
0.1.1
1.9.0
0.1.7
- 1.0.0-alpha1
+ 1.0.0-beta1
2.0.16