Skip to content

Commit c20be38

Browse files
Merge pull request #10184 from dotty-staging/move-tree-map-to-refelction
Move TreeMap, TreeAccumulator and TreeTraverser into Reflection
2 parents a9f9e96 + 38973e2 commit c20be38

File tree

5 files changed

+294
-322
lines changed

5 files changed

+294
-322
lines changed

library/src/scala/tasty/Reflection.scala

+290-12
Original file line numberDiff line numberDiff line change
@@ -3312,20 +3312,298 @@ trait Reflection { reflection =>
33123312
// UTILS //
33133313
///////////////
33143314

3315-
/** TASTy Reflect tree accumulator */
3316-
trait TreeAccumulator[X] extends reflect.TreeAccumulator[X] {
3317-
val reflect: reflection.type = reflection
3318-
}
3315+
/** TASTy Reflect tree accumulator.
3316+
*
3317+
* Usage:
3318+
* ```
3319+
* class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3320+
* extends scala.tasty.reflect.TreeAccumulator[X] {
3321+
* import reflect._
3322+
* def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
3323+
* }
3324+
* ```
3325+
*/
3326+
trait TreeAccumulator[X]:
3327+
3328+
// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
3329+
def foldTree(x: X, tree: Tree)(using ctx: Context): X
3330+
3331+
def foldTrees(x: X, trees: Iterable[Tree])(using ctx: Context): X = trees.foldLeft(x)(foldTree)
3332+
3333+
def foldOverTree(x: X, tree: Tree)(using ctx: Context): X = {
3334+
def localCtx(definition: Definition): Context = definition.symbol.localContext
3335+
tree match {
3336+
case Ident(_) =>
3337+
x
3338+
case Select(qualifier, _) =>
3339+
foldTree(x, qualifier)
3340+
case This(qual) =>
3341+
x
3342+
case Super(qual, _) =>
3343+
foldTree(x, qual)
3344+
case Apply(fun, args) =>
3345+
foldTrees(foldTree(x, fun), args)
3346+
case TypeApply(fun, args) =>
3347+
foldTrees(foldTree(x, fun), args)
3348+
case Literal(const) =>
3349+
x
3350+
case New(tpt) =>
3351+
foldTree(x, tpt)
3352+
case Typed(expr, tpt) =>
3353+
foldTree(foldTree(x, expr), tpt)
3354+
case NamedArg(_, arg) =>
3355+
foldTree(x, arg)
3356+
case Assign(lhs, rhs) =>
3357+
foldTree(foldTree(x, lhs), rhs)
3358+
case Block(stats, expr) =>
3359+
foldTree(foldTrees(x, stats), expr)
3360+
case If(cond, thenp, elsep) =>
3361+
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
3362+
case While(cond, body) =>
3363+
foldTree(foldTree(x, cond), body)
3364+
case Closure(meth, tpt) =>
3365+
foldTree(x, meth)
3366+
case Match(selector, cases) =>
3367+
foldTrees(foldTree(x, selector), cases)
3368+
case Return(expr, _) =>
3369+
foldTree(x, expr)
3370+
case Try(block, handler, finalizer) =>
3371+
foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
3372+
case Repeated(elems, elemtpt) =>
3373+
foldTrees(foldTree(x, elemtpt), elems)
3374+
case Inlined(call, bindings, expansion) =>
3375+
foldTree(foldTrees(x, bindings), expansion)
3376+
case vdef @ ValDef(_, tpt, rhs) =>
3377+
val ctx = localCtx(vdef)
3378+
given Context = ctx
3379+
foldTrees(foldTree(x, tpt), rhs)
3380+
case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) =>
3381+
val ctx = localCtx(ddef)
3382+
given Context = ctx
3383+
foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
3384+
case tdef @ TypeDef(_, rhs) =>
3385+
val ctx = localCtx(tdef)
3386+
given Context = ctx
3387+
foldTree(x, rhs)
3388+
case cdef @ ClassDef(_, constr, parents, derived, self, body) =>
3389+
val ctx = localCtx(cdef)
3390+
given Context = ctx
3391+
foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
3392+
case Import(expr, _) =>
3393+
foldTree(x, expr)
3394+
case clause @ PackageClause(pid, stats) =>
3395+
foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
3396+
case Inferred() => x
3397+
case TypeIdent(_) => x
3398+
case TypeSelect(qualifier, _) => foldTree(x, qualifier)
3399+
case Projection(qualifier, _) => foldTree(x, qualifier)
3400+
case Singleton(ref) => foldTree(x, ref)
3401+
case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
3402+
case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
3403+
case ByName(result) => foldTree(x, result)
3404+
case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
3405+
case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
3406+
case TypeBind(_, tbt) => foldTree(x, tbt)
3407+
case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
3408+
case MatchTypeTree(boundopt, selector, cases) =>
3409+
foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
3410+
case WildcardTypeTree() => x
3411+
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
3412+
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
3413+
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
3414+
case Bind(_, body) => foldTree(x, body)
3415+
case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
3416+
case Alternatives(patterns) => foldTrees(x, patterns)
3417+
}
3418+
}
3419+
end TreeAccumulator
33193420

