Skip to content

Commit 4e048bf

Browse files
author
marcrasi
authored
Merge pull request #34893 from marcrasi/make-it-automatically-inherit
inherit required protocols during TangentVector synthesis
2 parents c724153 + 6378c3e commit 4e048bf

File tree

4 files changed

+143
-15
lines changed

4 files changed

+143
-15
lines changed

docs/DifferentiableProgramming.md

+6
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,12 @@ The synthesized `TangentVector` has the same effective access level as the
12731273
original type declaration. Properties in the synthesized `TangentVector` have
12741274
the same effective access level as their corresponding original properties.
12751275

1276+
The synthesized `TangentVector` adopts protocols from all `TangentVector`
1277+
conformance constraints implied by the declaration that triggers synthesis. For
1278+
example, synthesized `TangentVector`s always adopt the `AdditiveArithmetic` and
1279+
`Differentiable` protocols because the `Differentiable` protocol requires that
1280+
`TangentVector` conforms to `AdditiveArithmetic` and `Differentiable`.
1281+
12761282
The synthesized `move(along:)` method calls `move(along:)` for each pair of a
12771283
differentiable variable and its corresponding property in `TangentVector`.
12781284

lib/Sema/DerivedConformanceDifferentiable.cpp

+77-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
#include "CodeSynthesis.h"
1919
#include "TypeChecker.h"
20+
#include "TypeCheckType.h"
21+
#include "llvm/ADT/SmallPtrSet.h"
2022
#include "swift/AST/AutoDiff.h"
2123
#include "swift/AST/Decl.h"
2224
#include "swift/AST/Expr.h"
@@ -627,6 +629,49 @@ deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) {
627629
return propDecl;
628630
}
629631

632+
/// Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
633+
///
634+
/// Precondition: `decl` is a nominal type decl or an extension decl.
635+
void getInheritedProtocols(Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
636+
ArrayRef<TypeLoc> inheritedTypeLocs;
637+
if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
638+
inheritedTypeLocs = nominalDecl->getInherited();
639+
else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
640+
inheritedTypeLocs = extDecl->getInherited();
641+
else
642+
llvm_unreachable("conformance is not a nominal or an extension");
643+
644+
std::function<void(Type)> handleInheritedType;
645+
646+
auto handleProto = [&](ProtocolType *proto) -> void {
647+
proto->getDecl()->walkInheritedProtocols([&](ProtocolDecl *p) -> TypeWalker::Action {
648+
protos.insert(p);
649+
return TypeWalker::Action::Continue;
650+
});
651+
};
652+
653+
auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
654+
for (auto ty : comp->getMembers())
655+
handleInheritedType(ty);
656+
};
657+
658+
handleInheritedType = [&](Type ty) -> void {
659+
if (auto *proto = ty->getAs<ProtocolType>())
660+
handleProto(proto);
661+
else if (auto *comp = ty->getAs<ProtocolCompositionType>())
662+
handleProtoComp(comp);
663+
};
664+
665+
for (auto loc : inheritedTypeLocs) {
666+
if (loc.getTypeRepr())
667+
handleInheritedType(TypeResolution::forStructural(
668+
cast<DeclContext>(decl), None, /*unboundTyOpener*/ nullptr)
669+
.resolveType(loc.getTypeRepr()));
670+
else
671+
handleInheritedType(loc.getType());
672+
}
673+
}
674+
630675
/// Return associated `TangentVector` struct for a nominal type, if it exists.
631676
/// If not, synthesize the struct.
632677
static StructDecl *
@@ -646,23 +691,46 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
646691
}
647692

648693
// Otherwise, synthesize a new struct.
649-
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
650-
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredInterfaceType());
651-
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
652-
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredInterfaceType());
653694

