Skip to content

Commit

Permalink
[SPARK-46331][SQL] Removing CodegenFallback from subset of DateTime e…
Browse files Browse the repository at this point in the history
…xpressions and version() expression

### What changes were proposed in this pull request?

This PR moves us a bit closer to removing CodegenFallback class and instead of it relying on RuntimeReplaceable with StaticInvoke.

In this PR there are following changes:
- Doing StaticInvoke + RuntimeReplaceable against spark version expression.
- Adding Unevaluable trait for DateTime expressions. These expressions need to be replaced during analysis anyhow so we explicitly forbid eval from being called.

### Why are the changes needed?

Direction is to get away from CodegenFallback. This PR moves us closer to that destination.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Running existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#44261 from dbatomic/codegenfallback_removal.

Lead-authored-by: Aleksandar Tomic <[email protected]>
Co-authored-by: Aleksandar Tomic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Jan 9, 2024
1 parent 8fa794b commit ee1fd92
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.SparkBuildInfo;
import org.apache.spark.sql.errors.QueryExecutionErrors;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.VersionUtils;

import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
Expand Down Expand Up @@ -143,6 +145,17 @@ public static byte[] aesDecrypt(byte[] input,
);
}

/**
* Function to return the Spark version.
* @return
* Space separated version and revision.
*/
public static UTF8String getSparkVersion() {
String shortVersion = VersionUtils.shortVersion(SparkBuildInfo.spark_version());
String revision = SparkBuildInfo.spark_revision();
return UTF8String.fromString(shortVersion + " " + revision);
}

