Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ContributeSubComponent: Support returning Super Type #83

3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
**Unreleased**
--------------

### Deprecated
- `ClassReference.functions` has been deprecated in favor of `ClassReference.memberFunctions` and `ClassReference.declaredMemberFunctions`

0.4.0
-----

Expand Down
4 changes: 4 additions & 0 deletions compiler-utils/api/compiler-utils.api
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,11 @@ public abstract class com/squareup/anvil/compiler/internal/reference/ClassRefere
public abstract fun getClassId ()Lorg/jetbrains/kotlin/name/ClassId;
public abstract fun getConstructors ()Ljava/util/List;
public abstract fun getContainingFileAsJavaFile ()Ljava/io/File;
public abstract fun getDeclaredMemberFunctions ()Ljava/util/List;
public abstract fun getFqName ()Lorg/jetbrains/kotlin/name/FqName;
public abstract fun getFunctions ()Ljava/util/List;
protected abstract fun getInnerClassesAndObjects ()Ljava/util/List;
public final fun getMemberFunctions ()Ljava/util/List;
public abstract fun getModule ()Lcom/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor;
public final fun getPackageFqName ()Lorg/jetbrains/kotlin/name/FqName;
public abstract fun getProperties ()Ljava/util/List;
Expand Down Expand Up @@ -330,6 +332,7 @@ public final class com/squareup/anvil/compiler/internal/reference/ClassReference
public final fun getClazz ()Lorg/jetbrains/kotlin/descriptors/ClassDescriptor;
public fun getConstructors ()Ljava/util/List;
public fun getContainingFileAsJavaFile ()Ljava/io/File;
public fun getDeclaredMemberFunctions ()Ljava/util/List;
public fun getFqName ()Lorg/jetbrains/kotlin/name/FqName;
public fun getFunctions ()Ljava/util/List;
public fun getModule ()Lcom/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor;
Expand All @@ -355,6 +358,7 @@ public final class com/squareup/anvil/compiler/internal/reference/ClassReference
public final fun getClazz ()Lorg/jetbrains/kotlin/psi/KtClassOrObject;
public fun getConstructors ()Ljava/util/List;
public fun getContainingFileAsJavaFile ()Ljava/io/File;
public fun getDeclaredMemberFunctions ()Ljava/util/List;
public fun getFqName ()Lorg/jetbrains/kotlin/name/FqName;
public fun getFunctions ()Ljava/util/List;
public fun getModule ()Lcom/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,30 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
public val packageFqName: FqName get() = classId.packageFqName

public abstract val constructors: List<MemberFunctionReference>

@Deprecated(
"renamed to `declaredMemberFunctions`. " +
"Use `memberFunctions` to include inherited functions.",
replaceWith = ReplaceWith("declaredMemberFunctions"),
)
public abstract val functions: List<MemberFunctionReference>

/**
* All functions that are declared in this class, including overrides.
* This list does not include inherited functions that are not overridden by this class.
*/
public abstract val declaredMemberFunctions: List<MemberFunctionReference>

/**
* All functions declared in this class or any of its super-types.
*/
public val memberFunctions: List<MemberFunctionReference> by lazy(NONE) {
declaredMemberFunctions + directSuperTypeReferences()
.flatMap { it.asClassReference().memberFunctions }
}

public abstract val properties: List<MemberPropertyReference>

public abstract val typeParameters: List<TypeParameterReference>

protected abstract val innerClassesAndObjects: List<ClassReference>
Expand Down Expand Up @@ -146,7 +168,13 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
clazz.containingFileAsJavaFile()
}

override val functions: List<MemberFunctionReference.Psi> by lazy(NONE) {
@Deprecated(
"renamed to `declaredMemberFunctions`. Use `memberFunctions` to include inherited functions.",
replaceWith = ReplaceWith("declaredMemberFunctions"),
)
override val functions: List<MemberFunctionReference.Psi> get() = declaredMemberFunctions

override val declaredMemberFunctions: List<MemberFunctionReference.Psi> by lazy(NONE) {
clazz
.children
.filterIsInstance<KtClassBody>()
Expand Down Expand Up @@ -263,9 +291,14 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
)
}

