Skip to content

Commit a8d44bc

Browse files
authored
Merge pull request #32558 from xedin/fix-req-conformace-assessment
[ConstraintSystem] Adjust recording of "fixed" requirements to avoid conflicts
2 parents ebccb92 + 866835d commit a8d44bc

11 files changed

+192
-41
lines changed

lib/Sema/CSDiagnostics.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,32 @@ bool MissingConformanceFailure::diagnoseAsError() {
446446
}
447447
}
448448

449+
// If the problem has been (unambiguously) determined to be related
450+
// to one of of the standard comparison operators and argument is
451+
// enum with associated values, let's produce a tailored note which
452+
// says that conformances for enums with associated values can't be
453+
// synthesized.
454+
if (isStandardComparisonOperator(anchor)) {
455+
auto isEnumWithAssociatedValues = [](Type type) -> bool {
456+
if (auto *enumType = type->getAs<EnumType>())
457+
return !enumType->getDecl()->hasOnlyCasesWithoutAssociatedValues();
458+
return false;
459+
};
460+
461+
// Limit this to `Equatable` and `Comparable` protocols for now.
462+
auto *protocol = getRHS()->castTo<ProtocolType>()->getDecl();
463+
if (isEnumWithAssociatedValues(getLHS()) &&
464+
(protocol->isSpecificProtocol(KnownProtocolKind::Equatable) ||
465+
protocol->isSpecificProtocol(KnownProtocolKind::Comparable))) {
466+
if (RequirementFailure::diagnoseAsError()) {
467+
auto opName = getOperatorName(anchor);
468+
emitDiagnostic(diag::no_binary_op_overload_for_enum_with_payload,
469+
opName->str());
470+
return true;
471+
}
472+
}
473+
}
474+
449475
if (diagnoseAsAmbiguousOperatorRef())
450476
return true;
451477

lib/Sema/CSSimplify.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -2339,7 +2339,7 @@ ConstraintSystem::matchExistentialTypes(Type type1, Type type2,
23392339

23402340
auto *fix = fixRequirementFailure(*this, type1, type2, locator);
23412341
if (fix && !recordFix(fix)) {
2342-
recordFixedRequirement(type1, RequirementKind::Layout, type2);
2342+
recordFixedRequirement(getConstraintLocator(locator), type2);
23432343
return getTypeMatchSuccess();
23442344
}
23452345
}
@@ -3875,14 +3875,13 @@ bool ConstraintSystem::repairFailures(
38753875
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality))
38763876
break;
38773877

3878-
auto reqElt = elt.castTo<LocatorPathElt::AnyRequirement>();
3879-
auto reqKind = reqElt.getRequirementKind();
3878+
auto *reqLoc = getConstraintLocator(locator);
38803879

3881-
if (hasFixedRequirement(lhs, reqKind, rhs))
3880+
if (isFixedRequirement(reqLoc, rhs))
38823881
return true;
38833882

38843883
if (auto *fix = fixRequirementFailure(*this, lhs, rhs, anchor, path)) {
3885-
recordFixedRequirement(lhs, reqKind, rhs);
3884+
recordFixedRequirement(reqLoc, rhs);
38863885
conversionsOrFixes.push_back(fix);
38873886
}
38883887
break;
@@ -5404,7 +5403,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
54045403
auto protocolTy = protocol->getDeclaredType();
54055404

54065405
// If this conformance has been fixed already, let's just consider this done.
5407-
if (hasFixedRequirement(type, RequirementKind::Conformance, protocolTy))
5406+
if (isFixedRequirement(getConstraintLocator(locator), protocolTy))
54085407
return SolutionKind::Solved;
54095408

