diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h index 3c47fb5479bd8..cf34b6a72c803 100644 --- a/cpp/core/config/GlutenConfig.h +++ b/cpp/core/config/GlutenConfig.h @@ -34,6 +34,8 @@ const std::string kLegacySize = "spark.sql.legacy.sizeOfNull"; const std::string kSessionTimezone = "spark.sql.session.timeZone"; +const std::string kAllowPrecisionLoss = "spark.sql.decimalOperations.allowPrecisionLoss"; + const std::string kIgnoreMissingFiles = "spark.sql.files.ignoreMissingFiles"; const std::string kDefaultSessionTimezone = "spark.gluten.sql.session.timeZone.default"; diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index 214d3cab94a57..1ef9636f3de1a 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -444,6 +444,9 @@ std::unordered_map WholeStageResultIterator::getQueryC veloxCfg_->get(kSessionTimezone, defaultTimezone); // Adjust timestamp according to the above configured session timezone. configs[velox::core::QueryConfig::kAdjustTimestampToTimezone] = std::to_string(true); + // To align with Spark's behavior, allow decimal precision loss or not. + configs[velox::core::QueryConfig::kAllowPrecisionLoss] = + veloxCfg_->get(kAllowPrecisionLoss, "true"); // Align Velox size function with Spark. configs[velox::core::QueryConfig::kSparkLegacySizeOfNull] = std::to_string(veloxCfg_->get(kLegacySize, true)); diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index eb7cd5d841a43..24beeeb6b8581 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -16,8 +16,8 @@ set -exu -VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_03_06 +VELOX_REPO=https://github.com/zhouyuan/velox.git +VELOX_BRANCH=wip_decimal_precision_loss VELOX_HOME="" #Set on run gluten on HDFS diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala index 5c994bdc0a281..0648cfc6ccc48 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala @@ -480,14 +480,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), expr) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => - // PrecisionLoss=true: velox support / ch not support - // PrecisionLoss=false: velox not support / ch support - // TODO ch support PrecisionLoss=true - if (!BackendsApiManager.getSettings.allowDecimalArithmetic) { - throw new UnsupportedOperationException( - s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " + - s"${conf.decimalOperationsAllowPrecisionLoss} mode") - } val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) { DecimalArithmeticUtil.rescaleLiteral(b) } else { diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/DecimalArithmeticUtil.scala b/gluten-core/src/main/scala/io/glutenproject/utils/DecimalArithmeticUtil.scala index da1feab1e5705..5a43130adc1ea 100644 --- a/gluten-core/src/main/scala/io/glutenproject/utils/DecimalArithmeticUtil.scala +++ b/gluten-core/src/main/scala/io/glutenproject/utils/DecimalArithmeticUtil.scala @@ -21,6 +21,7 @@ import io.glutenproject.expression.{CheckOverflowTransformer, ChildTransformer, import org.apache.spark.sql.catalyst.analysis.DecimalPrecision import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType} object DecimalArithmeticUtil { @@ -32,12 +33,14 @@ object DecimalArithmeticUtil { val MIN_ADJUSTED_SCALE = 6 val MAX_PRECISION = 38 + val MAX_SCALE = 38 // Returns the result decimal type of a decimal arithmetic computing. def getResultTypeForOperation( operationType: OperationType.Config, type1: DecimalType, type2: DecimalType): DecimalType = { + val allowPrecisionLoss = SQLConf.get.decimalOperationsAllowPrecisionLoss var resultScale = 0 var resultPrecision = 0 operationType match { @@ -53,8 +56,20 @@ object DecimalArithmeticUtil { resultScale = type1.scale + type2.scale resultPrecision = type1.precision + type2.precision + 1 case OperationType.DIVIDE => - resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1) - resultPrecision = type1.precision - type1.scale + type2.scale + resultScale + if (allowPrecisionLoss) { + resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1) + resultPrecision = type1.precision - type1.scale + type2.scale + resultScale + } else { + var intDig = Math.min(MAX_SCALE, type1.precision - type1.scale + type2.scale) + var decDig = Math.min(MAX_SCALE, Math.max(6, type1.scale + type2.precision + 1)) + val diff = (intDig + decDig) - MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = MAX_SCALE - decDig + } + resultScale = intDig + decDig + resultPrecision = decDig + } case OperationType.MOD => resultScale = Math.max(type1.scale, type2.scale) resultPrecision = @@ -62,7 +77,11 @@ object DecimalArithmeticUtil { case other => throw new UnsupportedOperationException(s"$other is not supported.") } - adjustScaleIfNeeded(resultPrecision, resultScale) + if (allowPrecisionLoss) { + adjustScaleIfNeeded(resultPrecision, resultScale) + } else { + bounded(resultPrecision, resultScale) + } } // Returns the adjusted decimal type when the precision is larger the maximum. @@ -78,6 +97,10 @@ object DecimalArithmeticUtil { DecimalType(typePrecision, typeScale) } + def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(Math.min(precision, MAX_PRECISION), Math.min(scale, MAX_SCALE)) + } + // If casting between DecimalType, unnecessary cast is skipped to avoid data loss, // because argument input type of "cast" is actually the res type of "+-*/". // Cast will use a wider input type, then calculates result type with less scale than expected. diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index bba16aa8bda39..6f595419534f6 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -523,6 +523,7 @@ object GlutenConfig { GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY, SQLConf.LEGACY_SIZE_OF_NULL.key, "spark.io.compression.codec", + "spark.sql.decimalOperations.allowPrecisionLoss", COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS.key, COLUMNAR_VELOX_BLOOM_FILTER_NUM_BITS.key, COLUMNAR_VELOX_BLOOM_FILTER_MAX_NUM_BITS.key, @@ -601,6 +602,7 @@ object GlutenConfig { ("spark.hadoop.input.write.timeout", "180000"), ("spark.hadoop.dfs.client.log.severity", "INFO"), ("spark.sql.orc.compression.codec", "snappy"), + ("spark.sql.decimalOperations.allowPrecisionLoss", "true"), ( COLUMNAR_VELOX_FILE_HANDLE_CACHE_ENABLED.key, COLUMNAR_VELOX_FILE_HANDLE_CACHE_ENABLED.defaultValueString),