Skip to content

Commit ee1fd92

Browse files
dbatomiccloud-fan
authored andcommitted
[SPARK-46331][SQL] Removing CodegenFallback from subset of DateTime expressions and version() expression
### What changes were proposed in this pull request? This PR moves us a bit closer to removing CodegenFallback class and instead of it relying on RuntimeReplaceable with StaticInvoke. In this PR there are following changes: - Doing StaticInvoke + RuntimeReplaceable against spark version expression. - Adding Unevaluable trait for DateTime expressions. These expressions need to be replaced during analysis anyhow so we explicitly forbid eval from being called. ### Why are the changes needed? Direction is to get away from CodegenFallback. This PR moves us closer to that destination. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Running existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#44261 from dbatomic/codegenfallback_removal. Lead-authored-by: Aleksandar Tomic <[email protected]> Co-authored-by: Aleksandar Tomic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 8fa794b commit ee1fd92

File tree

12 files changed

+96
-97
lines changed

12 files changed

+96
-97
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import org.apache.spark.SparkBuildInfo;
2021
import org.apache.spark.sql.errors.QueryExecutionErrors;
2122
import org.apache.spark.unsafe.types.UTF8String;
23+
import org.apache.spark.util.VersionUtils;
2224

2325
import javax.crypto.Cipher;
2426
import javax.crypto.spec.GCMParameterSpec;
@@ -143,6 +145,17 @@ public static byte[] aesDecrypt(byte[] input,
143145
);
144146
}
145147

