|
1 | 1 | #ifndef CP_ALGO_MATH_CVECTOR_HPP
|
2 | 2 | #define CP_ALGO_MATH_CVECTOR_HPP
|
3 |
| -#include "../util/complex.hpp" |
4 |
| -#include <experimental/simd> |
| 3 | +#include <algorithm> |
| 4 | +#include <complex> |
| 5 | +#include <vector> |
| 6 | +#include <ranges> |
5 | 7 | namespace cp_algo::math::fft {
|
6 | 8 | using ftype = double;
|
7 |
| - using point = complex<ftype>; |
8 |
| - using vftype = std::experimental::native_simd<ftype>; |
9 |
| - using vpoint = complex<vftype>; |
10 |
| - static constexpr size_t flen = vftype::size(); |
| 9 | + using point = std::complex<ftype>; |
11 | 10 |
|
12 |
| - struct cvector { |
13 |
| - static constexpr size_t pre_roots = 1 << 18; |
14 |
| - std::vector<vftype> x, y; |
15 |
| - cvector(size_t n) { |
16 |
| - n = std::max(flen, std::bit_ceil(n)); |
17 |
| - x.resize(n / flen); |
18 |
| - y.resize(n / flen); |
| 11 | + struct ftvec: std::vector<point> { |
| 12 | + static constexpr size_t pre_roots = 1 << 16; |
| 13 | + static constexpr size_t threshold = 32; |
| 14 | + ftvec(size_t n) { |
| 15 | + this->resize(std::max(threshold, std::bit_ceil(n))); |
19 | 16 | }
|
20 |
| - template<class pt = point> |
21 |
| - void set(size_t k, pt t) { |
22 |
| - if constexpr(std::is_same_v<pt, point>) { |
23 |
| - x[k / flen][k % flen] = real(t); |
24 |
| - y[k / flen][k % flen] = imag(t); |
25 |
| - } else { |
26 |
| - x[k / flen] = real(t); |
27 |
| - y[k / flen] = imag(t); |
| 17 | + static auto dot_block(size_t k, ftvec const& A, ftvec const& B) { |
| 18 | + static std::array<point, 2 * threshold> r; |
| 19 | + std::ranges::fill(r, point(0)); |
| 20 | + for(size_t i = 0; i < threshold; i++) { |
| 21 | + for(size_t j = 0; j < threshold; j++) { |
| 22 | + r[i + j] += A[k + i] * B[k + j]; |
| 23 | + } |
28 | 24 | }
|
29 |
| - } |
30 |
| - template<class pt = point> |
31 |
| - pt get(size_t k) const { |
32 |
| - if constexpr(std::is_same_v<pt, point>) { |
33 |
| - return {x[k / flen][k % flen], y[k / flen][k % flen]}; |
34 |
| - } else { |
35 |
| - return {x[k / flen], y[k / flen]}; |
| 25 | + auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2); |
| 26 | + static std::array<point, threshold> res; |
| 27 | + for(size_t i = 0; i < threshold; i++) { |
| 28 | + res[i] = r[i] + r[i + threshold] * rt; |
36 | 29 | }
|
37 |
| - } |
38 |
| - vpoint vget(size_t k) const { |
39 |
| - return get<vpoint>(k); |
| 30 | + return res; |
40 | 31 | }
|
41 | 32 |
|
42 |
| - size_t size() const { |
43 |
| - return flen * std::size(x); |
| 33 | + void dot(ftvec const& t) { |
| 34 | + size_t n = this->size(); |
| 35 | + for(size_t k = 0; k < n; k += threshold) { |
| 36 | + std::ranges::copy(dot_block(k, *this, t), this->begin() + k); |
| 37 | + } |
44 | 38 | }
|
45 |
| - void dot(cvector const& t) { |
46 |
| - size_t n = size(); |
47 |
| - for(size_t k = 0; k < n; k += flen) { |
48 |
| - set(k, get<vpoint>(k) * t.get<vpoint>(k)); |
| 39 | + static std::array<point, pre_roots> roots, evalp; |
| 40 | + static std::array<size_t, pre_roots> eval_args; |
| 41 | + static point root(size_t n, size_t k) { |
| 42 | + if(n + k < pre_roots && roots[n + k] != point{}) { |
| 43 | + return roots[n + k]; |
49 | 44 | }
|
| 45 | + auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n)); |
| 46 | + if(n + k < pre_roots) { |
| 47 | + roots[n + k] = res; |
| 48 | + } |
| 49 | + return res; |
50 | 50 | }
|
51 |
| - static const cvector roots; |
52 |
| - template< bool precalc = false, class ft = point> |
53 |
| - static auto root(size_t n, size_t k, ft &&arg) { |
54 |
| - if(n < pre_roots && !precalc) { |
55 |
| - return roots.get<complex<ft>>(n + k); |
56 |
| - } else { |
57 |
| - return complex<ft>::polar(1., arg); |
| 51 | + static size_t eval_arg(size_t n) { |
| 52 | + if(n < pre_roots && eval_args[n]) { |
| 53 | + return eval_args[n]; |
| 54 | + } else if(n == 0) { |
| 55 | + return 0; |
| 56 | + } |
| 57 | + auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1); |
| 58 | + if(n < pre_roots) { |
| 59 | + eval_args[n] = res; |
58 | 60 | }
|
| 61 | + return res; |
| 62 | + } |
| 63 | + static point eval_point(size_t n) { |
| 64 | + if(n < pre_roots && evalp[n] != point{}) { |
| 65 | + return evalp[n]; |
| 66 | + } else if(n == 0) { |
| 67 | + return point(1); |
| 68 | + } |
| 69 | + auto res = root(2 * std::bit_floor(n), eval_arg(n)); |
| 70 | + if(n < pre_roots) { |
| 71 | + evalp[n] = res; |
| 72 | + } |
| 73 | + return res; |
59 | 74 | }
|
60 |
| - template<class pt = point, bool precalc = false> |
61 | 75 | static void exec_on_roots(size_t n, size_t m, auto &&callback) {
|
62 |
| - ftype arg = std::numbers::pi / (ftype)n; |
63 |
| - size_t step = sizeof(pt) / sizeof(point); |
64 |
| - using ft = pt::value_type; |
65 |
| - auto k = [&]() { |
66 |
| - if constexpr(std::is_same_v<pt, point>) { |
67 |
| - return ft{}; |
68 |
| - } else { |
69 |
| - return ft{[](auto i) {return i;}}; |
| 76 | + auto step = root(n, 1); |
| 77 | + auto rt = point(1); |
| 78 | + for(size_t i = 0; i < m; i++) { |
| 79 | + if(i % threshold == 0) { |
| 80 | + rt = root(n / threshold, i / threshold); |
70 | 81 | }
|
71 |
| - }(); |
72 |
| - for(size_t i = 0; i < m; i += step, k += (ftype)step) { |
73 |
| - callback(i, root<precalc>(n, i, arg * k)); |
| 82 | + callback(i, rt); |
| 83 | + rt *= step; |
| 84 | + } |
| 85 | + } |
| 86 | + static void exec_on_evals(size_t n, auto &&callback) { |
| 87 | + for(size_t i = 0; i < n; i++) { |
| 88 | + callback(i, eval_point(i)); |
74 | 89 | }
|
75 | 90 | }
|
76 | 91 |
|
77 | 92 | void ifft() {
|
78 |
| - size_t n = size(); |
79 |
| - for(size_t i = 1; i < n; i *= 2) { |
80 |
| - for(size_t j = 0; j < n; j += 2 * i) { |
81 |
| - auto butterfly = [&]<class pt>(size_t k, pt rt) { |
82 |
| - k += j; |
83 |
| - auto t = get<pt>(k + i) * conj(rt); |
84 |
| - set(k + i, get<pt>(k) - t); |
85 |
| - set(k, get<pt>(k) + t); |
86 |
| - }; |
87 |
| - if(i < flen) { |
88 |
| - exec_on_roots<point>(i, i, butterfly); |
89 |
| - } else { |
90 |
| - exec_on_roots<vpoint>(i, i, butterfly); |
| 93 | + size_t n = this->size(); |
| 94 | + for(size_t half = threshold; half <= n / 2; half *= 2) { |
| 95 | + exec_on_evals(n / (2 * half), [&](size_t k, point rt) { |
| 96 | + k *= 2 * half; |
| 97 | + for(size_t j = k; j < k + half; j++) { |
| 98 | + auto A = this->at(j) + this->at(j + half); |
| 99 | + auto B = this->at(j) - this->at(j + half); |
| 100 | + this->at(j) = A; |
| 101 | + this->at(j + half) = B * conj(rt); |
91 | 102 | }
|
92 |
| - } |
| 103 | + }); |
93 | 104 | }
|
94 |
| - for(size_t k = 0; k < n; k += flen) { |
95 |
| - set(k, get<vpoint>(k) /= (ftype)n); |
| 105 | + point ni = point(int(threshold)) / point(int(n)); |
| 106 | + for(auto &it: *this) { |
| 107 | + it *= ni; |
96 | 108 | }
|
97 | 109 | }
|
98 | 110 | void fft() {
|
99 |
| - size_t n = size(); |
100 |
| - for(size_t i = n / 2; i >= 1; i /= 2) { |
101 |
| - for(size_t j = 0; j < n; j += 2 * i) { |
102 |
| - auto butterfly = [&]<class pt>(size_t k, pt rt) { |
103 |
| - k += j; |
104 |
| - auto A = get<pt>(k) + get<pt>(k + i); |
105 |
| - auto B = get<pt>(k) - get<pt>(k + i); |
106 |
| - set(k, A); |
107 |
| - set(k + i, B * rt); |
108 |
| - }; |
109 |
| - if(i < flen) { |
110 |
| - exec_on_roots<point>(i, i, butterfly); |
111 |
| - } else { |
112 |
| - exec_on_roots<vpoint>(i, i, butterfly); |
| 111 | + size_t n = this->size(); |
| 112 | + for(size_t half = n / 2; half >= threshold; half /= 2) { |
| 113 | + exec_on_evals(n / (2 * half), [&](size_t k, point rt) { |
| 114 | + k *= 2 * half; |
| 115 | + for(size_t j = k; j < k + half; j++) { |
| 116 | + auto t = this->at(j + half) * rt; |
| 117 | + this->at(j + half) = this->at(j) - t; |
| 118 | + this->at(j) += t; |
113 | 119 | }
|
114 |
| - } |
| 120 | + }); |
115 | 121 | }
|
116 | 122 | }
|
117 | 123 | };
|
118 |
| - const cvector cvector::roots = []() { |
119 |
| - cvector res(pre_roots); |
120 |
| - for(size_t n = 1; n < res.size(); n *= 2) { |
121 |
| - auto propagate = [&](size_t k, auto rt) { |
122 |
| - res.set(n + k, rt); |
123 |
| - }; |
124 |
| - if(n < flen) { |
125 |
| - res.exec_on_roots<point, true>(n, n, propagate); |
126 |
| - } else { |
127 |
| - res.exec_on_roots<vpoint, true>(n, n, propagate); |
128 |
| - } |
129 |
| - } |
130 |
| - return res; |
131 |
| - }(); |
| 124 | + std::array<point, ftvec::pre_roots> ftvec::roots = {}; |
| 125 | + std::array<point, ftvec::pre_roots> ftvec::evalp = {}; |
| 126 | + std::array<size_t, ftvec::pre_roots> ftvec::eval_args = {}; |
132 | 127 | }
|
133 | 128 | #endif // CP_ALGO_MATH_CVECTOR_HPP
|
0 commit comments