private static SecretKeySpec getSecretKeySpec(byte[] key) {
return switch (key.length) {
case 16, 24, 32 -> new SecretKeySpec(key, 0, key.length, "AES");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ object ResolveInlineTables extends Rule[LogicalPlan]
/**
* Validates that all inline table data are valid expressions that can be evaluated
* (in this they must be foldable).
*
* Note that nondeterministic expressions are not supported since they are not foldable.
* Exception are CURRENT_LIKE expressions, which are replaced by a literal in later stages.
* This is package visible for unit testing.
*/
private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = {
table.rows.foreach { row =>
row.foreach { e =>
// Note that nondeterministic expressions are not supported since they are not foldable.
// Only exception are CURRENT_LIKE expressions, which are replaced by a literal
// In later stages.
if ((!e.resolved && !e.containsPattern(CURRENT_LIKE))
|| !trimAliases(prepareForEval(e)).foldable) {
if (e.containsPattern(CURRENT_LIKE)) {
// Do nothing.
} else if (!e.resolved || !trimAliases(prepareForEval(e)).foldable) {
e.failAnalysis(
errorClass = "INVALID_INLINE_TABLE.CANNOT_EVALUATE_EXPRESSION_IN_INLINE_TABLE",
messageParameters = Map("expr" -> toSQLExpr(e)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}

/**
Expand All @@ -34,7 +36,15 @@ import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
*/
object ResolveTableSpec extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) {
val preparedPlan = if (SQLConf.get.legacyEvalCurrentTime && plan.containsPattern(COMMAND)) {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
ComputeCurrentTime(ResolveTimeZone(plan))
}
} else {
plan
}

preparedPlan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) {
case t: CreateTable =>
resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s))
case t: CreateTableAsSelect =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,14 @@ case class CurrentTimeZone() extends LeafExpression with Unevaluable {
since = "1.5.0")
// scalastyle:on line.size.limit
case class CurrentDate(timeZoneId: Option[String] = None)
extends LeafExpression with TimeZoneAwareExpression with CodegenFallback {

extends LeafExpression with TimeZoneAwareExpression with Unevaluable {
def this() = this(None)

override def foldable: Boolean = true
override def nullable: Boolean = false

override def dataType: DataType = DateType

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def eval(input: InternalRow): Any = currentDate(zoneId)

override def prettyName: String = "current_date"
}

Expand Down Expand Up @@ -177,11 +169,9 @@ object CurDateExpressionBuilder extends ExpressionBuilder {
}
}

abstract class CurrentTimestampLike() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
abstract class CurrentTimestampLike() extends LeafExpression with Unevaluable {
override def nullable: Boolean = false
override def dataType: DataType = TimestampType
override def eval(input: InternalRow): Any = currentTimestamp()
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

Expand Down Expand Up @@ -245,22 +235,13 @@ case class Now() extends CurrentTimestampLike {
group = "datetime_funcs",
since = "3.4.0")
case class LocalTimestamp(timeZoneId: Option[String] = None) extends LeafExpression
with TimeZoneAwareExpression with CodegenFallback {

with TimeZoneAwareExpression with Unevaluable {
def this() = this(None)

override def foldable: Boolean = true
override def nullable: Boolean = false

override def dataType: DataType = TimestampNTZType

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def eval(input: InternalRow): Any = localDateTimeToMicros(LocalDateTime.now(zoneId))

override def prettyName: String = "localtimestamp"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, UnresolvedSeed}
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -288,14 +287,14 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non
since = "3.0.0",
group = "misc_funcs")
// scalastyle:on line.size.limit
case class SparkVersion() extends LeafExpression with CodegenFallback {
override def nullable: Boolean = false
override def foldable: Boolean = true
override def dataType: DataType = StringType
case class SparkVersion() extends LeafExpression with RuntimeReplaceable {
override def prettyName: String = "version"
override def eval(input: InternalRow): Any = {
UTF8String.fromString(SPARK_VERSION_SHORT + " " + SPARK_REVISION)
}

override lazy val replacement: Expression = StaticInvoke(
classOf[ExpressionImplUtils],
StringType,
"getSparkVersion",
returnNullable = false)
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,6 @@ object DateTimeUtils extends SparkDateTimeUtils {
convertTz(micros, getZoneId(timeZone), ZoneOffset.UTC)
}

/**
* Obtains the current instant as microseconds since the epoch at the UTC time zone.
*/
def currentTimestamp(): Long = instantToMicros(Instant.now())

/**
* Obtains the current date as days since the epoch in the specified time-zone.
*/
Expand Down Expand Up @@ -572,7 +567,7 @@ object DateTimeUtils extends SparkDateTimeUtils {
def convertSpecialTimestamp(input: String, zoneId: ZoneId): Option[Long] = {
extractSpecialValue(input.trim).flatMap {
case "epoch" => Some(0)
case "now" => Some(currentTimestamp())
case "now" => Some(instantToMicros(Instant.now()))
case "today" => Some(instantToMicros(today(zoneId).toInstant))
case "tomorrow" => Some(instantToMicros(today(zoneId).plusDays(1).toInstant))
case "yesterday" => Some(instantToMicros(today(zoneId).minusDays(1).toInstant))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4612,6 +4612,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_EVAL_CURRENT_TIME = buildConf("spark.sql.legacy.earlyEvalCurrentTime")
.internal()
.doc("When set to true, evaluation and constant folding will happen for now() and " +
"current_timestamp() expressions before finish analysis phase. " +
"This flag will allow a bit more liberal syntax but it will sacrifice correctness - " +
"Results of now() and current_timestamp() can be different for different operations " +
"in a single query."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5516,6 +5528,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def legacyJavaCharsets: Boolean = getConf(SQLConf.LEGACY_JAVA_CHARSETS)

def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkFunSuite, SparkUpgradeException}
import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkException, SparkFunSuite, SparkUpgradeException}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
Expand Down Expand Up @@ -78,33 +78,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("datetime function current_date") {
val d0 = DateTimeUtils.currentDate(UTC)
val cd = CurrentDate(UTC_OPT).eval(EmptyRow).asInstanceOf[Int]
val d1 = DateTimeUtils.currentDate(UTC)
assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1)

val cdjst = CurrentDate(JST_OPT).eval(EmptyRow).asInstanceOf[Int]
val cdpst = CurrentDate(PST_OPT).eval(EmptyRow).asInstanceOf[Int]
assert(cdpst <= cd && cd <= cdjst)
}

test("datetime function current_timestamp") {
val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long])
val t1 = System.currentTimeMillis()
assert(math.abs(t1 - ct.getTime) < 5000)
}

test("datetime function localtimestamp") {
// Verify with multiple outstanding time zones which has no daylight saving time.
Seq("UTC", "Africa/Dakar", "Asia/Hong_Kong").foreach { zid =>
val zoneId = DateTimeUtils.getZoneId(zid)
val ct = LocalTimestamp(Some(zid)).eval(EmptyRow).asInstanceOf[Long]
val t1 = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
assert(math.abs(t1 - ct) < 1000000)
}
}

test("DayOfYear") {
val sdfDay = new SimpleDateFormat("D", Locale.US)

Expand Down Expand Up @@ -970,11 +943,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId),
MICROSECONDS.toSeconds(DateTimeUtils.daysToMicros(
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz.toZoneId)))
val t1 = UnixTimestamp(
CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
val t2 = UnixTimestamp(
CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
assert(t2 - t1 <= 1)
checkEvaluation(
UnixTimestamp(
Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId),
Expand Down Expand Up @@ -1041,11 +1009,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId),
MICROSECONDS.toSeconds(DateTimeUtils.daysToMicros(
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), zid)))
val t1 = ToUnixTimestamp(
CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
val t2 = ToUnixTimestamp(
CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
assert(t2 - t1 <= 1)
checkEvaluation(ToUnixTimestamp(
Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null)
checkEvaluation(
Expand Down Expand Up @@ -1516,7 +1479,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkExceptionInExpression[T](ToUnixTimestamp(Literal("1"), Literal(c)), c)
checkExceptionInExpression[T](UnixTimestamp(Literal("1"), Literal(c)), c)
if (!Set("E", "F", "q", "Q").contains(c)) {
checkExceptionInExpression[T](DateFormatClass(CurrentTimestamp(), Literal(c)), c)
checkExceptionInExpression[T](FromUnixTime(Literal(0L), Literal(c)), c)
}
}
Expand Down Expand Up @@ -2124,4 +2086,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}

test("datetime function CurrentDate and localtimestamp are Unevaluable") {
checkError(exception = intercept[SparkException] { CurrentDate(UTC_OPT).eval(EmptyRow) },
errorClass = "INTERNAL_ERROR",
parameters = Map("message" -> "Cannot evaluate expression: current_date(Some(UTC))"))

checkError(exception = intercept[SparkException] { LocalTimestamp(UTC_OPT).eval(EmptyRow) },
errorClass = "INTERNAL_ERROR",
parameters = Map("message" -> "Cannot evaluate expression: localtimestamp(Some(UTC))"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import java.lang.Thread.sleep
import java.time.{LocalDateTime, ZoneId}

import scala.concurrent.duration._
Expand Down Expand Up @@ -51,6 +52,19 @@ class ComputeCurrentTimeSuite extends PlanTest {
assert(lits(0) == lits(1))
}

test("analyzer should respect time flow in current timestamp calls") {
val in = Project(Alias(CurrentTimestamp(), "t1")() :: Nil, LocalRelation())

val planT1 = Optimize.execute(in.analyze).asInstanceOf[Project]
sleep(1)
val planT2 = Optimize.execute(in.analyze).asInstanceOf[Project]

val t1 = DateTimeUtils.microsToMillis(literals[Long](planT1)(0))
val t2 = DateTimeUtils.microsToMillis(literals[Long](planT2)(0))

assert(t2 - t1 <= 1000 && t2 - t1 > 0)
}

test("analyzer should replace current_date with literals") {
val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ class EliminateSortsSuite extends AnalysisTest {
test("Remove no-op alias") {
val x = testRelation

val query = x.select($"a".as("x"), Year(CurrentDate()).as("y"), $"b")
val query = x.select($"a".as("x"), Literal(1).as("y"), $"b")
.orderBy($"x".asc, $"y".asc, $"b".desc)
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = analyzer.execute(
x.select($"a".as("x"), Year(CurrentDate()).as("y"), $"b")
x.select($"a".as("x"), Literal(1).as("y"), $"b")
.orderBy($"x".asc, $"b".desc))

comparePlans(optimized, correctAnswer)
Expand Down
Loading

0 comments on commit ee1fd92

Please sign in to comment.