diff --git a/langchain4j-kotlin/pom.xml b/langchain4j-kotlin/pom.xml index bb99bac..f98daad 100644 --- a/langchain4j-kotlin/pom.xml +++ b/langchain4j-kotlin/pom.xml @@ -56,6 +56,12 @@ mockito-junit-jupiter test + + + me.kpavlov.aimocks + ai-mocks-openai + test + diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt index 8e65690..6840fb3 100644 --- a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt +++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/TestEnvironment.kt @@ -1,7 +1,10 @@ package me.kpavlov.langchain4j.kotlin +import me.kpavlov.aimocks.openai.MockOpenai + object TestEnvironment : me.kpavlov.finchly.BaseTestEnvironment( dotEnvFileDir = "../", ) { - val openaiApiKey = TestEnvironment.get("OPENAI_API_KEY", "demo") + val openaiApiKey = get("OPENAI_API_KEY", "demo") + val mockOpenAi = MockOpenai() } diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt index 6dc2db6..67302bd 100644 --- a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt +++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelExtensionsKtTest.kt @@ -7,6 +7,7 @@ import assertk.assertions.hasMessage import dev.langchain4j.data.message.AiMessage import dev.langchain4j.data.message.UserMessage.userMessage import dev.langchain4j.model.chat.StreamingChatLanguageModel +import dev.langchain4j.model.chat.request.ChatRequest import dev.langchain4j.model.chat.response.ChatResponse import dev.langchain4j.model.chat.response.StreamingChatResponseHandler import kotlinx.coroutines.flow.toList @@ -40,7 +41,7 @@ internal class StreamingChatLanguageModelExtensionsKtTest { handler.onPartialResponse(partialToken1) handler.onPartialResponse(partialToken2) handler.onCompleteResponse(completeResponse) - }.whenever(mockModel).chat(any(), any()) + }.whenever(mockModel).chat(any(), any()) val flow = mockModel.chatFlow { @@ -58,7 +59,7 @@ internal class StreamingChatLanguageModelExtensionsKtTest { ) // Verify interactions - verify(mockModel).chat(any(), any()) + verify(mockModel).chat(any(), any()) } @Test @@ -70,7 +71,7 @@ internal class StreamingChatLanguageModelExtensionsKtTest { doAnswer { val handler = it.arguments[1] as StreamingChatResponseHandler handler.onError(error) - }.whenever(mockModel).chat(any(), any()) + }.whenever(mockModel).chat(any(), any()) val flow = mockModel.chatFlow { diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt index 9a0509e..e2d207c 100644 --- a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt +++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/StreamingChatLanguageModelIT.kt @@ -8,33 +8,31 @@ import dev.langchain4j.data.message.SystemMessage.systemMessage import dev.langchain4j.data.message.UserMessage.userMessage import dev.langchain4j.model.chat.StreamingChatLanguageModel import dev.langchain4j.model.chat.response.ChatResponse -import dev.langchain4j.model.openai.OpenAiStreamingChatModel +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield import me.kpavlov.langchain4j.kotlin.TestEnvironment +import me.kpavlov.langchain4j.kotlin.TestEnvironment.mockOpenAi import me.kpavlov.langchain4j.kotlin.loadDocument import me.kpavlov.langchain4j.kotlin.model.chat.StreamingChatLanguageModelReply.CompleteResponse import me.kpavlov.langchain4j.kotlin.model.chat.StreamingChatLanguageModelReply.PartialResponse +import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.fail import org.junit.jupiter.api.Test -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import org.slf4j.LoggerFactory +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicReference -@EnabledIfEnvironmentVariable( - named = "OPENAI_API_KEY", - matches = ".+", -) -class StreamingChatLanguageModelIT { +internal class StreamingChatLanguageModelIT { private val logger = LoggerFactory.getLogger(javaClass) - private val model: StreamingChatLanguageModel = - OpenAiStreamingChatModel - .builder() - .apiKey(TestEnvironment.openaiApiKey) - .modelName("gpt-4o-mini") - .temperature(0.0) - .maxTokens(100) - .build() + private val model: StreamingChatLanguageModel = createOpenAiStreamingModel() + + @AfterEach + fun afterEach() { + mockOpenAi.verifyNoUnmatchedRequests() + } @Test fun `StreamingChatLanguageModel should generateFlow`() = @@ -54,9 +52,15 @@ class StreamingChatLanguageModelIT { """.trimIndent(), ) + setupMockResponseIfNecessary( + systemMessage.text(), + "What does Blumblefang love", + "Blumblefang loves to help and cookies", + ) + val responseRef = AtomicReference() - val collectedTokens = mutableListOf() + val collectedTokens = ConcurrentLinkedQueue() model .chatFlow { @@ -66,10 +70,11 @@ class StreamingChatLanguageModelIT { when (it) { is PartialResponse -> { println("Token: '${it.token}'") - collectedTokens.add(it.token) + collectedTokens += it.token } is CompleteResponse -> responseRef.set(it.response) + is StreamingChatLanguageModelReply.Error -> fail("Error", it.cause) else -> fail("Unsupported event: $it") } } @@ -82,4 +87,31 @@ class StreamingChatLanguageModelIT { assertThat(collectedTokens.joinToString("")).isEqualTo(textContent) assertThat(textContent).contains("Blumblefang loves to help") } + + fun setupMockResponseIfNecessary( + expectedSystemMessage: String, + expectedUserMessage: String, + expectedAnswer: String, + ) { + if (TestEnvironment["OPENAI_API_KEY"] != null) { + logger.error("Running with real OpenAI API") + return + } + logger.error("Running with Mock OpenAI API (Ai-Mocks/Mokksy)") + + mockOpenAi.completion { + requestBodyContains(expectedSystemMessage) + requestBodyContains(expectedUserMessage) + } respondsStream { + responseFlow = + flow { + expectedAnswer.split(" ").forEach { token -> + emit("$token ") + yield() + delay(42) + } + } + sendDone = true + } + } } diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/TestSetup.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/TestSetup.kt new file mode 100644 index 0000000..407cdfa --- /dev/null +++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/model/chat/TestSetup.kt @@ -0,0 +1,29 @@ +package me.kpavlov.langchain4j.kotlin.model.chat + +import dev.langchain4j.model.chat.StreamingChatLanguageModel +import dev.langchain4j.model.openai.OpenAiStreamingChatModel +import dev.langchain4j.model.openai.OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder +import me.kpavlov.langchain4j.kotlin.TestEnvironment + +internal fun createOpenAiStreamingModel( + configurer: OpenAiStreamingChatModelBuilder.() -> Unit = {}, +): StreamingChatLanguageModel { + val modelBuilder = + OpenAiStreamingChatModel + .builder() + .modelName("gpt-4o-mini") + .temperature(0.1) + .maxTokens(100) + + val apiKey = TestEnvironment["OPENAI_API_KEY"] + if (apiKey != null) { + modelBuilder.apiKey(apiKey) + } else { + modelBuilder + .apiKey("my-key") + .baseUrl("http://localhost:${TestEnvironment.mockOpenAi.port()}/v1") + } + configurer.invoke(modelBuilder) + + return modelBuilder.build() +} diff --git a/pom.xml b/pom.xml index 7eaa3d1..89f703c 100644 --- a/pom.xml +++ b/pom.xml @@ -41,7 +41,7 @@ - + UTF-8 official 17 @@ -53,11 +53,12 @@ ${java.version} ${java.version} + 0.1.1 4.2.2 0.1.1 5.11.4 1.10.1 - 1.0.0-alpha1 + 1.0.0-beta1 5.4.0 5.15.2 2.0.16 @@ -111,6 +112,13 @@ pom import + + me.kpavlov.aimocks + bom + ${ai-mocks.version} + pom + import + org.mockito.kotlin mockito-kotlin @@ -135,6 +143,7 @@ ${finchly.version} test + @@ -148,7 +157,6 @@ org.awaitility awaitility-kotlin - ${awaitility.version} test @@ -198,6 +206,10 @@ org.jetbrains.dokka dokka-maven-plugin 2.0.0 + + 1.9 + false + com.github.ozsie @@ -239,17 +251,17 @@ origin/main - + - + **/*.md - + diff --git a/samples/pom.xml b/samples/pom.xml index 8de50d7..30e69dc 100644 --- a/samples/pom.xml +++ b/samples/pom.xml @@ -22,7 +22,7 @@ 0.1.1 1.9.0 0.1.7 - 1.0.0-alpha1 + 1.0.0-beta1 2.0.16