Skip to content

Commit bd93937

Browse files
authored
Merge pull request #6810 from dotty-staging/fix-6720
Fix #6720: implement extractor for function literals
2 parents 4c674d2 + d93e10c commit bd93937

File tree

11 files changed

+129
-75
lines changed

11 files changed

+129
-75
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

+8-8
Original file line numberDiff line numberDiff line change
@@ -504,21 +504,21 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
504504
def Inlined_copy(original: Tree)(call: Option[Term | TypeTree], bindings: List[Definition], expansion: Term)(implicit ctx: Context): Inlined =
505505
tpd.cpy.Inlined(original)(call.getOrElse(tpd.EmptyTree), bindings.asInstanceOf[List[tpd.MemberDef]], expansion)
506506

507-
type Lambda = tpd.Closure
507+
type Closure = tpd.Closure
508508

509-
def matchLambda(x: Term)(implicit ctx: Context): Option[Lambda] = x match {
509+
def matchClosure(x: Term)(implicit ctx: Context): Option[Closure] = x match {
510510
case x: tpd.Closure => Some(x)
511511
case _ => None
512512
}
513513

514-
def Lambda_meth(self: Lambda)(implicit ctx: Context): Term = self.meth
515-
def Lambda_tptOpt(self: Lambda)(implicit ctx: Context): Option[TypeTree] = optional(self.tpt)
514+
def Closure_meth(self: Closure)(implicit ctx: Context): Term = self.meth
515+
def Closure_tpeOpt(self: Closure)(implicit ctx: Context): Option[Type] = optional(self.tpt).map(_.tpe)
516516

517-
def Lambda_apply(meth: Term, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
518-
withDefaultPos(ctx => tpd.Closure(Nil, meth, tpt.getOrElse(tpd.EmptyTree))(ctx))
517+
def Closure_apply(meth: Term, tpe: Option[Type])(implicit ctx: Context): Closure =
518+
withDefaultPos(ctx => tpd.Closure(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))(ctx))
519519

520-
def Lambda_copy(original: Tree)(meth: Tree, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
521-
tpd.cpy.Closure(original)(Nil, meth, tpt.getOrElse(tpd.EmptyTree))
520+
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(implicit ctx: Context): Closure =
521+
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))
522522

523523
type If = tpd.If
524524

