Skip to content

Commit fce4fc8

Browse files
committed
#375 Add workaround for timestamp overflows.
1 parent cc1c6fd commit fce4fc8

File tree

2 files changed

+81
-6
lines changed

2 files changed

+81
-6
lines changed

pramen/core/src/main/scala/za/co/absa/pramen/core/utils/impl/ResultSetToRowIterator.scala

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ import org.apache.spark.sql.Row
2020
import org.apache.spark.sql.catalyst.expressions.GenericRow
2121
import org.apache.spark.sql.types._
2222

23-
import java.sql.ResultSet
2423
import java.sql.Types._
24+
import java.sql.{ResultSet, Timestamp}
25+
import java.time.{LocalDateTime, ZoneOffset}
2526

2627
class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
28+
import ResultSetToRowIterator._
29+
2730
private var didHasNext = false
2831
private var item: Option[Row] = None
2932
private val columnCount = rs.getMetaData.getColumnCount
@@ -62,7 +65,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
6265
rs.close()
6366
}
6467

65-
private def fetchNext(): Unit = {
68+
private[core] def fetchNext(): Unit = {
6669
didHasNext = true
6770
if (rs.next()) {
6871
val data = new Array[Any](columnCount)
@@ -78,7 +81,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
7881
}
7982
}
8083

81-
private def getStructField(columnIndex: Int): StructField = {
84+
private[core] def getStructField(columnIndex: Int): StructField = {
8285
val columnName = rs.getMetaData.getColumnName(columnIndex)
8386
val dataType = rs.getMetaData.getColumnType(columnIndex)
8487

@@ -98,7 +101,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
98101
}
99102
}
100103

101-
private def getCell(columnIndex: Int): Any = {
104+
private[core] def getCell(columnIndex: Int): Any = {
102105
val dataType = rs.getMetaData.getColumnType(columnIndex)
103106

104107
dataType match {
@@ -112,9 +115,26 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
112115
case REAL => rs.getBigDecimal(columnIndex)
113116
case NUMERIC => rs.getBigDecimal(columnIndex)
114117
case DATE => rs.getDate(columnIndex)
115-
case TIMESTAMP => rs.getTimestamp(columnIndex)
118+
case TIMESTAMP => sanitizeTimestamp(rs.getTimestamp(columnIndex))
116119
case _ => rs.getString(columnIndex)
117120
}
118121
}
119122

123+
private[core] def sanitizeTimestamp(timestamp: Timestamp): Timestamp = {
124+
val timeMilli = timestamp.getTime
125+
if (timeMilli > MAX_SAFE_TIMESTAMP_MILLI)
126+
MAX_SAFE_TIMESTAMP
127+
else if (timeMilli < MIN_SAFE_TIMESTAMP_MILLI)
128+
MIN_SAFE_TIMESTAMP
129+
else
130+
timestamp
131+
}
132+
}
133+
134+
object ResultSetToRowIterator {
135+
val MAX_SAFE_TIMESTAMP_MILLI: Long = LocalDateTime.of(9999, 12, 31, 23, 59, 59).toEpochSecond(ZoneOffset.UTC) * 1000
136+
val MAX_SAFE_TIMESTAMP = new Timestamp(MAX_SAFE_TIMESTAMP_MILLI)
137+
138+
val MIN_SAFE_TIMESTAMP_MILLI: Long = LocalDateTime.of(1, 1, 1, 0, 0, 0).toEpochSecond(ZoneOffset.UTC) * 1000
139+
val MIN_SAFE_TIMESTAMP = new Timestamp(MIN_SAFE_TIMESTAMP_MILLI)
120140
}

pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/JdbcNativeUtilsSuite.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
package za.co.absa.pramen.core.tests.utils
1818

1919
import org.apache.spark.sql.types.IntegerType
20+
import org.mockito.Mockito.{mock, when}
2021
import org.scalatest.wordspec.AnyWordSpec
2122
import za.co.absa.pramen.core.base.SparkTestBase
2223
import za.co.absa.pramen.core.fixtures.{RelationalDbFixture, TextComparisonFixture}
2324
import za.co.absa.pramen.core.reader.model.JdbcConfig
2425
import za.co.absa.pramen.core.samples.RdbExampleTable
26+
import za.co.absa.pramen.core.utils.impl.ResultSetToRowIterator
2527
import za.co.absa.pramen.core.utils.{JdbcNativeUtils, SparkUtils}
2628

27-
import java.sql.{DriverManager, ResultSet, SQLSyntaxErrorException}
29+
import java.sql._
30+
import java.time.{Instant, ZoneId}
31+
import java.util.{Calendar, GregorianCalendar, TimeZone}
2832

2933
class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with SparkTestBase with TextComparisonFixture {
3034
private val tableName = RdbExampleTable.Company.tableName
@@ -189,4 +193,55 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
189193
}
190194
}
191195

196+
"sanitizeTimestamp" should {
197+
// Variable names come from PostgreSQL "constant field docs":
198+
// https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
199+
val POSTGRESQL_DATE_NEGATIVE_INFINITY: Long = -9223372036832400000L
200+
val POSTGRESQL_DATE_POSITIVE_INFINITY: Long = 9223372036825200000L
201+
202+
val resultSet = mock(classOf[ResultSet])
203+
val resultSetMetaData = mock(classOf[ResultSetMetaData])
204+
205+
when(resultSetMetaData.getColumnCount).thenReturn(1)
206+
when(resultSet.getMetaData).thenReturn(resultSetMetaData)
207+
208+
val iterator = new ResultSetToRowIterator(resultSet)
209+
210+
"convert PostgreSql positive infinity value" in {
211+
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY))
212+
213+
val fixedTs = iterator.sanitizeTimestamp(timestamp)
214+
215+
val calendar = new GregorianCalendar(TimeZone.getTimeZone(ZoneId.of("UTC")))
216+
calendar.setTime(fixedTs)
217+
val year = calendar.get(Calendar.YEAR)
218+
219+
assert(year == 9999)
220+
}
221+
222+
"convert PostgreSql negative infinity value" in {
223+
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY))
224+
225+
val fixedTs = iterator.sanitizeTimestamp(timestamp)
226+
227+
val calendar = new GregorianCalendar(TimeZone.getTimeZone(ZoneId.of("UTC")))
228+
calendar.setTime(fixedTs)
229+
val year = calendar.get(Calendar.YEAR)
230+
231+
assert(year == 1)
232+
}
233+
234+
"convert overflowed value to null" in {
235+
val timestamp = Timestamp.from(Instant.ofEpochMilli(1000000000000000L))
236+
237+
val actual = iterator.sanitizeTimestamp(timestamp)
238+
239+
val calendar = new GregorianCalendar(TimeZone.getTimeZone(ZoneId.of("UTC")))
240+
calendar.setTime(actual)
241+
val year = calendar.get(Calendar.YEAR)
242+
243+
assert(year == 9999)
244+
}
245+
}
246+
192247
}

0 commit comments

Comments
 (0)