|
| 1 | +/* |
| 2 | + * Copyright (c) Ian Pike |
| 3 | + * Copyright (c) CCMath contributors |
| 4 | + * |
| 5 | + * CCMath is provided under the Apache-2.0 License WITH LLVM-exception. |
| 6 | + * See LICENSE for more information. |
| 7 | + * |
| 8 | + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 9 | + */ |
| 10 | + |
| 11 | +#pragma once |
| 12 | + |
| 13 | +#include "ccmath/internal/predef/unlikely.hpp" |
| 14 | +#include "ccmath/math/compare/isinf.hpp" |
| 15 | +#include "ccmath/math/compare/isnan.hpp" |
| 16 | +#include "ccmath/math/compare/signbit.hpp" |
| 17 | +#include "ccmath/internal/math/generic/builtins/basic/fma.hpp" |
| 18 | + |
| 19 | +#include <limits> |
| 20 | +#include <type_traits> |
| 21 | + |
| 22 | +namespace ccm |
| 23 | +{ |
| 24 | + /** |
| 25 | + * @brief Fused multiply-add operation. |
| 26 | + * @tparam T Numeric type. |
| 27 | + * @param x Floating-point or integer value. |
| 28 | + * @param y Floating-point or integer value. |
| 29 | + * @param z Floating-point or integer value. |
| 30 | + * @return If successful, returns the value of x * y + z as if calculated to infinite precision and rounded once to fit the result type (or, alternatively, |
| 31 | + * calculated as a single ternary floating-point operation). |
| 32 | + */ |
| 33 | + template <typename T, std::enable_if_t<!std::is_integral_v<T>, bool> = true> |
| 34 | + constexpr T fma(T x, T y, T z) noexcept |
| 35 | + { |
| 36 | + // Check for GCC 6.1 or later |
| 37 | + #if defined(__GNUC__) && (__GNUC__ > 6 || (__GNUC__ == 6 && __GNUC_MINOR__ >= 1)) && !defined(__clang__) |
| 38 | + if constexpr (std::is_same_v<T, float>) { return __builtin_fmaf(x, y, z); } |
| 39 | + if constexpr (std::is_same_v<T, double>) { return __builtin_fma(x, y, z); } |
| 40 | + if constexpr (std::is_same_v<T, long double>) { return __builtin_fmal(x, y, z); } |
| 41 | + return static_cast<T>(__builtin_fmal(x, y, z)); |
| 42 | + #else |
| 43 | + if (CCM_UNLIKELY(x == 0 || y == 0 || z == 0)) { return x * y + z; } |
| 44 | + |
| 45 | + // If x is zero and y is infinity, or if y is zero and x is infinity and... |
| 46 | + if ((x == static_cast<T>(0) && ccm::isinf(y)) || (y == T{0} && ccm::isinf(x))) |
| 47 | + { |
| 48 | + // ...z is NaN, return +NaN... |
| 49 | + if (ccm::isnan(z)) |
| 50 | + { |
| 51 | + return std::numeric_limits<T>::quiet_NaN(); |
| 52 | + } |
| 53 | + |
| 54 | + // ...else return -NaN if Z is not NaN. |
| 55 | + return -std::numeric_limits<T>::quiet_NaN(); |
| 56 | + } |
| 57 | + |
| 58 | + // If x is a zero and y is an infinity, or if y is zero and x is an infinity and Z is NaN, then the result is -NaN. |
| 59 | + if (ccm::isinf(x * y) && ccm::isinf(z) && ccm::signbit(x * y) != ccm::signbit(z)) |
| 60 | + { |
| 61 | + return -std::numeric_limits<T>::quiet_NaN(); |
| 62 | + } |
| 63 | + |
| 64 | + // If x or y are NaN, NaN is returned. |
| 65 | + if (ccm::isnan(x) || ccm::isnan(y)) { return std::numeric_limits<T>::quiet_NaN(); } |
| 66 | + |
| 67 | + // If z is NaN, and x * y is not 0 * Inf or Inf * 0, then +NaN is returned |
| 68 | + if (ccm::isnan(z) && (x * y != 0 * std::numeric_limits<T>::infinity() || x * y != std::numeric_limits<T>::infinity() * 0)) |
| 69 | + { |
| 70 | + return std::numeric_limits<T>::quiet_NaN(); |
| 71 | + } |
| 72 | + |
| 73 | + // Hope the compiler optimizes this. |
| 74 | + return (x * y) + z; |
| 75 | +#endif |
| 76 | + } |
| 77 | + |
| 78 | + template <typename Integer, std::enable_if_t<std::is_integral_v<Integer>, bool> = true> |
| 79 | + constexpr Integer fma(Integer x, Integer y, Integer z) noexcept |
| 80 | + { |
| 81 | + return (x * y) + z; |
| 82 | + } |
| 83 | + |
| 84 | + /** |
| 85 | + * @brief Fused multiply-add operation. |
| 86 | + * @tparam T Floating-point or integer type converted to a common type. |
| 87 | + * @tparam U Floating-point or integer type converted to a common type. |
| 88 | + * @tparam V Floating-point or integer type converted to a common type. |
| 89 | + * @param x Floating-point or integer value converted to a common type. |
| 90 | + * @param y Floating-point or integer value converted to a common type. |
| 91 | + * @param z Floating-point or integer value converted to a common type. |
| 92 | + * @return If successful, returns the value of x * y + z as if calculated to infinite precision and rounded once to fit the result type (or, alternatively, |
| 93 | + * calculated as a single ternary floating-point operation). |
| 94 | + */ |
| 95 | + template <typename T, typename U, typename V> |
| 96 | + constexpr auto fma(T x, U y, V z) noexcept |
| 97 | + { |
| 98 | + // If our type is an integer epsilon will be set to 0 by default. |
| 99 | + // Instead, set epsilon to 1 so that our type is always at least the widest floating point type. |
| 100 | + constexpr auto TCommon = std::numeric_limits<T>::epsilon() > 0 ? std::numeric_limits<T>::epsilon() : 1; |
| 101 | + constexpr auto UCommon = std::numeric_limits<U>::epsilon() > 0 ? std::numeric_limits<U>::epsilon() : 1; |
| 102 | + constexpr auto VCommon = std::numeric_limits<V>::epsilon() > 0 ? std::numeric_limits<V>::epsilon() : 1; |
| 103 | + |
| 104 | + using epsilon_type = std::common_type_t<decltype(TCommon), decltype(UCommon), decltype(VCommon)>; |
| 105 | + |
| 106 | + using shared_type = std::conditional_t< |
| 107 | + TCommon <= std::numeric_limits<epsilon_type>::epsilon() && TCommon <= UCommon, T, |
| 108 | + std::conditional_t<UCommon <= std::numeric_limits<epsilon_type>::epsilon() && UCommon <= TCommon, U, |
| 109 | + std::conditional_t<VCommon <= std::numeric_limits<epsilon_type>::epsilon() && VCommon <= UCommon, V, epsilon_type>>>; |
| 110 | + |
| 111 | + return ccm::fma<shared_type>(static_cast<shared_type>(x), static_cast<shared_type>(y), static_cast<shared_type>(z)); |
| 112 | + } |
| 113 | + |
| 114 | + /** |
| 115 | + * @brief Fused multiply-add operation. |
| 116 | + * @tparam T Integer type converted to a common type. |
| 117 | + * @tparam U Integer type converted to a common type. |
| 118 | + * @tparam V Integer type converted to a common type. |
| 119 | + * @param x Integer value converted to a common type. |
| 120 | + * @param y Integer value converted to a common type. |
| 121 | + * @param z Integer value converted to a common type. |
| 122 | + * @return If successful, returns the value of x * y + z as if calculated to infinite precision and rounded once to fit the result type (or, alternatively, |
| 123 | + * calculated as a single ternary floating-point operation). |
| 124 | + */ |
| 125 | + template <typename T, typename U, typename V, std::enable_if_t<std::is_integral_v<T> && std::is_integral_v<U> && std::is_integral_v<V>, bool> = true> |
| 126 | + constexpr auto fma(T x, U y, V z) noexcept // Special case for if all types are integers. |
| 127 | + { |
| 128 | + using shared_type = std::common_type_t<T, U, V>; |
| 129 | + return ccm::fma<shared_type>(static_cast<shared_type>(x), static_cast<shared_type>(y), static_cast<shared_type>(z)); |
| 130 | + } |
| 131 | + |
| 132 | + /** |
| 133 | + * @brief Fused multiply-add operation. |
| 134 | + * @param x Floating-point value. |
| 135 | + * @param y Floating-point value. |
| 136 | + * @param z Floating-point value. |
| 137 | + * @return If successful, returns the value of x * y + z as if calculated to infinite precision and rounded once to fit the result type (or, alternatively, |
| 138 | + * calculated as a single ternary floating-point operation). |
| 139 | + */ |
| 140 | + constexpr float fmaf(float x, float y, float z) noexcept |
| 141 | + { |
| 142 | + return ccm::fma<float>(x, y, z); |
| 143 | + } |
| 144 | + |
| 145 | + /** |
| 146 | + * @brief Fused multiply-add operation. |
| 147 | + * @param x Floating-point value. |
| 148 | + * @param y Floating-point value. |
| 149 | + * @param z Floating-point value. |
| 150 | + * @return If successful, returns the value of x * y + z as if calculated to infinite precision and rounded once to fit the result type (or, alternatively, |
| 151 | + * calculated as a single ternary floating-point operation). |
| 152 | + */ |
| 153 | + constexpr long double fmal(long double x, long double y, long double z) noexcept |
| 154 | + { |
| 155 | + return ccm::fma<long double>(x, y, z); |
| 156 | + } |
| 157 | +} // namespace ccm |
| 158 | + |
| 159 | +/// @ingroup basic |
0 commit comments