Skip to content

Commit e545cf9

Browse files
committed
#672 Add unit tests for Spark schema generation.
1 parent a39c527 commit e545cf9

File tree

1 file changed

+77
-3
lines changed

1 file changed

+77
-3
lines changed

spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/CobolSchemaSpec.scala

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package za.co.absa.cobrix.spark.cobol
1818

19-
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
19+
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}
2020
import org.scalatest.wordspec.AnyWordSpec
2121
import org.slf4j.{Logger, LoggerFactory}
2222
import za.co.absa.cobrix.cobol.parser.CopybookParser
@@ -409,14 +409,14 @@ class CobolSchemaSpec extends AnyWordSpec with SimpleComparisonBase {
409409

410410
"fromSparkOptions" should {
411411
"return a schema for a copybook" in {
412-
val copyBook: String =
412+
val copybook: String =
413413
""" 01 RECORD.
414414
| 05 STR1 PIC X(10).
415415
| 05 STR2 PIC A(7).
416416
| 05 NUM3 PIC 9(7).
417417
|""".stripMargin
418418

419-
val cobolSchema = CobolSchema.fromSparkOptions(Seq(copyBook), Map.empty)
419+
val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook), Map.empty)
420420

421421
val sparkSchema = cobolSchema.getSparkSchema
422422

@@ -428,6 +428,80 @@ class CobolSchemaSpec extends AnyWordSpec with SimpleComparisonBase {
428428
assert(sparkSchema.fields(2).name == "NUM3")
429429
assert(sparkSchema.fields(2).dataType == IntegerType)
430430
}
431+
432+
"return a schema for multiple copybooks" in {
433+
val copybook1: String =
434+
""" 01 RECORD1.
435+
| 05 STR1 PIC X(10).
436+
| 05 STR2 PIC A(7).
437+
| 05 NUM3 PIC 9(7).
438+
|""".stripMargin
439+
440+
val copybook2: String =
441+
""" 01 RECORD2.
442+
| 05 STR4 PIC X(10).
443+
| 05 STR5 PIC A(7).
444+
| 05 NUM6 PIC 9(7).
445+
|""".stripMargin
446+
447+
val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook1, copybook2), Map("schema_retention_policy" -> "keep_original"))
448+
449+
val sparkSchema = cobolSchema.getSparkSchema
450+
451+
assert(sparkSchema.fields.length == 2)
452+
assert(sparkSchema.fields.head.name == "RECORD1")
453+
assert(sparkSchema.fields.head.dataType.isInstanceOf[StructType])
454+
assert(sparkSchema.fields(1).name == "RECORD2")
455+
assert(sparkSchema.fields(1).dataType.isInstanceOf[StructType])
456+
assert(cobolSchema.getCobolSchema.ast.children.head.isRedefined)
457+
assert(cobolSchema.getCobolSchema.ast.children(1).redefines.contains("RECORD1"))
458+
}
459+
460+
"return a schema for a hierarchical copybook" in {
461+
val copybook: String =
462+
""" 01 RECORD.
463+
| 05 HEADER PIC X(5).
464+
| 05 SEGMENT-ID PIC X(2).
465+
| 05 SEG1.
466+
| 10 FIELD1 PIC 9(7).
467+
| 05 SEG2 REDEFINES SEG1.
468+
| 10 FIELD3 PIC X(7).
469+
| 05 SEG3 REDEFINES SEG1.
470+
| 10 FIELD4 PIC S9(7).
471+
|""".stripMargin
472+
473+
val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook),
474+
Map(
475+
"segment_field" -> "SEGMENT-ID",
476+
"redefine-segment-id-map:0" -> "SEG1 => 01",
477+
"redefine-segment-id-map:1" -> "SEG2 => 02",
478+
"redefine-segment-id-map:2" -> "SEG3 => 03,0A",
479+
"segment-children:1" -> "SEG1 => SEG2",
480+
"segment-children:2" -> "SEG1 => SEG3"
481+
)
482+
)
483+
484+
val sparkSchema = cobolSchema.getSparkSchema
485+
486+
sparkSchema.printTreeString()
487+
488+
assert(sparkSchema.fields.length == 3)
489+
assert(sparkSchema.fields.head.name == "HEADER")
490+
assert(sparkSchema.fields.head.dataType == StringType)
491+
assert(sparkSchema.fields(1).name == "SEGMENT_ID")
492+
assert(sparkSchema.fields(1).dataType == StringType)
493+
assert(sparkSchema.fields(2).name == "SEG1")
494+
assert(sparkSchema.fields(2).dataType.isInstanceOf[StructType])
495+
496+
val seg1 = sparkSchema.fields(2).dataType.asInstanceOf[StructType]
497+
assert(seg1.fields.length == 3)
498+
assert(seg1.fields.head.name == "FIELD1")
499+
assert(seg1.fields.head.dataType == IntegerType)
500+
assert(seg1.fields(1).name == "SEG2")
501+
assert(seg1.fields(1).dataType.isInstanceOf[ArrayType])
502+
assert(seg1.fields(2).name == "SEG3")
503+
assert(seg1.fields(2).dataType.isInstanceOf[ArrayType])
504+
}
431505
}
432506

433507
}

0 commit comments

Comments
 (0)