Skip to content

inherit required protocols during TangentVector synthesis #34893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/DifferentiableProgramming.md
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,12 @@ The synthesized `TangentVector` has the same effective access level as the
original type declaration. Properties in the synthesized `TangentVector` have
the same effective access level as their corresponding original properties.

The synthesized `TangentVector` adopts protocols from all `TangentVector`
conformance constraints implied by the declaration that triggers synthesis. For
example, synthesized `TangentVector`s always adopt the `AdditiveArithmetic` and
`Differentiable` protocols because the `Differentiable` protocol requires that
`TangentVector` conforms to `AdditiveArithmetic` and `Differentiable`.

The synthesized `move(along:)` method calls `move(along:)` for each pair of a
differentiable variable and its corresponding property in `TangentVector`.

Expand Down
86 changes: 77 additions & 9 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "TypeCheckType.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
Expand Down Expand Up @@ -627,6 +629,49 @@ deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) {
return propDecl;
}

/// Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
///
/// Precondition: `decl` is a nominal type decl or an extension decl.
void getInheritedProtocols(Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
ArrayRef<TypeLoc> inheritedTypeLocs;
if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
inheritedTypeLocs = nominalDecl->getInherited();
else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
inheritedTypeLocs = extDecl->getInherited();
else
llvm_unreachable("conformance is not a nominal or an extension");

std::function<void(Type)> handleInheritedType;

auto handleProto = [&](ProtocolType *proto) -> void {
proto->getDecl()->walkInheritedProtocols([&](ProtocolDecl *p) -> TypeWalker::Action {
protos.insert(p);
return TypeWalker::Action::Continue;
});
};

auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
for (auto ty : comp->getMembers())
handleInheritedType(ty);
};

handleInheritedType = [&](Type ty) -> void {
if (auto *proto = ty->getAs<ProtocolType>())
handleProto(proto);
else if (auto *comp = ty->getAs<ProtocolCompositionType>())
handleProtoComp(comp);
};

for (auto loc : inheritedTypeLocs) {
if (loc.getTypeRepr())
handleInheritedType(TypeResolution::forStructural(
cast<DeclContext>(decl), None, /*unboundTyOpener*/ nullptr)
.resolveType(loc.getTypeRepr()));
else
handleInheritedType(loc.getType());
}
}

/// Return associated `TangentVector` struct for a nominal type, if it exists.
/// If not, synthesize the struct.
static StructDecl *
Expand All @@ -646,23 +691,46 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
}

// Otherwise, synthesize a new struct.
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredInterfaceType());
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredInterfaceType());

// By definition, `TangentVector` must conform to `Differentiable` and
// `AdditiveArithmetic`.
SmallVector<TypeLoc, 4> inherited{diffableType, addArithType};
// Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
// inherit, by collecting all the `TangentVector` conformance requirements imposed by the
// protocols that `derived.ConformanceDecl` inherits.
//
// Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
// the `Differentiable` protocol itself requires that its `TangentVector` conforms to
// `AdditiveArithmetic` and `Differentiable`.
llvm::SmallPtrSet<ProtocolType *, 4> tvDesiredProtos;
llvm::SmallPtrSet<ProtocolDecl *, 4> conformanceInheritedProtos;
getInheritedProtocols(derived.ConformanceDecl, conformanceInheritedProtos);
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector);
for (auto proto : conformanceInheritedProtos) {
for (auto req : proto->getRequirementSignature()) {
if (req.getKind() != RequirementKind::Conformance)
continue;
auto *firstType = req.getFirstType()->getAs<DependentMemberType>();
if (!firstType || firstType->getAssocType() != tvAssocType)
continue;
auto tvRequiredProto = req.getSecondType()->getAs<ProtocolType>();
if (!tvRequiredProto)
continue;
tvDesiredProtos.insert(tvRequiredProto);
}
}
SmallVector<TypeLoc, 4> tvDesiredProtoTypeLocs;
for (auto *p : tvDesiredProtos)
tvDesiredProtoTypeLocs.push_back(TypeLoc::withoutLoc(p));

// Cache original members and their associated types for later use.
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);

auto synthesizedLoc = derived.ConformanceDecl->getEndLoc();
auto *structDecl =
new (C) StructDecl(SourceLoc(), C.Id_TangentVector, SourceLoc(),
/*Inherited*/ C.AllocateCopy(inherited),
new (C) StructDecl(synthesizedLoc, C.Id_TangentVector, synthesizedLoc,
/*Inherited*/ C.AllocateCopy(tvDesiredProtoTypeLocs),
/*GenericParams*/ {}, parentDC);
structDecl->setBraces({synthesizedLoc, synthesizedLoc});
structDecl->setImplicit();
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct GenericTangentVectorMember<T: Differentiable>: Differentiable,
var x: T.TangentVector
}

