diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index 6329938..908a2b2 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -360,6 +360,110 @@ class BurstKotlinPluginTest { ) } + /** Confirm that inline function declarations are assigned parents. */ + @Test + fun burstValuesWithInlineFunctions() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import app.cash.burst.burstValues + import kotlin.test.Test + + @Burst + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test( + greeting: () -> String = burstValues( + { "Hello" }, + { "Yo" }, + ), + subject: () -> String = burstValues( + { "Burst" }, + { "World" }, + ), + ) { + log += "${'$'}{greeting()} ${'$'}{subject()}" + } + } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseInstance = baseClass.constructors.single().newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + baseClass.getMethod("test").invoke(baseInstance) + baseClass.getMethod("test_0_1").invoke(baseInstance) + baseClass.getMethod("test_1_0").invoke(baseInstance) + baseClass.getMethod("test_1_1").invoke(baseInstance) + assertThat(baseLog).containsExactly( + "Hello Burst", + "Hello World", + "Yo Burst", + "Yo World", + ) + } + + /** Confirm that inline class declarations are assigned parents. */ + @Test + fun burstValuesWithInlineClassDeclarations() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import app.cash.burst.burstValues + import kotlin.test.Test + + @Burst + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test( + greeting: StringFactory = burstValues( + StringFactory { "Hello" }, + StringFactory { "Yo" }, + ), + subject: StringFactory = burstValues( + StringFactory { "Burst" }, + StringFactory { "World" }, + ), + ) { + log += "${'$'}{greeting.create()} ${'$'}{subject.create()}" + } + } + + fun interface StringFactory { + fun create(): String + } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseInstance = baseClass.constructors.single().newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + baseClass.getMethod("test").invoke(baseInstance) + baseClass.getMethod("test_0_1").invoke(baseInstance) + baseClass.getMethod("test_1_0").invoke(baseInstance) + baseClass.getMethod("test_1_1").invoke(baseInstance) + assertThat(baseLog).containsExactly( + "Hello Burst", + "Hello World", + "Yo Burst", + "Yo World", + ) + } + private val Class<*>.testSuffixes: List get() = methods.mapNotNull { when { diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt index bfffb6f..b03cd96 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt @@ -19,6 +19,7 @@ import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.backend.js.utils.valueArguments import org.jetbrains.kotlin.ir.declarations.IrClass +import org.jetbrains.kotlin.ir.declarations.IrDeclarationParent import org.jetbrains.kotlin.ir.declarations.IrEnumEntry import org.jetbrains.kotlin.ir.declarations.IrValueParameter import org.jetbrains.kotlin.ir.expressions.IrCall @@ -61,11 +62,12 @@ private class EnumValueArgument( } private class BurstValuesArgument( + private val declarationParent: IrDeclarationParent, override val isDefault: Boolean, override val name: String, private val value: IrExpression, ) : Argument { - override fun expression() = value.deepCopyWithSymbols() + override fun expression() = value.deepCopyWithSymbols(declarationParent) } /** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */ @@ -103,6 +105,7 @@ internal fun IrPluginContext.allPossibleArguments( unexpectedParameter(parameter) } +@UnsafeDuringIrConstructionAPI private fun burstValuesArguments( parameter: IrValueParameter, burstApisCall: IrCall, @@ -111,6 +114,7 @@ private fun burstValuesArguments( val defaultExpression = burstApisCall.valueArguments[0] ?: unexpectedParameter(parameter) add( BurstValuesArgument( + declarationParent = parameter.parent, isDefault = true, name = defaultExpression.suggestedName() ?: "0", value = defaultExpression, @@ -121,6 +125,7 @@ private fun burstValuesArguments( val varargExpression = element as? IrExpression ?: unexpectedParameter(parameter) add( BurstValuesArgument( + declarationParent = parameter.parent, isDefault = false, name = varargExpression.suggestedName() ?: (index + 1).toString(), value = varargExpression,