8
8
#include < ranges>
9
9
#include < vector>
10
10
#include < bit>
11
+ #include < experimental/simd>
11
12
12
13
namespace cp_algo ::math::fft {
13
14
using ftype = double ;
14
- static constexpr size_t bytes = 32 ;
15
- static constexpr size_t flen = bytes / sizeof (ftype);
16
15
using point = std::complex<ftype>;
17
- using vftype [[gnu::vector_size(bytes)]] = ftype;
16
+ using vftype = std::experimental::native_simd< ftype> ;
18
17
using vpoint = std::complex<vftype>;
18
+ static constexpr size_t flen = vftype::size();
19
19
20
- #define WITH_IV (...) \
21
- [&]<size_t ... i>(std::index_sequence<i...>) { \
22
- return __VA_ARGS__; \
23
- }(std::make_index_sequence<flen>());
24
20
25
21
template <typename ft>
26
22
constexpr ft to_ft (auto x) {
@@ -76,12 +72,12 @@ namespace cp_algo::math::fft {
76
72
if (n < pre_roots) {
77
73
return roots.get <pt>(n + k);
78
74
} else {
79
- auto arg = std::numbers::pi / n;
75
+ auto arg = std::numbers::pi / (ftype) n;
80
76
if constexpr (std::is_same_v<pt, point>) {
81
- return {cos (k * arg), sin (k * arg)};
77
+ return {(ftype) cos (k * arg), (ftype) sin (k * arg)};
82
78
} else {
83
- return WITH_IV ( pt{vftype{cos ((k + i) * arg)... },
84
- vftype{ sin ((k + i) * arg)...}}) ;
79
+ return pt{vftype{[&]( auto i) { return cos (ftype (k + i) * arg);} },
80
+ vftype{[&]( auto i) { return sin (ftype (k + i) * arg);}}} ;
85
81
}
86
82
}
87
83
}
@@ -118,7 +114,7 @@ namespace cp_algo::math::fft {
118
114
}
119
115
}
120
116
for (size_t k = 0 ; k < n; k += flen) {
121
- set (k, get<vpoint>(k) /= to_pt<vpoint>(n));
117
+ set (k, get<vpoint>(k) /= to_pt<vpoint>((ftype) n));
122
118
}
123
119
}
124
120
void fft () {
@@ -144,11 +140,11 @@ namespace cp_algo::math::fft {
144
140
const cvector cvector::roots = []() {
145
141
cvector res (pre_roots);
146
142
for (size_t n = 1 ; n < res.size (); n *= 2 ) {
147
- auto base = std::polar (1 ., std::numbers::pi / n);
143
+ auto base = std::polar (1 ., std::numbers::pi / (ftype) n);
148
144
point cur = 1 ;
149
145
for (size_t k = 0 ; k < n; k++) {
150
146
if ((k & 15 ) == 0 ) {
151
- cur = std::polar (1 ., std::numbers::pi * k / n);
147
+ cur = std::polar (1 ., std::numbers::pi * (ftype) k / (ftype) n);
152
148
}
153
149
res.set (n + k, cur);
154
150
cur *= base;
@@ -198,7 +194,7 @@ namespace cp_algo::math::fft {
198
194
cvector A, B;
199
195
200
196
dft (auto const & a, size_t n): A(n), B(n) {
201
- split = std::sqrt (base::mod ());
197
+ split = int ( std::sqrt (base::mod () ));
202
198
cvector::exec_on_roots (2 * n, size (a), [&](size_t i, point rt) {
203
199
size_t ti = std::min (i, i - n);
204
200
A.set (ti, A.get (ti) + ftype (a[i].rem () % split) * rt);
@@ -273,12 +269,12 @@ namespace cp_algo::math::fft {
273
269
if (empty (a) || empty (b)) {
274
270
a.clear ();
275
271
} else {
276
- int n = std::min (k, size (a));
277
- int m = std::min (k, size (b));
272
+ size_t n = std::min (k, size (a));
273
+ size_t m = std::min (k, size (b));
278
274
a.resize (k);
279
- for (int j = k - 1 ; j >= 0 ; j--) {
275
+ for (int j = int ( k - 1 ) ; j >= 0 ; j--) {
280
276
a[j] *= b[0 ];
281
- for (int i = std::max (j - n, 0 ) + 1 ; i < std::min (j + 1 , m); i++) {
277
+ for (size_t i = std::max< size_t > (j - n, 0 ) + 1 ; i < std::min< size_t > (j + 1 , m); i++) {
282
278
a[j] += a[j - i] * b[i];
283
279
}
284
280
}
0 commit comments