diff --git a/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala b/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala index aa03bacdf8..64b8154e47 100644 --- a/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala +++ b/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala @@ -1,17 +1,38 @@ package io.getquill.norm import io.getquill.ast._ +import io.getquill.util.Messages + import scala.collection.immutable.{ Map => IMap } case class BetaReduction(replacements: Replacements) extends StatelessTransformer { + object TriviallyTrueCondition { + def unapply(ast: Ast): Boolean = + ast match { + case TriviallyCheckable(Constant(true)) => true + case _ => false + } + } + + object TriviallyFalseCondition { + def unapply(ast: Ast): Boolean = + ast match { + case TriviallyCheckable(Constant(false)) => true + case _ => false + } + } + override def apply(ast: Ast): Ast = ast match { case ast if replacements.contains(ast) => BetaReduction(replacements - ast - replacements(ast))(replacements(ast)) + case If(TriviallyTrueCondition(), thenClause, _) if (Messages.reduceTrivials) => apply(thenClause) + case If(TriviallyFalseCondition(), _, elseClause) if (Messages.reduceTrivials) => apply(elseClause) + case Property(Tuple(values), name) => apply(values(name.drop(1).toInt - 1)) diff --git a/quill-core/src/main/scala/io/getquill/norm/TriviallyCheckable.scala b/quill-core/src/main/scala/io/getquill/norm/TriviallyCheckable.scala new file mode 100644 index 0000000000..9060019dfc --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/norm/TriviallyCheckable.scala @@ -0,0 +1,25 @@ +package io.getquill.norm + +import io.getquill.ast.{ +!=+, +&&+, +==+, +||+, Ast, Constant } + +object TriviallyCheckable { + + def unapply(ast: Ast): Option[Ast] = ast match { + case c @ Constant(_) => Some(c) + + case TriviallyCheckable(Constant(true)) +&&+ TriviallyCheckable(Constant(true)) => Some(Constant(true)) + case TriviallyCheckable(Constant(false)) +&&+ _ => Some(Constant(false)) + case _ +&&+ TriviallyCheckable(Constant(false)) => Some(Constant(false)) + + case TriviallyCheckable(Constant(true)) +||+ _ => Some(Constant(true)) + case _ +||+ TriviallyCheckable(Constant(true)) => Some(Constant(true)) + case TriviallyCheckable(Constant(false)) +||+ TriviallyCheckable(Constant(false)) => Some(Constant(false)) + + case TriviallyCheckable(one) +==+ TriviallyCheckable(two) if (one == two) => Some(Constant(true)) + case TriviallyCheckable(one) +!=+ TriviallyCheckable(two) if (one != two) => Some(Constant(true)) + case TriviallyCheckable(one) +==+ TriviallyCheckable(two) if (one != two) => Some(Constant(false)) + case TriviallyCheckable(one) +!=+ TriviallyCheckable(two) if (one == two) => Some(Constant(false)) + + case _ => None + } +} diff --git a/quill-core/src/main/scala/io/getquill/util/Messages.scala b/quill-core/src/main/scala/io/getquill/util/Messages.scala index e0d71517dc..0a48f6a1f6 100644 --- a/quill-core/src/main/scala/io/getquill/util/Messages.scala +++ b/quill-core/src/main/scala/io/getquill/util/Messages.scala @@ -11,6 +11,8 @@ object Messages { private def variable(propName: String, envName: String, default: String) = Option(System.getProperty(propName)).orElse(sys.env.get(envName)).getOrElse(default) + val reduceTrivials = variable("quill.transform.reducetrivial", "quill_transform_reducetrivial", "true").toBoolean + private[util] val prettyPrint = variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean private[util] val debugEnabled = variable("quill.macro.log", "quill_macro_log", "true").toBoolean private[util] val traceEnabled = variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean diff --git a/quill-core/src/test/scala/io/getquill/norm/ConditionalReductionSpec.scala b/quill-core/src/test/scala/io/getquill/norm/ConditionalReductionSpec.scala new file mode 100644 index 0000000000..c94d033b8f --- /dev/null +++ b/quill-core/src/test/scala/io/getquill/norm/ConditionalReductionSpec.scala @@ -0,0 +1,64 @@ +package io.getquill.norm + +import io.getquill.Spec +import io.getquill.testContext._ + +class ConditionalReductionSpec extends Spec { + + "trivial conditionals must" - { + val q = quote { + (value: String) => + if (value == "foo") qr1.filter(r => r.i == 1) + else qr1.filter(r => r.s == "blah") + } + + "reduce to 'then' clause when true" in { + Normalize(quote(q("foo")).ast) mustEqual (quote(qr1.filter(r => r.i == 1)).ast) + } + + "reduce to 'else' clause when false" in { + Normalize(quote(q("bar")).ast) mustEqual (quote(qr1.filter(r => r.s == "blah")).ast) + } + } + + "compound condition must" - { + val q = quote { + (value: String) => + if (value == "foo" || value == "bar") qr1.filter(r => r.i == 1) + else qr1.filter(r => r.s == "blah") + } + + "reduce to 'then' clause when true - first" in { + Normalize(quote(q("foo")).ast) mustEqual (quote(qr1.filter(r => r.i == 1)).ast) + } + + "reduce to 'then' clause when true - second" in { + Normalize(quote(q("bar")).ast) mustEqual (quote(qr1.filter(r => r.i == 1)).ast) + } + + "reduce to 'else' clause when false" in { + Normalize(quote(q("baz")).ast) mustEqual (quote(qr1.filter(r => r.s == "blah")).ast) + } + } + + "recursive compound condition must" - { + val q = quote { + (value: String) => + if (value == "foo") qr1.filter(r => r.i == 1) + else if (value == "bar") qr1.filter(r => r.i == 2) + else qr1.filter(r => r.s == "blah") + } + + "reduce to 'then' clause when true - first" in { + Normalize(quote(q("foo")).ast) mustEqual (quote(qr1.filter(r => r.i == 1)).ast) + } + + "reduce to 'then' clause when true - second" in { + Normalize(quote(q("bar")).ast) mustEqual (quote(qr1.filter(r => r.i == 2)).ast) + } + + "reduce to 'else' clause when false" in { + Normalize(quote(q("baz")).ast) mustEqual (quote(qr1.filter(r => r.s == "blah")).ast) + } + } +} diff --git a/quill-core/src/test/scala/io/getquill/norm/TriviallyCheckableSpec.scala b/quill-core/src/test/scala/io/getquill/norm/TriviallyCheckableSpec.scala new file mode 100644 index 0000000000..811d281a20 --- /dev/null +++ b/quill-core/src/test/scala/io/getquill/norm/TriviallyCheckableSpec.scala @@ -0,0 +1,66 @@ +package io.getquill.norm + +import io.getquill.Spec +import io.getquill.ast.Implicits._ +import io.getquill.ast.{ Ast, Constant, Ident } + +class TriviallyCheckableSpec extends Spec { + + def unapply(ast: Ast) = TriviallyCheckable.unapply(ast) + val other = Ident("something") + + val True = Constant(true) + val False = Constant(false) + + "constants must trivially reduce correctly" - { + "const == const => true" in { unapply(Constant(123) +==+ Constant(123)) mustEqual Some(True) } + "const == const => false" in { unapply(Constant(123) +==+ Constant(456)) mustEqual Some(False) } + "const != const => true" in { unapply(Constant(123) +!=+ Constant(123)) mustEqual Some(False) } + "const != const => false" in { unapply(Constant(123) +!=+ Constant(456)) mustEqual Some(True) } + + "const == other => none" in { unapply(Constant(123) +==+ other) mustEqual None } + "const != other => none" in { unapply(Constant(123) +!=+ other) mustEqual None } + "other == const => none" in { unapply(other +==+ Constant(123)) mustEqual None } + "other != const => none" in { unapply(other +!=+ Constant(123)) mustEqual None } + () + } + + "expressions must trivially reduce correctly" - { + "true || true => true" in { unapply(True +||+ True) mustEqual Some(True) } + "true || false => false" in { unapply(True +||+ False) mustEqual Some(True) } + "const || const => false" in { unapply(False +||+ True) mustEqual Some(True) } + "true && true => true" in { unapply(True +&&+ True) mustEqual Some(True) } + "true && false => false" in { unapply(True +&&+ False) mustEqual Some(False) } + "false && true => false" in { unapply(False +&&+ True) mustEqual Some(False) } + + "false && other => false" in { unapply(False +&&+ other) mustEqual Some(False) } + "other && false => false" in { unapply(other +&&+ False) mustEqual Some(False) } + "true && const => none" in { unapply(True +&&+ other) mustEqual None } + "other && true => none" in { unapply(other +&&+ True) mustEqual None } + "other && other => none" in { unapply(other +&&+ other) mustEqual None } + + "true || other => true" in { unapply(True +||+ other) mustEqual Some(True) } + "other || true => true" in { unapply(other +||+ True) mustEqual Some(True) } + "false || other => none" in { unapply(False +||+ other) mustEqual None } + "other || false => none" in { unapply(other +||+ False) mustEqual None } + "other || other => none" in { unapply(other +||+ other) mustEqual None } + () + } + + "compound expressions must trivially reduce correctly" - { + val compoundTrue = Constant(123) +==+ Constant(123) + val compoundFalse = Constant(123) +==+ Constant(456) + val compoundUnknown = Constant(123) +==+ other + + "true && true => true" in { unapply(compoundTrue +&&+ compoundTrue) mustEqual Some(True) } + "true || true => true" in { unapply(compoundTrue +||+ compoundTrue) mustEqual Some(True) } + "true && false => false" in { unapply(compoundTrue +&&+ compoundFalse) mustEqual Some(False) } + "true || false => true" in { unapply(compoundTrue +||+ compoundFalse) mustEqual Some(True) } + + "true && none => none" in { unapply(compoundTrue +&&+ compoundUnknown) mustEqual None } + "true || none => true" in { unapply(compoundTrue +||+ compoundUnknown) mustEqual Some(True) } + "false && none => false" in { unapply(compoundFalse +&&+ compoundUnknown) mustEqual Some(False) } + "false || none => true" in { unapply(compoundFalse +||+ compoundUnknown) mustEqual None } + () + } +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/SqlStaticReductionSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/SqlStaticReductionSpec.scala new file mode 100644 index 0000000000..9535ef0093 --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/SqlStaticReductionSpec.scala @@ -0,0 +1,25 @@ +package io.getquill.context.sql + +import io.getquill.Spec +import io.getquill.context.sql.testContext._ + +class SqlStaticReductionSpec extends Spec { + + "trivial conditionals must" - { + val q = quote { + (value: String) => + if (value == "foo") qr1.filter(r => r.i == 1) + else qr1.filter(r => r.s == "blah") + } + + "reduce to 'the' clause when true" in { + testContext.run(q("foo")).string mustEqual testContext.run(qr1.filter(r => r.i == 1)).string + } + + "reduce to 'else' clause when false" in { + testContext.run(q("bar")).string mustEqual testContext.run(qr1.filter(r => r.s == "blah")).string + } + + // TODO Test some conditionals + } +}