diff --git a/runtime/compute/cker/include/cker/Types.h b/runtime/compute/cker/include/cker/Types.h index ba661f6094d..ff00a0d25be 100644 --- a/runtime/compute/cker/include/cker/Types.h +++ b/runtime/compute/cker/include/cker/Types.h @@ -545,23 +545,23 @@ void ValidateGemmParams( [[maybe_unused]] const GemmParams ¶ms) { // Guard consistency of the quantized multiplier fields. - if (quantization_flavor == QuantizationFlavor::kFloatingPoint) + if constexpr (quantization_flavor == QuantizationFlavor::kFloatingPoint) { assert(!params.multiplier_fixedpoint); assert(!params.multiplier_exponent); assert(!params.multiplier_fixedpoint_perchannel); assert(!params.multiplier_exponent_perchannel); } - else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier && - !std::is_same::value) + else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier && + !std::is_same_v) { assert(params.multiplier_fixedpoint); // Nothing to check about multiplier_exponent assert(!params.multiplier_fixedpoint_perchannel); assert(!params.multiplier_exponent_perchannel); } - else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier && - !std::is_same::value) + else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier && + !std::is_same_v) { assert(!params.multiplier_fixedpoint); assert(!params.multiplier_exponent); diff --git a/runtime/compute/cker/include/cker/operation/BinaryArithmeticOps.h b/runtime/compute/cker/include/cker/operation/BinaryArithmeticOps.h index 6c6e01ce024..3d3f133d5b3 100644 --- a/runtime/compute/cker/include/cker/operation/BinaryArithmeticOps.h +++ b/runtime/compute/cker/include/cker/operation/BinaryArithmeticOps.h @@ -37,57 +37,55 @@ template ::value, bool> = true> const std::function GetBinaryArtithmeticFn() { - switch (op_type) + if constexpr (op_type == BinaryArithmeticOpType::ADD) { - case BinaryArithmeticOpType::ADD: - { - return [](const T &a, const T &b) -> T { return a + b; }; - } - case BinaryArithmeticOpType::MUL: - { - return [](const T &a, const T &b) -> T { return a * b; }; - } - case BinaryArithmeticOpType::SUB: - { - return [](const T &a, const T &b) -> T { return a - b; }; - } - case BinaryArithmeticOpType::DIV: - { - if (std::is_floating_point::value) - return [](const T &a, const T &b) -> T { return a / b; }; - else - return [](const T &a, const T &b) -> T { - if (b == 0) - throw std::runtime_error("Divide by zero"); - return a / b; - }; - } - case BinaryArithmeticOpType::POW: + return [](const T &a, const T &b) -> T { return a + b; }; + } + else if constexpr (op_type == BinaryArithmeticOpType::MUL) + { + return [](const T &a, const T &b) -> T { return a * b; }; + } + else if constexpr (op_type == BinaryArithmeticOpType::SUB) + { + return [](const T &a, const T &b) -> T { return a - b; }; + } + else if constexpr (op_type == BinaryArithmeticOpType::DIV) + { + if constexpr (std::is_floating_point::value) { - return [](const T &a, const T &b) -> T { return std::pow(a, b); }; + return [](const T &a, const T &b) -> T { return a / b; }; } - default: + else { - assert(false); - return nullptr; + return [](const T &a, const T &b) -> T { + if (b == 0) + throw std::runtime_error("Divide by zero"); + return a / b; + }; } } + else if constexpr (op_type == BinaryArithmeticOpType::POW) + { + return [](const T &a, const T &b) -> T { return std::pow(a, b); }; + } + else + { + assert(false); + return nullptr; + } } template ::value, bool> = true> const std::function GetBinaryArtithmeticFn() { - switch (op_type) + if constexpr (op_type == BinaryArithmeticOpType::MUL) { - case BinaryArithmeticOpType::MUL: - { - return [](const bool &a, const bool &b) -> bool { return a && b; }; - } - default: - { - throw std::runtime_error("GetBinaryArtithmeticFn: Unsupported OpType with Bool8"); - } + return [](const bool &a, const bool &b) -> bool { return a && b; }; + } + else + { + throw std::runtime_error("GetBinaryArtithmeticFn: Unsupported OpType with Bool8"); } } } // namespace @@ -234,22 +232,24 @@ BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_sh const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data) { - switch (op_type) + if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::ADD || + op_type == nnfw::cker::BinaryArithmeticOpType::SUB) { - case nnfw::cker::BinaryArithmeticOpType::ADD: - case nnfw::cker::BinaryArithmeticOpType::SUB: - optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, - output_data); - break; - case nnfw::cker::BinaryArithmeticOpType::MUL: - optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, - output_data); - break; - case nnfw::cker::BinaryArithmeticOpType::DIV: - throw std::runtime_error{"Quant8 Asymm NYI"}; - default: - assert(false); - break; + optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, + output_data); + } + else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::MUL) + { + optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, + output_data); + } + else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::DIV) + { + throw std::runtime_error{"Quant8 Asymm NYI"}; + } + else + { + assert(false); } } @@ -301,23 +301,25 @@ BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam ¶ms, const Shape &input1 const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data) { - switch (op_type) + if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::ADD || + op_type == nnfw::cker::BinaryArithmeticOpType::SUB) { - case nnfw::cker::BinaryArithmeticOpType::ADD: - case nnfw::cker::BinaryArithmeticOpType::SUB: - optimized::BroadcastAddDispatch(params, input1_shape, input1_data, input2_shape, input2_data, - output_shape, output_data); - break; - case nnfw::cker::BinaryArithmeticOpType::MUL: - optimized::BroadcastMulDispatch(params, input1_shape, input1_data, input2_shape, input2_data, - output_shape, output_data); - break; - case nnfw::cker::BinaryArithmeticOpType::DIV: - case nnfw::cker::BinaryArithmeticOpType::POW: - throw std::runtime_error{"Quant8 Asymm NYI"}; - default: - assert(false); - break; + optimized::BroadcastAddDispatch(params, input1_shape, input1_data, input2_shape, input2_data, + output_shape, output_data); + } + else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::MUL) + { + optimized::BroadcastMulDispatch(params, input1_shape, input1_data, input2_shape, input2_data, + output_shape, output_data); + } + else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::DIV || + op_type == nnfw::cker::BinaryArithmeticOpType::POW) + { + throw std::runtime_error{"Quant8 Asymm NYI"}; + } + else + { + assert(false); } } diff --git a/runtime/compute/cker/include/cker/operation/DepthwiseConv.h b/runtime/compute/cker/include/cker/operation/DepthwiseConv.h index 5be36557ee8..7c3d951f87e 100644 --- a/runtime/compute/cker/include/cker/operation/DepthwiseConv.h +++ b/runtime/compute/cker/include/cker/operation/DepthwiseConv.h @@ -139,7 +139,7 @@ inline void DepthwiseConv(const DepthwiseConvParams ¶ms, const Shape &input_ thread_count = std::max(1, std::min(thread_count, max_threads)); // Cap the number of threads to 2 for float path to avoid regression in // performance (b/132294857). - if (std::is_floating_point::value) + if constexpr (std::is_floating_point::value) { thread_count = std::min(thread_count, 2); } diff --git a/runtime/compute/cker/include/cker/operation/FloorMod.h b/runtime/compute/cker/include/cker/operation/FloorMod.h index f40b573322f..0839ceeb7b4 100644 --- a/runtime/compute/cker/include/cker/operation/FloorMod.h +++ b/runtime/compute/cker/include/cker/operation/FloorMod.h @@ -40,9 +40,6 @@ inline void FloorModBroadcast(const Shape &unextended_input1_shape, const T *inp float operator()(const float lhs, const float rhs) const { return std::fmod(lhs, rhs); } }; - using ModFunc = - typename std::conditional::value, std::modulus, FloatMod>::type; - if (unextended_output_shape.DimensionsCount() > 4) throw std::runtime_error(std::string("cker::FloorModBroadcast: Unsupported rank size : ") + std::to_string(unextended_output_shape.DimensionsCount())); @@ -67,8 +64,15 @@ inline void FloorModBroadcast(const Shape &unextended_input1_shape, const T *inp auto in1_val = input1_data[in1_idx]; auto in2_val = input2_data[in2_idx]; - ModFunc mod_func; - T trunc_mod = mod_func(in1_val, in2_val); + T trunc_mod; + if constexpr (std::is_integral_v) + { + trunc_mod = std::modulus()(in1_val, in2_val); + } + else + { + trunc_mod = FloatMod{}(in1_val, in2_val); + } output_data[out_idx] = (trunc_mod != 0) && ((in2_val < 0) != (trunc_mod < 0)) ? (trunc_mod + in2_val) : trunc_mod; @@ -82,21 +86,20 @@ template inline void FloorModElementwise(const Shape &shape, const T *input1_data, const T *input2_data, T *output_data) { - struct FloatMod - { - float operator()(const float lhs, const float rhs) const { return std::fmod(lhs, rhs); } - }; - - using ModFunc = - typename std::conditional::value, std::modulus, FloatMod>::type; - int num_elements = shape.FlatSize(); for (int t = 0; t < num_elements; t++) { - ModFunc mod_func; - auto in1_val = input1_data[t]; - auto in2_val = input2_data[t]; - T trunc_mod = mod_func(in1_val, in2_val); + T in1_val = input1_data[t]; + T in2_val = input2_data[t]; + T trunc_mod; + if constexpr (std::is_integral_v) + { + trunc_mod = std::modulus()(in1_val, in2_val); + } + else + { + trunc_mod = std::fmod(in1_val, in2_val); + } output_data[t] = (trunc_mod != 0) && ((in2_val < 0) != (trunc_mod < 0)) ? (trunc_mod + in2_val) : trunc_mod; } diff --git a/runtime/compute/cker/include/cker/operation/Range.h b/runtime/compute/cker/include/cker/operation/Range.h index d6ccc68c83a..2e6426be771 100644 --- a/runtime/compute/cker/include/cker/operation/Range.h +++ b/runtime/compute/cker/include/cker/operation/Range.h @@ -34,10 +34,14 @@ template inline int GetSize(T start, T limit, T delta) throw std::runtime_error("Range: invalid input values"); } - int size = (std::is_integral::value - ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) - : std::ceil(std::abs((limit - start) / delta))); - return size; + if constexpr (std::is_integral_v) + { + return ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)); + } + else + { + return static_cast(std::ceil(std::abs((limit - start) / delta))); + } } template diff --git a/runtime/compute/cker/include/cker/operation/SoftMax.h b/runtime/compute/cker/include/cker/operation/SoftMax.h index 35ecde4ba9d..3bc8cfd7e03 100644 --- a/runtime/compute/cker/include/cker/operation/SoftMax.h +++ b/runtime/compute/cker/include/cker/operation/SoftMax.h @@ -296,11 +296,10 @@ inline void SoftmaxInt8LUT(const SoftmaxParams ¶ms, const Shape &input_shape // Offset is used to interpret the input data "correctly". // If the input is uint8, the data will be unchanged. - // If the input is int8, since it will be reinterpret as uint8. - // e.g., - // int8 127 will be applied "offset" to become 255 in uint8. + // If the input is int8, since it will be reinterpreted as uint8. + // e.g., int8 127 will be offset to become 255 in uint8. uint8_t offset = 0; - if (std::is_same::value) + if constexpr (std::is_same_v) { offset = 0x80; } diff --git a/runtime/compute/ruy/include/ruy/Types.h b/runtime/compute/ruy/include/ruy/Types.h index 7fd5218c67a..967919414a0 100644 --- a/runtime/compute/ruy/include/ruy/Types.h +++ b/runtime/compute/ruy/include/ruy/Types.h @@ -218,6 +218,7 @@ struct GemmParams DstScalar clamp_min = std::is_floating_point::value ? -std::numeric_limits::infinity() : std::numeric_limits::lowest(); + // max clamp bound of destination values. DstScalar clamp_max = std::is_floating_point::value ? std::numeric_limits::infinity() @@ -230,23 +231,23 @@ void ValidateGemmParams( [[maybe_unused]] const GemmParams ¶ms) { // Guard consistency of the quantized multiplier fields. - if (quantization_flavor == QuantizationFlavor::kFloatingPoint) + if constexpr (quantization_flavor == QuantizationFlavor::kFloatingPoint) { assert(!params.multiplier_fixedpoint); assert(!params.multiplier_exponent); assert(!params.multiplier_fixedpoint_perchannel); assert(!params.multiplier_exponent_perchannel); } - else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier && - !std::is_same::value) + else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier && + !std::is_same_v) { assert(params.multiplier_fixedpoint); // Nothing to check about multiplier_exponent assert(!params.multiplier_fixedpoint_perchannel); assert(!params.multiplier_exponent_perchannel); } - else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier && - !std::is_same::value) + else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier && + !std::is_same_v) { assert(!params.multiplier_fixedpoint); assert(!params.multiplier_exponent);