654-
// By definition, `TangentVector` must conform to `Differentiable` and
655-
// `AdditiveArithmetic`.
656-
SmallVector<TypeLoc, 4> inherited{diffableType, addArithType};
695+
// Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
696+
// inherit, by collecting all the `TangentVector` conformance requirements imposed by the
697+
// protocols that `derived.ConformanceDecl` inherits.
698+
//
699+
// Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
700+
// the `Differentiable` protocol itself requires that its `TangentVector` conforms to
701+
// `AdditiveArithmetic` and `Differentiable`.
702+
llvm::SmallPtrSet<ProtocolType *, 4> tvDesiredProtos;
703+
llvm::SmallPtrSet<ProtocolDecl *, 4> conformanceInheritedProtos;
704+
getInheritedProtocols(derived.ConformanceDecl, conformanceInheritedProtos);
705+
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
706+
auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector);
707+
for (auto proto : conformanceInheritedProtos) {
708+
for (auto req : proto->getRequirementSignature()) {
709+
if (req.getKind() != RequirementKind::Conformance)
710+
continue;
711+
auto *firstType = req.getFirstType()->getAs<DependentMemberType>();
712+
if (!firstType || firstType->getAssocType() != tvAssocType)
713+
continue;
714+
auto tvRequiredProto = req.getSecondType()->getAs<ProtocolType>();
715+
if (!tvRequiredProto)
716+
continue;
717+
tvDesiredProtos.insert(tvRequiredProto);
718+
}
719+
}
720+
SmallVector<TypeLoc, 4> tvDesiredProtoTypeLocs;
721+
for (auto *p : tvDesiredProtos)
722+
tvDesiredProtoTypeLocs.push_back(TypeLoc::withoutLoc(p));
657723

658724
// Cache original members and their associated types for later use.
659725
SmallVector<VarDecl *, 8> diffProperties;
660726
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
661727

728+
auto synthesizedLoc = derived.ConformanceDecl->getEndLoc();
662729
auto *structDecl =
663-
new (C) StructDecl(SourceLoc(), C.Id_TangentVector, SourceLoc(),
664-
/*Inherited*/ C.AllocateCopy(inherited),
730+
new (C) StructDecl(synthesizedLoc, C.Id_TangentVector, synthesizedLoc,
731+
/*Inherited*/ C.AllocateCopy(tvDesiredProtoTypeLocs),
665732
/*GenericParams*/ {}, parentDC);
733+
structDecl->setBraces({synthesizedLoc, synthesizedLoc});
666734
structDecl->setImplicit();
667735
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
668736

test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift

+45-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct GenericTangentVectorMember<T: Differentiable>: Differentiable,
88
var x: T.TangentVector
99
}
1010

11-
// CHECK-AST-LABEL: internal struct GenericTangentVectorMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
11+
// CHECK-AST-LABEL: internal struct GenericTangentVectorMember<T> : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} where T : Differentiable
1212
// CHECK-AST: internal var x: T.TangentVector
1313
// CHECK-AST: internal init(x: T.TangentVector)
1414
// CHECK-AST: internal typealias TangentVector = GenericTangentVectorMember<T>
@@ -62,15 +62,15 @@ final class AdditiveArithmeticClass<T: AdditiveArithmetic & Differentiable>: Add
6262

6363
// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
6464
// CHECK-AST: final internal var x: T, y: T
65-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
65+
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}}
6666
// CHECK-AST: }
6767

6868
@frozen
6969
public struct FrozenStruct: Differentiable {}
7070

7171
// CHECK-AST-LABEL: @frozen public struct FrozenStruct : Differentiable {
7272
// CHECK-AST: internal init()
73-
// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic {
73+
// CHECK-AST: @frozen public struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
7474

7575
@usableFromInline
7676
struct UsableFromInlineStruct: Differentiable {}
@@ -79,7 +79,7 @@ struct UsableFromInlineStruct: Differentiable {}
7979
// CHECK-AST: struct UsableFromInlineStruct : Differentiable {
8080
// CHECK-AST: internal init()
8181
// CHECK-AST: @usableFromInline
82-
// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
82+
// CHECK-AST: struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
8383

