Skip to content

Commit

Permalink
#394 Fix PostgreSQL with numeric fields with no permission and scale …
Browse files Browse the repository at this point in the history
…specified in JdbcNative.
  • Loading branch information
yruslan committed Apr 24, 2024
1 parent dcffa1a commit 1076b36
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ case class JdbcConfig(
retries: Option[Int] = None,
connectionTimeoutSeconds: Option[Int] = None,
sanitizeDateTime: Boolean = true,
incorrectDecimalsAsString: Boolean = false,
extraOptions: Map[String, String] = Map.empty[String, String]
)

Expand All @@ -46,6 +47,7 @@ object JdbcConfig {
val JDBC_RETRIES = "jdbc.retries"
val JDBC_CONNECTION_TIMEOUT = "jdbc.connection.timeout"
val JDBC_SANITIZE_DATETIME = "jdbc.sanitize.datetime"
val JDBC_INCORRECT_PRECISION_AS_STRING = "jdbc.incorrect.precision.as.string"
val JDBC_EXTRA_OPTIONS_PREFIX = "jdbc.option"

def load(conf: Config, parent: String = ""): JdbcConfig = {
Expand Down Expand Up @@ -75,6 +77,7 @@ object JdbcConfig {
retries = ConfigUtils.getOptionInt(conf, JDBC_RETRIES),
connectionTimeoutSeconds = ConfigUtils.getOptionInt(conf, JDBC_CONNECTION_TIMEOUT),
sanitizeDateTime = ConfigUtils.getOptionBoolean(conf, JDBC_SANITIZE_DATETIME).getOrElse(true),
incorrectDecimalsAsString = ConfigUtils.getOptionBoolean(conf, JDBC_INCORRECT_PRECISION_AS_STRING).getOrElse(true),
extraOptions = ConfigUtils.getExtraOptions(conf, JDBC_EXTRA_OPTIONS_PREFIX)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ object JdbcNativeUtils {

// Executing the query
val rs = getResultSet(jdbcConfig, url, query)
val driverIterator = new ResultSetToRowIterator(rs, jdbcConfig.sanitizeDateTime)
val driverIterator = new ResultSetToRowIterator(rs, jdbcConfig.sanitizeDateTime, jdbcConfig.incorrectDecimalsAsString)
val schema = JdbcSparkUtils.addMetadataFromJdbc(driverIterator.getSchema, rs.getMetaData)

driverIterator.close()

val rdd = spark.sparkContext.parallelize(Seq(query)).flatMap(q => {
new ResultSetToRowIterator(getResultSet(jdbcConfig, url, q), jdbcConfig.sanitizeDateTime)
new ResultSetToRowIterator(getResultSet(jdbcConfig, url, q), jdbcConfig.sanitizeDateTime, jdbcConfig.incorrectDecimalsAsString)
})

spark.createDataFrame(rdd, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.sql.Types._
import java.sql.{Date, ResultSet, Timestamp}
import java.time.{LocalDateTime, ZoneOffset}

class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean) extends Iterator[Row] {
class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean, incorrectDecimalsAsString: Boolean) extends Iterator[Row] {
import ResultSetToRowIterator._

private var didHasNext = false
Expand Down Expand Up @@ -93,8 +93,8 @@ class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean) extends I
case BIGINT => StructField(columnName, LongType)
case FLOAT => StructField(columnName, FloatType)
case DOUBLE => StructField(columnName, DoubleType)
case REAL => StructField(columnName, DecimalType(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
case NUMERIC => StructField(columnName, DecimalType(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
case REAL => StructField(columnName, getDecimalSparkSchema(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
case NUMERIC => StructField(columnName, getDecimalSparkSchema(rs.getMetaData.getPrecision(columnIndex), rs.getMetaData.getScale(columnIndex)))
case DATE => StructField(columnName, DateType)
case TIMESTAMP => StructField(columnName, TimestampType)
case _ => StructField(columnName, StringType)
Expand All @@ -113,14 +113,42 @@ class ResultSetToRowIterator(rs: ResultSet, sanitizeDateTime: Boolean) extends I
case BIGINT => rs.getLong(columnIndex)
case FLOAT => rs.getFloat(columnIndex)
case DOUBLE => rs.getDouble(columnIndex)
case REAL => rs.getBigDecimal(columnIndex)
case NUMERIC => rs.getBigDecimal(columnIndex)
case REAL => getDecimalData(columnIndex)
case NUMERIC => getDecimalData(columnIndex)
case DATE => sanitizeDate(rs.getDate(columnIndex))
case TIMESTAMP => sanitizeTimestamp(rs.getTimestamp(columnIndex))
case _ => rs.getString(columnIndex)
}
}

private[core] def getDecimalSparkSchema(precision: Int, scale: Int): DataType = {
if (scale >= precision || precision <= 0 || scale < 0 || precision > 38 || (precision + scale) > 38) {
if (incorrectDecimalsAsString) {
StringType
} else {
DecimalType(38, 18)
}
} else {
DecimalType(precision, scale)
}
}

private[core] def getDecimalData(columnIndex: Int): Any = {
if (incorrectDecimalsAsString) {
val precision = rs.getMetaData.getPrecision(columnIndex)
val scale = rs.getMetaData.getScale(columnIndex)

if (scale >= precision || precision <= 0 || scale < 0 || precision > 38 || (precision + scale) > 38) {
rs.getString(columnIndex)
} else {
rs.getBigDecimal(columnIndex)
}
} else {
rs.getBigDecimal(columnIndex)
}
}


private[core] def sanitizeDate(date: Date): Date = {
// This check against null is important since date=null is a valid value.
if (sanitizeDateTime && date != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package za.co.absa.pramen.core.tests.utils

import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}
import org.mockito.Mockito.{mock, when}
import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.core.base.SparkTestBase
Expand All @@ -26,6 +26,7 @@ import za.co.absa.pramen.core.samples.RdbExampleTable
import za.co.absa.pramen.core.utils.impl.ResultSetToRowIterator
import za.co.absa.pramen.core.utils.{JdbcNativeUtils, SparkUtils}

import java.math.BigDecimal
import java.sql._
import java.time.{Instant, ZoneId}
import java.util.{Calendar, GregorianCalendar, TimeZone}
Expand Down Expand Up @@ -193,6 +194,87 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}
}

"getDecimalSparkSchema" should {
val resultSet = mock(classOf[ResultSet])
val resultSetMetaData = mock(classOf[ResultSetMetaData])

when(resultSetMetaData.getColumnCount).thenReturn(1)
when(resultSet.getMetaData).thenReturn(resultSetMetaData)

"return normal decimal for correct precision and scale" in {
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)

assert(iterator.getDecimalSparkSchema(10, 0) == DecimalType(10, 0))
assert(iterator.getDecimalSparkSchema(10, 2) == DecimalType(10, 2))
assert(iterator.getDecimalSparkSchema(2, 1) == DecimalType(2, 1))
assert(iterator.getDecimalSparkSchema(38, 18) == DecimalType(38, 18))
}

"return fixed decimal for incorrect precision and scale" in {
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)

assert(iterator.getDecimalSparkSchema(1, -1) == DecimalType(38, 18))
assert(iterator.getDecimalSparkSchema(0, 0) == DecimalType(38, 18))
assert(iterator.getDecimalSparkSchema(0, 2) == DecimalType(38, 18))
assert(iterator.getDecimalSparkSchema(2, 2) == DecimalType(38, 18))
assert(iterator.getDecimalSparkSchema(39, 0) == DecimalType(38, 18))
assert(iterator.getDecimalSparkSchema(38, 19) == DecimalType(38, 18))
assert(iterator.getDecimalSparkSchema(20, 19) == DecimalType(38, 18))
}

"return string type for incorrect precision and scale" in {
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = true)

assert(iterator.getDecimalSparkSchema(1, -1) == StringType)
assert(iterator.getDecimalSparkSchema(0, 0) == StringType)
assert(iterator.getDecimalSparkSchema(0, 2) == StringType)
assert(iterator.getDecimalSparkSchema(2, 2) == StringType)
assert(iterator.getDecimalSparkSchema(39, 0) == StringType)
assert(iterator.getDecimalSparkSchema(38, 19) == StringType)
assert(iterator.getDecimalSparkSchema(20, 19) == StringType)
}
}

"getDecimalData" should {
val resultSet = mock(classOf[ResultSet])
val resultSetMetaData = mock(classOf[ResultSetMetaData])

when(resultSetMetaData.getColumnCount).thenReturn(1)
when(resultSet.getMetaData).thenReturn(resultSetMetaData)
when(resultSet.getBigDecimal(0)).thenReturn(new BigDecimal(1.0))
when(resultSet.getString(0)).thenReturn("1")

"return normal decimal for correct precision and scale" in {
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
when(resultSetMetaData.getPrecision(0)).thenReturn(10)
when(resultSetMetaData.getScale(0)).thenReturn(2)

val v = iterator.getDecimalData(0)
assert(v.isInstanceOf[BigDecimal])
assert(v.asInstanceOf[BigDecimal].toString == "1")
}

"return fixed decimal for incorrect precision and scale" in {
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
when(resultSetMetaData.getPrecision(0)).thenReturn(0)
when(resultSetMetaData.getScale(0)).thenReturn(2)

val v = iterator.getDecimalData(0)
assert(v.isInstanceOf[BigDecimal])
assert(v.asInstanceOf[BigDecimal].toString == "1")
}

"return string type for incorrect precision and scale" in {
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = true)
when(resultSetMetaData.getPrecision(0)).thenReturn(0)
when(resultSetMetaData.getScale(0)).thenReturn(2)

val v = iterator.getDecimalData(0)
assert(v.isInstanceOf[String])
assert(v.asInstanceOf[String] == "1")
}
}

"sanitizeDateTime" when {
// Variable names come from PostgreSQL "constant field docs":
// https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
Expand All @@ -212,15 +294,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
val maxTimestamp = 253402300799999L

"ignore null values" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)

val fixedTs = iterator.sanitizeTimestamp(null)

assert(fixedTs == null)
}

"convert PostgreSql positive infinity value" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY))

