Skip to content

Commit dd6da26

Browse files
authored
Merge pull request #13509 from dotty-staging/fix-spotted-leopards
Fix specifity comparison for extensions in polymorphic givens
2 parents 8d9542c + 3bcd253 commit dd6da26

File tree

5 files changed

+89
-15
lines changed

5 files changed

+89
-15
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,11 @@ object Contexts {
526526
final def withOwner(owner: Symbol): Context =
527527
if (owner ne this.owner) fresh.setOwner(owner) else this
528528

529+
final def withTyperState(typerState: TyperState): Context =
530+
if typerState ne this.typerState then fresh.setTyperState(typerState) else this
531+
529532
final def withUncommittedTyperState: Context =
530-
val ts = typerState.uncommittedAncestor
531-
if ts ne typerState then fresh.setTyperState(ts) else this
533+
withTyperState(typerState.uncommittedAncestor)
532534

533535
final def withProperty[T](key: Key[T], value: Option[T]): Context =
534536
if (property(key) == value) this
@@ -599,8 +601,8 @@ object Contexts {
599601
this.scope = newScope
600602
this
601603
def setTyperState(typerState: TyperState): this.type = { this.typerState = typerState; this }
602-
def setNewTyperState(): this.type = setTyperState(typerState.fresh().setCommittable(true))
603-
def setExploreTyperState(): this.type = setTyperState(typerState.fresh().setCommittable(false))
604+
def setNewTyperState(): this.type = setTyperState(typerState.fresh(committable = true))
605+
def setExploreTyperState(): this.type = setTyperState(typerState.fresh(committable = false))
604606
def setReporter(reporter: Reporter): this.type = setTyperState(typerState.fresh().setReporter(reporter))
605607
def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) }
606608
def setGadt(gadt: GadtConstraint): this.type =

compiler/src/dotty/tools/dotc/core/TyperState.scala

+13-10
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,12 @@ class TyperState() {
103103
this
104104

105105
/** A fresh typer state with the same constraint as this one. */
106-
def fresh(reporter: Reporter = StoreReporter(this.reporter)): TyperState =
106+
def fresh(reporter: Reporter = StoreReporter(this.reporter),
107+
committable: Boolean = this.isCommittable): TyperState =
107108
util.Stats.record("TyperState.fresh")
108109
TyperState().init(this, this.constraint)
109110
.setReporter(reporter)
110-
.setCommittable(this.isCommittable)
111+
.setCommittable(committable)
111112

112113
/** The uninstantiated variables */
113114
def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars
@@ -182,24 +183,25 @@ class TyperState() {
182183

183184
/** Integrate the constraints from `that` into this TyperState.
184185
*
185-
* @pre If `that` is committable, it must not contain any type variable which
186+
* @pre If `this` and `that` are committable, `that` must not contain any type variable which
186187
* does not exist in `this` (in other words, all its type variables must
187188
* be owned by a common parent of `this` and `that`).
188189
*/
189-
def mergeConstraintWith(that: TyperState)(using Context): Unit =
190+
def mergeConstraintWith(that: TyperState)(using Context): this.type =
191+
if this eq that then return this
192+
190193
that.ensureNotConflicting(constraint)
191194

192-
val comparingCtx =
193-
if ctx.typerState == this then ctx
194-
else ctx.fresh.setTyperState(this)
195+
val comparingCtx = ctx.withTyperState(this)
195196

196-
comparing(typeComparer =>
197+
inContext(comparingCtx)(comparing(typeComparer =>
197198
val other = that.constraint
198199
val res = other.domainLambdas.forall(tl =>
199200
// Integrate the type lambdas from `other`
200201
constraint.contains(tl) || other.isRemovable(tl) || {
201202
val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv }
202-
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
203+
if this.isCommittable then
204+
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
203205
typeComparer.addToConstraint(tl, tvars)
204206
}) &&
205207
// Integrate the additional constraints on type variables from `other`
@@ -220,10 +222,11 @@ class TyperState() {
220222
)
221223
)
222224
assert(res || ctx.reporter.errorsReported, i"cannot merge $constraint with $other.")
223-
)(using comparingCtx)
225+
))
224226

