@@ -3313,20 +3313,298 @@ trait Reflection { reflection =>
3313
3313
// UTILS //
3314
3314
// /////////////
3315
3315
3316
- /** TASTy Reflect tree accumulator */
3317
- trait TreeAccumulator [X ] extends reflect.TreeAccumulator [X ] {
3318
- val reflect : reflection.type = reflection
3319
- }
3316
+ /** TASTy Reflect tree accumulator.
3317
+ *
3318
+ * Usage:
3319
+ * ```
3320
+ * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3321
+ * extends scala.tasty.reflect.TreeAccumulator[X] {
3322
+ * import reflect._
3323
+ * def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
3324
+ * }
3325
+ * ```
3326
+ */
3327
+ trait TreeAccumulator [X ]:
3328
+
3329
+ // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
3330
+ def foldTree (x : X , tree : Tree )(using ctx : Context ): X
3331
+
3332
+ def foldTrees (x : X , trees : Iterable [Tree ])(using ctx : Context ): X = trees.foldLeft(x)(foldTree)
3333
+
3334
+ def foldOverTree (x : X , tree : Tree )(using ctx : Context ): X = {
3335
+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3336
+ tree match {
3337
+ case Ident (_) =>
3338
+ x
3339
+ case Select (qualifier, _) =>
3340
+ foldTree(x, qualifier)
3341
+ case This (qual) =>
3342
+ x
3343
+ case Super (qual, _) =>
3344
+ foldTree(x, qual)
3345
+ case Apply (fun, args) =>
3346
+ foldTrees(foldTree(x, fun), args)
3347
+ case TypeApply (fun, args) =>
3348
+ foldTrees(foldTree(x, fun), args)
3349
+ case Literal (const) =>
3350
+ x
3351
+ case New (tpt) =>
3352
+ foldTree(x, tpt)
3353
+ case Typed (expr, tpt) =>
3354
+ foldTree(foldTree(x, expr), tpt)
3355
+ case NamedArg (_, arg) =>
3356
+ foldTree(x, arg)
3357
+ case Assign (lhs, rhs) =>
3358
+ foldTree(foldTree(x, lhs), rhs)
3359
+ case Block (stats, expr) =>
3360
+ foldTree(foldTrees(x, stats), expr)
3361
+ case If (cond, thenp, elsep) =>
3362
+ foldTree(foldTree(foldTree(x, cond), thenp), elsep)
3363
+ case While (cond, body) =>
3364
+ foldTree(foldTree(x, cond), body)
3365
+ case Closure (meth, tpt) =>
3366
+ foldTree(x, meth)
3367
+ case Match (selector, cases) =>
3368
+ foldTrees(foldTree(x, selector), cases)
3369
+ case Return (expr, _) =>
3370
+ foldTree(x, expr)
3371
+ case Try (block, handler, finalizer) =>
3372
+ foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
3373
+ case Repeated (elems, elemtpt) =>
3374
+ foldTrees(foldTree(x, elemtpt), elems)
3375
+ case Inlined (call, bindings, expansion) =>
3376
+ foldTree(foldTrees(x, bindings), expansion)
3377
+ case vdef @ ValDef (_, tpt, rhs) =>
3378
+ val ctx = localCtx(vdef)
3379
+ given Context = ctx
3380
+ foldTrees(foldTree(x, tpt), rhs)
3381
+ case ddef @ DefDef (_, tparams, vparamss, tpt, rhs) =>
3382
+ val ctx = localCtx(ddef)
3383
+ given Context = ctx
3384
+ foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
3385
+ case tdef @ TypeDef (_, rhs) =>
3386
+ val ctx = localCtx(tdef)
3387
+ given Context = ctx
3388
+ foldTree(x, rhs)
3389
+ case cdef @ ClassDef (_, constr, parents, derived, self, body) =>
3390
+ val ctx = localCtx(cdef)
3391
+ given Context = ctx
3392
+ foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
3393
+ case Import (expr, _) =>
3394
+ foldTree(x, expr)
3395
+ case clause @ PackageClause (pid, stats) =>
3396
+ foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
3397
+ case Inferred () => x
3398
+ case TypeIdent (_) => x
3399
+ case TypeSelect (qualifier, _) => foldTree(x, qualifier)
3400
+ case Projection (qualifier, _) => foldTree(x, qualifier)
3401
+ case Singleton (ref) => foldTree(x, ref)
3402
+ case Refined (tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
3403
+ case Applied (tpt, args) => foldTrees(foldTree(x, tpt), args)
3404
+ case ByName (result) => foldTree(x, result)
3405
+ case Annotated (arg, annot) => foldTree(foldTree(x, arg), annot)
3406
+ case LambdaTypeTree (typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
3407
+ case TypeBind (_, tbt) => foldTree(x, tbt)
3408
+ case TypeBlock (typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
3409
+ case MatchTypeTree (boundopt, selector, cases) =>
3410
+ foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
3411
+ case WildcardTypeTree () => x
3412
+ case TypeBoundsTree (lo, hi) => foldTree(foldTree(x, lo), hi)
3413
+ case CaseDef (pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
3414
+ case TypeCaseDef (pat, body) => foldTree(foldTree(x, pat), body)
3415
+ case Bind (_, body) => foldTree(x, body)
3416
+ case Unapply (fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
3417
+ case Alternatives (patterns) => foldTrees(x, patterns)
3418
+ }
3419
+ }
3420
+ end TreeAccumulator
3320
3421
3321
- /** TASTy Reflect tree traverser */
3322
- trait TreeTraverser extends reflect.TreeTraverser {
3323
- val reflect : reflection.type = reflection
3324
- }
3325
3422
3326
- /** TASTy Reflect tree map */
3327
- trait TreeMap extends reflect.TreeMap {
3328
- val reflect : reflection.type = reflection
3329
- }
3423
+ /** TASTy Reflect tree traverser.
3424
+ *
3425
+ * Usage:
3426
+ * ```
3427
+ * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3428
+ * extends scala.tasty.reflect.TreeTraverser {
3429
+ * import reflect._
3430
+ * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
3431
+ * }
3432
+ * ```
3433
+ */
3434
+ trait TreeTraverser extends TreeAccumulator [Unit ]:
3435
+
3436
+ def traverseTree (tree : Tree )(using ctx : Context ): Unit = traverseTreeChildren(tree)
3437
+
3438
+ def foldTree (x : Unit , tree : Tree )(using ctx : Context ): Unit = traverseTree(tree)
3439
+
3440
+ protected def traverseTreeChildren (tree : Tree )(using ctx : Context ): Unit = foldOverTree((), tree)
3441
+
3442
+ end TreeTraverser
3443
+
3444
+ /** TASTy Reflect tree map.
3445
+ *
3446
+ * Usage:
3447
+ * ```
3448
+ * import qctx.reflect._
3449
+ * class MyTreeMap extends TreeMap {
3450
+ * override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
3451
+ * }
3452
+ * ```
3453
+ */
3454
+ trait TreeMap :
3455
+
3456
+ def transformTree (tree : Tree )(using ctx : Context ): Tree = {
3457
+ tree match {
3458
+ case tree : PackageClause =>
3459
+ PackageClause .copy(tree)(transformTerm(tree.pid).asInstanceOf [Ref ], transformTrees(tree.stats)(using tree.symbol.localContext))
3460
+ case tree : Import =>
3461
+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3462
+ case tree : Statement =>
3463
+ transformStatement(tree)
3464
+ case tree : TypeTree => transformTypeTree(tree)
3465
+ case tree : TypeBoundsTree => tree // TODO traverse tree
3466
+ case tree : WildcardTypeTree => tree // TODO traverse tree
3467
+ case tree : CaseDef =>
3468
+ transformCaseDef(tree)
3469
+ case tree : TypeCaseDef =>
3470
+ transformTypeCaseDef(tree)
3471
+ case pattern : Bind =>
3472
+ Bind .copy(pattern)(pattern.name, pattern.pattern)
3473
+ case pattern : Unapply =>
3474
+ Unapply .copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
3475
+ case pattern : Alternatives =>
3476
+ Alternatives .copy(pattern)(transformTrees(pattern.patterns))
3477
+ }
3478
+ }
3479
+
3480
+ def transformStatement (tree : Statement )(using ctx : Context ): Statement = {
3481
+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3482
+ tree match {
3483
+ case tree : Term =>
3484
+ transformTerm(tree)
3485
+ case tree : ValDef =>
3486
+ val ctx = localCtx(tree)
3487
+ given Context = ctx
3488
+ val tpt1 = transformTypeTree(tree.tpt)
3489
+ val rhs1 = tree.rhs.map(x => transformTerm(x))
3490
+ ValDef .copy(tree)(tree.name, tpt1, rhs1)
3491
+ case tree : DefDef =>
3492
+ val ctx = localCtx(tree)
3493
+ given Context = ctx
3494
+ DefDef .copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
3495
+ case tree : TypeDef =>
3496
+ val ctx = localCtx(tree)
3497
+ given Context = ctx
3498
+ TypeDef .copy(tree)(tree.name, transformTree(tree.rhs))
3499
+ case tree : ClassDef =>
3500
+ ClassDef .copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
3501
+ case tree : Import =>
3502
+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3503
+ }
3504
+ }
3505
+
3506
+ def transformTerm (tree : Term )(using ctx : Context ): Term = {
3507
+ tree match {
3508
+ case Ident (name) =>
3509
+ tree
3510
+ case Select (qualifier, name) =>
3511
+ Select .copy(tree)(transformTerm(qualifier), name)
3512
+ case This (qual) =>
3513
+ tree
3514
+ case Super (qual, mix) =>
3515
+ Super .copy(tree)(transformTerm(qual), mix)
3516
+ case Apply (fun, args) =>
3517
+ Apply .copy(tree)(transformTerm(fun), transformTerms(args))
3518
+ case TypeApply (fun, args) =>
3519
+ TypeApply .copy(tree)(transformTerm(fun), transformTypeTrees(args))
3520
+ case Literal (const) =>
3521
+ tree
3522
+ case New (tpt) =>
3523
+ New .copy(tree)(transformTypeTree(tpt))
3524
+ case Typed (expr, tpt) =>
3525
+ Typed .copy(tree)(transformTerm(expr), transformTypeTree(tpt))
3526
+ case tree : NamedArg =>
3527
+ NamedArg .copy(tree)(tree.name, transformTerm(tree.value))
3528
+ case Assign (lhs, rhs) =>
3529
+ Assign .copy(tree)(transformTerm(lhs), transformTerm(rhs))
3530
+ case Block (stats, expr) =>
3531
+ Block .copy(tree)(transformStats(stats), transformTerm(expr))
3532
+ case If (cond, thenp, elsep) =>
3533
+ If .copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
3534
+ case Closure (meth, tpt) =>
3535
+ Closure .copy(tree)(transformTerm(meth), tpt)
3536
+ case Match (selector, cases) =>
3537
+ Match .copy(tree)(transformTerm(selector), transformCaseDefs(cases))
3538
+ case Return (expr, from) =>
3539
+ Return .copy(tree)(transformTerm(expr), from)
3540
+ case While (cond, body) =>
3541
+ While .copy(tree)(transformTerm(cond), transformTerm(body))
3542
+ case Try (block, cases, finalizer) =>
3543
+ Try .copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
3544
+ case Repeated (elems, elemtpt) =>
3545
+ Repeated .copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
3546
+ case Inlined (call, bindings, expansion) =>
3547
+ Inlined .copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/* ()call.symbol.localContext)*/ )
3548
+ }
3549
+ }
3550
+
3551
+ def transformTypeTree (tree : TypeTree )(using ctx : Context ): TypeTree = tree match {
3552
+ case Inferred () => tree
3553
+ case tree : TypeIdent => tree
3554
+ case tree : TypeSelect =>
3555
+ TypeSelect .copy(tree)(tree.qualifier, tree.name)
3556
+ case tree : Projection =>
3557
+ Projection .copy(tree)(tree.qualifier, tree.name)
3558
+ case tree : Annotated =>
3559
+ Annotated .copy(tree)(tree.arg, tree.annotation)
3560
+ case tree : Singleton =>
3561
+ Singleton .copy(tree)(transformTerm(tree.ref))
3562
+ case tree : Refined =>
3563
+ Refined .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf [List [Definition ]])
3564
+ case tree : Applied =>
3565
+ Applied .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
3566
+ case tree : MatchTypeTree =>
3567
+ MatchTypeTree .copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
3568
+ case tree : ByName =>
3569
+ ByName .copy(tree)(transformTypeTree(tree.result))
3570
+ case tree : LambdaTypeTree =>
3571
+ LambdaTypeTree .copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
3572
+ case tree : TypeBind =>
3573
+ TypeBind .copy(tree)(tree.name, tree.body)
3574
+ case tree : TypeBlock =>
3575
+ TypeBlock .copy(tree)(tree.aliases, tree.tpt)
3576
+ }
3577
+
3578
+ def transformCaseDef (tree : CaseDef )(using ctx : Context ): CaseDef = {
3579
+ CaseDef .copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
3580
+ }
3581
+
3582
+ def transformTypeCaseDef (tree : TypeCaseDef )(using ctx : Context ): TypeCaseDef = {
3583
+ TypeCaseDef .copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
3584
+ }
3585
+
3586
+ def transformStats (trees : List [Statement ])(using ctx : Context ): List [Statement ] =
3587
+ trees mapConserve (transformStatement(_))
3588
+
3589
+ def transformTrees (trees : List [Tree ])(using ctx : Context ): List [Tree ] =
3590
+ trees mapConserve (transformTree(_))
3591
+
3592
+ def transformTerms (trees : List [Term ])(using ctx : Context ): List [Term ] =
3593
+ trees mapConserve (transformTerm(_))
3594
+
3595
+ def transformTypeTrees (trees : List [TypeTree ])(using ctx : Context ): List [TypeTree ] =
3596
+ trees mapConserve (transformTypeTree(_))
3597
+
3598
+ def transformCaseDefs (trees : List [CaseDef ])(using ctx : Context ): List [CaseDef ] =
3599
+ trees mapConserve (transformCaseDef(_))
3600
+
3601
+ def transformTypeCaseDefs (trees : List [TypeCaseDef ])(using ctx : Context ): List [TypeCaseDef ] =
3602
+ trees mapConserve (transformTypeCaseDef(_))
3603
+
3604
+ def transformSubTrees [Tr <: Tree ](trees : List [Tr ])(using ctx : Context ): List [Tr ] =
3605
+ transformTrees(trees).asInstanceOf [List [Tr ]]
3606
+
3607
+ end TreeMap
3330
3608
3331
3609
// TODO: extract from Reflection
3332
3610
0 commit comments