Skip to content

Commit dfbb958

Browse files
committed
Sema: Try to derive type witnesses before running inference
Sema can infer type witnesses for a small set of known conformances, to RawRepresentable, CaseIterable, and Differentiable. Previously, we would try to compute the type witness in this order: 1) First, via name lookup, to find an explicit nested type with the same name as an associated type. 2) Second, we would attempt inference. 3) Third, we would attempt derivation. Instead, let's do 3) before 2). This avoids circularity errors in situations where the witness can be derived, but inference fails. This breaks source compatibility with enum declarations where the raw type in the inheritance clause is a lie, and the user defines their own witnesses with mismatched types. However, I suspect this does not come up in practice, because if you don't synthesize witnesses, there is no way to access the actual raw literal values of the enum cases.
1 parent 61640fe commit dfbb958

8 files changed

+73
-40
lines changed

Diff for: lib/Sema/DerivedConformanceDifferentiable.cpp

+14-10
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
814814
}
815815

816816
/// Get or synthesize `TangentVector` struct type.
817-
static Type
817+
static std::pair<Type, TypeDecl *>
818818
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
819819
auto *parentDC = derived.getConformanceContext();
820820
auto *nominal = derived.Nominal;
@@ -824,25 +824,28 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
824824
auto *tangentStruct =
825825
getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector);
826826
if (!tangentStruct)
827-
return nullptr;
827+
return std::make_pair(nullptr, nullptr);
828+
828829
// Check and emit warnings for implicit `@noDerivative` members.
829830
checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC);
830831

831832
// Return the `TangentVector` struct type.
832-
return parentDC->mapTypeIntoContext(
833-
tangentStruct->getDeclaredInterfaceType());
833+
return std::make_pair(
834+
parentDC->mapTypeIntoContext(
835+
tangentStruct->getDeclaredInterfaceType()),
836+
tangentStruct);
834837
}
835838

836839
/// Synthesize the `TangentVector` struct type.
837-
static Type
840+
static std::pair<Type, TypeDecl *>
838841
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
839842
auto *parentDC = derived.getConformanceContext();
840843
auto *nominal = derived.Nominal;
841844

842845
// If nominal type can derive `TangentVector` as the contextual `Self` type,
843846
// return it.
844847
if (canDeriveTangentVectorAsSelf(nominal, parentDC))
845-
return parentDC->getSelfTypeInContext();
848+
return std::make_pair(parentDC->getSelfTypeInContext(), nullptr);
846849

847850
// Otherwise, get or synthesize `TangentVector` struct type.
848851
return getOrSynthesizeTangentVectorStructType(derived);
@@ -883,16 +886,17 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
883886
return nullptr;
884887
}
885888

886-
Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
889+
std::pair<Type, TypeDecl *>
890+
DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
887891
// Diagnose unknown requirements.
888892
if (requirement->getBaseName() != Context.Id_TangentVector) {
889893
Context.Diags.diagnose(requirement->getLoc(),
890894
diag::broken_differentiable_requirement);
891-
return nullptr;
895+
return std::make_pair(nullptr, nullptr);
892896
}
893897
// Diagnose conformances in disallowed contexts.
894898
if (checkAndDiagnoseDisallowedContext(requirement))
895-
return nullptr;
899+
return std::make_pair(nullptr, nullptr);
896900

897901
// Start an error diagnostic before attempting derivation.
898902
// If derivation succeeds, cancel the diagnostic.
@@ -908,5 +912,5 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
908912
}
909913

910914
// Otherwise, return nullptr.
911-
return nullptr;
915+
return std::make_pair(nullptr, nullptr);
912916
}

