@@ -3312,20 +3312,298 @@ trait Reflection { reflection =>
3312
3312
// UTILS //
3313
3313
// /////////////
3314
3314
3315
- /** TASTy Reflect tree accumulator */
3316
- trait TreeAccumulator [X ] extends reflect.TreeAccumulator [X ] {
3317
- val reflect : reflection.type = reflection
3318
- }
3315
+ /** TASTy Reflect tree accumulator.
3316
+ *
3317
+ * Usage:
3318
+ * ```
3319
+ * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3320
+ * extends scala.tasty.reflect.TreeAccumulator[X] {
3321
+ * import reflect._
3322
+ * def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
3323
+ * }
3324
+ * ```
3325
+ */
3326
+ trait TreeAccumulator [X ]:
3327
+
3328
+ // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
3329
+ def foldTree (x : X , tree : Tree )(using ctx : Context ): X
3330
+
3331
+ def foldTrees (x : X , trees : Iterable [Tree ])(using ctx : Context ): X = trees.foldLeft(x)(foldTree)
3332
+
3333
+ def foldOverTree (x : X , tree : Tree )(using ctx : Context ): X = {
3334
+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3335
+ tree match {
3336
+ case Ident (_) =>
3337
+ x
3338
+ case Select (qualifier, _) =>
3339
+ foldTree(x, qualifier)
3340
+ case This (qual) =>
3341
+ x
3342
+ case Super (qual, _) =>
3343
+ foldTree(x, qual)
3344
+ case Apply (fun, args) =>
3345
+ foldTrees(foldTree(x, fun), args)
3346
+ case TypeApply (fun, args) =>
3347
+ foldTrees(foldTree(x, fun), args)
3348
+ case Literal (const) =>
3349
+ x
3350
+ case New (tpt) =>
3351
+ foldTree(x, tpt)
3352
+ case Typed (expr, tpt) =>
3353
+ foldTree(foldTree(x, expr), tpt)
3354
+ case NamedArg (_, arg) =>
3355
+ foldTree(x, arg)
3356
+ case Assign (lhs, rhs) =>
3357
+ foldTree(foldTree(x, lhs), rhs)
3358
+ case Block (stats, expr) =>
3359
+ foldTree(foldTrees(x, stats), expr)
3360
+ case If (cond, thenp, elsep) =>
3361
+ foldTree(foldTree(foldTree(x, cond), thenp), elsep)
3362
+ case While (cond, body) =>
3363
+ foldTree(foldTree(x, cond), body)
3364
+ case Closure (meth, tpt) =>
3365
+ foldTree(x, meth)
3366
+ case Match (selector, cases) =>
3367
+ foldTrees(foldTree(x, selector), cases)
3368
+ case Return (expr, _) =>
3369
+ foldTree(x, expr)
3370
+ case Try (block, handler, finalizer) =>
3371
+ foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
3372
+ case Repeated (elems, elemtpt) =>
3373
+ foldTrees(foldTree(x, elemtpt), elems)
3374
+ case Inlined (call, bindings, expansion) =>
3375
+ foldTree(foldTrees(x, bindings), expansion)
3376
+ case vdef @ ValDef (_, tpt, rhs) =>
3377
+ val ctx = localCtx(vdef)
3378
+ given Context = ctx
3379
+ foldTrees(foldTree(x, tpt), rhs)
3380
+ case ddef @ DefDef (_, tparams, vparamss, tpt, rhs) =>
3381
+ val ctx = localCtx(ddef)
3382
+ given Context = ctx
3383
+ foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
3384
+ case tdef @ TypeDef (_, rhs) =>
3385
+ val ctx = localCtx(tdef)
3386
+ given Context = ctx
3387
+ foldTree(x, rhs)
3388
+ case cdef @ ClassDef (_, constr, parents, derived, self, body) =>
3389
+ val ctx = localCtx(cdef)
3390
+ given Context = ctx
3391
+ foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
3392
+ case Import (expr, _) =>
3393
+ foldTree(x, expr)
3394
+ case clause @ PackageClause (pid, stats) =>
3395
+ foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
3396
+ case Inferred () => x
3397
+ case TypeIdent (_) => x
3398
+ case TypeSelect (qualifier, _) => foldTree(x, qualifier)
3399
+ case Projection (qualifier, _) => foldTree(x, qualifier)
3400
+ case Singleton (ref) => foldTree(x, ref)
3401
+ case Refined (tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
3402
+ case Applied (tpt, args) => foldTrees(foldTree(x, tpt), args)
3403
+ case ByName (result) => foldTree(x, result)
3404
+ case Annotated (arg, annot) => foldTree(foldTree(x, arg), annot)
3405
+ case LambdaTypeTree (typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
3406
+ case TypeBind (_, tbt) => foldTree(x, tbt)
3407
+ case TypeBlock (typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
3408
+ case MatchTypeTree (boundopt, selector, cases) =>
3409
+ foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
3410
+ case WildcardTypeTree () => x
3411
+ case TypeBoundsTree (lo, hi) => foldTree(foldTree(x, lo), hi)
3412
+ case CaseDef (pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
3413
+ case TypeCaseDef (pat, body) => foldTree(foldTree(x, pat), body)
3414
+ case Bind (_, body) => foldTree(x, body)
3415
+ case Unapply (fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
3416
+ case Alternatives (patterns) => foldTrees(x, patterns)
3417
+ }
3418
+ }
3419
+ end TreeAccumulator
3319
3420
3320
- /** TASTy Reflect tree traverser */
3321
- trait TreeTraverser extends reflect.TreeTraverser {
3322
- val reflect : reflection.type = reflection
3323
- }
3324
3421
3325
- /** TASTy Reflect tree map */
3326
- trait TreeMap extends reflect.TreeMap {
3327
- val reflect : reflection.type = reflection
3328
- }
3422
+ /** TASTy Reflect tree traverser.
3423
+ *
3424
+ * Usage:
3425
+ * ```
3426
+ * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3427
+ * extends scala.tasty.reflect.TreeTraverser {
3428
+ * import reflect._
3429
+ * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
3430
+ * }
3431
+ * ```
3432
+ */
3433
+ trait TreeTraverser extends TreeAccumulator [Unit ]:
3434
+
3435
+ def traverseTree (tree : Tree )(using ctx : Context ): Unit = traverseTreeChildren(tree)
3436
+
3437
+ def foldTree (x : Unit , tree : Tree )(using ctx : Context ): Unit = traverseTree(tree)
3438
+
3439
+ protected def traverseTreeChildren (tree : Tree )(using ctx : Context ): Unit = foldOverTree((), tree)
3440
+
3441
+ end TreeTraverser
3442
+
3443
+ /** TASTy Reflect tree map.
3444
+ *
3445
+ * Usage:
3446
+ * ```
3447
+ * import qctx.reflect._
3448
+ * class MyTreeMap extends TreeMap {
3449
+ * override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
3450
+ * }
3451
+ * ```
3452
+ */
3453
+ trait TreeMap :
3454
+
3455
+ def transformTree (tree : Tree )(using ctx : Context ): Tree = {
3456
+ tree match {
3457
+ case tree : PackageClause =>
3458
+ PackageClause .copy(tree)(transformTerm(tree.pid).asInstanceOf [Ref ], transformTrees(tree.stats)(using tree.symbol.localContext))
3459
+ case tree : Import =>
3460
+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3461
+ case tree : Statement =>
3462
+ transformStatement(tree)
3463
+ case tree : TypeTree => transformTypeTree(tree)
3464
+ case tree : TypeBoundsTree => tree // TODO traverse tree
3465
+ case tree : WildcardTypeTree => tree // TODO traverse tree
3466
+ case tree : CaseDef =>
3467
+ transformCaseDef(tree)
3468
+ case tree : TypeCaseDef =>
3469
+ transformTypeCaseDef(tree)
3470
+ case pattern : Bind =>
3471
+ Bind .copy(pattern)(pattern.name, pattern.pattern)
3472
+ case pattern : Unapply =>
3473
+ Unapply .copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
3474
+ case pattern : Alternatives =>
3475
+ Alternatives .copy(pattern)(transformTrees(pattern.patterns))
3476
+ }
3477
+ }
3478
+
3479
+ def transformStatement (tree : Statement )(using ctx : Context ): Statement = {
3480
+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3481
+ tree match {
3482
+ case tree : Term =>
3483
+ transformTerm(tree)
3484
+ case tree : ValDef =>
3485
+ val ctx = localCtx(tree)
3486
+ given Context = ctx
3487
+ val tpt1 = transformTypeTree(tree.tpt)
3488
+ val rhs1 = tree.rhs.map(x => transformTerm(x))
3489
+ ValDef .copy(tree)(tree.name, tpt1, rhs1)
3490
+ case tree : DefDef =>
3491
+ val ctx = localCtx(tree)
3492
+ given Context = ctx
3493
+ DefDef .copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
3494
+ case tree : TypeDef =>
3495
+ val ctx = localCtx(tree)
3496
+ given Context = ctx
3497
+ TypeDef .copy(tree)(tree.name, transformTree(tree.rhs))
3498
+ case tree : ClassDef =>
3499
+ ClassDef .copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
3500
+ case tree : Import =>
3501
+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3502
+ }
3503
+ }
3504
+
3505
+ def transformTerm (tree : Term )(using ctx : Context ): Term = {
3506
+ tree match {
3507
+ case Ident (name) =>
3508
+ tree
3509
+ case Select (qualifier, name) =>
3510
+ Select .copy(tree)(transformTerm(qualifier), name)
3511
+ case This (qual) =>
3512
+ tree
3513
+ case Super (qual, mix) =>
3514
+ Super .copy(tree)(transformTerm(qual), mix)
3515
+ case Apply (fun, args) =>
3516
+ Apply .copy(tree)(transformTerm(fun), transformTerms(args))
3517
+ case TypeApply (fun, args) =>
3518
+ TypeApply .copy(tree)(transformTerm(fun), transformTypeTrees(args))
3519
+ case Literal (const) =>
3520
+ tree
3521
+ case New (tpt) =>
3522
+ New .copy(tree)(transformTypeTree(tpt))
3523
+ case Typed (expr, tpt) =>
3524
+ Typed .copy(tree)(transformTerm(expr), transformTypeTree(tpt))
3525
+ case tree : NamedArg =>
3526
+ NamedArg .copy(tree)(tree.name, transformTerm(tree.value))
3527
+ case Assign (lhs, rhs) =>
3528
+ Assign .copy(tree)(transformTerm(lhs), transformTerm(rhs))
3529
+ case Block (stats, expr) =>
3530
+ Block .copy(tree)(transformStats(stats), transformTerm(expr))
3531
+ case If (cond, thenp, elsep) =>
3532
+ If .copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
3533
+ case Closure (meth, tpt) =>
3534
+ Closure .copy(tree)(transformTerm(meth), tpt)
3535
+ case Match (selector, cases) =>
3536
+ Match .copy(tree)(transformTerm(selector), transformCaseDefs(cases))
3537
+ case Return (expr, from) =>
3538
+ Return .copy(tree)(transformTerm(expr), from)
3539
+ case While (cond, body) =>
3540
+ While .copy(tree)(transformTerm(cond), transformTerm(body))
3541
+ case Try (block, cases, finalizer) =>
3542
+ Try .copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
3543
+ case Repeated (elems, elemtpt) =>
3544
+ Repeated .copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
3545
+ case Inlined (call, bindings, expansion) =>
3546
+ Inlined .copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/* ()call.symbol.localContext)*/ )
3547
+ }
3548
+ }
3549
+
3550
+ def transformTypeTree (tree : TypeTree )(using ctx : Context ): TypeTree = tree match {
3551
+ case Inferred () => tree
3552
+ case tree : TypeIdent => tree
3553
+ case tree : TypeSelect =>
3554
+ TypeSelect .copy(tree)(tree.qualifier, tree.name)
3555
+ case tree : Projection =>
3556
+ Projection .copy(tree)(tree.qualifier, tree.name)
3557
+ case tree : Annotated =>
3558
+ Annotated .copy(tree)(tree.arg, tree.annotation)
3559
+ case tree : Singleton =>
3560
+ Singleton .copy(tree)(transformTerm(tree.ref))
3561
+ case tree : Refined =>
3562
+ Refined .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf [List [Definition ]])
3563
+ case tree : Applied =>
3564
+ Applied .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
3565
+ case tree : MatchTypeTree =>
3566
+ MatchTypeTree .copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
3567
+ case tree : ByName =>
3568
+ ByName .copy(tree)(transformTypeTree(tree.result))
3569
+ case tree : LambdaTypeTree =>
3570
+ LambdaTypeTree .copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
3571
+ case tree : TypeBind =>
3572
+ TypeBind .copy(tree)(tree.name, tree.body)
3573
+ case tree : TypeBlock =>
3574
+ TypeBlock .copy(tree)(tree.aliases, tree.tpt)
3575
+ }
3576
+
3577
+ def transformCaseDef (tree : CaseDef )(using ctx : Context ): CaseDef = {
3578
+ CaseDef .copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
3579
+ }
3580
+
3581
+ def transformTypeCaseDef (tree : TypeCaseDef )(using ctx : Context ): TypeCaseDef = {
3582
+ TypeCaseDef .copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
3583
+ }
3584
+
3585
+ def transformStats (trees : List [Statement ])(using ctx : Context ): List [Statement ] =
3586
+ trees mapConserve (transformStatement(_))
3587
+
3588
+ def transformTrees (trees : List [Tree ])(using ctx : Context ): List [Tree ] =
3589
+ trees mapConserve (transformTree(_))
3590
+
3591
+ def transformTerms (trees : List [Term ])(using ctx : Context ): List [Term ] =
3592
+ trees mapConserve (transformTerm(_))
3593
+
3594
+ def transformTypeTrees (trees : List [TypeTree ])(using ctx : Context ): List [TypeTree ] =
3595
+ trees mapConserve (transformTypeTree(_))
3596
+
3597
+ def transformCaseDefs (trees : List [CaseDef ])(using ctx : Context ): List [CaseDef ] =
3598
+ trees mapConserve (transformCaseDef(_))
3599
+
3600
+ def transformTypeCaseDefs (trees : List [TypeCaseDef ])(using ctx : Context ): List [TypeCaseDef ] =
3601
+ trees mapConserve (transformTypeCaseDef(_))
3602
+
3603
+ def transformSubTrees [Tr <: Tree ](trees : List [Tr ])(using ctx : Context ): List [Tr ] =
3604
+ transformTrees(trees).asInstanceOf [List [Tr ]]
3605
+
3606
+ end TreeMap
3329
3607
3330
3608
// TODO: extract from Reflection
3331
3609
0 commit comments