@@ -20,10 +20,13 @@ import org.apache.spark.sql.Row
20
20
import org .apache .spark .sql .catalyst .expressions .GenericRow
21
21
import org .apache .spark .sql .types ._
22
22
23
- import java .sql .ResultSet
24
23
import java .sql .Types ._
24
+ import java .sql .{ResultSet , Timestamp }
25
+ import java .time .{LocalDateTime , ZoneOffset }
25
26
26
27
class ResultSetToRowIterator (rs : ResultSet ) extends Iterator [Row ] {
28
+ import ResultSetToRowIterator ._
29
+
27
30
private var didHasNext = false
28
31
private var item : Option [Row ] = None
29
32
private val columnCount = rs.getMetaData.getColumnCount
@@ -62,7 +65,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
62
65
rs.close()
63
66
}
64
67
65
- private def fetchNext (): Unit = {
68
+ private [core] def fetchNext (): Unit = {
66
69
didHasNext = true
67
70
if (rs.next()) {
68
71
val data = new Array [Any ](columnCount)
@@ -78,7 +81,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
78
81
}
79
82
}
80
83
81
- private def getStructField (columnIndex : Int ): StructField = {
84
+ private [core] def getStructField (columnIndex : Int ): StructField = {
82
85
val columnName = rs.getMetaData.getColumnName(columnIndex)
83
86
val dataType = rs.getMetaData.getColumnType(columnIndex)
84
87
@@ -98,7 +101,7 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
98
101
}
99
102
}
100
103
101
- private def getCell (columnIndex : Int ): Any = {
104
+ private [core] def getCell (columnIndex : Int ): Any = {
102
105
val dataType = rs.getMetaData.getColumnType(columnIndex)
103
106
104
107
dataType match {
@@ -112,9 +115,26 @@ class ResultSetToRowIterator(rs: ResultSet) extends Iterator[Row] {
112
115
case REAL => rs.getBigDecimal(columnIndex)
113
116
case NUMERIC => rs.getBigDecimal(columnIndex)
114
117
case DATE => rs.getDate(columnIndex)
115
- case TIMESTAMP => rs.getTimestamp(columnIndex)
118
+ case TIMESTAMP => sanitizeTimestamp( rs.getTimestamp(columnIndex) )
116
119
case _ => rs.getString(columnIndex)
117
120
}
118
121
}
119
122
123
+ private [core] def sanitizeTimestamp (timestamp : Timestamp ): Timestamp = {
124
+ val timeMilli = timestamp.getTime
125
+ if (timeMilli > MAX_SAFE_TIMESTAMP_MILLI )
126
+ MAX_SAFE_TIMESTAMP
127
+ else if (timeMilli < MIN_SAFE_TIMESTAMP_MILLI )
128
+ MIN_SAFE_TIMESTAMP
129
+ else
130
+ timestamp
131
+ }
132
+ }
133
+
134
+ object ResultSetToRowIterator {
135
+ val MAX_SAFE_TIMESTAMP_MILLI : Long = LocalDateTime .of(9999 , 12 , 31 , 23 , 59 , 59 ).toEpochSecond(ZoneOffset .UTC ) * 1000
136
+ val MAX_SAFE_TIMESTAMP = new Timestamp (MAX_SAFE_TIMESTAMP_MILLI )
137
+
138
+ val MIN_SAFE_TIMESTAMP_MILLI : Long = LocalDateTime .of(1 , 1 , 1 , 0 , 0 , 0 ).toEpochSecond(ZoneOffset .UTC ) * 1000
139
+ val MIN_SAFE_TIMESTAMP = new Timestamp (MIN_SAFE_TIMESTAMP_MILLI )
120
140
}
0 commit comments