|
| 1 | +#ifndef CP_ALGO_MATH_CVECTOR_HPP |
| 2 | +#define CP_ALGO_MATH_CVECTOR_HPP |
| 3 | +#include "../util/complex.hpp" |
| 4 | +#include <experimental/simd> |
| 5 | +namespace cp_algo::math::fft { |
| 6 | + using ftype = double; |
| 7 | + using point = complex<ftype>; |
| 8 | + using vftype = std::experimental::native_simd<ftype>; |
| 9 | + using vpoint = complex<vftype>; |
| 10 | + static constexpr size_t flen = vftype::size(); |
| 11 | + |
| 12 | + struct cvector { |
| 13 | + static constexpr size_t pre_roots = 1 << 18; |
| 14 | + std::vector<vftype> x, y; |
| 15 | + cvector(size_t n) { |
| 16 | + n = std::max(flen, std::bit_ceil(n)); |
| 17 | + x.resize(n / flen); |
| 18 | + y.resize(n / flen); |
| 19 | + } |
| 20 | + template<class pt = point> |
| 21 | + void set(size_t k, pt t) { |
| 22 | + if constexpr(std::is_same_v<pt, point>) { |
| 23 | + x[k / flen][k % flen] = real(t); |
| 24 | + y[k / flen][k % flen] = imag(t); |
| 25 | + } else { |
| 26 | + x[k / flen] = real(t); |
| 27 | + y[k / flen] = imag(t); |
| 28 | + } |
| 29 | + } |
| 30 | + template<class pt = point> |
| 31 | + pt get(size_t k) const { |
| 32 | + if constexpr(std::is_same_v<pt, point>) { |
| 33 | + return {x[k / flen][k % flen], y[k / flen][k % flen]}; |
| 34 | + } else { |
| 35 | + return {x[k / flen], y[k / flen]}; |
| 36 | + } |
| 37 | + } |
| 38 | + vpoint vget(size_t k) const { |
| 39 | + return get<vpoint>(k); |
| 40 | + } |
| 41 | + |
| 42 | + size_t size() const { |
| 43 | + return flen * std::size(x); |
| 44 | + } |
| 45 | + void dot(cvector const& t) { |
| 46 | + size_t n = size(); |
| 47 | + for(size_t k = 0; k < n; k += flen) { |
| 48 | + set(k, get<vpoint>(k) * t.get<vpoint>(k)); |
| 49 | + } |
| 50 | + } |
| 51 | + static const cvector roots; |
| 52 | + template< bool precalc = false, class ft = point> |
| 53 | + static auto root(size_t n, size_t k, ft &&arg) { |
| 54 | + if(n < pre_roots && !precalc) { |
| 55 | + return roots.get<complex<ft>>(n + k); |
| 56 | + } else { |
| 57 | + return complex<ft>::polar(1., arg); |
| 58 | + } |
| 59 | + } |
| 60 | + template<class pt = point, bool precalc = false> |
| 61 | + static void exec_on_roots(size_t n, size_t m, auto &&callback) { |
| 62 | + ftype arg = std::numbers::pi / (ftype)n; |
| 63 | + size_t step = sizeof(pt) / sizeof(point); |
| 64 | + using ft = pt::value_type; |
| 65 | + auto k = [&]() { |
| 66 | + if constexpr(std::is_same_v<pt, point>) { |
| 67 | + return ft{}; |
| 68 | + } else { |
| 69 | + return ft{[](auto i) {return i;}}; |
| 70 | + } |
| 71 | + }(); |
| 72 | + for(size_t i = 0; i < m; i += step, k += (ftype)step) { |
| 73 | + callback(i, root<precalc>(n, i, arg * k)); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + void ifft() { |
| 78 | + size_t n = size(); |
| 79 | + for(size_t i = 1; i < n; i *= 2) { |
| 80 | + for(size_t j = 0; j < n; j += 2 * i) { |
| 81 | + auto butterfly = [&]<class pt>(size_t k, pt rt) { |
| 82 | + k += j; |
| 83 | + auto t = get<pt>(k + i) * conj(rt); |
| 84 | + set(k + i, get<pt>(k) - t); |
| 85 | + set(k, get<pt>(k) + t); |
| 86 | + }; |
| 87 | + if(i < flen) { |
| 88 | + exec_on_roots<point>(i, i, butterfly); |
| 89 | + } else { |
| 90 | + exec_on_roots<vpoint>(i, i, butterfly); |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | + for(size_t k = 0; k < n; k += flen) { |
| 95 | + set(k, get<vpoint>(k) /= (ftype)n); |
| 96 | + } |
| 97 | + } |
| 98 | + void fft() { |
| 99 | + size_t n = size(); |
| 100 | + for(size_t i = n / 2; i >= 1; i /= 2) { |
| 101 | + for(size_t j = 0; j < n; j += 2 * i) { |
| 102 | + auto butterfly = [&]<class pt>(size_t k, pt rt) { |
| 103 | + k += j; |
| 104 | + auto A = get<pt>(k) + get<pt>(k + i); |
| 105 | + auto B = get<pt>(k) - get<pt>(k + i); |
| 106 | + set(k, A); |
| 107 | + set(k + i, B * rt); |
| 108 | + }; |
| 109 | + if(i < flen) { |
| 110 | + exec_on_roots<point>(i, i, butterfly); |
| 111 | + } else { |
| 112 | + exec_on_roots<vpoint>(i, i, butterfly); |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + }; |
| 118 | + const cvector cvector::roots = []() { |
| 119 | + cvector res(pre_roots); |
| 120 | + for(size_t n = 1; n < res.size(); n *= 2) { |
| 121 | + auto propagate = [&](size_t k, auto rt) { |
| 122 | + res.set(n + k, rt); |
| 123 | + }; |
| 124 | + if(n < flen) { |
| 125 | + res.exec_on_roots<point, true>(n, n, propagate); |
| 126 | + } else { |
| 127 | + res.exec_on_roots<vpoint, true>(n, n, propagate); |
| 128 | + } |
| 129 | + } |
| 130 | + return res; |
| 131 | + }(); |
| 132 | +} |
| 133 | +#endif // CP_ALGO_MATH_CVECTOR_HPP |
0 commit comments