diff --git a/docs/DifferentiableProgramming.md b/docs/DifferentiableProgramming.md index 0c80821f4e030..fdf116302a19d 100644 --- a/docs/DifferentiableProgramming.md +++ b/docs/DifferentiableProgramming.md @@ -1271,6 +1271,12 @@ The synthesized `TangentVector` has the same effective access level as the original type declaration. Properties in the synthesized `TangentVector` have the same effective access level as their corresponding original properties. +The synthesized `TangentVector` adopts protocols from all `TangentVector` +conformance constraints implied by the declaration that triggers synthesis. For +example, synthesized `TangentVector`s always adopt the `AdditiveArithmetic` and +`Differentiable` protocols because the `Differentiable` protocol requires that +`TangentVector` conforms to `AdditiveArithmetic` and `Differentiable`. + The synthesized `move(along:)` method calls `move(along:)` for each pair of a differentiable variable and its corresponding property in `TangentVector`. diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index c56f3e58753b3..ae5ea062fc7d9 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -17,6 +17,8 @@ #include "CodeSynthesis.h" #include "TypeChecker.h" +#include "TypeCheckType.h" +#include "llvm/ADT/SmallPtrSet.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/Decl.h" #include "swift/AST/Expr.h" @@ -627,6 +629,49 @@ deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) { return propDecl; } +/// Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`. +/// +/// Precondition: `decl` is a nominal type decl or an extension decl. +void getInheritedProtocols(Decl *decl, SmallPtrSetImpl &protos) { + ArrayRef inheritedTypeLocs; + if (auto *nominalDecl = dyn_cast(decl)) + inheritedTypeLocs = nominalDecl->getInherited(); + else if (auto *extDecl = dyn_cast(decl)) + inheritedTypeLocs = extDecl->getInherited(); + else + llvm_unreachable("conformance is not a nominal or an extension"); + + std::function handleInheritedType; + + auto handleProto = [&](ProtocolType *proto) -> void { + proto->getDecl()->walkInheritedProtocols([&](ProtocolDecl *p) -> TypeWalker::Action { + protos.insert(p); + return TypeWalker::Action::Continue; + }); + }; + + auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void { + for (auto ty : comp->getMembers()) + handleInheritedType(ty); + }; + + handleInheritedType = [&](Type ty) -> void { + if (auto *proto = ty->getAs()) + handleProto(proto); + else if (auto *comp = ty->getAs()) + handleProtoComp(comp); + }; + + for (auto loc : inheritedTypeLocs) { + if (loc.getTypeRepr()) + handleInheritedType(TypeResolution::forStructural( + cast(decl), None, /*unboundTyOpener*/ nullptr) + .resolveType(loc.getTypeRepr())); + else + handleInheritedType(loc.getType()); + } +} + /// Return associated `TangentVector` struct for a nominal type, if it exists. /// If not, synthesize the struct. static StructDecl * @@ -646,23 +691,46 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { } // Otherwise, synthesize a new struct. - auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredInterfaceType()); - auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); - auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredInterfaceType()); - // By definition, `TangentVector` must conform to `Differentiable` and - // `AdditiveArithmetic`. - SmallVector inherited{diffableType, addArithType}; + // Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must + // inherit, by collecting all the `TangentVector` conformance requirements imposed by the + // protocols that `derived.ConformanceDecl` inherits. + // + // Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because + // the `Differentiable` protocol itself requires that its `TangentVector` conforms to + // `AdditiveArithmetic` and `Differentiable`. + llvm::SmallPtrSet tvDesiredProtos; + llvm::SmallPtrSet conformanceInheritedProtos; + getInheritedProtocols(derived.ConformanceDecl, conformanceInheritedProtos); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); + auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector); + for (auto proto : conformanceInheritedProtos) { + for (auto req : proto->getRequirementSignature()) { + if (req.getKind() != RequirementKind::Conformance) + continue; + auto *firstType = req.getFirstType()->getAs(); + if (!firstType || firstType->getAssocType() != tvAssocType) + continue; + auto tvRequiredProto = req.getSecondType()->getAs(); + if (!tvRequiredProto) + continue; + tvDesiredProtos.insert(tvRequiredProto); + } + } + SmallVector tvDesiredProtoTypeLocs; + for (auto *p : tvDesiredProtos) + tvDesiredProtoTypeLocs.push_back(TypeLoc::withoutLoc(p)); // Cache original members and their associated types for later use. SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); + auto synthesizedLoc = derived.ConformanceDecl->getEndLoc(); auto *structDecl = - new (C) StructDecl(SourceLoc(), C.Id_TangentVector, SourceLoc(), - /*Inherited*/ C.AllocateCopy(inherited), + new (C) StructDecl(synthesizedLoc, C.Id_TangentVector, synthesizedLoc, + /*Inherited*/ C.AllocateCopy(tvDesiredProtoTypeLocs), /*GenericParams*/ {}, parentDC); + structDecl->setBraces({synthesizedLoc, synthesizedLoc}); structDecl->setImplicit(); structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); diff --git a/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift b/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift index e933855276a4b..d99cf0b775565 100644 --- a/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift +++ b/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift @@ -8,7 +8,7 @@ struct GenericTangentVectorMember: Differentiable, var x: T.TangentVector } -// CHECK-AST-LABEL: internal struct GenericTangentVectorMember : Differentiable, AdditiveArithmetic where T : Differentiable +// CHECK-AST-LABEL: internal struct GenericTangentVectorMember : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} where T : Differentiable // CHECK-AST: internal var x: T.TangentVector // CHECK-AST: internal init(x: T.TangentVector) // CHECK-AST: internal typealias TangentVector = GenericTangentVectorMember @@ -62,7 +62,7 @@ final class AdditiveArithmeticClass: Add // CHECK-AST-LABEL: final internal class AdditiveArithmeticClass : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable { // CHECK-AST: final internal var x: T, y: T -// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic +// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} // CHECK-AST: } @frozen @@ -70,7 +70,7 @@ public struct FrozenStruct: Differentiable {} // CHECK-AST-LABEL: @frozen public struct FrozenStruct : Differentiable { // CHECK-AST: internal init() -// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic { +// CHECK-AST: @frozen public struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} { @usableFromInline struct UsableFromInlineStruct: Differentiable {} @@ -79,7 +79,7 @@ struct UsableFromInlineStruct: Differentiable {} // CHECK-AST: struct UsableFromInlineStruct : Differentiable { // CHECK-AST: internal init() // CHECK-AST: @usableFromInline -// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic { +// CHECK-AST: struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} { // Test property wrappers. @@ -96,7 +96,7 @@ struct WrappedPropertiesStruct: Differentiable { } // CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable { -// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic { +// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} { // CHECK-AST: internal var x: Float.TangentVector // CHECK-AST: internal var y: Float.TangentVector // CHECK-AST: internal var z: Float.TangentVector @@ -111,9 +111,48 @@ class WrappedPropertiesClass: Differentiable { } // CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable { -// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic { +// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} { // CHECK-AST: internal var x: Float.TangentVector // CHECK-AST: internal var y: Float.TangentVector // CHECK-AST: internal var z: Float.TangentVector // CHECK-AST: } // CHECK-AST: } + +protocol TangentVectorMustBeEncodable: Differentiable where TangentVector: Encodable {} + +struct AutoDeriveEncodableTV1: TangentVectorMustBeEncodable { + var x: Float +} + +// CHECK-AST-LABEL: internal struct AutoDeriveEncodableTV1 : TangentVectorMustBeEncodable { +// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} { + +struct AutoDeriveEncodableTV2 { + var x: Float +} + +extension AutoDeriveEncodableTV2: TangentVectorMustBeEncodable {} + +// CHECK-AST-LABEL: extension AutoDeriveEncodableTV2 : TangentVectorMustBeEncodable { +// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} { + +protocol TangentVectorP: Differentiable { + var requirement: Int { get } +} + +protocol TangentVectorConstrained: Differentiable where TangentVector: TangentVectorP {} + +struct StructWithTangentVectorConstrained: TangentVectorConstrained { + var x: Float +} + +// `extension StructWithTangentVectorConstrained.TangentVector: TangentVectorP` gives +// "error: type 'StructWithTangentVectorConstrained.TangentVector' does not conform to protocol 'TangentVectorP'", +// maybe because it typechecks the conformance before seeing the extension. But this roundabout way +// of stating the same thing works. +extension TangentVectorP where Self == StructWithTangentVectorConstrained.TangentVector { + var requirement: Int { 42 } +} + +// CHECK-AST-LABEL: internal struct StructWithTangentVectorConstrained : TangentVectorConstrained { +// CHECK-AST: internal struct TangentVector : {{(TangentVectorP, Differentiable, AdditiveArithmetic)|(TangentVectorP, AdditiveArithmetic, Differentiable)|(Differentiable, TangentVectorP, AdditiveArithmetic)|(AdditiveArithmetic, TangentVectorP, Differentiable)|(Differentiable, AdditiveArithmetic, TangentVectorP)|(AdditiveArithmetic, Differentiable, TangentVectorP)}} { diff --git a/test/AutoDiff/Sema/DerivedConformances/derived_differentiable_diagnostics.swift b/test/AutoDiff/Sema/DerivedConformances/derived_differentiable_diagnostics.swift new file mode 100644 index 0000000000000..889f489e339c2 --- /dev/null +++ b/test/AutoDiff/Sema/DerivedConformances/derived_differentiable_diagnostics.swift @@ -0,0 +1,15 @@ +// RUN: %target-swift-frontend -typecheck -verify %s + +import _Differentiation + +protocol TangentVectorP: Differentiable { + // expected-note @+1 {{protocol requires property 'requirement' with type 'Int'; do you want to add a stub?}} + var requirement: Int { get } +} + +protocol TangentVectorConstrained: Differentiable where TangentVector: TangentVectorP {} + +struct StructWithTangentVectorConstrained: TangentVectorConstrained { + var x: Float +} +// expected-error @-1 {{type 'StructWithTangentVectorConstrained.TangentVector' does not conform to protocol 'TangentVectorP'}}