Diff for: lib/Sema/DerivedConformances.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ class DerivedConformance {
122122
/// Derive a Differentiable type witness for a nominal type.
123123
///
124124
/// \returns the derived member, which will also be added to the type.
125-
Type deriveDifferentiable(AssociatedTypeDecl *assocType);
125+
std::pair<Type, TypeDecl *>
126+
deriveDifferentiable(AssociatedTypeDecl *assocType);
126127

127128
/// Derive a CaseIterable requirement for an enum if it has no associated
128129
/// values for any of its cases.

Diff for: lib/Sema/TypeCheckProtocol.cpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -5579,29 +5579,30 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
55795579
llvm_unreachable("unknown derivable protocol kind");
55805580
}
55815581

5582-
Type TypeChecker::deriveTypeWitness(DeclContext *DC,
5583-
NominalTypeDecl *TypeDecl,
5584-
AssociatedTypeDecl *AssocType) {
5582+
std::pair<Type, TypeDecl *>
5583+
TypeChecker::deriveTypeWitness(DeclContext *DC,
5584+
NominalTypeDecl *TypeDecl,
5585+
AssociatedTypeDecl *AssocType) {
55855586
auto *protocol = cast<ProtocolDecl>(AssocType->getDeclContext());
55865587

55875588
auto knownKind = protocol->getKnownProtocolKind();
55885589

55895590
if (!knownKind)
5590-
return nullptr;
5591+
return std::make_pair(nullptr, nullptr);
55915592

55925593
auto Decl = DC->getInnermostDeclarationDeclContext();
55935594

55945595
DerivedConformance derived(TypeDecl->getASTContext(), Decl, TypeDecl,
55955596
protocol);
55965597
switch (*knownKind) {
55975598
case KnownProtocolKind::RawRepresentable:
5598-
return derived.deriveRawRepresentable(AssocType);
5599+
return std::make_pair(derived.deriveRawRepresentable(AssocType), nullptr);
55995600
case KnownProtocolKind::CaseIterable:
5600-
return derived.deriveCaseIterable(AssocType);
5601+
return std::make_pair(derived.deriveCaseIterable(AssocType), nullptr);
56015602
case KnownProtocolKind::Differentiable:
56025603
return derived.deriveDifferentiable(AssocType);
56035604
default:
5604-
return nullptr;
5605+
return std::make_pair(nullptr, nullptr);
56055606
}
56065607
}
56075608

Diff for: lib/Sema/TypeCheckProtocol.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,8 @@ class AssociatedTypeInference {
822822

823823
/// Compute the "derived" type witness for an associated type that is
824824
/// known to the compiler.
825-
Type computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
825+
std::pair<Type, TypeDecl *>
826+
computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
826827

827828
/// Compute a type witness without using a specific potential witness,
828829
/// e.g., using a fixed type (from a refined protocol), default type

Diff for: lib/Sema/TypeCheckProtocolInference.cpp

+27-18
Original file line numberDiff line numberDiff line change
@@ -868,32 +868,32 @@ Type AssociatedTypeInference::computeDefaultTypeWitness(
868868
return defaultType;
869869
}
870870

871-
Type AssociatedTypeInference::computeDerivedTypeWitness(
871+
std::pair<Type, TypeDecl *>
872+
AssociatedTypeInference::computeDerivedTypeWitness(
872873
AssociatedTypeDecl *assocType) {
873874
if (adoptee->hasError())
874-
return Type();
875+
return std::make_pair(Type(), nullptr);
875876

876877
// Can we derive conformances for this protocol and adoptee?
877878
NominalTypeDecl *derivingTypeDecl = adoptee->getAnyNominal();
878879
if (!DerivedConformance::derivesProtocolConformance(dc, derivingTypeDecl,
879880
proto))
880-
return Type();
881+
return std::make_pair(Type(), nullptr);
881882

882883
// Try to derive the type witness.
883-
Type derivedType =
884-
TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType);
885-
if (!derivedType)
886-
return Type();
884+
auto result = TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType);
885+
if (!result.first)
886+
return std::make_pair(Type(), nullptr);
887887

888-
// Make sure that the derived type is sane.
889-
if (checkTypeWitness(derivedType, assocType, conformance)) {
888+
// Make sure that the derived type satisfies requirements.
889+
if (checkTypeWitness(result.first, assocType, conformance)) {
890890
/// FIXME: Diagnose based on this.
891891
failedDerivedAssocType = assocType;
892-
failedDerivedWitness = derivedType;
893-
return Type();
892+
failedDerivedWitness = result.first;
893+
return std::make_pair(Type(), nullptr);
894894
}
895895

896-
return derivedType;
896+
return result;
897897
}
898898

899899
Type
@@ -908,10 +908,6 @@ AssociatedTypeInference::computeAbstractTypeWitness(
908908
if (Type defaultType = computeDefaultTypeWitness(assocType))
909909
return defaultType;
910910

911-
// If we can derive a type witness, do so.
912-
if (Type derivedType = computeDerivedTypeWitness(assocType))
913-
return derivedType;
914-
915911
// If there is a generic parameter of the named type, use that.
916912
if (auto genericSig = dc->getGenericSignatureOfContext()) {
917913
for (auto gp : genericSig->getInnermostGenericParams()) {
@@ -1876,10 +1872,23 @@ auto AssociatedTypeInference::solve(ConformanceChecker &checker)
18761872
continue;
18771873

18781874
case ResolveWitnessResult::Missing:
1879-
// Note that we haven't resolved this associated type yet.
1880-
unresolvedAssocTypes.insert(assocType);
1875+
// We did not find the witness via name lookup. Try to derive
1876+
// it below.
18811877
break;
18821878
}
1879+
1880+
// Finally, try to derive the witness if we know how.
1881+
auto derivedType = computeDerivedTypeWitness(assocType);
1882+
if (derivedType.first) {
1883+
checker.recordTypeWitness(assocType,
1884+
derivedType.first->mapTypeOutOfContext(),
1885+
derivedType.second);
1886+
continue;
1887+
}
1888+
1889+
// We failed to derive the witness. We're going to go on to try
1890+
// to infer it from potential value witnesses next.
1891+
unresolvedAssocTypes.insert(assocType);
18831892
}
18841893

18851894
// Result variable to use for returns so that we get NRVO.

Diff for: lib/Sema/TypeChecker.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,9 @@ ValueDecl *deriveProtocolRequirement(DeclContext *DC,
911911
/// Derive an implicit type witness for the given associated type in
912912
/// the conformance of the given nominal type to some known
913913
/// protocol.
914-
Type deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal,
915-
AssociatedTypeDecl *assocType);
914+
std::pair<Type, TypeDecl *>
915+
deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal,
916+
AssociatedTypeDecl *assocType);
916917

917918
/// \name Name lookup
918919
///

Diff for: test/Sema/enum_raw_representable.swift

+4-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ var doubles: [Double] = serialize([Bar.a, .b, .c])
4646
var foos: [Foo] = deserialize([1, 2, 3])
4747
var bars: [Bar] = deserialize([1.2, 3.4, 5.6])
4848

49-
// Infer RawValue from witnesses.
49+
// We reject enums where the raw type stated in the inheritance clause does not
50+
// match the types of the witnesses.
5051
enum Color : Int {
5152
case red
5253
case blue
@@ -56,11 +57,13 @@ enum Color : Int {
5657
}
5758

5859
var rawValue: Double {
60+
// expected-error@-1 {{invalid redeclaration of synthesized implementation for protocol requirement 'rawValue'}}
5961
return 1.0
6062
}
6163
}
6264

6365
var colorRaw: Color.RawValue = 7.5
66+
// expected-error@-1 {{cannot convert value of type 'Double' to specified type 'Color.RawValue' (aka 'Int')}}
6467

6568
// Mismatched case types
6669

Diff for: test/Sema/enum_raw_representable_circularity.swift

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
// This used to fail with "reference to invalid associated type 'RawValue' of type 'E'"
4+
_ = E(rawValue: 123)
5+
6+
enum E : Int {
7+
case a = 123
8+
9+
init?(rawValue: RawValue) {
10+
self = .a
11+
}
12+
}
13+

0 commit comments

Comments
 (0)