148+
/**
149+
* Function to return the Spark version.
150+
* @return
151+
* Space separated version and revision.
152+
*/
153+
public static UTF8String getSparkVersion() {
154+
String shortVersion = VersionUtils.shortVersion(SparkBuildInfo.spark_version());
155+
String revision = SparkBuildInfo.spark_revision();
156+
return UTF8String.fromString(shortVersion + " " + revision);
157+
}
158+
146159
private static SecretKeySpec getSecretKeySpec(byte[] key) {
147160
return switch (key.length) {
148161
case 16, 24, 32 -> new SecretKeySpec(key, 0, key.length, "AES");

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,16 @@ object ResolveInlineTables extends Rule[LogicalPlan]
6868
/**
6969
* Validates that all inline table data are valid expressions that can be evaluated
7070
* (in this they must be foldable).
71-
*
71+
* Note that nondeterministic expressions are not supported since they are not foldable.
72+
* Exception are CURRENT_LIKE expressions, which are replaced by a literal in later stages.
7273
* This is package visible for unit testing.
7374
*/
7475
private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = {
7576
table.rows.foreach { row =>
7677
row.foreach { e =>
77-
// Note that nondeterministic expressions are not supported since they are not foldable.
78-
// Only exception are CURRENT_LIKE expressions, which are replaced by a literal
79-
// In later stages.
80-
if ((!e.resolved && !e.containsPattern(CURRENT_LIKE))
81-
|| !trimAliases(prepareForEval(e)).foldable) {
78+
if (e.containsPattern(CURRENT_LIKE)) {
79+
// Do nothing.
80+
} else if (!e.resolved || !trimAliases(prepareForEval(e)).foldable) {
8281
e.failAnalysis(
8382
errorClass = "INVALID_INLINE_TABLE.CANNOT_EVALUATE_EXPRESSION_IN_INLINE_TABLE",
8483
messageParameters = Map("expr" -> toSQLExpr(e)))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.SparkThrowable
2121
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
22+
import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime
2223
import org.apache.spark.sql.catalyst.plans.logical._
2324
import org.apache.spark.sql.catalyst.rules.Rule
2425
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
2526
import org.apache.spark.sql.errors.QueryCompilationErrors
27+
import org.apache.spark.sql.internal.SQLConf
2628
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
2729

2830
/**
@@ -34,7 +36,15 @@ import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
3436
*/
3537
object ResolveTableSpec extends Rule[LogicalPlan] {
3638
override def apply(plan: LogicalPlan): LogicalPlan = {
37-
plan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) {
39+
val preparedPlan = if (SQLConf.get.legacyEvalCurrentTime && plan.containsPattern(COMMAND)) {
40+
AnalysisHelper.allowInvokingTransformsInAnalyzer {
41+
ComputeCurrentTime(ResolveTimeZone(plan))
42+
}
43+
} else {
44+
plan
45+
}
46+
47+
preparedPlan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) {
3848
case t: CreateTable =>
3949
resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s))
4050
case t: CreateTableAsSelect =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,14 @@ case class CurrentTimeZone() extends LeafExpression with Unevaluable {
134134
since = "1.5.0")
135135
// scalastyle:on line.size.limit
136136
case class CurrentDate(timeZoneId: Option[String] = None)
137-
extends LeafExpression with TimeZoneAwareExpression with CodegenFallback {
138-
137+
extends LeafExpression with TimeZoneAwareExpression with Unevaluable {
139138
def this() = this(None)
140-
141-
override def foldable: Boolean = true
142139
override def nullable: Boolean = false
143-
144140
override def dataType: DataType = DateType
145-
146141
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)
147-
148142
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
149143
copy(timeZoneId = Option(timeZoneId))
150144

151-
override def eval(input: InternalRow): Any = currentDate(zoneId)
152-
153145
override def prettyName: String = "current_date"
154146
}
155147

@@ -177,11 +169,9 @@ object CurDateExpressionBuilder extends ExpressionBuilder {
177169
}
178170
}
179171

180-
abstract class CurrentTimestampLike() extends LeafExpression with CodegenFallback {
181-
override def foldable: Boolean = true
172+
abstract class CurrentTimestampLike() extends LeafExpression with Unevaluable {
182173
override def nullable: Boolean = false
183174
override def dataType: DataType = TimestampType
184-
override def eval(input: InternalRow): Any = currentTimestamp()
185175
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
186176
}
187177

@@ -245,22 +235,13 @@ case class Now() extends CurrentTimestampLike {
245235
group = "datetime_funcs",
246236
since = "3.4.0")
247237
case class LocalTimestamp(timeZoneId: Option[String] = None) extends LeafExpression
248-
with TimeZoneAwareExpression with CodegenFallback {
249-
238+
with TimeZoneAwareExpression with Unevaluable {
250239
def this() = this(None)
251-
252-
override def foldable: Boolean = true
253240
override def nullable: Boolean = false
254-
255241
override def dataType: DataType = TimestampNTZType
256-
257242
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)
258-
259243
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
260244
copy(timeZoneId = Option(timeZoneId))
261-
262-
override def eval(input: InternalRow): Any = localDateTimeToMicros(LocalDateTime.now(zoneId))
263-
264245
override def prettyName: String = "localtimestamp"
265246
}
266247

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
2120
import org.apache.spark.sql.catalyst.InternalRow
2221
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, UnresolvedSeed}
2322
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -288,14 +287,14 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non
288287
since = "3.0.0",
289288
group = "misc_funcs")
290289
// scalastyle:on line.size.limit
291-
case class SparkVersion() extends LeafExpression with CodegenFallback {
292-
override def nullable: Boolean = false
293-
override def foldable: Boolean = true
294-
override def dataType: DataType = StringType
290+
case class SparkVersion() extends LeafExpression with RuntimeReplaceable {
295291
override def prettyName: String = "version"
296-
override def eval(input: InternalRow): Any = {
297-
UTF8String.fromString(SPARK_VERSION_SHORT + " " + SPARK_REVISION)
298-
}
292+
293+
override lazy val replacement: Expression = StaticInvoke(
294+
classOf[ExpressionImplUtils],
295+
StringType,
296+
"getSparkVersion",
297+
returnNullable = false)
299298
}
300299

301300
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,6 @@ object DateTimeUtils extends SparkDateTimeUtils {
516516
convertTz(micros, getZoneId(timeZone), ZoneOffset.UTC)
517517
}
518518

519-
/**
520-
* Obtains the current instant as microseconds since the epoch at the UTC time zone.
521-
*/
522-
def currentTimestamp(): Long = instantToMicros(Instant.now())
523-
524519
/**
525520
* Obtains the current date as days since the epoch in the specified time-zone.
526521
*/
@@ -572,7 +567,7 @@ object DateTimeUtils extends SparkDateTimeUtils {
572567
def convertSpecialTimestamp(input: String, zoneId: ZoneId): Option[Long] = {
573568
extractSpecialValue(input.trim).flatMap {
574569
case "epoch" => Some(0)
575-
case "now" => Some(currentTimestamp())
570+
case "now" => Some(instantToMicros(Instant.now()))
576571
case "today" => Some(instantToMicros(today(zoneId).toInstant))
577572
case "tomorrow" => Some(instantToMicros(today(zoneId).plusDays(1).toInstant))
578573
case "yesterday" => Some(instantToMicros(today(zoneId).minusDays(1).toInstant))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4612,6 +4612,18 @@ object SQLConf {
46124612
.booleanConf
46134613
.createWithDefault(false)
46144614

4615+
val LEGACY_EVAL_CURRENT_TIME = buildConf("spark.sql.legacy.earlyEvalCurrentTime")
4616+
.internal()
4617+
.doc("When set to true, evaluation and constant folding will happen for now() and " +
4618+
"current_timestamp() expressions before finish analysis phase. " +
4619+
"This flag will allow a bit more liberal syntax but it will sacrifice correctness - " +
4620+
"Results of now() and current_timestamp() can be different for different operations " +
4621+
"in a single query."
4622+
)
4623+
.version("4.0.0")
4624+
.booleanConf
4625+
.createWithDefault(false)
4626+
46154627
/**
46164628
* Holds information about keys that have been deprecated.
46174629
*
@@ -5516,6 +5528,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
55165528

55175529
def legacyJavaCharsets: Boolean = getConf(SQLConf.LEGACY_JAVA_CHARSETS)
55185530

5531+
def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)
5532+
55195533
/** ********************** SQLConf functionality methods ************ */
55205534

55215535
/** Set Spark SQL configuration properties. */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import scala.language.postfixOps
2828
import scala.reflect.ClassTag
2929
import scala.util.Random
3030

31-
import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkFunSuite, SparkUpgradeException}
31+
import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkException, SparkFunSuite, SparkUpgradeException}
3232
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3333
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3434
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
@@ -78,33 +78,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
7878
}
7979
}
8080

81-
test("datetime function current_date") {
82-
val d0 = DateTimeUtils.currentDate(UTC)
83-
val cd = CurrentDate(UTC_OPT).eval(EmptyRow).asInstanceOf[Int]
84-
val d1 = DateTimeUtils.currentDate(UTC)
85-
assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1)
86-
87-
val cdjst = CurrentDate(JST_OPT).eval(EmptyRow).asInstanceOf[Int]
88-
val cdpst = CurrentDate(PST_OPT).eval(EmptyRow).asInstanceOf[Int]
89-
assert(cdpst <= cd && cd <= cdjst)
90-
}
91-
92-
test("datetime function current_timestamp") {
93-
val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long])
94-
val t1 = System.currentTimeMillis()
95-
assert(math.abs(t1 - ct.getTime) < 5000)
96-
}
97-
98-
test("datetime function localtimestamp") {
99-
// Verify with multiple outstanding time zones which has no daylight saving time.
100-
Seq("UTC", "Africa/Dakar", "Asia/Hong_Kong").foreach { zid =>
101-
val zoneId = DateTimeUtils.getZoneId(zid)
102-
val ct = LocalTimestamp(Some(zid)).eval(EmptyRow).asInstanceOf[Long]
103-
val t1 = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
104-
assert(math.abs(t1 - ct) < 1000000)
105-
}
106-
}
107-
10881
test("DayOfYear") {
10982
val sdfDay = new SimpleDateFormat("D", Locale.US)
11083

@@ -970,11 +943,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
970943
Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId),
971944
MICROSECONDS.toSeconds(DateTimeUtils.daysToMicros(
972945
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz.toZoneId)))
973-
val t1 = UnixTimestamp(
974-
CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
975-
val t2 = UnixTimestamp(
976-
CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
977-
assert(t2 - t1 <= 1)
978946
checkEvaluation(
979947
UnixTimestamp(
980948
Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId),
@@ -1041,11 +1009,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
10411009
Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId),
10421010
MICROSECONDS.toSeconds(DateTimeUtils.daysToMicros(
10431011
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), zid)))
1044-
val t1 = ToUnixTimestamp(
1045-
CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
1046-
val t2 = ToUnixTimestamp(
1047-
CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
1048-
assert(t2 - t1 <= 1)
10491012
checkEvaluation(ToUnixTimestamp(
10501013
Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null)
10511014
checkEvaluation(
@@ -1516,7 +1479,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
15161479
checkExceptionInExpression[T](ToUnixTimestamp(Literal("1"), Literal(c)), c)
15171480
checkExceptionInExpression[T](UnixTimestamp(Literal("1"), Literal(c)), c)
15181481
if (!Set("E", "F", "q", "Q").contains(c)) {
1519-
checkExceptionInExpression[T](DateFormatClass(CurrentTimestamp(), Literal(c)), c)
15201482
checkExceptionInExpression[T](FromUnixTime(Literal(0L), Literal(c)), c)
15211483
}
15221484
}
@@ -2124,4 +2086,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
21242086
}
21252087
}
21262088
}
2089+
2090+
test("datetime function CurrentDate and localtimestamp are Unevaluable") {
2091+
checkError(exception = intercept[SparkException] { CurrentDate(UTC_OPT).eval(EmptyRow) },
2092+
errorClass = "INTERNAL_ERROR",
2093+
parameters = Map("message" -> "Cannot evaluate expression: current_date(Some(UTC))"))
2094+
2095+
checkError(exception = intercept[SparkException] { LocalTimestamp(UTC_OPT).eval(EmptyRow) },
2096+
errorClass = "INTERNAL_ERROR",
2097+
parameters = Map("message" -> "Cannot evaluate expression: localtimestamp(Some(UTC))"))
2098+
}
21272099
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import java.lang.Thread.sleep
2021
import java.time.{LocalDateTime, ZoneId}
2122

2223
import scala.concurrent.duration._
@@ -51,6 +52,19 @@ class ComputeCurrentTimeSuite extends PlanTest {
5152
assert(lits(0) == lits(1))
5253
}
5354

55+
test("analyzer should respect time flow in current timestamp calls") {
56+
val in = Project(Alias(CurrentTimestamp(), "t1")() :: Nil, LocalRelation())
57+
58+
val planT1 = Optimize.execute(in.analyze).asInstanceOf[Project]
59+
sleep(1)
60+
val planT2 = Optimize.execute(in.analyze).asInstanceOf[Project]
61+
62+
val t1 = DateTimeUtils.microsToMillis(literals[Long](planT1)(0))
63+
val t2 = DateTimeUtils.microsToMillis(literals[Long](planT2)(0))
64+
65+
assert(t2 - t1 <= 1000 && t2 - t1 > 0)
66+
}
67+
5468
test("analyzer should replace current_date with literals") {
5569
val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
5670

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ class EliminateSortsSuite extends AnalysisTest {
9898
test("Remove no-op alias") {
9999
val x = testRelation
100100

101-
val query = x.select($"a".as("x"), Year(CurrentDate()).as("y"), $"b")
101+
val query = x.select($"a".as("x"), Literal(1).as("y"), $"b")
102102
.orderBy($"x".asc, $"y".asc, $"b".desc)
103103
val optimized = Optimize.execute(analyzer.execute(query))
104104
val correctAnswer = analyzer.execute(
105-
x.select($"a".as("x"), Year(CurrentDate()).as("y"), $"b")
105+
x.select($"a".as("x"), Literal(1).as("y"), $"b")
106106
.orderBy($"x".asc, $"b".desc))
107107

108108
comparePlans(optimized, correctAnswer)

0 commit comments

Comments
 (0)