val fixedTs = iterator.sanitizeTimestamp(timestamp)
Expand All @@ -229,7 +311,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}

"convert PostgreSql negative infinity value" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
val timestamp = Timestamp.from(Instant.ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY))

val fixedTs = iterator.sanitizeTimestamp(timestamp)
Expand All @@ -238,7 +320,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}

"convert overflowed value to the maximum value supported" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
val timestamp = Timestamp.from(Instant.ofEpochMilli(1000000000000000L))

val actual = iterator.sanitizeTimestamp(timestamp)
Expand All @@ -252,7 +334,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}

"do nothing if the feature is turned off" in {
val iterator = new ResultSetToRowIterator(resultSet, false)
val iterator = new ResultSetToRowIterator(resultSet, false, incorrectDecimalsAsString = false)
val timestamp = Timestamp.from(Instant.ofEpochMilli(1000000000000000L))

val actual = iterator.sanitizeTimestamp(timestamp)
Expand All @@ -272,15 +354,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
val maxDate = 253402214400000L

"ignore null values" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)

val fixedDate = iterator.sanitizeDate(null)

assert(fixedDate == null)
}

"convert PostgreSql positive infinity value" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
val date = new Date(POSTGRESQL_DATE_POSITIVE_INFINITY)

val fixedDate = iterator.sanitizeDate(date)
Expand All @@ -289,7 +371,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}

"convert PostgreSql negative infinity value" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
val date = new Date(POSTGRESQL_DATE_NEGATIVE_INFINITY)

val fixedDate = iterator.sanitizeDate(date)
Expand All @@ -298,7 +380,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}

"convert overflowed value to the maximum value supported" in {
val iterator = new ResultSetToRowIterator(resultSet, true)
val iterator = new ResultSetToRowIterator(resultSet, true, incorrectDecimalsAsString = false)
val date = new Date(1000000000000000L)

val actual = iterator.sanitizeDate(date)
Expand All @@ -312,7 +394,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
}

"do nothing if the feature is turned off" in {
val iterator = new ResultSetToRowIterator(resultSet, false)
val iterator = new ResultSetToRowIterator(resultSet, false, incorrectDecimalsAsString = false)
val date = new Date(1000000000000000L)

val actual = iterator.sanitizeDate(date)
Expand Down

0 comments on commit 1076b36

Please sign in to comment.