Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#740 extend binary type support to make sure unsigned binary fields can fit Spark data types #742

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ object BinaryNumberDecoders {
if (v<0) null else v
}

def decodeBinaryUnsignedIntBigEndianAsLong(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 4) {
return null
}
((bytes(0) & 255L) << 24L) | ((bytes(1) & 255L) << 16L) | ((bytes(2) & 255L) << 8L) | (bytes(3) & 255L)
}

def decodeBinaryUnsignedIntLittleEndian(bytes: Array[Byte]): Integer = {
if (bytes.length < 4) {
return null
Expand All @@ -90,6 +97,13 @@ object BinaryNumberDecoders {
if (v<0) null else v
}

def decodeBinaryUnsignedIntLittleEndianAsLong(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 4) {
return null
}
((bytes(3) & 255L) << 24L) | ((bytes(2) & 255L) << 16L) | ((bytes(1) & 255L) << 8L) | (bytes(0) & 255L)
}

def decodeBinarySignedLongBigEndian(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 8) {
return null
Expand All @@ -112,6 +126,13 @@ object BinaryNumberDecoders {
if (v < 0L) null else v
}

def decodeBinaryUnsignedLongBigEndianAsDecimal(bytes: Array[Byte]): BigDecimal = {
if (bytes.length < 8) {
return null
}
BigDecimal(BigInt(1, bytes).toString())
}

def decodeBinaryUnsignedLongLittleEndian(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 8) {
return null
Expand All @@ -120,6 +141,13 @@ object BinaryNumberDecoders {
if (v < 0L) null else v
}

def decodeBinaryUnsignedLongLittleEndianAsDecimal(bytes: Array[Byte]): BigDecimal = {
if (bytes.length < 8) {
return null
}
BigDecimal(BigInt(1, bytes.reverse).toString())
}

def decodeBinaryAribtraryPrecision(bytes: Array[Byte], isBigEndian: Boolean, isSigned: Boolean): BigDecimal = {
if (bytes.length == 0) {
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package za.co.absa.cobrix.cobol.parser.decoders
import java.nio.charset.{Charset, StandardCharsets}
import za.co.absa.cobrix.cobol.parser.ast.datatype._
import za.co.absa.cobrix.cobol.parser.common.Constants
import za.co.absa.cobrix.cobol.parser.common.Constants.{maxIntegerPrecision, maxLongPrecision}
import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat.FloatingPointFormat
import za.co.absa.cobrix.cobol.parser.encoding._
import za.co.absa.cobrix.cobol.parser.encoding.codepage.{CodePage, CodePageCommon}
Expand Down Expand Up @@ -255,26 +256,32 @@ object DecoderSelector {
val isSigned = signPosition.nonEmpty

val numOfBytes = BinaryUtils.getBytesCount(compact, precision, isSigned, isExplicitDecimalPt = false, isSignSeparate = false)
val isMaxUnsignedPrecision = precision == maxIntegerPrecision || precision == maxLongPrecision

val decoder = if (strictIntegralPrecision) {
(a: Array[Byte]) => BinaryNumberDecoders.decodeBinaryAribtraryPrecision(a, isBigEndian, isSigned)
} else {
(isSigned, isBigEndian, numOfBytes) match {
case (true, true, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, true, 2) => BinaryNumberDecoders.decodeBinarySignedShortBigEndian _
case (true, true, 4) => BinaryNumberDecoders.decodeBinarySignedIntBigEndian _
case (true, true, 8) => BinaryNumberDecoders.decodeBinarySignedLongBigEndian _
case (true, false, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, false, 2) => BinaryNumberDecoders.decodeBinarySignedShortLittleEndian _
case (true, false, 4) => BinaryNumberDecoders.decodeBinarySignedIntLittleEndian _
case (true, false, 8) => BinaryNumberDecoders.decodeBinarySignedLongLittleEndian _
case (false, true, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, true, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortBigEndian _
case (false, true, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntBigEndian _
case (false, true, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongBigEndian _
case (false, false, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, false, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortLittleEndian _
case (false, false, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntLittleEndian _
case (false, false, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongLittleEndian _
(isSigned, isBigEndian, isMaxUnsignedPrecision, numOfBytes) match {
case (true, true, _, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, true, _, 2) => BinaryNumberDecoders.decodeBinarySignedShortBigEndian _
case (true, true, _, 4) => BinaryNumberDecoders.decodeBinarySignedIntBigEndian _
case (true, true, _, 8) => BinaryNumberDecoders.decodeBinarySignedLongBigEndian _
case (true, false, _, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, false, _, 2) => BinaryNumberDecoders.decodeBinarySignedShortLittleEndian _
case (true, false, _, 4) => BinaryNumberDecoders.decodeBinarySignedIntLittleEndian _
case (true, false, _, 8) => BinaryNumberDecoders.decodeBinarySignedLongLittleEndian _
case (false, true, _, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, true, _, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortBigEndian _
case (false, true, false, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntBigEndian _
case (false, true, true, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntBigEndianAsLong _
case (false, true, false, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongBigEndian _
case (false, true, true, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongBigEndianAsDecimal _
case (false, false, _, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, false, _, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortLittleEndian _
case (false, false, false, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntLittleEndian _
case (false, false, true, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntLittleEndianAsLong _
case (false, false, false, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongLittleEndian _
case (false, false, true, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongLittleEndianAsDecimal _
case _ =>
(a: Array[Byte]) => BinaryNumberDecoders.decodeBinaryAribtraryPrecision(a, isBigEndian, isSigned)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,14 @@ class BinaryDecoderSpec extends AnyFunSuite {
val decoderUnsignedShort = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 3, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderSignedInt = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 8, compact = Some(COMP4())), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntBe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 8, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntBeAsLong = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 9, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntLe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 8, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntLeAsLong = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 9, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderSignedLong = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 15, compact = Some(COMP4())), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongBe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 15, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongBeAsBig = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 18, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongLe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 15, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongLeAsBig = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 18, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)

val num1 = decoderSignedByte(Array(0x10).map(_.toByte))
assert(num1.isInstanceOf[Integer])
Expand Down Expand Up @@ -501,10 +505,18 @@ class BinaryDecoderSpec extends AnyFunSuite {
assert(num9.isInstanceOf[Integer])
assert(num9.asInstanceOf[Integer] == 9437184)

val num9a = decoderUnsignedIntBeAsLong(Array(0x00, 0x90, 0x00, 0x00).map(_.toByte))
assert(num9a.isInstanceOf[java.lang.Long])
assert(num9a.asInstanceOf[java.lang.Long] == 9437184L)

val num10 = decoderUnsignedIntLe(Array(0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num10.isInstanceOf[Integer])
assert(num10.asInstanceOf[Integer] == 9437184)

val num10a = decoderUnsignedIntLeAsLong(Array(0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num10a.isInstanceOf[java.lang.Long])
assert(num10a.asInstanceOf[java.lang.Long] == 9437184L)

val num11 = decoderSignedLong(Array(0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00).map(_.toByte))
assert(num11.isInstanceOf[Long])
assert(num11.asInstanceOf[Long] == 72057594037927936L)
Expand All @@ -517,9 +529,17 @@ class BinaryDecoderSpec extends AnyFunSuite {
assert(num13.isInstanceOf[Long])
assert(num13.asInstanceOf[Long] == 40532396646334464L)

val num13a = decoderUnsignedLongBeAsBig(Array(0x00, 0x90, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00).map(_.toByte))
assert(num13a.isInstanceOf[BigDecimal])
assert(num13a.asInstanceOf[BigDecimal] == BigDecimal("40532396646334464"))

val num14 = decoderUnsignedLongLe(Array(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num14.isInstanceOf[Long])
assert(num14.asInstanceOf[Long] == 40532396646334464L)

val num14a = decoderUnsignedLongLeAsBig(Array(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num14a.isInstanceOf[BigDecimal])
assert(num14a.asInstanceOf[BigDecimal] == BigDecimal("40532396646334464"))
}

test("Test Binary strict integral precision numbers") {
Expand Down
2 changes: 1 addition & 1 deletion data/test17_expected/test17a_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
} ]
Expand Down
2 changes: 1 addition & 1 deletion data/test17_expected/test17b_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
} ]
Expand Down
2 changes: 1 addition & 1 deletion data/test17_expected/test17c_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
2 changes: 1 addition & 1 deletion data/test18 special_char_expected/test18a_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
} ]
Expand Down
4 changes: 2 additions & 2 deletions data/test24_expected/test24_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@
}
}, {
"name" : "NUM_BIN_INT07",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand Down Expand Up @@ -760,7 +760,7 @@
}
}, {
"name" : "NUM_BIN_INT11",
"type" : "long",
"type" : "decimal(20,0)",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
4 changes: 2 additions & 2 deletions data/test24_expected/test24b_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@
}
}, {
"name" : "NUM_BIN_INT07",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand Down Expand Up @@ -760,7 +760,7 @@
}
}, {
"name" : "NUM_BIN_INT11",
"type" : "long",
"type" : "decimal(20,0)",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
4 changes: 2 additions & 2 deletions data/test6_expected/test6_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"metadata" : { }
}, {
"name" : "NUM_BIN_INT07",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand All @@ -319,7 +319,7 @@
"metadata" : { }
}, {
"name" : "NUM_BIN_INT11",
"type" : "long",
"type" : "decimal(20,0)",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.types._
import za.co.absa.cobrix.cobol.internal.Logging
import za.co.absa.cobrix.cobol.parser.Copybook
import za.co.absa.cobrix.cobol.parser.ast._
import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP1, COMP2, Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP1, COMP2, COMP4, COMP5, COMP9, Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.common.Constants
import za.co.absa.cobrix.cobol.parser.encoding.RAW
import za.co.absa.cobrix.cobol.parser.policies.MetadataPolicy
Expand Down Expand Up @@ -66,23 +66,10 @@ class CobolSchema(copybook: Copybook,
@throws(classOf[IllegalStateException])
private[this] lazy val sparkSchema = createSparkSchema()

@throws(classOf[IllegalStateException])
private[this] lazy val sparkFlatSchema = {
val arraySchema = copybook.ast.children.toArray
val records = arraySchema.flatMap(record => {
parseGroupFlat(record.asInstanceOf[Group], s"${record.name}_")
})
StructType(records)
}

def getSparkSchema: StructType = {
sparkSchema
}

def getSparkFlatSchema: StructType = {
sparkFlatSchema
}

@throws(classOf[IllegalStateException])
private def createSparkSchema(): StructType = {
val records = for (record <- copybook.getRootRecords) yield {
Expand Down Expand Up @@ -200,12 +187,16 @@ class CobolSchema(copybook: Copybook,
case dt: Integral if strictIntegralPrecision =>
DecimalType(precision = dt.precision, scale = 0)
case dt: Integral =>
val isBinary = dt.compact.exists(c => c == COMP4() || c == COMP5() || c == COMP9())
if (dt.precision > Constants.maxLongPrecision) {
DecimalType(precision = dt.precision, scale = 0)
} else if (dt.precision == Constants.maxLongPrecision && isBinary && dt.signPosition.isEmpty) { // promoting unsigned int to long to be able to fit any value
DecimalType(precision = dt.precision + 2, scale = 0)
} else if (dt.precision > Constants.maxIntegerPrecision) {
LongType
}
else {
} else if (dt.precision == Constants.maxIntegerPrecision && isBinary && dt.signPosition.isEmpty) { // promoting unsigned long to decimal(20) to be able to fit any value
LongType
} else {
IntegerType
}
case _ => throw new IllegalStateException("Unknown AST object")
Expand Down Expand Up @@ -290,53 +281,6 @@ class CobolSchema(copybook: Copybook,
})
childSegments
}

@throws(classOf[IllegalStateException])
private def parseGroupFlat(group: Group, structPath: String = ""): ArrayBuffer[StructField] = {
val fields = new ArrayBuffer[StructField]()
for (field <- group.children if !field.isFiller) {
field match {
case group: Group =>
if (group.isArray) {
for (i <- Range(1, group.arrayMaxSize + 1)) {
val path = s"$structPath${group.name}_${i}_"
fields ++= parseGroupFlat(group, path)
}
} else {
val path = s"$structPath${group.name}_"
fields ++= parseGroupFlat(group, path)
}
case s: Primitive =>
val dataType: DataType = s.dataType match {
case d: Decimal =>
DecimalType(d.getEffectivePrecision, d.getEffectiveScale)
case a: AlphaNumeric =>
a.enc match {
case Some(RAW) => BinaryType
case _ => StringType
}
case dt: Integral =>
if (dt.precision > Constants.maxIntegerPrecision) {
LongType
}
else {
IntegerType
}
case _ => throw new IllegalStateException("Unknown AST object")
}
val path = s"$structPath" //${group.name}_"
if (s.isArray) {
for (i <- Range(1, s.arrayMaxSize + 1)) {
fields += StructField(s"$path{s.name}_$i", ArrayType(dataType), nullable = true)
}
} else {
fields += StructField(s"$path${s.name}", dataType, nullable = true)
}
}
}

fields
}
}

object CobolSchema {
Expand Down
Loading
Loading