1616
1717package za .co .absa .pramen .core .tests .utils
1818
19- import org .apache .spark .sql .types .IntegerType
19+ import org .apache .spark .sql .types .{ DecimalType , IntegerType , StringType }
2020import org .mockito .Mockito .{mock , when }
2121import org .scalatest .wordspec .AnyWordSpec
2222import za .co .absa .pramen .core .base .SparkTestBase
@@ -26,6 +26,7 @@ import za.co.absa.pramen.core.samples.RdbExampleTable
2626import za .co .absa .pramen .core .utils .impl .ResultSetToRowIterator
2727import za .co .absa .pramen .core .utils .{JdbcNativeUtils , SparkUtils }
2828
29+ import java .math .BigDecimal
2930import java .sql ._
3031import java .time .{Instant , ZoneId }
3132import java .util .{Calendar , GregorianCalendar , TimeZone }
@@ -193,6 +194,87 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
193194 }
194195 }
195196
197+ " getDecimalSparkSchema" should {
198+ val resultSet = mock(classOf [ResultSet ])
199+ val resultSetMetaData = mock(classOf [ResultSetMetaData ])
200+
201+ when(resultSetMetaData.getColumnCount).thenReturn(1 )
202+ when(resultSet.getMetaData).thenReturn(resultSetMetaData)
203+
204+ " return normal decimal for correct precision and scale" in {
205+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
206+
207+ assert(iterator.getDecimalSparkSchema(10 , 0 ) == DecimalType (10 , 0 ))
208+ assert(iterator.getDecimalSparkSchema(10 , 2 ) == DecimalType (10 , 2 ))
209+ assert(iterator.getDecimalSparkSchema(2 , 1 ) == DecimalType (2 , 1 ))
210+ assert(iterator.getDecimalSparkSchema(38 , 18 ) == DecimalType (38 , 18 ))
211+ }
212+
213+ " return fixed decimal for incorrect precision and scale" in {
214+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
215+
216+ assert(iterator.getDecimalSparkSchema(1 , - 1 ) == DecimalType (38 , 18 ))
217+ assert(iterator.getDecimalSparkSchema(0 , 0 ) == DecimalType (38 , 18 ))
218+ assert(iterator.getDecimalSparkSchema(0 , 2 ) == DecimalType (38 , 18 ))
219+ assert(iterator.getDecimalSparkSchema(2 , 2 ) == DecimalType (38 , 18 ))
220+ assert(iterator.getDecimalSparkSchema(39 , 0 ) == DecimalType (38 , 18 ))
221+ assert(iterator.getDecimalSparkSchema(38 , 19 ) == DecimalType (38 , 18 ))
222+ assert(iterator.getDecimalSparkSchema(20 , 19 ) == DecimalType (38 , 18 ))
223+ }
224+
225+ " return string type for incorrect precision and scale" in {
226+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = true )
227+
228+ assert(iterator.getDecimalSparkSchema(1 , - 1 ) == StringType )
229+ assert(iterator.getDecimalSparkSchema(0 , 0 ) == StringType )
230+ assert(iterator.getDecimalSparkSchema(0 , 2 ) == StringType )
231+ assert(iterator.getDecimalSparkSchema(2 , 2 ) == StringType )
232+ assert(iterator.getDecimalSparkSchema(39 , 0 ) == StringType )
233+ assert(iterator.getDecimalSparkSchema(38 , 19 ) == StringType )
234+ assert(iterator.getDecimalSparkSchema(20 , 19 ) == StringType )
235+ }
236+ }
237+
238+ " getDecimalData" should {
239+ val resultSet = mock(classOf [ResultSet ])
240+ val resultSetMetaData = mock(classOf [ResultSetMetaData ])
241+
242+ when(resultSetMetaData.getColumnCount).thenReturn(1 )
243+ when(resultSet.getMetaData).thenReturn(resultSetMetaData)
244+ when(resultSet.getBigDecimal(0 )).thenReturn(new BigDecimal (1.0 ))
245+ when(resultSet.getString(0 )).thenReturn(" 1" )
246+
247+ " return normal decimal for correct precision and scale" in {
248+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
249+ when(resultSetMetaData.getPrecision(0 )).thenReturn(10 )
250+ when(resultSetMetaData.getScale(0 )).thenReturn(2 )
251+
252+ val v = iterator.getDecimalData(0 )
253+ assert(v.isInstanceOf [BigDecimal ])
254+ assert(v.asInstanceOf [BigDecimal ].toString == " 1" )
255+ }
256+
257+ " return fixed decimal for incorrect precision and scale" in {
258+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
259+ when(resultSetMetaData.getPrecision(0 )).thenReturn(0 )
260+ when(resultSetMetaData.getScale(0 )).thenReturn(2 )
261+
262+ val v = iterator.getDecimalData(0 )
263+ assert(v.isInstanceOf [BigDecimal ])
264+ assert(v.asInstanceOf [BigDecimal ].toString == " 1" )
265+ }
266+
267+ " return string type for incorrect precision and scale" in {
268+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = true )
269+ when(resultSetMetaData.getPrecision(0 )).thenReturn(0 )
270+ when(resultSetMetaData.getScale(0 )).thenReturn(2 )
271+
272+ val v = iterator.getDecimalData(0 )
273+ assert(v.isInstanceOf [String ])
274+ assert(v.asInstanceOf [String ] == " 1" )
275+ }
276+ }
277+
196278 " sanitizeDateTime" when {
197279 // Variable names come from PostgreSQL "constant field docs":
198280 // https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
@@ -212,15 +294,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
212294 val maxTimestamp = 253402300799999L
213295
214296 " ignore null values" in {
215- val iterator = new ResultSetToRowIterator (resultSet, true )
297+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
216298
217299 val fixedTs = iterator.sanitizeTimestamp(null )
218300
219301 assert(fixedTs == null )
220302 }
221303
222304 " convert PostgreSql positive infinity value" in {
223- val iterator = new ResultSetToRowIterator (resultSet, true )
305+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
224306 val timestamp = Timestamp .from(Instant .ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY ))
225307
226308 val fixedTs = iterator.sanitizeTimestamp(timestamp)
@@ -229,7 +311,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
229311 }
230312
231313 " convert PostgreSql negative infinity value" in {
232- val iterator = new ResultSetToRowIterator (resultSet, true )
314+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
233315 val timestamp = Timestamp .from(Instant .ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY ))
234316
235317 val fixedTs = iterator.sanitizeTimestamp(timestamp)
@@ -238,7 +320,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
238320 }
239321
240322 " convert overflowed value to the maximum value supported" in {
241- val iterator = new ResultSetToRowIterator (resultSet, true )
323+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
242324 val timestamp = Timestamp .from(Instant .ofEpochMilli(1000000000000000L ))
243325
244326 val actual = iterator.sanitizeTimestamp(timestamp)
@@ -252,7 +334,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
252334 }
253335
254336 " do nothing if the feature is turned off" in {
255- val iterator = new ResultSetToRowIterator (resultSet, false )
337+ val iterator = new ResultSetToRowIterator (resultSet, false , incorrectDecimalsAsString = false )
256338 val timestamp = Timestamp .from(Instant .ofEpochMilli(1000000000000000L ))
257339
258340 val actual = iterator.sanitizeTimestamp(timestamp)
@@ -272,15 +354,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
272354 val maxDate = 253402214400000L
273355
274356 " ignore null values" in {
275- val iterator = new ResultSetToRowIterator (resultSet, true )
357+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
276358
277359 val fixedDate = iterator.sanitizeDate(null )
278360
279361 assert(fixedDate == null )
280362 }
281363
282364 " convert PostgreSql positive infinity value" in {
283- val iterator = new ResultSetToRowIterator (resultSet, true )
365+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
284366 val date = new Date (POSTGRESQL_DATE_POSITIVE_INFINITY )
285367
286368 val fixedDate = iterator.sanitizeDate(date)
@@ -289,7 +371,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
289371 }
290372
291373 " convert PostgreSql negative infinity value" in {
292- val iterator = new ResultSetToRowIterator (resultSet, true )
374+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
293375 val date = new Date (POSTGRESQL_DATE_NEGATIVE_INFINITY )
294376
295377 val fixedDate = iterator.sanitizeDate(date)
@@ -298,7 +380,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
298380 }
299381
300382 " convert overflowed value to the maximum value supported" in {
301- val iterator = new ResultSetToRowIterator (resultSet, true )
383+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
302384 val date = new Date (1000000000000000L )
303385
304386 val actual = iterator.sanitizeDate(date)
@@ -312,7 +394,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
312394 }
313395
314396 " do nothing if the feature is turned off" in {
315- val iterator = new ResultSetToRowIterator (resultSet, false )
397+ val iterator = new ResultSetToRowIterator (resultSet, false , incorrectDecimalsAsString = false )
316398 val date = new Date (1000000000000000L )
317399
318400 val actual = iterator.sanitizeDate(date)
0 commit comments