16
16
17
17
package za .co .absa .pramen .core .tests .utils
18
18
19
- import org .apache .spark .sql .types .IntegerType
19
+ import org .apache .spark .sql .types .{ DecimalType , IntegerType , StringType }
20
20
import org .mockito .Mockito .{mock , when }
21
21
import org .scalatest .wordspec .AnyWordSpec
22
22
import za .co .absa .pramen .core .base .SparkTestBase
@@ -26,6 +26,7 @@ import za.co.absa.pramen.core.samples.RdbExampleTable
26
26
import za .co .absa .pramen .core .utils .impl .ResultSetToRowIterator
27
27
import za .co .absa .pramen .core .utils .{JdbcNativeUtils , SparkUtils }
28
28
29
+ import java .math .BigDecimal
29
30
import java .sql ._
30
31
import java .time .{Instant , ZoneId }
31
32
import java .util .{Calendar , GregorianCalendar , TimeZone }
@@ -193,6 +194,87 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
193
194
}
194
195
}
195
196
197
+ " getDecimalSparkSchema" should {
198
+ val resultSet = mock(classOf [ResultSet ])
199
+ val resultSetMetaData = mock(classOf [ResultSetMetaData ])
200
+
201
+ when(resultSetMetaData.getColumnCount).thenReturn(1 )
202
+ when(resultSet.getMetaData).thenReturn(resultSetMetaData)
203
+
204
+ " return normal decimal for correct precision and scale" in {
205
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
206
+
207
+ assert(iterator.getDecimalSparkSchema(10 , 0 ) == DecimalType (10 , 0 ))
208
+ assert(iterator.getDecimalSparkSchema(10 , 2 ) == DecimalType (10 , 2 ))
209
+ assert(iterator.getDecimalSparkSchema(2 , 1 ) == DecimalType (2 , 1 ))
210
+ assert(iterator.getDecimalSparkSchema(38 , 18 ) == DecimalType (38 , 18 ))
211
+ }
212
+
213
+ " return fixed decimal for incorrect precision and scale" in {
214
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
215
+
216
+ assert(iterator.getDecimalSparkSchema(1 , - 1 ) == DecimalType (38 , 18 ))
217
+ assert(iterator.getDecimalSparkSchema(0 , 0 ) == DecimalType (38 , 18 ))
218
+ assert(iterator.getDecimalSparkSchema(0 , 2 ) == DecimalType (38 , 18 ))
219
+ assert(iterator.getDecimalSparkSchema(2 , 2 ) == DecimalType (38 , 18 ))
220
+ assert(iterator.getDecimalSparkSchema(39 , 0 ) == DecimalType (38 , 18 ))
221
+ assert(iterator.getDecimalSparkSchema(38 , 19 ) == DecimalType (38 , 18 ))
222
+ assert(iterator.getDecimalSparkSchema(20 , 19 ) == DecimalType (38 , 18 ))
223
+ }
224
+
225
+ " return string type for incorrect precision and scale" in {
226
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = true )
227
+
228
+ assert(iterator.getDecimalSparkSchema(1 , - 1 ) == StringType )
229
+ assert(iterator.getDecimalSparkSchema(0 , 0 ) == StringType )
230
+ assert(iterator.getDecimalSparkSchema(0 , 2 ) == StringType )
231
+ assert(iterator.getDecimalSparkSchema(2 , 2 ) == StringType )
232
+ assert(iterator.getDecimalSparkSchema(39 , 0 ) == StringType )
233
+ assert(iterator.getDecimalSparkSchema(38 , 19 ) == StringType )
234
+ assert(iterator.getDecimalSparkSchema(20 , 19 ) == StringType )
235
+ }
236
+ }
237
+
238
+ " getDecimalData" should {
239
+ val resultSet = mock(classOf [ResultSet ])
240
+ val resultSetMetaData = mock(classOf [ResultSetMetaData ])
241
+
242
+ when(resultSetMetaData.getColumnCount).thenReturn(1 )
243
+ when(resultSet.getMetaData).thenReturn(resultSetMetaData)
244
+ when(resultSet.getBigDecimal(0 )).thenReturn(new BigDecimal (1.0 ))
245
+ when(resultSet.getString(0 )).thenReturn(" 1" )
246
+
247
+ " return normal decimal for correct precision and scale" in {
248
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
249
+ when(resultSetMetaData.getPrecision(0 )).thenReturn(10 )
250
+ when(resultSetMetaData.getScale(0 )).thenReturn(2 )
251
+
252
+ val v = iterator.getDecimalData(0 )
253
+ assert(v.isInstanceOf [BigDecimal ])
254
+ assert(v.asInstanceOf [BigDecimal ].toString == " 1" )
255
+ }
256
+
257
+ " return fixed decimal for incorrect precision and scale" in {
258
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
259
+ when(resultSetMetaData.getPrecision(0 )).thenReturn(0 )
260
+ when(resultSetMetaData.getScale(0 )).thenReturn(2 )
261
+
262
+ val v = iterator.getDecimalData(0 )
263
+ assert(v.isInstanceOf [BigDecimal ])
264
+ assert(v.asInstanceOf [BigDecimal ].toString == " 1" )
265
+ }
266
+
267
+ " return string type for incorrect precision and scale" in {
268
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = true )
269
+ when(resultSetMetaData.getPrecision(0 )).thenReturn(0 )
270
+ when(resultSetMetaData.getScale(0 )).thenReturn(2 )
271
+
272
+ val v = iterator.getDecimalData(0 )
273
+ assert(v.isInstanceOf [String ])
274
+ assert(v.asInstanceOf [String ] == " 1" )
275
+ }
276
+ }
277
+
196
278
" sanitizeDateTime" when {
197
279
// Variable names come from PostgreSQL "constant field docs":
198
280
// https://jdbc.postgresql.org/documentation/publicapi/index.html?constant-values.html
@@ -212,15 +294,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
212
294
val maxTimestamp = 253402300799999L
213
295
214
296
" ignore null values" in {
215
- val iterator = new ResultSetToRowIterator (resultSet, true )
297
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
216
298
217
299
val fixedTs = iterator.sanitizeTimestamp(null )
218
300
219
301
assert(fixedTs == null )
220
302
}
221
303
222
304
" convert PostgreSql positive infinity value" in {
223
- val iterator = new ResultSetToRowIterator (resultSet, true )
305
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
224
306
val timestamp = Timestamp .from(Instant .ofEpochMilli(POSTGRESQL_DATE_POSITIVE_INFINITY ))
225
307
226
308
val fixedTs = iterator.sanitizeTimestamp(timestamp)
@@ -229,7 +311,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
229
311
}
230
312
231
313
" convert PostgreSql negative infinity value" in {
232
- val iterator = new ResultSetToRowIterator (resultSet, true )
314
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
233
315
val timestamp = Timestamp .from(Instant .ofEpochMilli(POSTGRESQL_DATE_NEGATIVE_INFINITY ))
234
316
235
317
val fixedTs = iterator.sanitizeTimestamp(timestamp)
@@ -238,7 +320,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
238
320
}
239
321
240
322
" convert overflowed value to the maximum value supported" in {
241
- val iterator = new ResultSetToRowIterator (resultSet, true )
323
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
242
324
val timestamp = Timestamp .from(Instant .ofEpochMilli(1000000000000000L ))
243
325
244
326
val actual = iterator.sanitizeTimestamp(timestamp)
@@ -252,7 +334,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
252
334
}
253
335
254
336
" do nothing if the feature is turned off" in {
255
- val iterator = new ResultSetToRowIterator (resultSet, false )
337
+ val iterator = new ResultSetToRowIterator (resultSet, false , incorrectDecimalsAsString = false )
256
338
val timestamp = Timestamp .from(Instant .ofEpochMilli(1000000000000000L ))
257
339
258
340
val actual = iterator.sanitizeTimestamp(timestamp)
@@ -272,15 +354,15 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
272
354
val maxDate = 253402214400000L
273
355
274
356
" ignore null values" in {
275
- val iterator = new ResultSetToRowIterator (resultSet, true )
357
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
276
358
277
359
val fixedDate = iterator.sanitizeDate(null )
278
360
279
361
assert(fixedDate == null )
280
362
}
281
363
282
364
" convert PostgreSql positive infinity value" in {
283
- val iterator = new ResultSetToRowIterator (resultSet, true )
365
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
284
366
val date = new Date (POSTGRESQL_DATE_POSITIVE_INFINITY )
285
367
286
368
val fixedDate = iterator.sanitizeDate(date)
@@ -289,7 +371,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
289
371
}
290
372
291
373
" convert PostgreSql negative infinity value" in {
292
- val iterator = new ResultSetToRowIterator (resultSet, true )
374
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
293
375
val date = new Date (POSTGRESQL_DATE_NEGATIVE_INFINITY )
294
376
295
377
val fixedDate = iterator.sanitizeDate(date)
@@ -298,7 +380,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
298
380
}
299
381
300
382
" convert overflowed value to the maximum value supported" in {
301
- val iterator = new ResultSetToRowIterator (resultSet, true )
383
+ val iterator = new ResultSetToRowIterator (resultSet, true , incorrectDecimalsAsString = false )
302
384
val date = new Date (1000000000000000L )
303
385
304
386
val actual = iterator.sanitizeDate(date)
@@ -312,7 +394,7 @@ class JdbcNativeUtilsSuite extends AnyWordSpec with RelationalDbFixture with Spa
312
394
}
313
395
314
396
" do nothing if the feature is turned off" in {
315
- val iterator = new ResultSetToRowIterator (resultSet, false )
397
+ val iterator = new ResultSetToRowIterator (resultSet, false , incorrectDecimalsAsString = false )
316
398
val date = new Date (1000000000000000L )
317
399
318
400
val actual = iterator.sanitizeDate(date)
0 commit comments