Skip to content

Commit 1076b36

Browse files
committed
#394 Fix PostgreSQL with numeric fields with no permission and scale specified in JdbcNative.
1 parent dcffa1a commit 1076b36

File tree

4 files changed

+131
-18
lines changed

4 files changed

+131
-18
lines changed

pramen/core/src/main/scala/za/co/absa/pramen/core/reader/model/JdbcConfig.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ case class JdbcConfig(
3131
retries: Option[Int] = None,
3232
connectionTimeoutSeconds: Option[Int] = None,
3333
sanitizeDateTime: Boolean = true,
34+
incorrectDecimalsAsString: Boolean = false,
3435
extraOptions: Map[String, String] = Map.empty[String, String]
3536
)
3637

@@ -46,6 +47,7 @@ object JdbcConfig {
4647
val JDBC_RETRIES = "jdbc.retries"
4748
val JDBC_CONNECTION_TIMEOUT = "jdbc.connection.timeout"
4849
val JDBC_SANITIZE_DATETIME = "jdbc.sanitize.datetime"
50+
val JDBC_INCORRECT_PRECISION_AS_STRING = "jdbc.incorrect.precision.as.string"
4951
val JDBC_EXTRA_OPTIONS_PREFIX = "jdbc.option"
5052

5153
def load(conf: Config, parent: String = ""): JdbcConfig = {
@@ -75,6 +77,7 @@ object JdbcConfig {
7577
retries = ConfigUtils.getOptionInt(conf, JDBC_RETRIES),
7678
connectionTimeoutSeconds = ConfigUtils.getOptionInt(conf, JDBC_CONNECTION_TIMEOUT),
7779
sanitizeDateTime = ConfigUtils.getOptionBoolean(conf, JDBC_SANITIZE_DATETIME).getOrElse(true),
80+
incorrectDecimalsAsString = ConfigUtils.getOptionBoolean(conf, JDBC_INCORRECT_PRECISION_AS_STRING).getOrElse(true),
7881
extraOptions = ConfigUtils.getExtraOptions(conf, JDBC_EXTRA_OPTIONS_PREFIX)
7982
)
8083
}

pramen/core/src/main/scala/za/co/absa/pramen/core/utils/JdbcNativeUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ object JdbcNativeUtils {
8989

9090
// Executing the query
9191
val rs = getResultSet(jdbcConfig, url, query)
92-
val driverIterator = new ResultSetToRowIterator(rs, jdbcConfig.sanitizeDateTime)
92+
val driverIterator = new ResultSetToRowIterator(rs, jdbcConfig.sanitizeDateTime, jdbcConfig.incorrectDecimalsAsString)
9393
val schema = JdbcSparkUtils.addMetadataFromJdbc(driverIterator.getSchema, rs.getMetaData)
9494

9595
driverIterator.close()
9696

9797
val rdd = spark.sparkContext.parallelize(Seq(query)).flatMap(q => {
98-
new ResultSetToRowIterator(getResultSet(jdbcConfig, url, q), jdbcConfig.sanitizeDateTime)
98+
new ResultSetToRowIterator(getResultSet(jdbcConfig, url, q), jdbcConfig.sanitizeDateTime, jdbcConfig.incorrectDecimalsAsString)
9999
})
100100

101101
spark.createDataFrame(rdd, schema)

pramen/core/src/main/scala/za/co/absa/pramen/core/utils/impl/ResultSetToRowIterator.scala

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.sql.Types._
2424
import java.sql.{Date, ResultSet, Timestamp}
2525
import java.time.{LocalDateTime, ZoneOffset}
2626

27-
class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean) extends Iterator[Row] {
27+
class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean, incorrectDecimalsAsString: Boolean) extends Iterator[Row] {
2828
import ResultSetToRowIterator._
2929

3030
private var didHasNext = false
@@ -93,8 +93,8 @@ class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean) extends I
9393
case BIGINT => StructField(columnName, LongType)
9494
case FLOAT => StructField(columnName, FloatType)
9595
case DOUBLE => StructField(columnName, DoubleType)
96-
case REAL => StructField(columnName, DecimalType(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
97-
case NUMERIC => StructField(columnName, DecimalType(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
96+
case REAL => StructField(columnName, getDecimalSparkSchema(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
97+
case NUMERIC => StructField(columnName, getDecimalSparkSchema(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
9898
case DATE => StructField(columnName, DateType)
9999
case TIMESTAMP => StructField(columnName, TimestampType)
100100
case _ => StructField(columnName, StringType)
@@ -113,14 +113,42 @@ class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean) extends I
113113
case BIGINT => rs.getLong(columnIndex)
114114
case FLOAT => rs.getFloat(columnIndex)
115115
case DOUBLE => rs.getDouble(columnIndex)
116-
case REAL => rs.getBigDecimal(columnIndex)
117-
case NUMERIC => rs.getBigDecimal(columnIndex)
116+
case REAL => getDecimalData(columnIndex)
117+
case NUMERIC => getDecimalData(columnIndex)
118118
case DATE => sanitizeDate(rs.getDate(columnIndex))
119119
case TIMESTAMP => sanitizeTimestamp(rs.getTimestamp(columnIndex))
120120
case _ => rs.getString(columnIndex)
121121
}
122122
}
123123

124+
private[core] def getDecimalSparkSchema(precision: Int, scale: Int): DataType = {
125+
if (scale >= precision || precision <= 0 || scale < 0 || precision > 38 || (precision + scale) > 38) {
126+
if (incorrectDecimalsAsString) {
127+
StringType
128+
} else {
129+
DecimalType(38, 18)
130+
}
131+
} else {
132+
DecimalType(precision, scale)
133+
}
134+
}
135+
136+
private[core] def getDecimalData(columnIndex: Int): Any = {
137+
if (incorrectDecimalsAsString) {
138+
val precision = rs.getMetaData.getPrecision(columnIndex)
139+
val scale = rs.getMetaData.getScale(columnIndex)
140+
141+
if (scale >= precision || precision <= 0 || scale < 0 || precision > 38 || (precision + scale) > 38) {
142+
rs.getString(columnIndex)
143+
} else {
144+
rs.getBigDecimal(columnIndex)
145+
}
146+
} else {
147+
rs.getBigDecimal(columnIndex)
148+
}
149+
}
150+
151+
124152
private[core] def sanitizeDate(date: Date): Date = {
125153
// This check against null is important since date=null is a valid value.
126154
if (sanitizeDateTime && date != null) {

pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/JdbcNativeUtilsSuite.scala

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package za.co.absa.pramen.core.tests.utils
1818

19-
import org.apache.spark.sql.types.IntegerType
19+
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}
2020
import org.mockito.Mockito.{mock, when}
2121
import org.scalatest.wordspec.AnyWordSpec
2222
import za.co.absa.pramen.core.base.SparkTestBase
@@ -26,6 +26,7 @@ import za.co.absa.pramen.core.samples.RdbExampleTable
2626
import za.co.absa.pramen.core.utils.impl.ResultSetToRowIterator
2727
import za.co.absa.pramen.core.utils.{JdbcNativeUtils, SparkUtils}
2828

29+
import java.math.BigDecimal
2930
import java.sql._
3031
import java.time.{Instant, ZoneId}
3132
import java.util.{Calendar, GregorianCalendar, TimeZone}
@@ -193,6 +194,87 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
193194
}
194195
}
195196

197+
"getDecimalSparkSchema" should {
198+
val resultSet = mock(classOf[ResultSet])
199+
val resultSetMetaData = mock(classOf[ResultSetMetaData])
200+
201+
when(resultSetMetaData.getColumnCount).thenReturn(1)
202+
when(resultSet.getMetaData).thenReturn(resultSetMetaData)
203+
204+
"return normal decimal for correct precision and scale" in {
205+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
206+
207+
assert(iterator.getDecimalSparkSchema(10, 0) == DecimalType(10, 0))
208+
assert(iterator.getDecimalSparkSchema(10, 2) == DecimalType(10, 2))
209+
assert(iterator.getDecimalSparkSchema(2, 1) == DecimalType(2, 1))
210+
assert(iterator.getDecimalSparkSchema(38, 18) == DecimalType(38, 18))
211+
}
212+
213+
"return fixed decimal for incorrect precision and scale" in {
214+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
215+
216+
assert(iterator.getDecimalSparkSchema(1, -1) == DecimalType(38, 18))
217+
assert(iterator.getDecimalSparkSchema(0, 0) == DecimalType(38, 18))
218+
assert(iterator.getDecimalSparkSchema(0, 2) == DecimalType(38, 18))
219+
assert(iterator.getDecimalSparkSchema(2, 2) == DecimalType(38, 18))
220+
assert(iterator.getDecimalSparkSchema(39, 0) == DecimalType(38, 18))
221+
assert(iterator.getDecimalSparkSchema(38, 19) == DecimalType(38, 18))
222+
assert(iterator.getDecimalSparkSchema(20, 19) == DecimalType(38, 18))
223+
}
224+
225+
"return string type for incorrect precision and scale" in {
226+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = true)
227+
228+
assert(iterator.getDecimalSparkSchema(1, -1) == StringType)
229+
assert(iterator.getDecimalSparkSchema(0, 0) == StringType)
230+
assert(iterator.getDecimalSparkSchema(0, 2) == StringType)
231+
assert(iterator.getDecimalSparkSchema(2, 2) == StringType)
232+
assert(iterator.getDecimalSparkSchema(39, 0) == StringType)
233+
assert(iterator.getDecimalSparkSchema(38, 19) == StringType)
234+
assert(iterator.getDecimalSparkSchema(20, 19) == StringType)
235+
}
236+
}
237+
238+
"getDecimalData" should {
239+
val resultSet = mock(classOf[ResultSet])
240+
val resultSetMetaData = mock(classOf[ResultSetMetaData])
241+
242+
when(resultSetMetaData.getColumnCount).thenReturn(1)
243+
when(resultSet.getMetaData).thenReturn(resultSetMetaData)
244+
when(resultSet.getBigDecimal(0)).thenReturn(new BigDecimal(1.0))
245+
when(resultSet.getString(0)).thenReturn("1")
246+
247+
"return normal decimal for correct precision and scale" in {
248+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
249+
when(resultSetMetaData.getPrecision(0)).thenReturn(10)
250+
when(resultSetMetaData.getScale(0)).thenReturn(2)
251+
252+
val v = iterator.getDecimalData(0)
253+
assert(v.isInstanceOf[BigDecimal])
254+
assert(v.asInstanceOf[BigDecimal].toString == "1")
255+
}
256+
257+
"return fixed decimal for incorrect precision and scale" in {
258+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
259+
when(resultSetMetaData.getPrecision(0)).thenReturn(0)
260+
when(resultSetMetaData.getScale(0)).thenReturn(2)
261+
262+
val v = iterator.getDecimalData(0)
263+
assert(v.isInstanceOf[BigDecimal])
264+
assert(v.asInstanceOf[BigDecimal].toString == "1")
265+
}
266+
267+
"return string type for incorrect precision and scale" in {
268+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = true)
269+
when(resultSetMetaData.getPrecision(0)).thenReturn(0)
270+
when(resultSetMetaData.getScale(0)).thenReturn(2)
271+
272+
val v = iterator.getDecimalData(0)
273+
assert(v.isInstanceOf[String])
274+
assert(v.asInstanceOf[String] == "1")
275+
}
276+
}
277+
196278
"sanitizeDateTime" when {
197279
// Variable names come from PostgreSQL "constant field docs":
198280
// https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
@@ -212,15 +294,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
212294
val maxTimestamp = 253402300799999L
213295

214296
"ignore null values" in {
215-
val iterator = new ResultSetToRowIterator(resultSet, true)
297+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
216298

217299
val fixedTs = iterator.sanitizeTimestamp(null)
218300

219301
assert(fixedTs == null)
220302
}
221303

222304
"convert PostgreSql positive infinity value" in {
223-
val iterator = new ResultSetToRowIterator(resultSet, true)
305+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
224306
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY))
225307

226308
val fixedTs = iterator.sanitizeTimestamp(timestamp)
@@ -229,7 +311,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
229311
}
230312

231313
"convert PostgreSql negative infinity value" in {
232-
val iterator = new ResultSetToRowIterator(resultSet, true)
314+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
233315
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY))
234316

235317
val fixedTs = iterator.sanitizeTimestamp(timestamp)
@@ -238,7 +320,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
238320
}
239321

