Skip to content

Commit 38973e2

Browse files
committed
Move TreeMap, TreeAccumulator and TreeTraverser into Reflection
1 parent df52911 commit 38973e2

File tree

5 files changed

+294
-322
lines changed

5 files changed

+294
-322
lines changed

library/src/scala/tasty/Reflection.scala

+290-12
Original file line numberDiff line numberDiff line change
@@ -3313,20 +3313,298 @@ trait Reflection { reflection =>
33133313
// UTILS //
33143314
///////////////
33153315

3316-
/** TASTy Reflect tree accumulator */
3317-
trait TreeAccumulator[X] extends reflect.TreeAccumulator[X] {
3318-
val reflect: reflection.type = reflection
3319-
}
3316+
/** TASTy Reflect tree accumulator.
3317+
*
3318+
* Usage:
3319+
* ```
3320+
* class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3321+
* extends scala.tasty.reflect.TreeAccumulator[X] {
3322+
* import reflect._
3323+
* def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
3324+
* }
3325+
* ```
3326+
*/
3327+
trait TreeAccumulator[X]:
3328+
3329+
// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
3330+
def foldTree(x: X, tree: Tree)(using ctx: Context): X
3331+
3332+
def foldTrees(x: X, trees: Iterable[Tree])(using ctx: Context): X = trees.foldLeft(x)(foldTree)
3333+
3334+
def foldOverTree(x: X, tree: Tree)(using ctx: Context): X = {
3335+
def localCtx(definition: Definition): Context = definition.symbol.localContext
3336+
tree match {
3337+
case Ident(_) =>
3338+
x
3339+
case Select(qualifier, _) =>
3340+
foldTree(x, qualifier)
3341+
case This(qual) =>
3342+
x
3343+
case Super(qual, _) =>
3344+
foldTree(x, qual)
3345+
case Apply(fun, args) =>
3346+
foldTrees(foldTree(x, fun), args)
3347+
case TypeApply(fun, args) =>
3348+
foldTrees(foldTree(x, fun), args)
3349+
case Literal(const) =>
3350+
x
3351+
case New(tpt) =>
3352+
foldTree(x, tpt)
3353+
case Typed(expr, tpt) =>
3354+
foldTree(foldTree(x, expr), tpt)
3355+
case NamedArg(_, arg) =>
3356+
foldTree(x, arg)
3357+
case Assign(lhs, rhs) =>
3358+
foldTree(foldTree(x, lhs), rhs)
3359+
case Block(stats, expr) =>
3360+
foldTree(foldTrees(x, stats), expr)
3361+
case If(cond, thenp, elsep) =>
3362+
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
3363+
case While(cond, body) =>
3364+
foldTree(foldTree(x, cond), body)
3365+
case Closure(meth, tpt) =>
3366+
foldTree(x, meth)
3367+
case Match(selector, cases) =>
3368+
foldTrees(foldTree(x, selector), cases)
3369+
case Return(expr, _) =>
3370+
foldTree(x, expr)
3371+
case Try(block, handler, finalizer) =>
3372+
foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
3373+
case Repeated(elems, elemtpt) =>
3374+
foldTrees(foldTree(x, elemtpt), elems)
3375+
case Inlined(call, bindings, expansion) =>
3376+
foldTree(foldTrees(x, bindings), expansion)
3377+
case vdef @ ValDef(_, tpt, rhs) =>
3378+
val ctx = localCtx(vdef)
3379+
given Context = ctx
3380+
foldTrees(foldTree(x, tpt), rhs)
3381+
case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) =>
3382+
val ctx = localCtx(ddef)
3383+
given Context = ctx
3384+
foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
3385+
case tdef @ TypeDef(_, rhs) =>
3386+
val ctx = localCtx(tdef)
3387+
given Context = ctx
3388+
foldTree(x, rhs)
3389+
case cdef @ ClassDef(_, constr, parents, derived, self, body) =>
3390+
val ctx = localCtx(cdef)
3391+
given Context = ctx
3392+
foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
3393+
case Import(expr, _) =>
3394+
foldTree(x, expr)
3395+
case clause @ PackageClause(pid, stats) =>
3396+
foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
3397+
case Inferred() => x
3398+
case TypeIdent(_) => x
3399+
case TypeSelect(qualifier, _) => foldTree(x, qualifier)
3400+
case Projection(qualifier, _) => foldTree(x, qualifier)
3401+
case Singleton(ref) => foldTree(x, ref)
3402+
case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
3403+
case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
3404+
case ByName(result) => foldTree(x, result)
3405+
case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
3406+
case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
3407+
case TypeBind(_, tbt) => foldTree(x, tbt)
3408+
case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
3409+
case MatchTypeTree(boundopt, selector, cases) =>
3410+
foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
3411+
case WildcardTypeTree() => x
3412+
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
3413+
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
3414+
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
3415+
case Bind(_, body) => foldTree(x, body)
3416+
case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
3417+
case Alternatives(patterns) => foldTrees(x, patterns)
3418+
}
3419+
}
3420+
end TreeAccumulator
33203421