override val functions: List<MemberFunctionReference.Descriptor> by lazy(NONE) {
@Deprecated(
"renamed to `declaredMemberFunctions`. Use `memberFunctions` to include inherited functions.",
replaceWith = ReplaceWith("declaredMemberFunctions"),
)
override val functions: List<MemberFunctionReference.Descriptor> get() = declaredMemberFunctions
override val declaredMemberFunctions: List<MemberFunctionReference.Descriptor> by lazy(NONE) {
clazz.unsubstitutedMemberScope
.getContributedDescriptors(kindFilter = DescriptorKindFilter.FUNCTIONS)
.getDescriptorsFiltered(kindFilter = DescriptorKindFilter.FUNCTIONS)
.filterIsInstance<FunctionDescriptor>()
.filterNot { it is ConstructorDescriptor }
.map { it.toFunctionReference(this) }
Expand All @@ -279,10 +312,8 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
clazz.unsubstitutedMemberScope
.getDescriptorsFiltered(kindFilter = DescriptorKindFilter.VARIABLES)
.filterIsInstance<PropertyDescriptor>()
.filter {
// Remove inherited properties that aren't overridden in this class.
it.kind == DECLARATION
}
// Remove inherited properties that aren't overridden in this class.
.filter { it.kind == DECLARATION }
.map { it.toPropertyReference(this) }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,13 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
.filter { it.isAbstract }
.toList()

if (functions.size != 1 || functions[0].returnType?.resolve()
?.resolveKSClassDeclaration() != this
) {
val returnType = functions.singleOrNull()?.returnType?.resolve()?.resolveKSClassDeclaration()
if (returnType != this) {

val isReturnSuperType = returnType != null && this.superTypes
.any { type -> type.resolve().resolveKSClassDeclaration() == returnType }
if (isReturnSuperType) return

throw KspAnvilException(
node = factory,
message = "A factory must have exactly one abstract function returning the " +
Expand Down Expand Up @@ -325,7 +329,7 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
)
}

val functions = componentInterface.functions
val functions = componentInterface.memberFunctions
.filter { it.returnType().asClassReference() == this }

if (functions.size >= 2) {
Expand Down Expand Up @@ -378,7 +382,7 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
)
}

val functions = factory.functions
val functions = factory.memberFunctions
.let { functions ->
if (factory.isInterface()) {
functions
Expand All @@ -387,7 +391,13 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
}
}

if (functions.size != 1 || functions[0].returnType().asClassReference() != this) {
val returnType = functions.singleOrNull()?.returnType()?.asClassReference()
if (returnType != this) {

val isReturnSuperType = returnType != null && this.directSuperTypeReferences()
.any { it.asClassReference() == returnType }
if (isReturnSuperType) return

throw AnvilCompilationExceptionClassReference(
classReference = factory,
message = "A factory must have exactly one abstract function returning the " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ internal class ContributesSubcomponentHandlerGenerator(
)
}

val functions = componentInterface.functions
val functions = componentInterface.memberFunctions
.filter { it.isAbstract() && it.visibility() == PUBLIC }
.filter {
val returnType = it.returnType().asClassReference()
Expand Down Expand Up @@ -333,9 +333,14 @@ internal class ContributesSubcomponentHandlerGenerator(
)
}

val createComponentFunctions = factory.functions
val createComponentFunctions = factory.memberFunctions
.filter { it.isAbstract() }
.filter { it.returnType().asClassReference().fqName == contributionFqName }
.filter {
val returnType = it.returnType().asClassReference()
returnType.fqName == contributionFqName ||
contribution.clazz.directSuperTypeReferences()
.any { type -> type.asClassReference() == returnType }
}

if (createComponentFunctions.size != 1) {
throw AnvilCompilationExceptionClassReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,14 @@ internal class KspContributesSubcomponentHandlerSymbolProcessor(
} else {
function.asMemberOf(implementingType).returnTypeOrNull()
}
returnTypeToCheck
?.resolveKSClassDeclaration()
?.toClassName() == contributionClassName

if (returnTypeToCheck != null) {
val returnTypeClassName = returnTypeToCheck.resolveKSClassDeclaration()?.toClassName()
returnTypeClassName == contributionClassName ||
returnTypeToCheck.isAssignableFrom(contribution.clazz.asType(emptyList()))
} else {
false
}
}
.toList()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ internal object AssistedFactoryCodeGen : AnvilApplicabilityChecker {
val assistedFunctions = allSuperTypeClassReferences(includeSelf = true)
.distinctBy { it.fqName }
.flatMap { clazz ->
clazz.functions
clazz.declaredMemberFunctions
.filter {
it.isAbstract() &&
(it.visibility() == Visibility.PUBLIC || it.visibility() == Visibility.PROTECTED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {
.forEach { clazz ->
(clazz.companionObjects() + clazz)
.asSequence()
.flatMap { it.functions }
.flatMap { it.declaredMemberFunctions }
.filter { it.isAnnotatedWith(daggerBindsFqName) }
.also { functions ->
assertNoDuplicateFunctions(clazz, functions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ internal object ProvidesMethodFactoryCodeGen : AnvilApplicabilityChecker {
.asSequence()

val functions = types
.flatMap { it.functions }
.flatMap { it.declaredMemberFunctions }
.filter { it.isAnnotatedWith(daggerProvidesFqName) }
.onEach { function ->
checkFunctionIsNotAbstract(clazz, function)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ContributesSubcomponentGeneratorTest(
}
}

@Test fun `there is a hint for contributed subcomponents with an interace factory`() {
@Test fun `there is a hint for contributed subcomponents with an interface factory`() {
compile(
"""
package com.squareup.test
Expand Down Expand Up @@ -493,6 +493,84 @@ class ContributesSubcomponentGeneratorTest(
}
}

@Test fun `a factory function may be defined in a super interface`() {
compile(
"""
package com.squareup.test

import com.squareup.anvil.annotations.ContributesSubcomponent
import com.squareup.anvil.annotations.ContributesSubcomponent.Factory
import com.squareup.anvil.annotations.ContributesTo
import com.squareup.anvil.annotations.MergeComponent

@ContributesSubcomponent(Any::class, parentScope = Unit::class)
interface SubcomponentInterface {
@ContributesTo(Unit::class)
interface AnyParentComponent {
fun createFactory(): ComponentFactory
}

interface Creator {
fun createComponent(): SubcomponentInterface
}

@Factory
interface ComponentFactory : Creator
}

@MergeComponent(Unit::class)
interface ComponentInterface
""",
mode = mode,
) {
assertThat(subcomponentInterface.hintSubcomponent?.java).isEqualTo(subcomponentInterface)
assertThat(subcomponentInterface.hintSubcomponentParentScope).isEqualTo(Unit::class)

assertThat(subcomponentInterface.componentFactoryInterface.methods.map { it.name })
.containsExactly("createComponent")
}
}

@Test fun `a factory function may returns the component super type`() {
compile(
"""
package com.squareup.test

import com.squareup.anvil.annotations.ContributesSubcomponent
import com.squareup.anvil.annotations.ContributesSubcomponent.Factory
import com.squareup.anvil.annotations.ContributesTo
import com.squareup.anvil.annotations.MergeComponent

interface BaseSubcomponentInterface {
interface Factory {
fun createComponent(): BaseSubcomponentInterface
}
}

@ContributesSubcomponent(Any::class, parentScope = Unit::class)
interface SubcomponentInterface : BaseSubcomponentInterface {
@Factory
interface ComponentFactory: BaseSubcomponentInterface.Factory

@ContributesTo(Unit::class)
interface ParentComponent {
fun createFactory(): ComponentFactory
}
}

@MergeComponent(Unit::class)
interface ComponentInterface
""",
mode = mode,
) {
assertThat(subcomponentInterface.hintSubcomponent?.java).isEqualTo(subcomponentInterface)
assertThat(subcomponentInterface.hintSubcomponentParentScope).isEqualTo(Unit::class)

assertThat(subcomponentInterface.componentFactoryInterface.methods.map { it.name })
.containsExactly("createComponent")
}
}

@Test
fun `using Dagger's @Subcomponent_Factory is an error`() {
compile(
Expand Down Expand Up @@ -616,6 +694,9 @@ class ContributesSubcomponentGeneratorTest(
private val Class<*>.parentComponentInterface: Class<*>
get() = classLoader.loadClass("$canonicalName\$AnyParentComponent")

private val Class<*>.componentFactoryInterface: Class<*>
get() = classLoader.loadClass("$canonicalName\$ComponentFactory")

private val JvmCompilationResult.subcomponentInterface1: Class<*>
get() = classLoader.loadClass("com.squareup.test.SubcomponentInterface1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,13 @@ class ClassReferenceTest {
assertThat(psiRef.isGenericClass()).isFalse()
assertThat(descriptorRef.isGenericClass()).isFalse()

assertThat(psiRef.functions.single().returnType().isGenericType()).isFalse()
assertThat(
descriptorRef.functions.single { it.name == "string" }
psiRef.declaredMemberFunctions.single()
.returnType()
.isGenericType(),
).isFalse()
assertThat(
descriptorRef.declaredMemberFunctions.single { it.name == "string" }
.returnType()
.isGenericType(),
).isFalse()
Expand All @@ -223,17 +227,25 @@ class ClassReferenceTest {
).isTrue()
}
"SomeClass3" -> {
assertThat(psiRef.functions.single().returnType().isGenericType()).isTrue()
assertThat(
descriptorRef.functions.single { it.name == "list" }
psiRef.declaredMemberFunctions.single()
.returnType()
.isGenericType(),
).isTrue()
assertThat(
descriptorRef.declaredMemberFunctions.single { it.name == "list" }
.returnType()
.isGenericType(),
).isTrue()
}
"SomeClass4" -> {
assertThat(psiRef.functions.single().returnType().isGenericType()).isTrue()
assertThat(
descriptorRef.functions.single { it.name == "list" }
psiRef.declaredMemberFunctions.single()
.returnType()
.isGenericType(),
).isTrue()
assertThat(
descriptorRef.declaredMemberFunctions.single { it.name == "list" }
.returnType()
.isGenericType(),
).isTrue()
Expand Down
Loading
Loading