54105409
// If this is a generic requirement let's try to record that
@@ -5491,7 +5490,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
54915490
auto impact = assessRequirementFailureImpact(*this, rawType, locator);
54925491
if (!recordFix(fix, impact)) {
54935492
// Record this conformance requirement as "fixed".
5494-
recordFixedRequirement(type, RequirementKind::Conformance,
5493+
recordFixedRequirement(getConstraintLocator(anchor, path),
54955494
protocolTy);
54965495
return SolutionKind::Solved;
54975496
}

lib/Sema/ConstraintSystem.cpp

+112-7
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,8 @@ Type ConstraintSystem::openUnboundGenericType(
669669
openGeneric(decl->getDeclContext(), decl->getGenericSignature(), locator,
670670
replacements);
671671

672+
recordOpenedTypes(locator, replacements);
673+
672674
if (parentTy) {
673675
auto subs = parentTy->getContextSubstitutions(decl->getDeclContext());
674676
for (auto pair : subs) {
@@ -2806,12 +2808,6 @@ static void diagnoseOperatorAmbiguity(ConstraintSystem &cs,
28062808
auto *anchor = castToExpr(locator->getAnchor());
28072809
auto *applyExpr = cast<ApplyExpr>(cs.getParentExpr(anchor));
28082810

2809-
auto isNameOfStandardComparisonOperator = [](Identifier opName) -> bool {
2810-
return opName.is("==") || opName.is("!=") || opName.is("===") ||
2811-
opName.is("!==") || opName.is("<") || opName.is(">") ||
2812-
opName.is("<=") || opName.is(">=");
2813-
};
2814-
28152811
auto isEnumWithAssociatedValues = [](Type type) -> bool {
28162812
if (auto *enumType = type->getAs<EnumType>())
28172813
return !enumType->getDecl()->hasOnlyCasesWithoutAssociatedValues();
@@ -2834,7 +2830,7 @@ static void diagnoseOperatorAmbiguity(ConstraintSystem &cs,
28342830
.highlight(lhs->getSourceRange())
28352831
.highlight(rhs->getSourceRange());
28362832

2837-
if (isNameOfStandardComparisonOperator(operatorName) &&
2833+
if (isStandardComparisonOperator(binaryOp->getFn()) &&
28382834
isEnumWithAssociatedValues(lhsType)) {
28392835
DE.diagnose(applyExpr->getLoc(),
28402836
diag::no_binary_op_overload_for_enum_with_payload,
@@ -4257,6 +4253,15 @@ bool constraints::isPatternMatchingOperator(Expr *expr) {
42574253
return isOperator(expr, "~=");
42584254
}
42594255

4256+
bool constraints::isStandardComparisonOperator(Expr *expr) {
4257+
if (auto opName = getOperatorName(expr)) {
4258+
return opName->is("==") || opName->is("!=") || opName->is("===") ||
4259+
opName->is("!==") || opName->is("<") || opName->is(">") ||
4260+
opName->is("<=") || opName->is(">=");
4261+
}
4262+
return false;
4263+
}
4264+
42604265
bool constraints::isOperatorArgument(ConstraintLocator *locator,
42614266
StringRef expectedOperator) {
42624267
if (!locator->findLast<LocatorPathElt::ApplyArgToParam>())
@@ -4839,3 +4844,103 @@ SourceLoc constraints::getLoc(ASTNode anchor) {
48394844
SourceRange constraints::getSourceRange(ASTNode anchor) {
48404845
return anchor.getSourceRange();
48414846
}
4847+
4848+
static Optional<Requirement> getRequirement(ConstraintSystem &cs,
4849+
ConstraintLocator *reqLocator) {
4850+
auto reqLoc = reqLocator->getLastElementAs<LocatorPathElt::AnyRequirement>();
4851+
if (!reqLoc)
4852+
return None;
4853+
4854+
if (reqLoc->isConditionalRequirement()) {
4855+
auto path = reqLocator->getPath();
4856+
auto *typeReqLoc =
4857+
cs.getConstraintLocator(reqLocator->getAnchor(), path.drop_back());
4858+
4859+
auto conformances = cs.getCheckedConformances();
4860+
auto result = llvm::find_if(
4861+
conformances,
4862+
[&typeReqLoc](
4863+
const std::pair<ConstraintLocator *, ProtocolConformanceRef>
4864+
&conformance) { return conformance.first == typeReqLoc; });
4865+
assert(result != conformances.end());
4866+
4867+
auto conformance = result->second;
4868+
assert(conformance.isConcrete());
4869+
4870+
return conformance.getConditionalRequirements()[reqLoc->getIndex()];
4871+
}
4872+
4873+
if (auto openedGeneric =
4874+
reqLocator->findLast<LocatorPathElt::OpenedGeneric>()) {
4875+
auto signature = openedGeneric->getSignature();
4876+
return signature->getRequirements()[reqLoc->getIndex()];
4877+
}
4878+
4879+
return None;
4880+
}
4881+
4882+
static Optional<std::pair<GenericTypeParamType *, RequirementKind>>
4883+
getRequirementInfo(ConstraintSystem &cs, ConstraintLocator *reqLocator) {
4884+
auto requirement = getRequirement(cs, reqLocator);
4885+
if (!requirement)
4886+
return None;
4887+
4888+
auto *GP = requirement->getFirstType()->getAs<GenericTypeParamType>();
4889+
if (!GP)
4890+
return None;
4891+
4892+
auto path = reqLocator->getPath();
4893+
auto iter = path.rbegin();
4894+
auto openedGeneric =
4895+
reqLocator->findLast<LocatorPathElt::OpenedGeneric>(iter);
4896+
assert(openedGeneric);
4897+
4898+
auto newPath = path.drop_back(iter - path.rbegin() + 1);
4899+
auto *baseLoc = cs.getConstraintLocator(reqLocator->getAnchor(), newPath);
4900+
4901+
auto openedTypes = cs.getOpenedTypes();
4902+
auto substitutions = llvm::find_if(
4903+
openedTypes,
4904+
[&baseLoc](
4905+
const std::pair<ConstraintLocator *, ArrayRef<OpenedType>> &entry) {
4906+
return entry.first == baseLoc;
4907+
});
4908+
4909+
if (substitutions == openedTypes.end())
4910+
return None;
4911+
4912+
auto replacement =
4913+
llvm::find_if(substitutions->second, [&GP](const OpenedType &entry) {
4914+
auto *typeVar = entry.second;
4915+
return typeVar->getImpl().getGenericParameter() == GP;
4916+
});
4917+
4918+
if (replacement == substitutions->second.end())
4919+
return None;
4920+
4921+
auto *repr = cs.getRepresentative(replacement->second);
4922+
return std::make_pair(repr->getImpl().getGenericParameter(),
4923+
requirement->getKind());
4924+
}
4925+
4926+
bool ConstraintSystem::isFixedRequirement(ConstraintLocator *reqLocator,
4927+
Type requirementTy) {
4928+
if (auto reqInfo = getRequirementInfo(*this, reqLocator)) {
4929+
auto *GP = reqInfo->first;
4930+
auto reqKind = static_cast<unsigned>(reqInfo->second);
4931+
return FixedRequirements.count(
4932+
std::make_tuple(GP, reqKind, requirementTy.getPointer()));
4933+
}
4934+
4935+
return false;
4936+
}
4937+
4938+
void ConstraintSystem::recordFixedRequirement(ConstraintLocator *reqLocator,
4939+
Type requirementTy) {
4940+
if (auto reqInfo = getRequirementInfo(*this, reqLocator)) {
4941+
auto *GP = reqInfo->first;
4942+
auto reqKind = static_cast<unsigned>(reqInfo->second);
4943+
FixedRequirements.insert(
4944+
std::make_tuple(GP, reqKind, requirementTy.getPointer()));
4945+
}
4946+
}

lib/Sema/ConstraintSystem.h

+23-13
Original file line numberDiff line numberDiff line change
@@ -1949,20 +1949,13 @@ class ConstraintSystem {
19491949

19501950
/// The list of all generic requirements fixed along the current
19511951
/// solver path.
1952-
using FixedRequirement = std::tuple<TypeBase *, RequirementKind, TypeBase *>;
1953-
SmallVector<FixedRequirement, 4> FixedRequirements;
1952+
using FixedRequirement =
1953+
std::tuple<GenericTypeParamType *, unsigned, TypeBase *>;
1954+
llvm::SmallSetVector<FixedRequirement, 4> FixedRequirements;
19541955

1955-
bool hasFixedRequirement(Type lhs, RequirementKind kind, Type rhs) {
1956-
auto reqInfo = std::make_tuple(lhs.getPointer(), kind, rhs.getPointer());
1957-
return llvm::any_of(
1958-
FixedRequirements,
1959-
[&reqInfo](const FixedRequirement &entry) { return entry == reqInfo; });
1960-
}
1961-
1962-
void recordFixedRequirement(Type lhs, RequirementKind kind, Type rhs) {
1963-
FixedRequirements.push_back(
1964-
std::make_tuple(lhs.getPointer(), kind, rhs.getPointer()));
1965-
}
1956+
bool isFixedRequirement(ConstraintLocator *reqLocator, Type requirementTy);
1957+
void recordFixedRequirement(ConstraintLocator *reqLocator,
1958+
Type requirementTy);
19661959

19671960
/// A mapping from constraint locators to the opened existential archetype
19681961
/// used for the 'self' of an existential type.
@@ -3548,6 +3541,19 @@ class ConstraintSystem {
35483541
const DeclRefExpr *base = nullptr,
35493542
OpenedTypeMap *replacements = nullptr);
35503543

3544+
/// Retrieve a list of conformances established along the current solver path.
3545+
ArrayRef<std::pair<ConstraintLocator *, ProtocolConformanceRef>>
3546+
getCheckedConformances() const {
3547+
return CheckedConformances;
3548+
}
3549+
3550+
/// Retrieve a list of generic parameter types solver has "opened" (replaced
3551+
/// with a type variable) along the current path.
3552+
ArrayRef<std::pair<ConstraintLocator *, ArrayRef<OpenedType>>>
3553+
getOpenedTypes() const {
3554+
return OpenedTypes;
3555+
}
3556+
35513557
private:
35523558
/// Adjust the constraint system to accomodate the given selected overload, and
35533559
/// recompute the type of the referenced declaration.
@@ -5165,6 +5171,10 @@ bool isArgumentOfReferenceEqualityOperator(ConstraintLocator *locator);
51655171
/// pattern-matching operator `~=`
51665172
bool isPatternMatchingOperator(Expr *expr);
51675173

5174+
/// Determine whether given expression is a reference to a
5175+
/// "standard" comparison operator such as "==", "!=", ">" etc.
5176+
bool isStandardComparisonOperator(Expr *expr);
5177+
51685178
/// If given expression references operator overlaod(s)
51695179
/// extract and produce name of the operator.
51705180
Optional<Identifier> getOperatorName(Expr *expr);

test/AutoDiff/Sema/differentiable_func_type.swift

+8-4
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,21 @@ extension Vector: Differentiable where T: Differentiable {
175175
mutating func move(along direction: TangentVector) { fatalError() }
176176
}
177177

178-
// expected-note@+1 2 {{candidate requires that 'Int' conform to 'Differentiable' (requirement specified as 'T' == 'Differentiable')}}
178+
// expected-note@+1 2 {{found this candidate}}
179179
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}
180180

181-
// expected-note @+5 2 {{candidate requires that 'Int' conform to 'Differentiable' (requirement specified as 'T' == 'Differentiable')}}
181+
// expected-note @+5 2 {{found this candidate}}
182182
// expected-error @+4 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
183183
// expected-error @+3 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
184184
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
185185
// expected-error @+1 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
186186
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Vector<T>) -> Vector<U>) {}
187187

188188
func nondiff(x: Vector<Int>) -> Vector<Int> {}
189+
190+
// TODO(diagnostics): Ambiguity notes for two following calls should talk about `T` and `U` both not conforming to `Differentiable`
191+
// but we currently have to way to coalesce notes multiple fixes in to a single note.
192+
189193
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGeneric'}}
190194
inferredConformancesGeneric(nondiff)
191195
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGenericLinear'}}
@@ -210,10 +214,10 @@ extension Linear: Differentiable where T: Differentiable, T == T.TangentVector {
210214
typealias TangentVector = Self
211215
}
212216

213-
// expected-note @+1 2 {{candidate requires that 'Int' conform to 'Differentiable' (requirement specified as 'T' == 'Differentiable')}}
217+
// expected-note @+1 2 {{found this candidate}}
214218
func inferredConformancesGeneric<T, U>(_: @differentiable (Linear<T>) -> Linear<U>) {}
215219

216-
// expected-note @+1 2 {{candidate requires that 'Int' conform to 'Differentiable' (requirement specified as 'T' == 'Differentiable')}}
220+
// expected-note @+1 2 {{found this candidate}}
217221
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Linear<T>) -> Linear<U>) {}
218222

219223
func nondiff(x: Linear<Int>) -> Linear<Int> {}

test/Constraints/diagnostics.swift

+7-3
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ enum AssocTest {
687687
case one(Int)
688688
}
689689

690-
if AssocTest.one(1) == AssocTest.one(1) {} // expected-error{{binary operator '==' cannot be applied to two 'AssocTest' operands}}
690+
if AssocTest.one(1) == AssocTest.one(1) {} // expected-error{{referencing operator function '==' on 'Equatable' requires that 'AssocTest' conform to 'Equatable'}}
691691
// expected-note @-1 {{binary operator '==' cannot be synthesized for enums with associated values}}
692692

693693

@@ -1100,9 +1100,13 @@ func rdar17170728() {
11001100
}
11011101

11021102
let _ = [i, j, k].reduce(0 as Int?) {
1103+
// expected-error@-1 3 {{cannot convert value of type 'Int?' to expected element type 'Int'}}
11031104
$0 && $1 ? $0 + $1 : ($0 ? $0 : ($1 ? $1 : nil))
1104-
// expected-error@-1 {{binary operator '+' cannot be applied to two 'Int?' operands}}
1105-
// expected-error@-2 4 {{optional type 'Int?' cannot be used as a boolean; test for '!= nil' instead}}
1105+
// expected-error@-1 2 {{type 'Int' cannot be used as a boolean; test for '!= 0' instead}}
1106+
// expected-error@-2 {{value of optional type 'Int?' must be unwrapped to a value of type 'Int'}}
1107+
// expected-error@-3 2 {{optional type 'Int?' cannot be used as a boolean; test for '!= nil' instead}}
1108+
// expected-note@-4:16 {{coalesce using '??' to provide a default when the optional value contains 'nil'}}
1109+
// expected-note@-5:16 {{force-unwrap using '!' to abort execution if the optional value contains 'nil'}}
11061110
}
11071111
}
11081112

test/Constraints/operator.swift

+1
Original file line numberDiff line numberDiff line change
@@ -288,5 +288,6 @@ func rdar_62054241() {
288288

289289
func test(_ arr: [Foo]) -> [Foo] {
290290
return arr.sorted(by: <) // expected-error {{no exact matches in reference to operator function '<'}}
291+
// expected-note@-1 {{found candidate with type '(Foo, Foo) -> Bool'}}
291292
}
292293
}

test/Generics/deduction.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,7 @@ protocol Addable {
247247
static func +(x: Self, y: Self) -> Self
248248
}
249249
func addAddables<T : Addable, U>(_ x: T, y: T, u: U) -> T {
250-
// FIXME(diagnostics): This should report the "no exact matches" diagnostic.
251-
u + u // expected-error{{referencing operator function '+' on 'RangeReplaceableCollection' requires that 'U' conform to 'RangeReplaceableCollection'}}
250+
u + u // expected-error{{binary operator '+' cannot be applied to two 'U' operands}}
252251
return x+y
253252
}
254253

test/Sema/struct_equatable_hashable.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ struct StructWithoutExplicitConformance {
114114
}
115115

116116
func structWithoutExplicitConformance() {
117-
if StructWithoutExplicitConformance(a: 1, b: "b") == StructWithoutExplicitConformance(a: 2, b: "a") { } // expected-error{{binary operator '==' cannot be applied to two 'StructWithoutExplicitConformance' operands}}
117+
// This diagnostic is about `Equatable` because it's considered the best possible solution among other ones for operator `==`.
118+
if StructWithoutExplicitConformance(a: 1, b: "b") == StructWithoutExplicitConformance(a: 2, b: "a") { } // expected-error{{referencing operator function '==' on 'Equatable' requires that 'StructWithoutExplicitConformance' conform to 'Equatable'}}
118119
}
119120

120121
// Structs with non-hashable/equatable stored properties don't derive conformance.

0 commit comments

Comments
 (0)