Skip to content

Commit af061e2

Browse files
authored
Fix synthesis of mirrors for GADT with dependent type parameters (#26080)
This extends the current logic in Synthesizer.scala to instead of doing a single replacement do it until a fixed point is reached. Fixes #23774 ## How much have you relied on LLM-based tools in this contribution? Extensively, for exploring the code base and determining where the issue was. ## How was the solution tested? New automated tests
1 parent 7441a44 commit af061e2

3 files changed

Lines changed: 117 additions & 2 deletions

File tree

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,14 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
598598
resType <:< target
599599
val tparams = poly.paramRefs
600600
val variances = childClass.typeParams.map(_.paramVarianceSign)
601-
val instanceTypes = tparams.lazyZip(variances).map: (tparam, variance) =>
602-
TypeComparer.instanceType(tparam, fromBelow = variance < 0, Widen.Unions)
601+
@tailrec def fixInstances(cur: List[Type]): List[Type] =
602+
val next = cur.mapConserve(_.substParams(poly, cur))
603+
if next eq cur then next else fixInstances(next)
604+
val instanceTypes = {
605+
val types0 = tparams.lazyZip(variances).map: (tparam, variance) =>
606+
TypeComparer.instanceType(tparam, fromBelow = variance < 0, Widen.Unions)
607+
fixInstances(types0)
608+
}
603609
val instanceType = resType.substParams(poly, instanceTypes)
604610
// this is broken in tests/run/i13332intersection.scala,
605611
// because type parameters are not correctly inferred.

tests/pos/i23774.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
trait Iterable[+A]
2+
3+
enum Expr {
4+
case Case[T, C <: Iterable[T]]()
5+
}
6+
object Expr {
7+
val m = summon[scala.deriving.Mirror.SumOf[Expr]]
8+
summon[m.MirroredElemTypes =:= Expr.Case[Any, Iterable[Any]] *: EmptyTuple]
9+
}
10+
11+
enum Trait {
12+
case Upcast[Base, Case <: Base]()
13+
}
14+
15+
object Trait {
16+
val m = summon[scala.deriving.Mirror.SumOf[Trait]]
17+
summon[m.MirroredElemTypes =:= Trait.Upcast[Any, Any] *: EmptyTuple]
18+
}
19+
20+
enum MoreDeps {
21+
case Case[T, C <: Iterable[T], D <: Iterable[C]]()
22+
}
23+
object MoreDeps {
24+
val m = summon[scala.deriving.Mirror.SumOf[MoreDeps]]
25+
summon[m.MirroredElemTypes =:= MoreDeps.Case[Any, Iterable[Any], Iterable[Iterable[Any]]] *: EmptyTuple]
26+
}
27+
28+
enum Circular {
29+
case Case[C <: Iterable[D], D <: C]()
30+
}
31+
object Circular {
32+
val m = summon[scala.deriving.Mirror.SumOf[Circular]]
33+
summon[m.MirroredElemTypes =:= Circular.Case[Iterable[Nothing], Nothing] *: EmptyTuple]
34+
}
35+

tests/pos/i23774typeclass.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import scala.compiletime.{erasedValue, summonInline}
2+
import scala.deriving.Mirror
3+
4+
enum Expr[+T]:
5+
case UpcastToIterable[T, C <: Iterable[T]](v: Expr[C]) extends Expr[Iterable[T]]
6+
case Seq[T](elements: Expr[T]*) extends Expr[scala.Seq[T]]
7+
case Const(value: T)
8+
9+
trait Fold[E]:
10+
def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc
11+
12+
object Fold:
13+
private inline def summonAll[Elems <: Tuple]: List[Fold[?]] =
14+
inline erasedValue[Elems] match
15+
case _: (h *: tail) => summonInline[Fold[h]] :: summonAll[tail]
16+
case _: EmptyTuple => Nil
17+
18+
final class Leaf[E] extends Fold[E]:
19+
def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc = acc
20+
21+
given [T: Fold as fold] => Fold[Seq[T]] = new Fold[Seq[T]] {
22+
def apply[Acc](acc: Acc, expr: Seq[T], f: [t] => (Acc, Expr[t]) => Acc): Acc =
23+
expr.foldLeft(acc)((a, e) => fold(a, e, f))
24+
}
25+
26+
given Fold[EmptyTuple] = new Fold[EmptyTuple]:
27+
def apply[Acc](acc: Acc, expr: EmptyTuple, f: [t] => (Acc, Expr[t]) => Acc): Acc = acc
28+
29+
given [H: Fold as h, T <: Tuple: Fold as t] => Fold[H *: T] =
30+
new Fold[H *: T]:
31+
def apply[Acc](acc: Acc, expr: H *: T, f: [t] => (Acc, Expr[t]) => Acc): Acc =
32+
val acc1 = h(acc, expr.head, f)
33+
t(acc1, expr.tail, f)
34+
35+
private def product[E](m: Mirror.ProductOf[E], tupleFold: Fold[m.MirroredElemTypes]): Fold[E] =
36+
new Fold[E]:
37+
def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc =
38+
// The mirror makes this safe according to https://github.com/scala/scala3/issues/22382#issuecomment-2613187822
39+
tupleFold(acc, Tuple.fromProduct(expr.asInstanceOf[Product]).asInstanceOf[m.MirroredElemTypes], f)
40+
41+
private def sum[E](m: Mirror.SumOf[E], cases0: () => List[Fold[?]]): Fold[E] =
42+
new Fold[E]:
43+
lazy val cases = cases0()
44+
def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc =
45+
val ord = m.ordinal(expr)
46+
val caseFold = cases.apply(ord)
47+
caseFold.apply(acc, expr.asInstanceOf, f)
48+
49+
inline given derived[E](using m: Mirror.Of[E]): Fold[E] =
50+
inline m match
51+
case m: Mirror.SumOf[E] => sum(m, () => summonAll[m.MirroredElemTypes])
52+
case m: Mirror.ProductOf[E] => product[E](m, summonInline[Fold[m.MirroredElemTypes]])
53+
54+
given [T] => Fold[Expr.Const[T]] = Leaf()
55+
given Fold[Expr.UpcastToIterable[Any, Iterable[Any]]] = derived
56+
given [T] => Fold[Expr[T]] = new Fold[Expr[T]]:
57+
val default = derived[Expr[T]]
58+
def apply[Acc](acc: Acc, expr: Expr[T], f: [t] => (Acc, Expr[t]) => Acc): Acc =
59+
default(f(acc, expr), expr, f)
60+
61+
@main def test(): Unit =
62+
def count[T](expr: Expr[T], f: [t] => Expr[t] => Boolean)(using fold: Fold[Expr[T]]): Int =
63+
fold(0, expr, [t] => (acc, e) => if f(e) then acc + 1 else acc)
64+
65+
val ast: Expr[Iterable[Int]] = Expr.UpcastToIterable(Expr.Seq(Expr.Const(1), Expr.Const(2), Expr.Const(3)))
66+
val constCount = count(
67+
ast,
68+
[t] =>
69+
_ match {
70+
case Expr.Const(_) => true
71+
case _ => false
72+
}
73+
)
74+
println(s"Number of Const nodes: $constCount")

0 commit comments

Comments
 (0)