3320-
/** TASTy Reflect tree traverser */
3321-
trait TreeTraverser extends reflect.TreeTraverser {
3322-
val reflect: reflection.type = reflection
3323-
}
33243421

3325-
/** TASTy Reflect tree map */
3326-
trait TreeMap extends reflect.TreeMap {
3327-
val reflect: reflection.type = reflection
3328-
}
3422+
/** TASTy Reflect tree traverser.
3423+
*
3424+
* Usage:
3425+
* ```
3426+
* class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3427+
* extends scala.tasty.reflect.TreeTraverser {
3428+
* import reflect._
3429+
* override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
3430+
* }
3431+
* ```
3432+
*/
3433+
trait TreeTraverser extends TreeAccumulator[Unit]:
3434+
3435+
def traverseTree(tree: Tree)(using ctx: Context): Unit = traverseTreeChildren(tree)
3436+
3437+
def foldTree(x: Unit, tree: Tree)(using ctx: Context): Unit = traverseTree(tree)
3438+
3439+
protected def traverseTreeChildren(tree: Tree)(using ctx: Context): Unit = foldOverTree((), tree)
3440+
3441+
end TreeTraverser
3442+
3443+
/** TASTy Reflect tree map.
3444+
*
3445+
* Usage:
3446+
* ```
3447+
* import qctx.reflect._
3448+
* class MyTreeMap extends TreeMap {
3449+
* override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
3450+
* }
3451+
* ```
3452+
*/
3453+
trait TreeMap:
3454+
3455+
def transformTree(tree: Tree)(using ctx: Context): Tree = {
3456+
tree match {
3457+
case tree: PackageClause =>
3458+
PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(using tree.symbol.localContext))
3459+
case tree: Import =>
3460+
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
3461+
case tree: Statement =>
3462+
transformStatement(tree)
3463+
case tree: TypeTree => transformTypeTree(tree)
3464+
case tree: TypeBoundsTree => tree // TODO traverse tree
3465+
case tree: WildcardTypeTree => tree // TODO traverse tree
3466+
case tree: CaseDef =>
3467+
transformCaseDef(tree)
3468+
case tree: TypeCaseDef =>
3469+
transformTypeCaseDef(tree)
3470+
case pattern: Bind =>
3471+
Bind.copy(pattern)(pattern.name, pattern.pattern)
3472+
case pattern: Unapply =>
3473+
Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
3474+
case pattern: Alternatives =>
3475+
Alternatives.copy(pattern)(transformTrees(pattern.patterns))
3476+
}
3477+
}
3478+
3479+
def transformStatement(tree: Statement)(using ctx: Context): Statement = {
3480+
def localCtx(definition: Definition): Context = definition.symbol.localContext
3481+
tree match {
3482+
case tree: Term =>
3483+
transformTerm(tree)
3484+
case tree: ValDef =>
3485+
val ctx = localCtx(tree)
3486+
given Context = ctx
3487+
val tpt1 = transformTypeTree(tree.tpt)
3488+
val rhs1 = tree.rhs.map(x => transformTerm(x))
3489+
ValDef.copy(tree)(tree.name, tpt1, rhs1)
3490+
case tree: DefDef =>
3491+
val ctx = localCtx(tree)
3492+
given Context = ctx
3493+
DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
3494+
case tree: TypeDef =>
3495+
val ctx = localCtx(tree)
3496+
given Context = ctx
3497+
TypeDef.copy(tree)(tree.name, transformTree(tree.rhs))
3498+
case tree: ClassDef =>
3499+
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
3500+
case tree: Import =>
3501+
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
3502+
}
3503+
}
3504+
3505+
def transformTerm(tree: Term)(using ctx: Context): Term = {
3506+
tree match {
3507+
case Ident(name) =>
3508+
tree
3509+
case Select(qualifier, name) =>
3510+
Select.copy(tree)(transformTerm(qualifier), name)
3511+
case This(qual) =>
3512+
tree
3513+
case Super(qual, mix) =>
3514+
Super.copy(tree)(transformTerm(qual), mix)
3515+
case Apply(fun, args) =>
3516+
Apply.copy(tree)(transformTerm(fun), transformTerms(args))
3517+
case TypeApply(fun, args) =>
3518+
TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args))
3519+
case Literal(const) =>
3520+
tree
3521+
case New(tpt) =>
3522+
New.copy(tree)(transformTypeTree(tpt))
3523+
case Typed(expr, tpt) =>
3524+
Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt))
3525+
case tree: NamedArg =>
3526+
NamedArg.copy(tree)(tree.name, transformTerm(tree.value))
3527+
case Assign(lhs, rhs) =>
3528+
Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs))
3529+
case Block(stats, expr) =>
3530+
Block.copy(tree)(transformStats(stats), transformTerm(expr))
3531+
case If(cond, thenp, elsep) =>
3532+
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
3533+
case Closure(meth, tpt) =>
3534+
Closure.copy(tree)(transformTerm(meth), tpt)
3535+
case Match(selector, cases) =>
3536+
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
3537+
case Return(expr, from) =>
3538+
Return.copy(tree)(transformTerm(expr), from)
3539+
case While(cond, body) =>
3540+
While.copy(tree)(transformTerm(cond), transformTerm(body))
3541+
case Try(block, cases, finalizer) =>
3542+
Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
3543+
case Repeated(elems, elemtpt) =>
3544+
Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
3545+
case Inlined(call, bindings, expansion) =>
3546+
Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/)
3547+
}
3548+
}
3549+
3550+
def transformTypeTree(tree: TypeTree)(using ctx: Context): TypeTree = tree match {
3551+
case Inferred() => tree
3552+
case tree: TypeIdent => tree
3553+
case tree: TypeSelect =>
3554+
TypeSelect.copy(tree)(tree.qualifier, tree.name)
3555+
case tree: Projection =>
3556+
Projection.copy(tree)(tree.qualifier, tree.name)
3557+
case tree: Annotated =>
3558+
Annotated.copy(tree)(tree.arg, tree.annotation)
3559+
case tree: Singleton =>
3560+
Singleton.copy(tree)(transformTerm(tree.ref))
3561+
case tree: Refined =>
3562+
Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]])
3563+
case tree: Applied =>
3564+
Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
3565+
case tree: MatchTypeTree =>
3566+
MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
3567+
case tree: ByName =>
3568+
ByName.copy(tree)(transformTypeTree(tree.result))
3569+
case tree: LambdaTypeTree =>
3570+
LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
3571+
case tree: TypeBind =>
3572+
TypeBind.copy(tree)(tree.name, tree.body)
3573+
case tree: TypeBlock =>
3574+
TypeBlock.copy(tree)(tree.aliases, tree.tpt)
3575+
}
3576+
3577+
def transformCaseDef(tree: CaseDef)(using ctx: Context): CaseDef = {
3578+
CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
3579+
}
3580+
3581+
def transformTypeCaseDef(tree: TypeCaseDef)(using ctx: Context): TypeCaseDef = {
3582+
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
3583+
}
3584+
3585+
def transformStats(trees: List[Statement])(using ctx: Context): List[Statement] =
3586+
trees mapConserve (transformStatement(_))
3587+
3588+
def transformTrees(trees: List[Tree])(using ctx: Context): List[Tree] =
3589+
trees mapConserve (transformTree(_))
3590+
3591+
def transformTerms(trees: List[Term])(using ctx: Context): List[Term] =
3592+
trees mapConserve (transformTerm(_))
3593+
3594+
def transformTypeTrees(trees: List[TypeTree])(using ctx: Context): List[TypeTree] =
3595+
trees mapConserve (transformTypeTree(_))
3596+
3597+
def transformCaseDefs(trees: List[CaseDef])(using ctx: Context): List[CaseDef] =
3598+
trees mapConserve (transformCaseDef(_))
3599+
3600+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(using ctx: Context): List[TypeCaseDef] =
3601+
trees mapConserve (transformTypeCaseDef(_))
3602+
3603+
def transformSubTrees[Tr <: Tree](trees: List[Tr])(using ctx: Context): List[Tr] =
3604+
transformTrees(trees).asInstanceOf[List[Tr]]
3605+
3606+
end TreeMap
33293607

33303608
// TODO: extract from Reflection
33313609

0 commit comments

Comments
 (0)