240322
"convert overflowed value to the maximum value supported" in {
241-
val iterator = new ResultSetToRowIterator(resultSet, true)
323+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
242324
val timestamp = Timestamp.from(Instant.ofEpochMilli(1000000000000000L))
243325

244326
val actual = iterator.sanitizeTimestamp(timestamp)
@@ -252,7 +334,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
252334
}
253335

254336
"do nothing if the feature is turned off" in {
255-
val iterator = new ResultSetToRowIterator(resultSet, false)
337+
val iterator = new ResultSetToRowIterator(resultSet, false, incorrectDecimalsAsString = false)
256338
val timestamp = Timestamp.from(Instant.ofEpochMilli(1000000000000000L))
257339

258340
val actual = iterator.sanitizeTimestamp(timestamp)
@@ -272,15 +354,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
272354
val maxDate = 253402214400000L
273355

274356
"ignore null values" in {
275-
val iterator = new ResultSetToRowIterator(resultSet, true)
357+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
276358

277359
val fixedDate = iterator.sanitizeDate(null)
278360

279361
assert(fixedDate == null)
280362
}
281363

282364
"convert PostgreSql positive infinity value" in {
283-
val iterator = new ResultSetToRowIterator(resultSet, true)
365+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
284366
val date = new Date(POSTGRESQL_DATE_POSITIVE_INFINITY)
285367

286368
val fixedDate = iterator.sanitizeDate(date)
@@ -289,7 +371,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
289371
}
290372

291373
"convert PostgreSql negative infinity value" in {
292-
val iterator = new ResultSetToRowIterator(resultSet, true)
374+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
293375
val date = new Date(POSTGRESQL_DATE_NEGATIVE_INFINITY)
294376

295377
val fixedDate = iterator.sanitizeDate(date)
@@ -298,7 +380,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
298380
}
299381

300382
"convert overflowed value to the maximum value supported" in {
301-
val iterator = new ResultSetToRowIterator(resultSet, true)
383+
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
302384
val date = new Date(1000000000000000L)
303385

304386
val actual = iterator.sanitizeDate(date)
@@ -312,7 +394,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
312394
}
313395

314396
"do nothing if the feature is turned off" in {
315-
val iterator = new ResultSetToRowIterator(resultSet, false)
397+
val iterator = new ResultSetToRowIterator(resultSet, false, incorrectDecimalsAsString = false)
316398
val date = new Date(1000000000000000L)
317399

318400
val actual = iterator.sanitizeDate(date)

0 commit comments

Comments
 (0)