Skip to content

Commit

Permalink
[onert] Apply if constexpr (Samsung#14675)
Browse files Browse the repository at this point in the history
This commit applies if constexpr for code optimization.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Feb 14, 2025
1 parent 2ac8fce commit 5b49cc6
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 105 deletions.
10 changes: 5 additions & 5 deletions runtime/compute/cker/include/cker/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,23 +545,23 @@ void ValidateGemmParams(
[[maybe_unused]] const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
{
// Guard consistency of the quantized multiplier fields.
if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
if constexpr (quantization_flavor == QuantizationFlavor::kFloatingPoint)
{
assert(!params.multiplier_fixedpoint);
assert(!params.multiplier_exponent);
assert(!params.multiplier_fixedpoint_perchannel);
assert(!params.multiplier_exponent_perchannel);
}
else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
!std::is_same<DstScalar, int32_t>::value)
else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
!std::is_same_v<DstScalar, int32_t>)
{
assert(params.multiplier_fixedpoint);
// Nothing to check about multiplier_exponent
assert(!params.multiplier_fixedpoint_perchannel);
assert(!params.multiplier_exponent_perchannel);
}
else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
!std::is_same<DstScalar, int32_t>::value)
else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
!std::is_same_v<DstScalar, int32_t>)
{
assert(!params.multiplier_fixedpoint);
assert(!params.multiplier_exponent);
Expand Down
140 changes: 71 additions & 69 deletions runtime/compute/cker/include/cker/operation/BinaryArithmeticOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,57 +37,55 @@ template <BinaryArithmeticOpType op_type, typename T,
typename std::enable_if_t<!std::is_same<T, bool>::value, bool> = true>
const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn()
{
switch (op_type)
if constexpr (op_type == BinaryArithmeticOpType::ADD)
{
case BinaryArithmeticOpType::ADD:
{
return [](const T &a, const T &b) -> T { return a + b; };
}
case BinaryArithmeticOpType::MUL:
{
return [](const T &a, const T &b) -> T { return a * b; };
}
case BinaryArithmeticOpType::SUB:
{
return [](const T &a, const T &b) -> T { return a - b; };
}
case BinaryArithmeticOpType::DIV:
{
if (std::is_floating_point<T>::value)
return [](const T &a, const T &b) -> T { return a / b; };
else
return [](const T &a, const T &b) -> T {
if (b == 0)
throw std::runtime_error("Divide by zero");
return a / b;
};
}
case BinaryArithmeticOpType::POW:
return [](const T &a, const T &b) -> T { return a + b; };
}
else if constexpr (op_type == BinaryArithmeticOpType::MUL)
{
return [](const T &a, const T &b) -> T { return a * b; };
}
else if constexpr (op_type == BinaryArithmeticOpType::SUB)
{
return [](const T &a, const T &b) -> T { return a - b; };
}
else if constexpr (op_type == BinaryArithmeticOpType::DIV)
{
if constexpr (std::is_floating_point<T>::value)
{
return [](const T &a, const T &b) -> T { return std::pow(a, b); };
return [](const T &a, const T &b) -> T { return a / b; };
}
default:
else
{
assert(false);
return nullptr;
return [](const T &a, const T &b) -> T {
if (b == 0)
throw std::runtime_error("Divide by zero");
return a / b;
};
}
}
else if constexpr (op_type == BinaryArithmeticOpType::POW)
{
return [](const T &a, const T &b) -> T { return std::pow(a, b); };
}
else
{
assert(false);
return nullptr;
}
}

template <BinaryArithmeticOpType op_type, typename T,
typename std::enable_if_t<std::is_same<T, bool>::value, bool> = true>
const std::function<T(const bool &, const bool &)> GetBinaryArtithmeticFn()
{
switch (op_type)
if constexpr (op_type == BinaryArithmeticOpType::MUL)
{
case BinaryArithmeticOpType::MUL:
{
return [](const bool &a, const bool &b) -> bool { return a && b; };
}
default:
{
throw std::runtime_error("GetBinaryArtithmeticFn: Unsupported OpType with Bool8");
}
return [](const bool &a, const bool &b) -> bool { return a && b; };
}
else
{
throw std::runtime_error("GetBinaryArtithmeticFn: Unsupported OpType with Bool8");
}
}
} // namespace
Expand Down Expand Up @@ -234,22 +232,24 @@ BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_sh
const T *input1_data, const Shape &input2_shape, const T *input2_data,
const Shape &output_shape, T *output_data)
{
switch (op_type)
if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::ADD ||
op_type == nnfw::cker::BinaryArithmeticOpType::SUB)
{
case nnfw::cker::BinaryArithmeticOpType::ADD:
case nnfw::cker::BinaryArithmeticOpType::SUB:
optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
output_data);
break;
case nnfw::cker::BinaryArithmeticOpType::MUL:
optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
output_data);
break;
case nnfw::cker::BinaryArithmeticOpType::DIV:
throw std::runtime_error{"Quant8 Asymm NYI"};
default:
assert(false);
break;
optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
output_data);
}
else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::MUL)
{
optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
output_data);
}
else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::DIV)
{
throw std::runtime_error{"Quant8 Asymm NYI"};
}
else
{
assert(false);
}
}

