Skip to content

Commit

Permalink
#375 Add workaround for timestamp overflows.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Mar 19, 2024
1 parent cc1c6fd commit fce4fc8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

import java.sql.ResultSet
import java.sql.Types._
import java.sql.{ResultSet, Timestamp}
import java.time.{LocalDateTime, ZoneOffset}

class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
import ResultSetToRowIterator._

private var didHasNext = false
private var item: Option[Row] = None
private val columnCount = rs.getMetaData.getColumnCount
Expand Down Expand Up @@ -62,7 +65,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
rs.close()
}

private def fetchNext(): Unit = {
private[core] def fetchNext(): Unit = {
didHasNext = true
if (rs.next()) {
val data = new Array[Any](columnCount)
Expand All @@ -78,7 +81,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
}
}

private def getStructField(columnIndex: Int): StructField = {
private[core] def getStructField(columnIndex: Int): StructField = {
val columnName = rs.getMetaData.getColumnName(columnIndex)
val dataType = rs.getMetaData.getColumnType(columnIndex)

Expand All @@ -98,7 +101,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
}
}

private def getCell(columnIndex: Int): Any = {
private[core] def getCell(columnIndex: Int): Any = {
val dataType = rs.getMetaData.getColumnType(columnIndex)

dataType match {
Expand All @@ -112,9 +115,26 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
case REAL => rs.getBigDecimal(columnIndex)
case NUMERIC => rs.getBigDecimal(columnIndex)
case DATE => rs.getDate(columnIndex)
case TIMESTAMP => rs.getTimestamp(columnIndex)
case TIMESTAMP => sanitizeTimestamp(rs.getTimestamp(columnIndex))
case _ => rs.getString(columnIndex)
}
}

private[core] def sanitizeTimestamp(timestamp: Timestamp): Timestamp = {
val timeMilli = timestamp.getTime
if (timeMilli > MAX_SAFE_TIMESTAMP_MILLI)
MAX_SAFE_TIMESTAMP
else if (timeMilli < MIN_SAFE_TIMESTAMP_MILLI)
MIN_SAFE_TIMESTAMP
else
timestamp
}
}

object ResultSetToRowIterator {
val MAX_SAFE_TIMESTAMP_MILLI: Long = LocalDateTime.of(9999, 12, 31, 23, 59, 59).toEpochSecond(ZoneOffset.UTC) * 1000
val MAX_SAFE_TIMESTAMP = new Timestamp(MAX_SAFE_TIMESTAMP_MILLI)

val MIN_SAFE_TIMESTAMP_MILLI: Long = LocalDateTime.of(1, 1, 1, 0, 0, 0).toEpochSecond(ZoneOffset.UTC) * 1000
val MIN_SAFE_TIMESTAMP = new Timestamp(MIN_SAFE_TIMESTAMP_MILLI)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
package za.co.absa.pramen.core.tests.utils

import org.apache.spark.sql.types.IntegerType
import org.mockito.Mockito.{mock, when}
import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.core.base.SparkTestBase
import za.co.absa.pramen.core.fixtures.{RelationalDbFixture, TextComparisonFixture}
import za.co.absa.pramen.core.reader.model.JdbcConfig
import za.co.absa.pramen.core.samples.RdbExampleTable
import za.co.absa.pramen.core.utils.impl.ResultSetToRowIterator
import za.co.absa.pramen.core.utils.{JdbcNativeUtils, SparkUtils}

import java.sql.{DriverManager, ResultSet, SQLSyntaxErrorException}
import java.sql._
import java.time.{Instant, ZoneId}
import java.util.{Calendar, GregorianCalendar, TimeZone}

class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with SparkTestBase with TextComparisonFixture {
private val tableName = RdbExampleTable.Company.tableName
Expand Down Expand Up @@ -189,4 +193,55 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}
}

"sanitizeTimestamp" should {
// Variable names come from PostgreSQL "constant field docs":
// https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
val POSTGRESQL_DATE_NEGATIVE_INFINITY: Long = -9223372036832400000L
val POSTGRESQL_DATE_POSITIVE_INFINITY: Long = 9223372036825200000L

val resultSet = mock(classOf[ResultSet])
val resultSetMetaData = mock(classOf[ResultSetMetaData])

when(resultSetMetaData.getColumnCount).thenReturn(1)
when(resultSet.getMetaData).thenReturn(resultSetMetaData)

val iterator = new ResultSetToRowIterator(resultSet)

"convert PostgreSql positive infinity value" in {
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY))

val fixedTs = iterator.sanitizeTimestamp(timestamp)

val calendar = new GregorianCalendar(TimeZone.getTimeZone(ZoneId.of("UTC")))
calendar.setTime(fixedTs)
val year = calendar.get(Calendar.YEAR)

assert(year == 9999)
}

"convert PostgreSql negative infinity value" in {
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY))

val fixedTs = iterator.sanitizeTimestamp(timestamp)

val calendar = new GregorianCalendar(TimeZone.getTimeZone(ZoneId.of("UTC")))
calendar.setTime(fixedTs)
val year = calendar.get(Calendar.YEAR)

assert(year == 1)
}

"convert overflowed value to null" in {
val timestamp = Timestamp.from(Instant.ofEpochMilli(1000000000000000L))

val actual = iterator.sanitizeTimestamp(timestamp)

val calendar = new GregorianCalendar(TimeZone.getTimeZone(ZoneId.of("UTC")))
calendar.setTime(actual)
val year = calendar.get(Calendar.YEAR)

assert(year == 9999)
}
}

}

0 comments on commit fce4fc8

Please sign in to comment.