1616
1717package  za .co .absa .pramen .core .tests .utils 
1818
19- import  org .apache .spark .sql .types .IntegerType 
19+ import  org .apache .spark .sql .types .{ DecimalType ,  IntegerType ,  StringType } 
2020import  org .mockito .Mockito .{mock , when }
2121import  org .scalatest .wordspec .AnyWordSpec 
2222import  za .co .absa .pramen .core .base .SparkTestBase 
@@ -26,6 +26,7 @@ import za.co.absa.pramen.core.samples.RdbExampleTable
2626import  za .co .absa .pramen .core .utils .impl .ResultSetToRowIterator 
2727import  za .co .absa .pramen .core .utils .{JdbcNativeUtils , SparkUtils }
2828
29+ import  java .math .BigDecimal 
2930import  java .sql ._ 
3031import  java .time .{Instant , ZoneId }
3132import  java .util .{Calendar , GregorianCalendar , TimeZone }
@@ -193,6 +194,87 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
193194    }
194195  }
195196
197+   " getDecimalSparkSchema" 
198+     val  resultSet  =  mock(classOf [ResultSet ])
199+     val  resultSetMetaData  =  mock(classOf [ResultSetMetaData ])
200+ 
201+     when(resultSetMetaData.getColumnCount).thenReturn(1 )
202+     when(resultSet.getMetaData).thenReturn(resultSetMetaData)
203+ 
204+     " return normal decimal for correct precision and scale" 
205+       val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString =  false )
206+ 
207+       assert(iterator.getDecimalSparkSchema(10 , 0 ) ==  DecimalType (10 , 0 ))
208+       assert(iterator.getDecimalSparkSchema(10 , 2 ) ==  DecimalType (10 , 2 ))
209+       assert(iterator.getDecimalSparkSchema(2 , 1 ) ==  DecimalType (2 , 1 ))
210+       assert(iterator.getDecimalSparkSchema(38 , 18 ) ==  DecimalType (38 , 18 ))
211+     }
212+ 
213+     " return fixed decimal for incorrect precision and scale" 
214+       val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString =  false )
215+ 
216+       assert(iterator.getDecimalSparkSchema(1 , - 1 ) ==  DecimalType (38 , 18 ))
217+       assert(iterator.getDecimalSparkSchema(0 , 0 ) ==  DecimalType (38 , 18 ))
218+       assert(iterator.getDecimalSparkSchema(0 , 2 ) ==  DecimalType (38 , 18 ))
219+       assert(iterator.getDecimalSparkSchema(2 , 2 ) ==  DecimalType (38 , 18 ))
220+       assert(iterator.getDecimalSparkSchema(39 , 0 ) ==  DecimalType (38 , 18 ))
221+       assert(iterator.getDecimalSparkSchema(38 , 19 ) ==  DecimalType (38 , 18 ))
222+       assert(iterator.getDecimalSparkSchema(20 , 19 ) ==  DecimalType (38 , 18 ))
223+     }
224+ 
225+     " return string type for incorrect precision and scale" 
226+       val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString =  true )
227+ 
228+       assert(iterator.getDecimalSparkSchema(1 , - 1 ) ==  StringType )
229+       assert(iterator.getDecimalSparkSchema(0 , 0 ) ==  StringType )
230+       assert(iterator.getDecimalSparkSchema(0 , 2 ) ==  StringType )
231+       assert(iterator.getDecimalSparkSchema(2 , 2 ) ==  StringType )
232+       assert(iterator.getDecimalSparkSchema(39 , 0 ) ==  StringType )
233+       assert(iterator.getDecimalSparkSchema(38 , 19 ) ==  StringType )
234+       assert(iterator.getDecimalSparkSchema(20 , 19 ) ==  StringType )
235+     }
236+   }
237+ 
238+   " getDecimalData" 
239+     val  resultSet  =  mock(classOf [ResultSet ])
240+     val  resultSetMetaData  =  mock(classOf [ResultSetMetaData ])
241+ 
242+     when(resultSetMetaData.getColumnCount).thenReturn(1 )
243+     when(resultSet.getMetaData).thenReturn(resultSetMetaData)
244+     when(resultSet.getBigDecimal(0 )).thenReturn(new  BigDecimal (1.0 ))
245+     when(resultSet.getString(0 )).thenReturn(" 1" 
246+ 
247+     " return normal decimal for correct precision and scale" 
248+       val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString =  false )
249+       when(resultSetMetaData.getPrecision(0 )).thenReturn(10 )
250+       when(resultSetMetaData.getScale(0 )).thenReturn(2 )
251+ 
252+       val  v  =  iterator.getDecimalData(0 )
253+       assert(v.isInstanceOf [BigDecimal ])
254+       assert(v.asInstanceOf [BigDecimal ].toString ==  " 1" 
255+     }
256+ 
257+     " return fixed decimal for incorrect precision and scale" 
258+       val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString =  false )
259+       when(resultSetMetaData.getPrecision(0 )).thenReturn(0 )
260+       when(resultSetMetaData.getScale(0 )).thenReturn(2 )
261+ 
262+       val  v  =  iterator.getDecimalData(0 )
263+       assert(v.isInstanceOf [BigDecimal ])
264+       assert(v.asInstanceOf [BigDecimal ].toString ==  " 1" 
265+     }
266+ 
267+     " return string type for incorrect precision and scale" 
268+       val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString =  true )
269+       when(resultSetMetaData.getPrecision(0 )).thenReturn(0 )
270+       when(resultSetMetaData.getScale(0 )).thenReturn(2 )
271+ 
272+       val  v  =  iterator.getDecimalData(0 )
273+       assert(v.isInstanceOf [String ])
274+       assert(v.asInstanceOf [String ] ==  " 1" 
275+     }
276+   }
277+ 
196278  " sanitizeDateTime" 
197279    //  Variable names come from PostgreSQL "constant field docs":
198280    //  https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
@@ -212,15 +294,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
212294      val  maxTimestamp  =  253402300799999L 
213295
214296      " ignore null values" 
215-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
297+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
216298
217299        val  fixedTs  =  iterator.sanitizeTimestamp(null )
218300
219301        assert(fixedTs ==  null )
220302      }
221303
222304      " convert PostgreSql positive infinity value" 
223-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
305+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
224306        val  timestamp  =  Timestamp .from(Instant .ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY ))
225307
226308        val  fixedTs  =  iterator.sanitizeTimestamp(timestamp)
@@ -229,7 +311,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
229311      }
230312
231313      " convert PostgreSql negative infinity value" 
232-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
314+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
233315        val  timestamp  =  Timestamp .from(Instant .ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY ))
234316
235317        val  fixedTs  =  iterator.sanitizeTimestamp(timestamp)
@@ -238,7 +320,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
238320      }
239321
240322      " convert overflowed value to the maximum value supported" 
241-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
323+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
242324        val  timestamp  =  Timestamp .from(Instant .ofEpochMilli(1000000000000000L ))
243325
244326        val  actual  =  iterator.sanitizeTimestamp(timestamp)
@@ -252,7 +334,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
252334      }
253335
254336      " do nothing if the feature is turned off" 
255-         val  iterator  =  new  ResultSetToRowIterator (resultSet, false )
337+         val  iterator  =  new  ResultSetToRowIterator (resultSet, false , incorrectDecimalsAsString  =   false )
256338        val  timestamp  =  Timestamp .from(Instant .ofEpochMilli(1000000000000000L ))
257339
258340        val  actual  =  iterator.sanitizeTimestamp(timestamp)
@@ -272,15 +354,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
272354      val  maxDate  =  253402214400000L 
273355
274356      " ignore null values" 
275-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
357+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
276358
277359        val  fixedDate  =  iterator.sanitizeDate(null )
278360
279361        assert(fixedDate ==  null )
280362      }
281363
282364      " convert PostgreSql positive infinity value" 
283-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
365+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
284366        val  date  =  new  Date (POSTGRESQL_DATE_POSITIVE_INFINITY )
285367
286368        val  fixedDate  =  iterator.sanitizeDate(date)
@@ -289,7 +371,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
289371      }
290372
291373      " convert PostgreSql negative infinity value" 
292-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
374+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
293375        val  date  =  new  Date (POSTGRESQL_DATE_NEGATIVE_INFINITY )
294376
295377        val  fixedDate  =  iterator.sanitizeDate(date)
@@ -298,7 +380,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
298380      }
299381
300382      " convert overflowed value to the maximum value supported" 
301-         val  iterator  =  new  ResultSetToRowIterator (resultSet, true )
383+         val  iterator  =  new  ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString  =   false )
302384        val  date  =  new  Date (1000000000000000L )
303385
304386        val  actual  =  iterator.sanitizeDate(date)
@@ -312,7 +394,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
312394      }
313395
314396      " do nothing if the feature is turned off" 
315-         val  iterator  =  new  ResultSetToRowIterator (resultSet, false )
397+         val  iterator  =  new  ResultSetToRowIterator (resultSet, false , incorrectDecimalsAsString  =   false )
316398        val  date  =  new  Date (1000000000000000L )
317399
318400        val  actual  =  iterator.sanitizeDate(date)
0 commit comments