// CHECK-AST-LABEL: internal struct GenericTangentVectorMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
// CHECK-AST-LABEL: internal struct GenericTangentVectorMember<T> : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} where T : Differentiable
// CHECK-AST: internal var x: T.TangentVector
// CHECK-AST: internal init(x: T.TangentVector)
// CHECK-AST: internal typealias TangentVector = GenericTangentVectorMember<T>
Expand Down Expand Up @@ -62,15 +62,15 @@ final class AdditiveArithmeticClass<T: AdditiveArithmetic & Differentiable>: Add

// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
// CHECK-AST: final internal var x: T, y: T
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}}
// CHECK-AST: }

@frozen
public struct FrozenStruct: Differentiable {}

// CHECK-AST-LABEL: @frozen public struct FrozenStruct : Differentiable {
// CHECK-AST: internal init()
// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic {
// CHECK-AST: @frozen public struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {

@usableFromInline
struct UsableFromInlineStruct: Differentiable {}
Expand All @@ -79,7 +79,7 @@ struct UsableFromInlineStruct: Differentiable {}
// CHECK-AST: struct UsableFromInlineStruct : Differentiable {
// CHECK-AST: internal init()
// CHECK-AST: @usableFromInline
// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
// CHECK-AST: struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {

// Test property wrappers.

Expand All @@ -96,7 +96,7 @@ struct WrappedPropertiesStruct: Differentiable {
}

// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
// CHECK-AST: internal var x: Float.TangentVector
// CHECK-AST: internal var y: Float.TangentVector
// CHECK-AST: internal var z: Float.TangentVector
Expand All @@ -111,9 +111,48 @@ class WrappedPropertiesClass: Differentiable {
}

// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
// CHECK-AST: internal var x: Float.TangentVector
// CHECK-AST: internal var y: Float.TangentVector
// CHECK-AST: internal var z: Float.TangentVector
// CHECK-AST: }
// CHECK-AST: }

protocol TangentVectorMustBeEncodable: Differentiable where TangentVector: Encodable {}

struct AutoDeriveEncodableTV1: TangentVectorMustBeEncodable {
var x: Float
}

// CHECK-AST-LABEL: internal struct AutoDeriveEncodableTV1 : TangentVectorMustBeEncodable {
// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} {

struct AutoDeriveEncodableTV2 {
var x: Float
}

extension AutoDeriveEncodableTV2: TangentVectorMustBeEncodable {}

// CHECK-AST-LABEL: extension AutoDeriveEncodableTV2 : TangentVectorMustBeEncodable {
// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} {

protocol TangentVectorP: Differentiable {
var requirement: Int { get }
}

protocol TangentVectorConstrained: Differentiable where TangentVector: TangentVectorP {}

struct StructWithTangentVectorConstrained: TangentVectorConstrained {
var x: Float
}

// `extension StructWithTangentVectorConstrained.TangentVector: TangentVectorP` gives
// "error: type 'StructWithTangentVectorConstrained.TangentVector' does not conform to protocol 'TangentVectorP'",
// maybe because it typechecks the conformance before seeing the extension. But this roundabout way
// of stating the same thing works.
extension TangentVectorP where Self == StructWithTangentVectorConstrained.TangentVector {
var requirement: Int { 42 }
}

// CHECK-AST-LABEL: internal struct StructWithTangentVectorConstrained : TangentVectorConstrained {
// CHECK-AST: internal struct TangentVector : {{(TangentVectorP, Differentiable, AdditiveArithmetic)|(TangentVectorP, AdditiveArithmetic, Differentiable)|(Differentiable, TangentVectorP, AdditiveArithmetic)|(AdditiveArithmetic, TangentVectorP, Differentiable)|(Differentiable, AdditiveArithmetic, TangentVectorP)|(AdditiveArithmetic, Differentiable, TangentVectorP)}} {
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: %target-swift-frontend -typecheck -verify %s

import _Differentiation

protocol TangentVectorP: Differentiable {
// expected-note @+1 {{protocol requires property 'requirement' with type 'Int'; do you want to add a stub?}}
var requirement: Int { get }
}

protocol TangentVectorConstrained: Differentiable where TangentVector: TangentVectorP {}

struct StructWithTangentVectorConstrained: TangentVectorConstrained {
var x: Float
}
// expected-error @-1 {{type 'StructWithTangentVectorConstrained.TangentVector' does not conform to protocol 'TangentVectorP'}}