|
| 1 | +import scala.compiletime.{erasedValue, summonInline} |
| 2 | +import scala.deriving.Mirror |
| 3 | + |
| 4 | +enum Expr[+T]: |
| 5 | + case UpcastToIterable[T, C <: Iterable[T]](v: Expr[C]) extends Expr[Iterable[T]] |
| 6 | + case Seq[T](elements: Expr[T]*) extends Expr[scala.Seq[T]] |
| 7 | + case Const(value: T) |
| 8 | + |
| 9 | +trait Fold[E]: |
| 10 | + def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc |
| 11 | + |
| 12 | +object Fold: |
| 13 | + private inline def summonAll[Elems <: Tuple]: List[Fold[?]] = |
| 14 | + inline erasedValue[Elems] match |
| 15 | + case _: (h *: tail) => summonInline[Fold[h]] :: summonAll[tail] |
| 16 | + case _: EmptyTuple => Nil |
| 17 | + |
| 18 | + final class Leaf[E] extends Fold[E]: |
| 19 | + def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc = acc |
| 20 | + |
| 21 | + given [T: Fold as fold] => Fold[Seq[T]] = new Fold[Seq[T]] { |
| 22 | + def apply[Acc](acc: Acc, expr: Seq[T], f: [t] => (Acc, Expr[t]) => Acc): Acc = |
| 23 | + expr.foldLeft(acc)((a, e) => fold(a, e, f)) |
| 24 | + } |
| 25 | + |
| 26 | + given Fold[EmptyTuple] = new Fold[EmptyTuple]: |
| 27 | + def apply[Acc](acc: Acc, expr: EmptyTuple, f: [t] => (Acc, Expr[t]) => Acc): Acc = acc |
| 28 | + |
| 29 | + given [H: Fold as h, T <: Tuple: Fold as t] => Fold[H *: T] = |
| 30 | + new Fold[H *: T]: |
| 31 | + def apply[Acc](acc: Acc, expr: H *: T, f: [t] => (Acc, Expr[t]) => Acc): Acc = |
| 32 | + val acc1 = h(acc, expr.head, f) |
| 33 | + t(acc1, expr.tail, f) |
| 34 | + |
| 35 | + private def product[E](m: Mirror.ProductOf[E], tupleFold: Fold[m.MirroredElemTypes]): Fold[E] = |
| 36 | + new Fold[E]: |
| 37 | + def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc = |
| 38 | + // The mirror makes this safe according to https://github.com/scala/scala3/issues/22382#issuecomment-2613187822 |
| 39 | + tupleFold(acc, Tuple.fromProduct(expr.asInstanceOf[Product]).asInstanceOf[m.MirroredElemTypes], f) |
| 40 | + |
| 41 | + private def sum[E](m: Mirror.SumOf[E], cases0: () => List[Fold[?]]): Fold[E] = |
| 42 | + new Fold[E]: |
| 43 | + lazy val cases = cases0() |
| 44 | + def apply[Acc](acc: Acc, expr: E, f: [t] => (Acc, Expr[t]) => Acc): Acc = |
| 45 | + val ord = m.ordinal(expr) |
| 46 | + val caseFold = cases.apply(ord) |
| 47 | + caseFold.apply(acc, expr.asInstanceOf, f) |
| 48 | + |
| 49 | + inline given derived[E](using m: Mirror.Of[E]): Fold[E] = |
| 50 | + inline m match |
| 51 | + case m: Mirror.SumOf[E] => sum(m, () => summonAll[m.MirroredElemTypes]) |
| 52 | + case m: Mirror.ProductOf[E] => product[E](m, summonInline[Fold[m.MirroredElemTypes]]) |
| 53 | + |
| 54 | + given [T] => Fold[Expr.Const[T]] = Leaf() |
| 55 | + given Fold[Expr.UpcastToIterable[Any, Iterable[Any]]] = derived |
| 56 | + given [T] => Fold[Expr[T]] = new Fold[Expr[T]]: |
| 57 | + val default = derived[Expr[T]] |
| 58 | + def apply[Acc](acc: Acc, expr: Expr[T], f: [t] => (Acc, Expr[t]) => Acc): Acc = |
| 59 | + default(f(acc, expr), expr, f) |
| 60 | + |
| 61 | +@main def test(): Unit = |
| 62 | + def count[T](expr: Expr[T], f: [t] => Expr[t] => Boolean)(using fold: Fold[Expr[T]]): Int = |
| 63 | + fold(0, expr, [t] => (acc, e) => if f(e) then acc + 1 else acc) |
| 64 | + |
| 65 | + val ast: Expr[Iterable[Int]] = Expr.UpcastToIterable(Expr.Seq(Expr.Const(1), Expr.Const(2), Expr.Const(3))) |
| 66 | + val constCount = count( |
| 67 | + ast, |
| 68 | + [t] => |
| 69 | + _ match { |
| 70 | + case Expr.Const(_) => true |
| 71 | + case _ => false |
| 72 | + } |
| 73 | + ) |
| 74 | + println(s"Number of Const nodes: $constCount") |
0 commit comments