8484
// Test property wrappers.
8585

@@ -96,7 +96,7 @@ struct WrappedPropertiesStruct: Differentiable {
9696
}
9797

9898
// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
99-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
99+
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
100100
// CHECK-AST: internal var x: Float.TangentVector
101101
// CHECK-AST: internal var y: Float.TangentVector
102102
// CHECK-AST: internal var z: Float.TangentVector
@@ -111,9 +111,48 @@ class WrappedPropertiesClass: Differentiable {
111111
}
112112

113113
// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
114-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
114+
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
115115
// CHECK-AST: internal var x: Float.TangentVector
116116
// CHECK-AST: internal var y: Float.TangentVector
117117
// CHECK-AST: internal var z: Float.TangentVector
118118
// CHECK-AST: }
119119
// CHECK-AST: }
120+
121+
protocol TangentVectorMustBeEncodable: Differentiable where TangentVector: Encodable {}
122+
123+
struct AutoDeriveEncodableTV1: TangentVectorMustBeEncodable {
124+
var x: Float
125+
}
126+
127+
// CHECK-AST-LABEL: internal struct AutoDeriveEncodableTV1 : TangentVectorMustBeEncodable {
128+
// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} {
129+
130+
struct AutoDeriveEncodableTV2 {
131+
var x: Float
132+
}
133+
134+
extension AutoDeriveEncodableTV2: TangentVectorMustBeEncodable {}
135+
136+
// CHECK-AST-LABEL: extension AutoDeriveEncodableTV2 : TangentVectorMustBeEncodable {
137+
// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} {
138+
139+
protocol TangentVectorP: Differentiable {
140+
var requirement: Int { get }
141+
}
142+
143+
protocol TangentVectorConstrained: Differentiable where TangentVector: TangentVectorP {}
144+
145+
struct StructWithTangentVectorConstrained: TangentVectorConstrained {
146+
var x: Float
147+
}
148+
149+
// `extension StructWithTangentVectorConstrained.TangentVector: TangentVectorP` gives
150+
// "error: type 'StructWithTangentVectorConstrained.TangentVector' does not conform to protocol 'TangentVectorP'",
151+
// maybe because it typechecks the conformance before seeing the extension. But this roundabout way
152+
// of stating the same thing works.
153+
extension TangentVectorP where Self == StructWithTangentVectorConstrained.TangentVector {
154+
var requirement: Int { 42 }
155+
}
156+
157+
// CHECK-AST-LABEL: internal struct StructWithTangentVectorConstrained : TangentVectorConstrained {
158+
// CHECK-AST: internal struct TangentVector : {{(TangentVectorP, Differentiable, AdditiveArithmetic)|(TangentVectorP, AdditiveArithmetic, Differentiable)|(Differentiable, TangentVectorP, AdditiveArithmetic)|(AdditiveArithmetic, TangentVectorP, Differentiable)|(Differentiable, AdditiveArithmetic, TangentVectorP)|(AdditiveArithmetic, Differentiable, TangentVectorP)}} {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %target-swift-frontend -typecheck -verify %s
2+
3+
import _Differentiation
4+
5+
protocol TangentVectorP: Differentiable {
6+
// expected-note @+1 {{protocol requires property 'requirement' with type 'Int'; do you want to add a stub?}}
7+
var requirement: Int { get }
8+
}
9+
10+
protocol TangentVectorConstrained: Differentiable where TangentVector: TangentVectorP {}
11+
12+
struct StructWithTangentVectorConstrained: TangentVectorConstrained {
13+
var x: Float
14+
}
15+
// expected-error @-1 {{type 'StructWithTangentVectorConstrained.TangentVector' does not conform to protocol 'TangentVectorP'}}

0 commit comments

Comments
 (0)