@@ -30,6 +30,16 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
30
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
31
******************************<GINKGO LICENSE>*******************************/
32
32
33
+ // We need this struct, because otherwise we would call a __host__ function in a
34
+ // __device__ function (even though it is constexpr)
35
+ template <typename T>
36
+ struct device_numeric_limits {
37
+ static constexpr auto inf = std::numeric_limits<T>::infinity();
38
+ static constexpr auto max = std::numeric_limits<T>::max();
39
+ static constexpr auto min = std::numeric_limits<T>::min();
40
+ };
41
+
42
+
33
43
namespace detail {
34
44
35
45
@@ -50,6 +60,38 @@ struct truncate_type_impl<thrust::complex<T>> {
50
60
};
51
61
52
62
63
+ /* *
64
+ * Checks if a given value is finite, meaning it is neither +/- infinity
65
+ * nor NaN.
66
+ *
67
+ * @internal Should only be used if the provided one (from CUDA or HIP) can
68
+ * not be used.
69
+ * Designed to mirror the math function of CUDA (PTX code was
70
+ * identical in the testcase).
71
+ *
72
+ * @tparam T type of the value to check
73
+ *
74
+ * @param value value to check
75
+ *
76
+ * returns `true` if the given value is finite, meaning it is neither
77
+ * +/- infinity nor NaN.
78
+ */
79
+ template <typename T>
80
+ GKO_INLINE __device__ xstd::enable_if_t<!is_complex_s<T>::value, bool>
81
+ custom_isfinite(T value)
82
+ {
83
+ constexpr T infinity{device_numeric_limits<T>::inf};
84
+ return abs(value) < infinity;
85
+ }
86
+
87
+ template <typename T>
88
+ GKO_INLINE __device__ xstd::enable_if_t<is_complex_s<T>::value, bool>
89
+ custom_isfinite(T value)
90
+ {
91
+ return custom_isfinite(value.real()) && custom_isfinite(value.imag());
92
+ }
93
+
94
+
53
95
} // namespace detail
54
96
55
97
@@ -66,103 +108,19 @@ struct truncate_type_impl<thrust::complex<T>> {
66
108
(defined(__clang__) || defined(__ICC) || defined(__ICL))))
67
109
68
110
69
- namespace detail {
70
-
71
-
72
- /* *
73
- * This structure can be used to get the exponent mask of a given floating
74
- * point type. Uses specialization to implement different types.
75
- */
76
- template <typename T>
77
- struct mask_creator {};
78
-
79
- template <>
80
- struct mask_creator<float> {
81
- using int_type = int32;
82
- static constexpr int_type number_exponent_bits = 8;
83
- static constexpr int_type number_significand_bits = 23;
84
- // integer representation of a floating point number, where all exponent
85
- // bits are set
86
- static constexpr int_type exponent_mask =
87
- ((int_type{1} << number_exponent_bits) - 1) << number_significand_bits;
88
- static __device__ int_type reinterpret_int(const float &value)
89
- {
90
- return __float_as_int(value);
91
- }
92
- };
93
-
94
- template <>
95
- struct mask_creator<double> {
96
- using int_type = int64;
97
- static constexpr int_type number_exponent_bits = 11;
98
- static constexpr int_type number_significand_bits = 52;
99
- // integer representation of a floating point number, where all exponent
100
- // bits are set
101
- static constexpr int_type exponent_mask =
102
- ((int_type{1} << number_exponent_bits) - 1) << number_significand_bits;
103
- static __device__ int_type reinterpret_int(const double &value)
104
- {
105
- return __double_as_longlong(value);
106
- }
107
- };
108
-
109
-
110
- } // namespace detail
111
-
112
-
113
- /* *
114
- * Checks if a given value is finite, meaning it is neither +/- infinity
115
- * nor NaN.
116
- *
117
- * @internal It checks if all exponent bits are set. If all are set, the
118
- * number either represents NaN or +/- infinity, meaning it is a
119
- * non-finite number.
120
- *
121
- * @param value value to check
122
- *
123
- * returns `true` if the given value is finite, meaning it is neither
124
- * +/- infinity nor NaN.
125
- */
126
- # define GKO_DEFINE_ISFINITE_FOR_TYPE(_type) \
127
- GKO_INLINE __device__ bool isfinite(const _type &value) \
128
- { \
129
- constexpr auto mask = detail::mask_creator<_type>::exponent_mask; \
130
- const auto re_int = \
131
- detail::mask_creator<_type>::reinterpret_int(value); \
132
- return (re_int & mask) != mask; \
111
+ # define GKO_DEFINE_ISFINITE_FOR_TYPE(_type) \
112
+ GKO_INLINE __device__ bool isfinite(const _type &value) \
113
+ { \
114
+ return detail::custom_isfinite(value); \
133
115
}
134
116
135
117
GKO_DEFINE_ISFINITE_FOR_TYPE(float)
136
118
GKO_DEFINE_ISFINITE_FOR_TYPE(double)
119
+ GKO_DEFINE_ISFINITE_FOR_TYPE(thrust::complex<float>)
120
+ GKO_DEFINE_ISFINITE_FOR_TYPE(thrust::complex<double>)
137
121
# undef GKO_DEFINE_ISFINITE_FOR_TYPE
138
122
139
123
140
- /* *
141
- * Checks if all components of a complex value are finite, meaning they are
142
- * neither +/- infinity nor NaN.
143
- *
144
- * @internal required for the clang compiler. This function will be used rather
145
- * than the `isfinite` function in the public `math.hpp` because
146
- * there is no template parameter, so it is prefered during lookup.
147
- *
148
- * @tparam T complex type of the value to check
149
- *
150
- * @param value complex value to check
151
- *
152
- * returns `true` if both components of the given value are finite, meaning
153
- * they are neither +/- infinity nor NaN.
154
- */
155
- # define GKO_DEFINE_ISFINITE_FOR_COMPLEX_TYPE(_type) \
156
- GKO_INLINE __device__ bool isfinite(const _type &value) \
157
- { \
158
- return isfinite(value.real()) && isfinite(value.imag()); \
159
- }
160
-
161
- GKO_DEFINE_ISFINITE_FOR_COMPLEX_TYPE(thrust::complex<float>)
162
- GKO_DEFINE_ISFINITE_FOR_COMPLEX_TYPE(thrust::complex<double>)
163
- # undef GKO_DEFINE_ISFINITE_FOR_COMPLEX_TYPE
164
-
165
-
166
124
// For all other compiler in combination with CUDA or HIP, just use the provided
167
125
// `isfinite` function
168
126
# elif defined(__CUDA_ARCH__) || __HIP_DEVICE_COMPILE__
@@ -173,13 +131,3 @@ using ::isfinite;
173
131
174
132
175
133
# endif // defined(__CUDA_ARCH__) || __HIP_DEVICE_COMPILE__
176
-
177
-
178
- // We need this struct, because otherwise we would call a __host__ function in a
179
- // __device__ function (even though it is constexpr)
180
- template <typename T>
181
- struct device_numeric_limits {
182
- static constexpr auto inf = std::numeric_limits<T>::infinity();
183
- static constexpr auto max = std::numeric_limits<T>::max();
184
- static constexpr auto min = std::numeric_limits<T>::min();
185
- };
0 commit comments