diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 264f22b..cfc2e58 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,4 @@ Thank you for submitting a pull request! But first: - [ ] Can you back your code up with tests? - - [ ] Please run `./gradlew spotlessApply :tests:spotlessApply` for auto-formatting. + - [ ] Please run `./gradlew spotlessApply` for auto-formatting. diff --git a/build.gradle b/build.gradle index dd0a415..24f21e2 100644 --- a/build.gradle +++ b/build.gradle @@ -15,6 +15,10 @@ tasks.named("check") { dependsOn test } +tasks.register("spotlessApply") { + dependsOn gradle.includedBuild("tests").task(":spotlessApply") +} + def isSnapshot = version.endsWith("-SNAPSHOT") def githubTokenProvider = providers.environmentVariable("GITHUB_TOKEN").orElse("") def githubShaProvider = providers.environmentVariable("GITHUB_SHA").orElse("") diff --git a/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Mocking.kt b/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Mocking.kt index 206302f..28cafc7 100644 --- a/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Mocking.kt +++ b/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Mocking.kt @@ -27,6 +27,9 @@ package org.mockito.kotlin import kotlin.DeprecationLevel.ERROR import kotlin.reflect.KClass +import kotlin.reflect.KFunction +import kotlin.reflect.full.extensionReceiverParameter +import kotlin.reflect.jvm.javaMethod import org.mockito.MockSettings import org.mockito.MockedConstruction import org.mockito.MockedStatic @@ -230,6 +233,51 @@ inline fun mockConstruction( return Mockito.mockConstruction(T::class.java, mockInitializer) } +/** + * Creates a thread-local mock for the static methods of the class that contains this top-level + * extension function. + * + * Top-level Kotlin extension functions compile to static methods in a `*Kt` class. This helper + * simplifies creating a [MockedStatic] for them. + * + * Usage: + * ``` + * fun String.isHello(): Boolean = this == "Hello" + * + * mockExtensionFun(String::isHello).use { + * whenever("test".isHello()).thenReturn(true) + * println("test".isHello()) // "true" + * } + * ``` + * + * When using matchers, all arguments including the receiver must use matchers: + * ``` + * fun String.hasPrefix(prefix: String): Boolean = this.startsWith(prefix) + * + * mockExtensionFun(String::hasPrefix).use { + * whenever(any().hasPrefix(eq("foo"))).thenReturn(true) + * println("bar".hasPrefix("foo")) // "true" + * } + * ``` + * + * Note: member extension functions (extension functions declared inside a class) do not need this + * helper. They can be mocked by creating a regular [mock] of the containing class. + * + * @param function a reference to the top-level extension function to mock. + * @see Mockito.mockStatic + */ +fun mockExtensionFun(function: KFunction<*>): MockedStatic<*> { + requireNotNull(function.extensionReceiverParameter) { + "Expected an extension function reference, but $function has no extension receiver." + } + val declaringClass = + requireNotNull(function.javaMethod?.declaringClass) { + "Could not determine declaring class for function $function. " + + "Ensure this is a top-level extension function reference." + } + return Mockito.mockStatic(declaringClass) +} + class UseConstructor private constructor(val args: Array) { companion object { diff --git a/tests/src/test/kotlin/test/Classes.kt b/tests/src/test/kotlin/test/Classes.kt index 83bc437..d571512 100644 --- a/tests/src/test/kotlin/test/Classes.kt +++ b/tests/src/test/kotlin/test/Classes.kt @@ -262,3 +262,18 @@ object SomeObject { @JvmStatic fun aStaticMethodReturningString(): String = "Some Value" } + +// Top-level extension functions for testing mockExtensionFun + +fun String.isEqualTo(compare: String): Boolean = this == compare + +fun String.isHello(): Boolean = this == "Hello" + +fun String.isHello(mood: String): Boolean = this == "Hello" && mood == "happy" + +// Classes for member extension function test +class Foo { + fun Bar.foobar(): String = this@Foo.toString() + this.toString() +} + +class Bar diff --git a/tests/src/test/kotlin/test/MockingTest.kt b/tests/src/test/kotlin/test/MockingTest.kt index 1b59b24..2e01a2a 100644 --- a/tests/src/test/kotlin/test/MockingTest.kt +++ b/tests/src/test/kotlin/test/MockingTest.kt @@ -6,6 +6,7 @@ import com.nhaarman.expect.fail import java.io.PrintStream import java.io.Serializable import java.util.* +import kotlin.reflect.KFunction2 import kotlinx.coroutines.test.runTest import org.junit.Test import org.mockito.Mockito @@ -16,8 +17,10 @@ import org.mockito.kotlin.UseConstructor.Companion.withArguments import org.mockito.kotlin.any import org.mockito.kotlin.argumentCaptor import org.mockito.kotlin.doReturn +import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.mockConstruction +import org.mockito.kotlin.mockExtensionFun import org.mockito.kotlin.mockStatic import org.mockito.kotlin.verify import org.mockito.kotlin.whenever @@ -470,6 +473,52 @@ class MockingTest : TestBase() { } } + @Test + fun mockExtensionFun_returnsMockedValue() { + mockExtensionFun(String::isEqualTo).use { + whenever("a".isEqualTo("b")).thenReturn(true) + + expect("a".isEqualTo("b")).toBe(true) + } + } + + @Test + fun mockExtensionFun_withMatchers_returnsMockedValue() { + mockExtensionFun(String::isEqualTo).use { + whenever(any().isEqualTo(eq("b"))).thenReturn(true) + + expect("a".isEqualTo("b")).toBe(true) + expect("a".isEqualTo("c")).toBe(false) + } + } + + @Test + fun mockExtensionFun_overloaded_returnsMockedValue() { + val ref: KFunction2 = String::isHello + mockExtensionFun(ref).use { + whenever("test".isHello("sad")).thenReturn(true) + + expect("test".isHello("sad")).toBe(true) + } + } + + @Test + fun mockExtensionFun_nonExtensionFunction_throwsIllegalArgument() { + expectErrorWithMessage("has no extension receiver") on + { + mockExtensionFun(Open::stringResult) + } + } + + @Test + fun memberExtensionFunction_mockedByCreatingMockHost() { + val foo = mock() + val bar = Bar() + whenever(with(foo) { bar.foobar() }).thenReturn("mocked") + + expect(with(foo) { bar.foobar() }).toBe("mocked") + } + private interface MyInterface private open class MyClass