Skip to content

Commit 2dcbb85

Browse files
authored
Merge pull request #421 from dcsobral/bug/257
Make RuleTransformer fully recursive [#257]
2 parents 13a595e + c17cd10 commit 2dcbb85

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* __ *\
2+
** ________ ___ / / ___ Scala API **
3+
** / __/ __// _ | / / / _ | (c) 2002-2020, LAMP/EPFL **
4+
** __\ \/ /__/ __ |/ /__/ __ | (c) 2011-2020, Lightbend, Inc. **
5+
** /____/\___/_/ |_/____/_/ | | http://scala-lang.org/ **
6+
** |/ **
7+
\* */
8+
9+
package scala
10+
package xml
11+
package transform
12+
13+
import scala.collection.Seq
14+
15+
class NestingTransformer(rule: RewriteRule) extends BasicTransformer {
16+
override def transform(n: Node): Seq[Node] = {
17+
rule.transform(super.transform(n))
18+
}
19+
}

shared/src/main/scala/scala/xml/transform/RuleTransformer.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ package transform
1313
import scala.collection.Seq
1414

1515
class RuleTransformer(rules: RewriteRule*) extends BasicTransformer {
16-
override def transform(n: Node): Seq[Node] =
17-
rules.foldLeft(super.transform(n)) { (res, rule) => rule transform res }
16+
private val transformers = rules.map(new NestingTransformer(_))
17+
override def transform(n: Node): Seq[Node] = {
18+
if (transformers.isEmpty) n
19+
else transformers.tail.foldLeft(transformers.head.transform(n)) { (res, transformer) => transformer.transform(res) }
20+
}
1821
}

shared/src/test/scala-2.x/scala/xml/TransformersTest.scala

+16-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class TransformersTest {
6060
@Test
6161
def preserveReferentialComplexityInLinearComplexity = { // SI-4528
6262
var i = 0
63-
63+
6464
val xmlNode = <a><b><c><h1>Hello Example</h1></c></b></a>
6565

6666
new RuleTransformer(new RewriteRule {
@@ -77,4 +77,19 @@ class TransformersTest {
7777

7878
assertEquals(1, i)
7979
}
80+
81+
@Test
82+
def appliesRulesRecursivelyOnPreviousChanges = { // #257
83+
def add(outer: Elem, inner: Node) = new RewriteRule {
84+
override def transform(n: Node): Seq[Node] = n match {
85+
case e: Elem if e.label == outer.label => e.copy(child = e.child ++ inner)
86+
case other => other
87+
}
88+
}
89+
90+
def transformer = new RuleTransformer(add(<element/>, <new/>), add(<new/>, <thing/>))
91+
92+
assertEquals(<element><new><thing/></new></element>, transformer(<element/>))
93+
}
8094
}
95+

0 commit comments

Comments
 (0)