225227
for tl <- constraint.domainLambdas do
226228
if constraint.isRemovable(tl) then constraint = constraint.remove(tl)
229+
this
227230
end mergeConstraintWith
228231

229232
/** Take ownership of `tvar`.

compiler/src/dotty/tools/dotc/typer/Implicits.scala

+23-1
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,29 @@ trait Implicits:
11771177
// compare the extension methods instead of their wrappers.
11781178
def stripExtension(alt: SearchSuccess) = methPart(stripApply(alt.tree)).tpe
11791179
(stripExtension(alt1), stripExtension(alt2)) match
1180-
case (ref1: TermRef, ref2: TermRef) => diff = compare(ref1, ref2)
1180+
case (ref1: TermRef, ref2: TermRef) =>
1181+
// ref1 and ref2 might refer to type variables owned by
1182+
// alt1.tstate and alt2.tstate respectively, to compare the
1183+
// alternatives correctly we need a TyperState that includes
1184+
// constraints from both sides, see
1185+
// tests/*/extension-specificity2.scala for test cases.
1186+
val constraintsIn1 = alt1.tstate.constraint ne ctx.typerState.constraint
1187+
val constraintsIn2 = alt2.tstate.constraint ne ctx.typerState.constraint
1188+
def exploreState(alt: SearchSuccess): TyperState =
1189+
alt.tstate.fresh(committable = false)
1190+
val comparisonState =
1191+
if constraintsIn1 && constraintsIn2 then
1192+
exploreState(alt1).mergeConstraintWith(alt2.tstate)
1193+
else if constraintsIn1 then
1194+
exploreState(alt1)
1195+
else if constraintsIn2 then
1196+
exploreState(alt2)
1197+
else
1198+
ctx.typerState
1199+
1200+
diff = inContext(ctx.withTyperState(comparisonState)) {
1201+
compare(ref1, ref2)
1202+
}
11811203
case _ =>
11821204
if diff < 0 then alt2
11831205
else if diff > 0 then alt1
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
trait Bla1[A]:
2+
extension (x: A) def foo(y: A): Int
3+
trait Bla2[A]:
4+
extension (x: A) def foo(y: A): Int
5+
6+
def test =
7+
given bla1[T <: Int]: Bla1[T] = ???
8+
given bla2[S <: Int]: Bla2[S] = ???
9+
10+
1.foo(2) // error: never extension is more specific than the other
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
trait Foo[F[_]]:
2+
extension [A](fa: F[A])
3+
def foo[B](fb: F[B]): Int
4+
5+
def test1 =
6+
// Simplified from https://github.com/typelevel/spotted-leopards/issues/2
7+
given listFoo: Foo[List] with
8+
extension [A](fa: List[A])
9+
def foo[B](fb: List[B]): Int = 1
10+
11+
given functionFoo[T]: Foo[[A] =>> T => A] with
12+
extension [A](fa: T => A)
13+
def foo[B](fb: T => B): Int = 2
14+
15+
val x = List(1, 2).foo(List(3, 4))
16+
assert(x == 1, x)
17+
18+
def test2 =
19+
// This test case would fail if we used `wildApprox` on the method types
20+
// instead of using the correct typer state.
21+
trait Bar1[A]:
22+
extension (x: A => A) def bar(y: A): Int
23+
trait Bar2:
24+
extension (x: Int => 1) def bar(y: Int): Int
25+
26+
given bla1[T]: Bar1[T] with
27+
extension (x: T => T) def bar(y: T): Int = 1
28+
given bla2: Bar2 with
29+
extension (x: Int => 1) def bar(y: Int): Int = 2
30+
31+
val f: Int => 1 = x => 1
32+
val x = f.bar(1)
33+
assert(x == 2, x)
34+
35+
@main def Test =
36+
test1
37+
test2

0 commit comments

Comments
 (0)