Expand Down Expand Up @@ -301,23 +301,25 @@ BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam &params, const Shape &input1
const T *input1_data, const Shape &input2_shape, const T *input2_data,
const Shape &output_shape, T *output_data)
{
switch (op_type)
if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::ADD ||
op_type == nnfw::cker::BinaryArithmeticOpType::SUB)
{
case nnfw::cker::BinaryArithmeticOpType::ADD:
case nnfw::cker::BinaryArithmeticOpType::SUB:
optimized::BroadcastAddDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
break;
case nnfw::cker::BinaryArithmeticOpType::MUL:
optimized::BroadcastMulDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
break;
case nnfw::cker::BinaryArithmeticOpType::DIV:
case nnfw::cker::BinaryArithmeticOpType::POW:
throw std::runtime_error{"Quant8 Asymm NYI"};
default:
assert(false);
break;
optimized::BroadcastAddDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
}
else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::MUL)
{
optimized::BroadcastMulDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
}
else if constexpr (op_type == nnfw::cker::BinaryArithmeticOpType::DIV ||
op_type == nnfw::cker::BinaryArithmeticOpType::POW)
{
throw std::runtime_error{"Quant8 Asymm NYI"};
}
else
{
assert(false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ inline void DepthwiseConv(const DepthwiseConvParams &params, const Shape &input_
thread_count = std::max(1, std::min(thread_count, max_threads));
// Cap the number of threads to 2 for float path to avoid regression in
// performance (b/132294857).
if (std::is_floating_point<T>::value)
if constexpr (std::is_floating_point<T>::value)
{
thread_count = std::min(thread_count, 2);
}
Expand Down
37 changes: 20 additions & 17 deletions runtime/compute/cker/include/cker/operation/FloorMod.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ inline void FloorModBroadcast(const Shape &unextended_input1_shape, const T *inp
float operator()(const float lhs, const float rhs) const { return std::fmod(lhs, rhs); }
};

using ModFunc =
typename std::conditional<std::is_integral<T>::value, std::modulus<T>, FloatMod>::type;

if (unextended_output_shape.DimensionsCount() > 4)
throw std::runtime_error(std::string("cker::FloorModBroadcast: Unsupported rank size : ") +
std::to_string(unextended_output_shape.DimensionsCount()));
Expand All @@ -67,8 +64,15 @@ inline void FloorModBroadcast(const Shape &unextended_input1_shape, const T *inp
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];

ModFunc mod_func;
T trunc_mod = mod_func(in1_val, in2_val);
T trunc_mod;
if constexpr (std::is_integral_v<T>)
{
trunc_mod = std::modulus<T>()(in1_val, in2_val);
}
else
{
trunc_mod = FloatMod{}(in1_val, in2_val);
}
output_data[out_idx] = (trunc_mod != 0) && ((in2_val < 0) != (trunc_mod < 0))
? (trunc_mod + in2_val)
: trunc_mod;
Expand All @@ -82,21 +86,20 @@ template <typename T>
inline void FloorModElementwise(const Shape &shape, const T *input1_data, const T *input2_data,
T *output_data)
{
struct FloatMod
{
float operator()(const float lhs, const float rhs) const { return std::fmod(lhs, rhs); }
};

using ModFunc =
typename std::conditional<std::is_integral<T>::value, std::modulus<T>, FloatMod>::type;

int num_elements = shape.FlatSize();
for (int t = 0; t < num_elements; t++)
{
ModFunc mod_func;
auto in1_val = input1_data[t];
auto in2_val = input2_data[t];
T trunc_mod = mod_func(in1_val, in2_val);
T in1_val = input1_data[t];
T in2_val = input2_data[t];
T trunc_mod;
if constexpr (std::is_integral_v<T>)
{
trunc_mod = std::modulus<T>()(in1_val, in2_val);
}
else
{
trunc_mod = std::fmod(in1_val, in2_val);
}
output_data[t] =
(trunc_mod != 0) && ((in2_val < 0) != (trunc_mod < 0)) ? (trunc_mod + in2_val) : trunc_mod;
}
Expand Down
12 changes: 8 additions & 4 deletions runtime/compute/cker/include/cker/operation/Range.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ template <typename T> inline int GetSize(T start, T limit, T delta)
throw std::runtime_error("Range: invalid input values");
}

int size = (std::is_integral<T>::value
? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
return size;
if constexpr (std::is_integral_v<T>)
{
return ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta));
}
else
{
return static_cast<int>(std::ceil(std::abs((limit - start) / delta)));
}
}

template <typename T>
Expand Down
7 changes: 3 additions & 4 deletions runtime/compute/cker/include/cker/operation/SoftMax.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,10 @@ inline void SoftmaxInt8LUT(const SoftmaxParams &params, const Shape &input_shape

// Offset is used to interpret the input data "correctly".
// If the input is uint8, the data will be unchanged.
// If the input is int8, since it will be reinterpret as uint8.
// e.g.,
// int8 127 will be applied "offset" to become 255 in uint8.
// If the input is int8, since it will be reinterpreted as uint8.
// e.g., int8 127 will be offset to become 255 in uint8.
uint8_t offset = 0;
if (std::is_same<In, int8_t>::value)
if constexpr (std::is_same_v<In, int8_t>)
{
offset = 0x80;
}
Expand Down
11 changes: 6 additions & 5 deletions runtime/compute/ruy/include/ruy/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ struct GemmParams
DstScalar clamp_min = std::is_floating_point<DstScalar>::value
? -std::numeric_limits<DstScalar>::infinity()
: std::numeric_limits<DstScalar>::lowest();

// max clamp bound of destination values.
DstScalar clamp_max = std::is_floating_point<DstScalar>::value
? std::numeric_limits<DstScalar>::infinity()
Expand All @@ -230,23 +231,23 @@ void ValidateGemmParams(
[[maybe_unused]] const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
{
// Guard consistency of the quantized multiplier fields.
if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
if constexpr (quantization_flavor == QuantizationFlavor::kFloatingPoint)
{
assert(!params.multiplier_fixedpoint);
assert(!params.multiplier_exponent);
assert(!params.multiplier_fixedpoint_perchannel);
assert(!params.multiplier_exponent_perchannel);
}
else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
!std::is_same<DstScalar, int32_t>::value)
else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
!std::is_same_v<DstScalar, int32_t>)
{
assert(params.multiplier_fixedpoint);
// Nothing to check about multiplier_exponent
assert(!params.multiplier_fixedpoint_perchannel);
assert(!params.multiplier_exponent_perchannel);
}
else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
!std::is_same<DstScalar, int32_t>::value)
else if constexpr (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
!std::is_same_v<DstScalar, int32_t>)
{
assert(!params.multiplier_fixedpoint);
assert(!params.multiplier_exponent);
Expand Down

0 comments on commit 5b49cc6

Please sign in to comment.