3321-
/** TASTy Reflect tree traverser */
3322-
trait TreeTraverser extends reflect.TreeTraverser {
3323-
val reflect: reflection.type = reflection
3324-
}
33253422

3326-
/** TASTy Reflect tree map */
3327-
trait TreeMap extends reflect.TreeMap {
3328-
val reflect: reflection.type = reflection
3329-
}
3423+
/** TASTy Reflect tree traverser.
3424+
*
3425+
* Usage:
3426+
* ```
3427+
* class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3428+
* extends scala.tasty.reflect.TreeTraverser {
3429+
* import reflect._
3430+
* override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
3431+
* }
3432+
* ```
3433+
*/
3434+
trait TreeTraverser extends TreeAccumulator[Unit]:
3435+
3436+
def traverseTree(tree: Tree)(using ctx: Context): Unit = traverseTreeChildren(tree)
3437+
3438+
def foldTree(x: Unit, tree: Tree)(using ctx: Context): Unit = traverseTree(tree)
3439+
3440+
protected def traverseTreeChildren(tree: Tree)(using ctx: Context): Unit = foldOverTree((), tree)
3441+
3442+
end TreeTraverser
3443+
3444+
/** TASTy Reflect tree map.
3445+
*
3446+
* Usage:
3447+
* ```
3448+
* import qctx.reflect._
3449+
* class MyTreeMap extends TreeMap {
3450+
* override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
3451+
* }
3452+
* ```
3453+
*/
3454+
trait TreeMap:
3455+
3456+
def transformTree(tree: Tree)(using ctx: Context): Tree = {
3457+
tree match {
3458+
case tree: PackageClause =>
3459+
PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(using tree.symbol.localContext))
3460+
case tree: Import =>
3461+
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
3462+
case tree: Statement =>
3463+
transformStatement(tree)
3464+
case tree: TypeTree => transformTypeTree(tree)
3465+
case tree: TypeBoundsTree => tree // TODO traverse tree
3466+
case tree: WildcardTypeTree => tree // TODO traverse tree
3467+
case tree: CaseDef =>
3468+
transformCaseDef(tree)
3469+
case tree: TypeCaseDef =>
3470+
transformTypeCaseDef(tree)
3471+
case pattern: Bind =>
3472+
Bind.copy(pattern)(pattern.name, pattern.pattern)
3473+
case pattern: Unapply =>
3474+
Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
3475+
case pattern: Alternatives =>
3476+
Alternatives.copy(pattern)(transformTrees(pattern.patterns))
3477+
}
3478+
}
3479+
3480+
def transformStatement(tree: Statement)(using ctx: Context): Statement = {
3481+
def localCtx(definition: Definition): Context = definition.symbol.localContext
3482+
tree match {
3483+
case tree: Term =>
3484+
transformTerm(tree)
3485+
case tree: ValDef =>
3486+
val ctx = localCtx(tree)
3487+
given Context = ctx
3488+
val tpt1 = transformTypeTree(tree.tpt)
3489+
val rhs1 = tree.rhs.map(x => transformTerm(x))
3490+
ValDef.copy(tree)(tree.name, tpt1, rhs1)
3491+
case tree: DefDef =>
3492+
val ctx = localCtx(tree)
3493+
given Context = ctx
3494+
DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
3495+
case tree: TypeDef =>
3496+
val ctx = localCtx(tree)
3497+
given Context = ctx
3498+
TypeDef.copy(tree)(tree.name, transformTree(tree.rhs))
3499+
case tree: ClassDef =>
3500+
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
3501+
case tree: Import =>
3502+
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
3503+
}
3504+
}
3505+
3506+
def transformTerm(tree: Term)(using ctx: Context): Term = {
3507+
tree match {
3508+
case Ident(name) =>
3509+
tree
3510+
case Select(qualifier, name) =>
3511+
Select.copy(tree)(transformTerm(qualifier), name)
3512+
case This(qual) =>
3513+
tree
3514+
case Super(qual, mix) =>
3515+
Super.copy(tree)(transformTerm(qual), mix)
3516+
case Apply(fun, args) =>
3517+
Apply.copy(tree)(transformTerm(fun), transformTerms(args))
3518+
case TypeApply(fun, args) =>
3519+
TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args))
3520+
case Literal(const) =>
3521+
tree
3522+
case New(tpt) =>
3523+
New.copy(tree)(transformTypeTree(tpt))
3524+
case Typed(expr, tpt) =>
3525+
Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt))
3526+
case tree: NamedArg =>
3527+
NamedArg.copy(tree)(tree.name, transformTerm(tree.value))
3528+
case Assign(lhs, rhs) =>
3529+
Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs))
3530+
case Block(stats, expr) =>
3531+
Block.copy(tree)(transformStats(stats), transformTerm(expr))
3532+
case If(cond, thenp, elsep) =>
3533+
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
3534+
case Closure(meth, tpt) =>
3535+
Closure.copy(tree)(transformTerm(meth), tpt)
3536+
case Match(selector, cases) =>
3537+
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
3538+
case Return(expr, from) =>
3539+
Return.copy(tree)(transformTerm(expr), from)
3540+
case While(cond, body) =>
3541+
While.copy(tree)(transformTerm(cond), transformTerm(body))
3542+
case Try(block, cases, finalizer) =>
3543+
Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
3544+
case Repeated(elems, elemtpt) =>
3545+
Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
3546+
case Inlined(call, bindings, expansion) =>
3547+
Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/)
3548+
}
3549+
}
3550+
3551+
def transformTypeTree(tree: TypeTree)(using ctx: Context): TypeTree = tree match {
3552+
case Inferred() => tree
3553+
case tree: TypeIdent => tree
3554+
case tree: TypeSelect =>
3555+
TypeSelect.copy(tree)(tree.qualifier, tree.name)
3556+
case tree: Projection =>
3557+
Projection.copy(tree)(tree.qualifier, tree.name)
3558+
case tree: Annotated =>
3559+
Annotated.copy(tree)(tree.arg, tree.annotation)
3560+
case tree: Singleton =>
3561+
Singleton.copy(tree)(transformTerm(tree.ref))
3562+
case tree: Refined =>
3563+
Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]])
3564+
case tree: Applied =>
3565+
Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
3566+
case tree: MatchTypeTree =>
3567+
MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
3568+
case tree: ByName =>
3569+
ByName.copy(tree)(transformTypeTree(tree.result))
3570+
case tree: LambdaTypeTree =>
3571+
LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
3572+
case tree: TypeBind =>
3573+
TypeBind.copy(tree)(tree.name, tree.body)
3574+
case tree: TypeBlock =>
3575+
TypeBlock.copy(tree)(tree.aliases, tree.tpt)
3576+
}
3577+
3578+
def transformCaseDef(tree: CaseDef)(using ctx: Context): CaseDef = {
3579+
CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
3580+
}
3581+
3582+
def transformTypeCaseDef(tree: TypeCaseDef)(using ctx: Context): TypeCaseDef = {
3583+
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
3584+
}
3585+
3586+
def transformStats(trees: List[Statement])(using ctx: Context): List[Statement] =
3587+
trees mapConserve (transformStatement(_))
3588+
3589+
def transformTrees(trees: List[Tree])(using ctx: Context): List[Tree] =
3590+
trees mapConserve (transformTree(_))
3591+
3592+
def transformTerms(trees: List[Term])(using ctx: Context): List[Term] =
3593+
trees mapConserve (transformTerm(_))
3594+
3595+
def transformTypeTrees(trees: List[TypeTree])(using ctx: Context): List[TypeTree] =
3596+
trees mapConserve (transformTypeTree(_))
3597+
3598+
def transformCaseDefs(trees: List[CaseDef])(using ctx: Context): List[CaseDef] =
3599+
trees mapConserve (transformCaseDef(_))
3600+
3601+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(using ctx: Context): List[TypeCaseDef] =
3602+
trees mapConserve (transformTypeCaseDef(_))
3603+
3604+
def transformSubTrees[Tr <: Tree](trees: List[Tr])(using ctx: Context): List[Tr] =
3605+
transformTrees(trees).asInstanceOf[List[Tr]]
3606+
3607+
end TreeMap
33303608

33313609
// TODO: extract from Reflection
33323610

0 commit comments

Comments
 (0)