diff --git a/src/commonMain/kotlin/app/cash/turbine/flow.kt b/src/commonMain/kotlin/app/cash/turbine/flow.kt index cbbf072e..2ddd9165 100644 --- a/src/commonMain/kotlin/app/cash/turbine/flow.kt +++ b/src/commonMain/kotlin/app/cash/turbine/flow.kt @@ -222,7 +222,7 @@ private fun Flow.collectTurbineIn(scope: CoroutineScope, timeout: Duratio } } -internal fun Flow.collectIntoChannel(scope: CoroutineScope): Channel { +private fun Flow.collectIntoChannel(scope: CoroutineScope): Channel { val output = Channel(UNLIMITED) val job = scope.launch(start = UNDISPATCHED) { try { diff --git a/src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt b/src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt index 8d86d7c7..d4d03571 100644 --- a/src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt +++ b/src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt @@ -20,16 +20,11 @@ import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertSame import kotlin.time.Duration.Companion.milliseconds -import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.launch import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.runTest @@ -41,13 +36,7 @@ class ChannelTest { val expected = CustomThrowable("hello") val actual = assertFailsWith { - val channel = flow { - emit(1) - emit(2) - emit(3) - throw expected - }.collectIntoChannel(this) - + val channel = channelOf(1, 2, 3, closeCause = expected) channel.expectMostRecentItem() } assertSame(expected, actual) @@ -56,7 +45,7 @@ class ChannelTest { @Test fun expectMostRecentItemButNoItemWasFoundThrows() = runTest { val actual = assertFailsWith { - val channel = emptyFlow().collectIntoChannel(this) + val channel = channelOf() channel.expectMostRecentItem() } assertEquals("No item was found", actual.message) @@ -64,51 +53,33 @@ class ChannelTest { @Test fun expectMostRecentItem() = runTest { - val onTwoSent = CompletableDeferred() - val onTwoContinue = CompletableDeferred() - val onCompleteSent = CompletableDeferred() - val onCompleteContinue = CompletableDeferred() - - val channel = flowOf(1, 2, 3, 4, 5) - .map { - if (it == 3) { - onTwoSent.complete(Unit) - onTwoContinue.await() - } - it - } - .onCompletion { - onCompleteSent.complete(Unit) - onCompleteContinue.await() - } - .collectIntoChannel(this) + val channel = Channel(UNLIMITED) + channel.trySend(1) + channel.trySend(2) - onTwoSent.await() assertEquals(2, channel.expectMostRecentItem()) - onTwoContinue.complete(Unit) - onCompleteSent.await() + channel.trySend(3) + channel.trySend(4) + channel.trySend(5) assertEquals(5, channel.expectMostRecentItem()) - onCompleteContinue.complete(Unit) - - channel.cancel() } @Test fun assertNullValuesWithExpectMostRecentItem() = runTest { - val channel = flowOf(1, 2, null).collectIntoChannel(this) + val channel = channelOf(1, 2, null) assertEquals(null, channel.expectMostRecentItem()) } @Test fun awaitItemsAreSkipped() = runTest { - val channel = flowOf(1, 2, 3).collectIntoChannel(this) + val channel = channelOf(1, 2, 3) channel.skipItems(2) assertEquals(3, channel.awaitItem()) } @Test fun skipItemsThrowsOnComplete() = runTest { - val channel = flowOf(1, 2).collectIntoChannel(this) + val channel = channelOf(1, 2) val message = assertFailsWith { channel.skipItems(3) }.message @@ -116,7 +87,7 @@ class ChannelTest { } @Test fun expectErrorOnCompletionBeforeAllItemsWereSkipped() = runTest { - val channel = flowOf(1).collectIntoChannel(this) + val channel = channelOf(1) assertFailsWith { channel.skipItems(2) } @@ -124,10 +95,7 @@ class ChannelTest { @Test fun expectErrorOnErrorReceivedBeforeAllItemsWereSkipped() = runTest { val error = CustomThrowable("hello") - val channel = flow { - emit(1) - throw error - }.collectIntoChannel(this) + val channel = channelOf(1, closeCause = error) val actual = assertFailsWith { channel.skipItems(2) } @@ -135,40 +103,40 @@ class ChannelTest { } @Test fun expectNoEvents() = runTest { - val channel = neverFlow().collectIntoChannel(this) + val channel = neverChannel() channel.expectNoEvents() channel.cancel() } @Test fun awaitItemEvent() = runTest { val item = Any() - val channel = flowOf(item).collectIntoChannel(this) + val channel = channelOf(item) val event = channel.awaitEvent() assertEquals(Event.Item(item), event) } @Test fun expectCompleteEvent() = runTest { - val channel = emptyFlow().collectIntoChannel(this) + val channel = emptyChannel() val event = channel.awaitEvent() assertEquals(Event.Complete, event) } @Test fun expectErrorEvent() = runTest { val exception = CustomThrowable("hello") - val channel = flow { throw exception }.collectIntoChannel(this) + val channel = channelOf(closeCause = exception) val event = channel.awaitEvent() assertEquals(Event.Error(exception), event) } @Test fun awaitItem() = runTest { val item = Any() - val channel = flowOf(item).collectIntoChannel(this) + val channel = channelOf(item) assertSame(item, channel.awaitItem()) } @Test fun awaitItemButWasCloseThrows() = runTest { val actual = assertFailsWith { - emptyFlow().collectIntoChannel(this).awaitItem() + emptyChannel().awaitItem() } assertEquals("Expected item but found Complete", actual.message) } @@ -176,58 +144,55 @@ class ChannelTest { @Test fun awaitItemButWasErrorThrows() = runTest { val error = CustomThrowable("hello") val actual = assertFailsWith { - flow { throw error }.collectIntoChannel(this) - .awaitItem() + channelOf(closeCause = error).awaitItem() } assertEquals("Expected item but found Error(CustomThrowable)", actual.message) assertSame(error, actual.cause) } @Test fun awaitComplete() = runTest { - emptyFlow().collectIntoChannel(this).awaitComplete() + emptyChannel().awaitComplete() } @Test fun awaitCompleteButWasItemThrows() = runTest { val actual = assertFailsWith { - flowOf("item!").collectIntoChannel(this) - .awaitComplete() + channelOf("item!").awaitComplete() } assertEquals("Expected complete but found Item(item!)", actual.message) } @Test fun awaitCompleteButWasErrorThrows() = runTest { + val error = CustomThrowable("hello") val actual = assertFailsWith { - flow { throw RuntimeException() }.collectIntoChannel(this) - .awaitComplete() + channelOf(closeCause = error).awaitComplete() } - assertEquals("Expected complete but found Error(RuntimeException)", actual.message) + assertEquals("Expected complete but found Error(CustomThrowable)", actual.message) + assertSame(error, actual.cause) } @Test fun awaitError() = runTest { val error = CustomThrowable("hello") - val channel = flow { throw error }.collectIntoChannel(this) + val channel = channelOf(closeCause = error) assertSame(error, channel.awaitError()) } @Test fun awaitErrorButWasItemThrows() = runTest { val actual = assertFailsWith { - flowOf("item!").collectIntoChannel(this).awaitError() + channelOf("item!").awaitError() } assertEquals("Expected error but found Item(item!)", actual.message) } @Test fun awaitErrorButWasCompleteThrows() = runTest { val actual = assertFailsWith { - emptyFlow().collectIntoChannel(this).awaitError() + emptyChannel().awaitError() } assertEquals("Expected error but found Complete", actual.message) } @Test fun failsOnDefaultTimeout() = runTest { val actual = assertFailsWith { - coroutineScope { - neverFlow().collectIntoChannel(this).awaitItem() - } + neverChannel().awaitItem() } assertEquals("No value produced in 3s", actual.message) assertCallSitePresentInStackTraceOnJvm( @@ -240,7 +205,7 @@ class ChannelTest { @Test fun awaitHonorsCoroutineContextTimeoutNoTimeout() = runTest { withTurbineTimeout(1500.milliseconds) { val job = launch { - neverFlow().collectIntoChannel(this).awaitItem() + neverChannel().awaitItem() } withContext(Dispatchers.Default) { @@ -253,7 +218,7 @@ class ChannelTest { @Test fun awaitHonorsCoroutineContextTimeoutTimeout() = runTest { val actual = assertFailsWith { withTurbineTimeout(10.milliseconds) { - neverFlow().collectIntoChannel(this).awaitItem() + neverChannel().awaitItem() } } assertEquals("No value produced in 10ms", actual.message) @@ -277,13 +242,13 @@ class ChannelTest { @Test fun takeItem() = withTestScope { val item = Any() - val channel = flowOf(item).collectIntoChannel(this) + val channel = channelOf(item) assertSame(item, channel.takeItem()) } @Test fun takeItemButWasCloseThrows() = withTestScope { val actual = assertFailsWith { - emptyFlow().collectIntoChannel(this).takeItem() + emptyChannel().takeItem() } assertEquals("Expected item but found Complete", actual.message) } @@ -291,8 +256,7 @@ class ChannelTest { @Test fun takeItemButWasErrorThrows() = withTestScope { val error = CustomThrowable("hello") val actual = assertFailsWith { - flow { throw error }.collectIntoChannel(this) - .takeItem() + channelOf(closeCause = error).takeItem() } assertEquals("Expected item but found Error(CustomThrowable)", actual.message) assertSame(error, actual.cause) @@ -301,30 +265,28 @@ class ChannelTest { @Test fun expectMostRecentItemButNoItemWasFoundThrowsWithName() = runTest { val actual = assertFailsWith { - val channel = emptyFlow().collectIntoChannel(this) - channel.expectMostRecentItem(name = "empty flow") + emptyChannel().expectMostRecentItem(name = "empty flow") } assertEquals("No item was found for empty flow", actual.message) } @Test fun awaitItemButWasCloseThrowsWithName() = runTest { val actual = assertFailsWith { - emptyFlow().collectIntoChannel(this).awaitItem(name = "closed flow") + emptyChannel().awaitItem(name = "closed flow") } assertEquals("Expected item for closed flow but found Complete", actual.message) } @Test fun awaitCompleteButWasItemThrowsWithName() = runTest { val actual = assertFailsWith { - flowOf("item!").collectIntoChannel(this) - .awaitComplete(name = "item flow") + channelOf("item!").awaitComplete(name = "item flow") } assertEquals("Expected complete for item flow but found Item(item!)", actual.message) } @Test fun awaitErrorButWasItemThrowsWithName() = runTest { val actual = assertFailsWith { - flowOf("item!").collectIntoChannel(this).awaitError(name = "item flow") + channelOf("item!").awaitError(name = "item flow") } assertEquals("Expected error for item flow but found Item(item!)", actual.message) } @@ -332,7 +294,7 @@ class ChannelTest { @Test fun awaitHonorsCoroutineContextTimeoutTimeoutWithName() = runTest { val actual = assertFailsWith { withTurbineTimeout(10.milliseconds) { - neverFlow().collectIntoChannel(this).awaitItem(name = "never flow") + neverChannel().awaitItem(name = "never flow") } } assertEquals("No value produced for never flow in 10ms", actual.message) @@ -340,13 +302,13 @@ class ChannelTest { @Test fun takeItemButWasCloseThrowsWithName() = withTestScope { val actual = assertFailsWith { - emptyFlow().collectIntoChannel(this).takeItem(name = "empty flow") + emptyChannel().takeItem(name = "empty flow") } assertEquals("Expected item for empty flow but found Complete", actual.message) } @Test fun skipItemsThrowsOnCompleteWithName() = runTest { - val channel = flowOf(1, 2).collectIntoChannel(this) + val channel = channelOf(1, 2) val message = assertFailsWith { channel.skipItems(3, name = "two item channel") }.message diff --git a/src/commonTest/kotlin/app/cash/turbine/testUtil.common.kt b/src/commonTest/kotlin/app/cash/turbine/testUtil.common.kt index a867719d..bfd7db9f 100644 --- a/src/commonTest/kotlin/app/cash/turbine/testUtil.common.kt +++ b/src/commonTest/kotlin/app/cash/turbine/testUtil.common.kt @@ -16,6 +16,9 @@ package app.cash.turbine import kotlinx.coroutines.awaitCancellation +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED +import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow @@ -33,3 +36,15 @@ expect fun assertCallSitePresentInStackTraceOnJvm( entryPoint: String, callSite: String, ) + +fun channelOf(vararg items: T, closeCause: Throwable? = null): ReceiveChannel { + return Channel(UNLIMITED).also { channel -> + for (item in items) { + channel.trySend(item).getOrThrow() + } + channel.close(closeCause) + } +} + +fun emptyChannel(): ReceiveChannel = channelOf() +fun neverChannel(): ReceiveChannel = Channel() diff --git a/src/jvmTest/kotlin/app/cash/turbine/ChannelJvmTest.kt b/src/jvmTest/kotlin/app/cash/turbine/ChannelJvmTest.kt index eca40b83..0bfbcbc6 100644 --- a/src/jvmTest/kotlin/app/cash/turbine/ChannelJvmTest.kt +++ b/src/jvmTest/kotlin/app/cash/turbine/ChannelJvmTest.kt @@ -2,7 +2,6 @@ package app.cash.turbine import kotlin.test.assertEquals import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.test.runTest import org.junit.Test @@ -10,7 +9,7 @@ class ChannelJvmTest { @Test fun takeItemSuspendingThrows() = runTest { val actual = assertFailsWith { - emptyFlow().collectIntoChannel(this).takeItem() + emptyChannel().takeItem() } assertEquals("Calling context is suspending; use a suspending method instead", actual.message) }