library/src-bootstrapped/scala/internal/quoted/Matcher.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ object Matcher {
207207
tpt1 =#= tpt2 &&
208208
withEnv(rhsEnv)(rhs1 =#= rhs2)
209209

210-
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
210+
case (Closure(_, tpt1), Closure(_, tpt2)) =>
211211
// TODO match tpt1 with tpt2?
212212
matched
213213

library/src/scala/tasty/reflect/Core.scala

+11-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ package scala.tasty.reflect
2626
* | +- Typed
2727
* | +- Assign
2828
* | +- Block
29-
* | +- Lambda
29+
* | +- Closure
3030
* | +- If
3131
* | +- Match
3232
* | +- ImpliedMatch
@@ -200,8 +200,16 @@ trait Core {
200200
/** Tree representing a block `{ ... }` in the source code */
201201
type Block = kernel.Block
202202

203-
/** Tree representing a lambda `(...) => ...` in the source code */
204-
type Lambda = kernel.Lambda
203+
/** A lambda `(...) => ...` in the source code is represented as
204+
* a local method and a closure:
205+
*
206+
* {
207+
* def m(...) = ...
208+
* closure(m)
209+
* }
210+
*
211+
*/
212+
type Closure = kernel.Closure
205213

206214
/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
207215
type If = kernel.If

library/src/scala/tasty/reflect/Kernel.scala

+16-8
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ package scala.tasty.reflect
2525
* | +- Typed
2626
* | +- Assign
2727
* | +- Block
28-
* | +- Lambda
28+
* | +- Closure
2929
* | +- If
3030
* | +- Match
3131
* | +- ImpliedMatch
@@ -436,16 +436,24 @@ trait Kernel {
436436
def Block_apply(stats: List[Statement], expr: Term)(implicit ctx: Context): Block
437437
def Block_copy(original: Tree)(stats: List[Statement], expr: Term)(implicit ctx: Context): Block
438438

439-
/** Tree representing a lambda `(...) => ...` in the source code */
440-
type Lambda <: Term
439+
/** A lambda `(...) => ...` in the source code is represented as
440+
* a local method and a closure:
441+
*
442+
* {
443+
* def m(...) = ...
444+
* closure(m)
445+
* }
446+
*
447+
*/
448+
type Closure <: Term
441449

442-
def matchLambda(tree: Tree)(implicit ctx: Context): Option[Lambda]
450+
def matchClosure(tree: Tree)(implicit ctx: Context): Option[Closure]
443451

444-
def Lambda_meth(self: Lambda)(implicit ctx: Context): Term
445-
def Lambda_tptOpt(self: Lambda)(implicit ctx: Context): Option[TypeTree]
452+
def Closure_meth(self: Closure)(implicit ctx: Context): Term
453+
def Closure_tpeOpt(self: Closure)(implicit ctx: Context): Option[Type]
446454

447-
def Lambda_apply(meth: Term, tpt: Option[TypeTree])(implicit ctx: Context): Lambda
448-
def Lambda_copy(original: Tree)(meth: Tree, tpt: Option[TypeTree])(implicit ctx: Context): Lambda
455+
def Closure_apply(meth: Term, tpe: Option[Type])(implicit ctx: Context): Closure
456+
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(implicit ctx: Context): Closure
449457

450458
/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
451459
type If <: Term

library/src/scala/tasty/reflect/Printers.scala

+19-28
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ trait Printers
201201
this += "Block(" ++= stats += ", " += expr += ")"
202202
case If(cond, thenp, elsep) =>
203203
this += "If(" += cond += ", " += thenp += ", " += elsep += ")"
204-
case Lambda(meth, tpt) =>
205-
this += "Lambda(" += meth += ", " += tpt += ")"
204+
case Closure(meth, tpt) =>
205+
this += "Closure(" += meth += ", " += tpt += ")"
206206
case Match(selector, cases) =>
207207
this += "Match(" += selector += ", " ++= cases += ")"
208208
case ImpliedMatch(cases) =>
@@ -406,6 +406,7 @@ trait Printers
406406

407407
private implicit class TypeOps(buff: Buffer) {
408408
def +=(x: TypeOrBounds): Buffer = { visitType(x); buff }
409+
def +=(x: Option[TypeOrBounds]): Buffer = { visitOption(x, visitType); buff }
409410
def ++=(x: List[TypeOrBounds]): Buffer = { visitList(x, visitType); buff }
410411
}
411412

@@ -740,17 +741,6 @@ trait Printers
740741
printTree(body)
741742
}
742743

743-
case IsDefDef(ddef @ DefDef(name, targs, argss, _, rhsOpt)) if name.startsWith("$anonfun") =>
744-
// Decompile lambda definition
745-
assert(targs.isEmpty)
746-
val args :: Nil = argss
747-
val Some(rhs) = rhsOpt
748-
inParens {
749-
printArgsDefs(args)
750-
this += " => "
751-
printTree(rhs)
752-
}
753-
754744
case IsDefDef(ddef @ DefDef(name, targs, argss, tpt, rhs)) =>
755745
printDefAnnotations(ddef)
756746

@@ -901,6 +891,13 @@ trait Printers
901891
this += " = "
902892
printTree(rhs)
903893

894+
case Lambda(params, body) => // must come before `Block`
895+
inParens {
896+
printArgsDefs(params)
897+
this += " => "
898+
printTree(body)
899+
}
900+
904901
case Block(stats0, expr) =>
905902
val stats = stats0.filter {
906903
case IsValDef(tree) => !tree.symbol.flags.is(Flags.Object)
@@ -911,10 +908,6 @@ trait Printers
911908
case Inlined(_, bindings, expansion) =>
912909
printFlatBlock(bindings, expansion)
913910

914-
case Lambda(meth, tpt) =>
915-
// Printed in by it's DefDef
916-
this
917-
918911
case If(cond, thenp, elsep) =>
919912
this += highlightKeyword("if ")
920913
inParens(printTree(cond))
@@ -982,6 +975,8 @@ trait Printers
982975
def flatBlock(stats: List[Statement], expr: Term): (List[Statement], Term) = {
983976
val flatStats = List.newBuilder[Statement]
984977
def extractFlatStats(stat: Statement): Unit = stat match {
978+
case Lambda(_, _) => // must come before `Block`
979+
flatStats += stat
985980
case Block(stats1, expr1) =>
986981
val it = stats1.iterator
987982
while (it.hasNext)
@@ -996,6 +991,8 @@ trait Printers
996991
case stat => flatStats += stat
997992
}
998993
def extractFlatExpr(term: Term): Term = term match {
994+
case Lambda(_, _) => // must come before `Block`
995+
term
999996
case Block(stats1, expr1) =>
1000997
val it = stats1.iterator
1001998
while (it.hasNext)
@@ -1017,23 +1014,16 @@ trait Printers
10171014

10181015
def printFlatBlock(stats: List[Statement], expr: Term)(implicit elideThis: Option[Symbol]): Buffer = {
10191016
val (stats1, expr1) = flatBlock(stats, expr)
1020-
// Remove Lambda nodes, lambdas are printed by their definition
10211017
val stats2 = stats1.filter {
1022-
case Lambda(_, _) => false
1018+
case IsTypeDef(tree) => !tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.quoteTypeTag")
10231019
case _ => true
10241020
}
1025-
val (stats3, expr3) = expr1 match {
1026-
case Lambda(_, _) =>
1027-
val init :+ last = stats2
1028-
(init, last)
1029-
case _ => (stats2, expr1)
1030-
}
1031-
if (stats3.isEmpty) {
1032-
printTree(expr3)
1021+
if (stats2.isEmpty) {
1022+
printTree(expr1)
10331023
} else {
10341024
this += "{"
10351025
indented {
1036-
printStats(stats3, expr3)
1026+
printStats(stats2, expr1)
10371027
}
10381028
this += lineBreak() += "}"
10391029
}
@@ -1043,6 +1033,7 @@ trait Printers
10431033
def printSeparator(next: Tree): Unit = {
10441034
// Avoid accidental application of opening `{` on next line with a double break
10451035
def rec(next: Tree): Unit = next match {
1036+
case Lambda(_, _) => this += lineBreak()
10461037
case Block(stats, _) if stats.nonEmpty => this += doubleLineBreak()
10471038
case Inlined(_, bindings, _) if bindings.nonEmpty => this += doubleLineBreak()
10481039
case Select(qual, _) => rec(qual)

library/src/scala/tasty/reflect/TreeOps.scala

+35-13
Original file line numberDiff line numberDiff line change
@@ -552,26 +552,48 @@ trait TreeOps extends Core {
552552
def expr(implicit ctx: Context): Term = kernel.Block_expr(self)
553553
}
554554

555-
object IsLambda {
556-
/** Matches any Lambda and returns it */
557-
def unapply(tree: Tree)(implicit ctx: Context): Option[Lambda] = kernel.matchLambda(tree)
555+
object IsClosure {
556+
/** Matches any Closure and returns it */
557+
def unapply(tree: Tree)(implicit ctx: Context): Option[Closure] = kernel.matchClosure(tree)
558558
}
559559

560-
object Lambda {
560+
object Closure {
561+
562+
def apply(meth: Term, tpt: Option[Type])(implicit ctx: Context): Closure =
563+
kernel.Closure_apply(meth, tpt)
561564

562-
def apply(meth: Term, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
563-
kernel.Lambda_apply(meth, tpt)
565+
def copy(original: Tree)(meth: Tree, tpt: Option[Type])(implicit ctx: Context): Closure =
566+
kernel.Closure_copy(original)(meth, tpt)
564567

565-
def copy(original: Tree)(meth: Tree, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
566-
kernel.Lambda_copy(original)(meth, tpt)
568+
def unapply(tree: Tree)(implicit ctx: Context): Option[(Term, Option[Type])] =
569+
kernel.matchClosure(tree).map(x => (x.meth, x.tpeOpt))
570+
}
567571

568-
def unapply(tree: Tree)(implicit ctx: Context): Option[(Term, Option[TypeTree])] =
569-
kernel.matchLambda(tree).map(x => (x.meth, x.tptOpt))
572+
implicit class ClosureAPI(self: Closure) {
573+
def meth(implicit ctx: Context): Term = kernel.Closure_meth(self)
574+
def tpeOpt(implicit ctx: Context): Option[Type] = kernel.Closure_tpeOpt(self)
570575
}
571576

572-
implicit class LambdaAPI(self: Lambda) {
573-
def meth(implicit ctx: Context): Term = kernel.Lambda_meth(self)
574-
def tptOpt(implicit ctx: Context): Option[TypeTree] = kernel.Lambda_tptOpt(self)
577+
/** A lambda `(...) => ...` in the source code is represented as
578+
* a local method and a closure:
579+
*
580+
* {
581+
* def m(...) = ...
582+
* closure(m)
583+
* }
584+
*
585+
* @note Due to the encoding, in pattern matches the case for `Lambda`
586+
* should come before the case for `Block` to avoid mishandling
587+
* of `Lambda`.
588+
*/
589+
object Lambda {
590+
def unapply(tree: Tree)(implicit ctx: Context): Option[(List[ValDef], Term)] = tree match {
591+
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
592+
if ddef.symbol == meth.symbol =>
593+
Some(params, body)
594+
595+
case _ => None
596+
}
575597
}
576598

577599
object IsIf {

library/src/scala/tasty/reflect/TreeUtils.scala

+4-5
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ trait TreeUtils
4848
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
4949
case While(cond, body) =>
5050
foldTree(foldTree(x, cond), body)
51-
case Lambda(meth, tpt) =>
52-
val a = foldTree(x, meth)
53-
tpt.fold(a)(b => foldTree(a, b))
51+
case Closure(meth, tpt) =>
52+
foldTree(x, meth)
5453
case Match(selector, cases) =>
5554
foldTrees(foldTree(x, selector), cases)
5655
case Return(expr) =>
@@ -193,8 +192,8 @@ trait TreeUtils
193192
Block.copy(tree)(transformStats(stats), transformTerm(expr))
194193
case If(cond, thenp, elsep) =>
195194
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
196-
case Lambda(meth, tpt) =>
197-
Lambda.copy(tree)(transformTerm(meth), tpt.map(x => transformTypeTree(x)))
195+
case Closure(meth, tpt) =>
196+
Closure.copy(tree)(transformTerm(meth), tpt)
198197
case Match(selector, cases) =>
199198
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
200199
case Return(expr) =>

tests/run-macros/i5941/macro_1.scala

+1-7
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@ object Lens {
4040

4141
object Function {
4242
def unapply(t: Term): Option[(List[ValDef], Term)] = t match {
43-
case Inlined(
44-
None, Nil,
45-
Block(
46-
(ddef @ DefDef(_, Nil, params :: Nil, _, Some(body))) :: Nil,
47-
Lambda(meth, _)
48-
)
49-
) if meth.symbol == ddef.symbol => Some((params, body))
43+
case Inlined(None, Nil, Lambda(params, body)) => Some((params, body))
5044
case _ => None
5145
}
5246
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.tasty._
3+
4+
object lib {
5+
6+
inline def assert(condition: => Boolean): Unit = ${ assertImpl('condition, '{""}) }
7+
8+
def assertImpl(cond: Expr[Boolean], clue: Expr[Any])(implicit refl: Reflection): Expr[Unit] = {
9+
import refl._
10+
import util._
11+
12+
cond.unseal.underlyingArgument match {
13+
case t @ Apply(Select(lhs, op), Lambda(param :: Nil, Apply(Select(a, "=="), b :: Nil)) :: Nil)
14+
if a.symbol == param.symbol || b.symbol == param.symbol =>
15+
'{ scala.Predef.assert($cond) }
16+
}
17+
}
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
import lib._
3+
4+
case class IntList(args: Int*) {
5+
def exists(f: Int => Boolean): Boolean = args.exists(f)
6+
}
7+
8+
def main(args: Array[String]): Unit = {
9+
assert(IntList(3, 5).exists(_ == 3))
10+
assert(IntList(3, 5).exists(5 == _))
11+
assert(IntList(3, 5).exists(x => x == 3))
12+
assert(IntList(3, 5).exists(x => 5 == x))
13+
}
14+
}

tests/run-macros/tasty-extractors-2.check

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Inlined(None, Nil, Block(List(ValDef("x", Inferred(), Some(Literal(Constant(1))))), Assign(Ident("x"), Literal(Constant(2)))))
22
Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix())))
33

4-
Inlined(None, Nil, Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", TypeIdent("Int"), None))), Inferred(), Some(Ident("x")))), Lambda(Ident("$anonfun"), None)))
4+
Inlined(None, Nil, Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", TypeIdent("Int"), None))), Inferred(), Some(Ident("x")))), Closure(Ident("$anonfun"), None)))
55
Type.AppliedType(Type.SymRef(IsClassDefSymbol(<scala.Function1>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix()))), List(Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix())))), Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix()))))))
66

77
Inlined(None, Nil, Ident("???"))
@@ -100,6 +100,6 @@ Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageD
100100
Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("<init>", Nil, List(Nil), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil)), Nil, None, List(TypeDef("X", TypeBoundsTree(Inferred(), Inferred())))), DefDef("f", Nil, List(List(ValDef("a", Refined(TypeIdent("Foo"), List(TypeDef("X", TypeIdent("Int")))), None))), TypeSelect(Ident("a"), "X"), Some(Ident("???")))), Literal(Constant(()))))
101101
Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix())))
102102

103-
Inlined(None, Nil, Block(List(ValDef("lambda", Applied(Inferred(), List(TypeIdent("Int"), TypeIdent("Int"))), Some(Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", Inferred(), None))), Inferred(), Some(Ident("x")))), Lambda(Ident("$anonfun"), None))))), Literal(Constant(()))))
103+
Inlined(None, Nil, Block(List(ValDef("lambda", Applied(Inferred(), List(TypeIdent("Int"), TypeIdent("Int"))), Some(Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", Inferred(), None))), Inferred(), Some(Ident("x")))), Closure(Ident("$anonfun"), None))))), Literal(Constant(()))))
104104
Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix())))
105105

0 commit comments

Comments
 (0)