88#include < ranges>
99#include < vector>
1010#include < bit>
11+ #include < experimental/simd>
1112
1213namespace cp_algo ::math::fft {
1314 using ftype = double ;
14- static constexpr size_t bytes = 32 ;
15- static constexpr size_t flen = bytes / sizeof (ftype);
1615 using point = std::complex <ftype>;
17- using vftype [[gnu::vector_size(bytes)]] = ftype;
16+ using vftype = std::experimental::native_simd< ftype> ;
1817 using vpoint = std::complex <vftype>;
18+ static constexpr size_t flen = vftype::size();
1919
20- #define WITH_IV (...) \
21- [&]<size_t ... i>(std::index_sequence<i...>) { \
22- return __VA_ARGS__; \
23- }(std::make_index_sequence<flen>());
2420
2521 template <typename ft>
2622 constexpr ft to_ft (auto x) {
@@ -76,12 +72,12 @@ namespace cp_algo::math::fft {
7672 if (n < pre_roots) {
7773 return roots.get <pt>(n + k);
7874 } else {
79- auto arg = std::numbers::pi / n;
75+ auto arg = std::numbers::pi / (ftype) n;
8076 if constexpr (std::is_same_v<pt, point>) {
81- return {cos (k * arg), sin (k * arg)};
77+ return {(ftype) cos (k * arg), (ftype) sin (k * arg)};
8278 } else {
83- return WITH_IV ( pt{vftype{cos ((k + i) * arg)... },
84- vftype{ sin ((k + i) * arg)...}}) ;
79+ return pt{vftype{[&]( auto i) { return cos (ftype (k + i) * arg);} },
80+ vftype{[&]( auto i) { return sin (ftype (k + i) * arg);}}} ;
8581 }
8682 }
8783 }
@@ -118,7 +114,7 @@ namespace cp_algo::math::fft {
118114 }
119115 }
120116 for (size_t k = 0 ; k < n; k += flen) {
121- set (k, get<vpoint>(k) /= to_pt<vpoint>(n));
117+ set (k, get<vpoint>(k) /= to_pt<vpoint>((ftype) n));
122118 }
123119 }
124120 void fft () {
@@ -144,11 +140,11 @@ namespace cp_algo::math::fft {
144140 const cvector cvector::roots = []() {
145141 cvector res (pre_roots);
146142 for (size_t n = 1 ; n < res.size (); n *= 2 ) {
147- auto base = std::polar (1 ., std::numbers::pi / n);
143+ auto base = std::polar (1 ., std::numbers::pi / (ftype) n);
148144 point cur = 1 ;
149145 for (size_t k = 0 ; k < n; k++) {
150146 if ((k & 15 ) == 0 ) {
151- cur = std::polar (1 ., std::numbers::pi * k / n);
147+ cur = std::polar (1 ., std::numbers::pi * (ftype) k / (ftype) n);
152148 }
153149 res.set (n + k, cur);
154150 cur *= base;
@@ -198,7 +194,7 @@ namespace cp_algo::math::fft {
198194 cvector A, B;
199195
200196 dft (auto const & a, size_t n): A(n), B(n) {
201- split = std::sqrt (base::mod ());
197+ split = int ( std::sqrt (base::mod () ));
202198 cvector::exec_on_roots (2 * n, size (a), [&](size_t i, point rt) {
203199 size_t ti = std::min (i, i - n);
204200 A.set (ti, A.get (ti) + ftype (a[i].rem () % split) * rt);
@@ -273,12 +269,12 @@ namespace cp_algo::math::fft {
273269 if (empty (a) || empty (b)) {
274270 a.clear ();
275271 } else {
276- int n = std::min (k, size (a));
277- int m = std::min (k, size (b));
272+ size_t n = std::min (k, size (a));
273+ size_t m = std::min (k, size (b));
278274 a.resize (k);
279- for (int j = k - 1 ; j >= 0 ; j--) {
275+ for (int j = int ( k - 1 ) ; j >= 0 ; j--) {
280276 a[j] *= b[0 ];
281- for (int i = std::max (j - n, 0 ) + 1 ; i < std::min (j + 1 , m); i++) {
277+ for (size_t i = std::max< size_t > (j - n, 0 ) + 1 ; i < std::min< size_t > (j + 1 , m); i++) {
282278 a[j] += a[j - i] * b[i];
283279 }
284280 }
0 commit comments