|
1 | 1 | #ifndef CP_ALGO_MATH_CVECTOR_HPP
|
2 | 2 | #define CP_ALGO_MATH_CVECTOR_HPP
|
3 | 3 | #include <algorithm>
|
| 4 | +#include <cassert> |
4 | 5 | #include <complex>
|
5 | 6 | #include <vector>
|
6 | 7 | #include <ranges>
|
7 | 8 | namespace cp_algo::math::fft {
|
8 | 9 | using ftype = double;
|
| 10 | + static constexpr size_t bytes = 32; |
| 11 | + static constexpr size_t flen = bytes / sizeof(ftype); |
9 | 12 | using point = std::complex<ftype>;
|
| 13 | + using vftype [[gnu::vector_size(bytes)]] = ftype; |
| 14 | + using vpoint = std::complex<vftype>; |
10 | 15 |
|
11 |
| - struct ftvec: std::vector<point> { |
12 |
| - static constexpr size_t pre_roots = 1 << 16; |
13 |
| - static constexpr size_t threshold = 32; |
14 |
| - ftvec(size_t n) { |
15 |
| - this->resize(std::max(threshold, std::bit_ceil(n))); |
| 16 | +#define WITH_IV(...) \ |
| 17 | + [&]<size_t ... i>(std::index_sequence<i...>) { \ |
| 18 | + return __VA_ARGS__; \ |
| 19 | + }(std::make_index_sequence<flen>()); |
| 20 | + |
| 21 | + template<typename ft> |
| 22 | + constexpr ft to_ft(auto x) { |
| 23 | + return ft{} + x; |
| 24 | + } |
| 25 | + template<typename pt> |
| 26 | + constexpr pt to_pt(point r) { |
| 27 | + using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>; |
| 28 | + return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())}; |
| 29 | + } |
| 30 | + struct cvector { |
| 31 | + static constexpr size_t pre_roots = 1 << 17; |
| 32 | + std::vector<vftype> x, y; |
| 33 | + cvector(size_t n) { |
| 34 | + n = std::max(flen, std::bit_ceil(n)); |
| 35 | + x.resize(n / flen); |
| 36 | + y.resize(n / flen); |
16 | 37 | }
|
17 |
| - static auto dot_block(size_t k, ftvec const& A, ftvec const& B) { |
18 |
| - static std::array<point, 2 * threshold> r; |
19 |
| - std::ranges::fill(r, point(0)); |
20 |
| - for(size_t i = 0; i < threshold; i++) { |
21 |
| - for(size_t j = 0; j < threshold; j++) { |
22 |
| - r[i + j] += A[k + i] * B[k + j]; |
23 |
| - } |
24 |
| - } |
25 |
| - auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2); |
26 |
| - static std::array<point, threshold> res; |
27 |
| - for(size_t i = 0; i < threshold; i++) { |
28 |
| - res[i] = r[i] + r[i + threshold] * rt; |
| 38 | + template<class pt = point> |
| 39 | + void set(size_t k, pt t) { |
| 40 | + if constexpr(std::is_same_v<pt, point>) { |
| 41 | + x[k / flen][k % flen] = real(t); |
| 42 | + y[k / flen][k % flen] = imag(t); |
| 43 | + } else { |
| 44 | + x[k / flen] = real(t); |
| 45 | + y[k / flen] = imag(t); |
29 | 46 | }
|
30 |
| - return res; |
31 | 47 | }
|
32 |
| - |
33 |
| - void dot(ftvec const& t) { |
34 |
| - size_t n = this->size(); |
35 |
| - for(size_t k = 0; k < n; k += threshold) { |
36 |
| - std::ranges::copy(dot_block(k, *this, t), this->begin() + k); |
| 48 | + template<class pt = point> |
| 49 | + pt get(size_t k) const { |
| 50 | + if constexpr(std::is_same_v<pt, point>) { |
| 51 | + return {x[k / flen][k % flen], y[k / flen][k % flen]}; |
| 52 | + } else { |
| 53 | + return {x[k / flen], y[k / flen]}; |
37 | 54 | }
|
38 | 55 | }
|
39 |
| - static std::array<point, pre_roots> roots, evalp; |
40 |
| - static std::array<size_t, pre_roots> eval_args; |
41 |
| - static point root(size_t n, size_t k) { |
42 |
| - if(n + k < pre_roots && roots[n + k] != point{}) { |
43 |
| - return roots[n + k]; |
44 |
| - } |
45 |
| - auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n)); |
46 |
| - if(n + k < pre_roots) { |
47 |
| - roots[n + k] = res; |
48 |
| - } |
49 |
| - return res; |
| 56 | + vpoint vget(size_t k) const { |
| 57 | + return get<vpoint>(k); |
50 | 58 | }
|
51 |
| - static size_t eval_arg(size_t n) { |
52 |
| - if(n < pre_roots && eval_args[n]) { |
53 |
| - return eval_args[n]; |
54 |
| - } else if(n == 0) { |
55 |
| - return 0; |
56 |
| - } |
57 |
| - auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1); |
58 |
| - if(n < pre_roots) { |
59 |
| - eval_args[n] = res; |
60 |
| - } |
61 |
| - return res; |
| 59 | + |
| 60 | + size_t size() const { |
| 61 | + return flen * std::size(x); |
62 | 62 | }
|
63 |
| - static point eval_point(size_t n) { |
64 |
| - if(n < pre_roots && evalp[n] != point{}) { |
65 |
| - return evalp[n]; |
66 |
| - } else if(n == 0) { |
67 |
| - return point(1); |
| 63 | + void dot(cvector const& t) { |
| 64 | + size_t n = size(); |
| 65 | + for(size_t k = 0; k < n; k += flen) { |
| 66 | + set(k, get<vpoint>(k) * t.get<vpoint>(k)); |
68 | 67 | }
|
69 |
| - auto res = root(2 * std::bit_floor(n), eval_arg(n)); |
| 68 | + } |
| 69 | + static const cvector roots; |
| 70 | + template<class pt = point> |
| 71 | + static pt root(size_t n, size_t k) { |
70 | 72 | if(n < pre_roots) {
|
71 |
| - evalp[n] = res; |
| 73 | + return roots.get<pt>(n + k); |
| 74 | + } else { |
| 75 | + auto arg = std::numbers::pi / ftype(n); |
| 76 | + if constexpr(std::is_same_v<pt, point>) { |
| 77 | + return {cos(ftype(k) * arg), sin(ftype(k) * arg)}; |
| 78 | + } else { |
| 79 | + return WITH_IV(pt{vftype{cos(ftype(k + i) * arg)...}, |
| 80 | + vftype{sin(ftype(k + i) * arg)...}}); |
| 81 | + } |
72 | 82 | }
|
73 |
| - return res; |
74 | 83 | }
|
| 84 | + template<class pt = point> |
75 | 85 | static void exec_on_roots(size_t n, size_t m, auto &&callback) {
|
76 |
| - auto step = root(n, 1); |
77 |
| - auto rt = point(1); |
78 |
| - for(size_t i = 0; i < m; i++) { |
79 |
| - if(i % threshold == 0) { |
80 |
| - rt = root(n / threshold, i / threshold); |
| 86 | + size_t step = sizeof(pt) / sizeof(point); |
| 87 | + pt cur; |
| 88 | + pt arg = to_pt<pt>(root<point>(n, step)); |
| 89 | + for(size_t i = 0; i < m; i += step) { |
| 90 | + if(i % 64 == 0 || n < pre_roots) { |
| 91 | + cur = root<pt>(n, i); |
| 92 | + } else { |
| 93 | + cur *= arg; |
81 | 94 | }
|
82 |
| - callback(i, rt); |
83 |
| - rt *= step; |
84 |
| - } |
85 |
| - } |
86 |
| - static void exec_on_evals(size_t n, auto &&callback) { |
87 |
| - for(size_t i = 0; i < n; i++) { |
88 |
| - callback(i, eval_point(i)); |
| 95 | + callback(i, cur); |
89 | 96 | }
|
90 | 97 | }
|
91 | 98 |
|
92 | 99 | void ifft() {
|
93 |
| - size_t n = this->size(); |
94 |
| - for(size_t half = threshold; half <= n / 2; half *= 2) { |
95 |
| - exec_on_evals(n / (2 * half), [&](size_t k, point rt) { |
96 |
| - k *= 2 * half; |
97 |
| - for(size_t j = k; j < k + half; j++) { |
98 |
| - auto A = this->at(j) + this->at(j + half); |
99 |
| - auto B = this->at(j) - this->at(j + half); |
100 |
| - this->at(j) = A; |
101 |
| - this->at(j + half) = B * conj(rt); |
| 100 | + size_t n = size(); |
| 101 | + for(size_t i = 1; i < n; i *= 2) { |
| 102 | + for(size_t j = 0; j < n; j += 2 * i) { |
| 103 | + auto butterfly = [&]<class pt>(size_t k, pt rt) { |
| 104 | + k += j; |
| 105 | + auto t = get<pt>(k + i) * conj(rt); |
| 106 | + set(k + i, get<pt>(k) - t); |
| 107 | + set(k, get<pt>(k) + t); |
| 108 | + }; |
| 109 | + if(2 * i <= flen) { |
| 110 | + exec_on_roots(i, i, butterfly); |
| 111 | + } else { |
| 112 | + exec_on_roots<vpoint>(i, i, butterfly); |
102 | 113 | }
|
103 |
| - }); |
| 114 | + } |
104 | 115 | }
|
105 |
| - point ni = point(int(threshold)) / point(int(n)); |
106 |
| - for(auto &it: *this) { |
107 |
| - it *= ni; |
| 116 | + for(size_t k = 0; k < n; k += flen) { |
| 117 | + set(k, get<vpoint>(k) /= to_pt<vpoint>(ftype(n))); |
108 | 118 | }
|
109 | 119 | }
|
110 | 120 | void fft() {
|
111 |
| - size_t n = this->size(); |
112 |
| - for(size_t half = n / 2; half >= threshold; half /= 2) { |
113 |
| - exec_on_evals(n / (2 * half), [&](size_t k, point rt) { |
114 |
| - k *= 2 * half; |
115 |
| - for(size_t j = k; j < k + half; j++) { |
116 |
| - auto t = this->at(j + half) * rt; |
117 |
| - this->at(j + half) = this->at(j) - t; |
118 |
| - this->at(j) += t; |
| 121 | + size_t n = size(); |
| 122 | + for(size_t i = n / 2; i >= 1; i /= 2) { |
| 123 | + for(size_t j = 0; j < n; j += 2 * i) { |
| 124 | + auto butterfly = [&]<class pt>(size_t k, pt rt) { |
| 125 | + k += j; |
| 126 | + auto A = get<pt>(k) + get<pt>(k + i); |
| 127 | + auto B = get<pt>(k) - get<pt>(k + i); |
| 128 | + set(k, A); |
| 129 | + set(k + i, B * rt); |
| 130 | + }; |
| 131 | + if(2 * i <= flen) { |
| 132 | + exec_on_roots(i, i, butterfly); |
| 133 | + } else { |
| 134 | + exec_on_roots<vpoint>(i, i, butterfly); |
119 | 135 | }
|
120 |
| - }); |
| 136 | + } |
| 137 | + } |
| 138 | + } |
| 139 | + }; |
| 140 | + const cvector cvector::roots = []() { |
| 141 | + cvector res(pre_roots); |
| 142 | + for(size_t n = 1; n < res.size(); n *= 2) { |
| 143 | + auto base = std::polar(1., std::numbers::pi / ftype(n)); |
| 144 | + point cur = 1; |
| 145 | + for(size_t k = 0; k < n; k++) { |
| 146 | + if((k & 15) == 0) { |
| 147 | + cur = std::polar(1., std::numbers::pi * ftype(k) / ftype(n)); |
| 148 | + } |
| 149 | + res.set(n + k, cur); |
| 150 | + cur *= base; |
| 151 | + } |
| 152 | + } |
| 153 | + return res; |
| 154 | + }(); |
| 155 | + |
| 156 | + template<typename base> |
| 157 | + struct dft { |
| 158 | + cvector A; |
| 159 | + |
| 160 | + dft(std::vector<base> const& a, size_t n): A(n) { |
| 161 | + for(size_t i = 0; i < std::min(n, a.size()); i++) { |
| 162 | + A.set(i, a[i]); |
| 163 | + } |
| 164 | + if(n) { |
| 165 | + A.fft(); |
121 | 166 | }
|
122 | 167 | }
|
| 168 | + |
| 169 | + std::vector<base> operator *= (dft const& B) { |
| 170 | + assert(A.size() == B.A.size()); |
| 171 | + size_t n = A.size(); |
| 172 | + if(!n) { |
| 173 | + return std::vector<base>(); |
| 174 | + } |
| 175 | + A.dot(B.A); |
| 176 | + A.ifft(); |
| 177 | + std::vector<base> res(n); |
| 178 | + for(size_t k = 0; k < n; k++) { |
| 179 | + res[k] = A.get(k); |
| 180 | + } |
| 181 | + return res; |
| 182 | + } |
| 183 | + |
| 184 | + auto operator * (dft const& B) const { |
| 185 | + return dft(*this) *= B; |
| 186 | + } |
| 187 | + |
| 188 | + point operator [](int i) const {return A.get(i);} |
123 | 189 | };
|
124 |
| - std::array<point, ftvec::pre_roots> ftvec::roots = {}; |
125 |
| - std::array<point, ftvec::pre_roots> ftvec::evalp = {}; |
126 |
| - std::array<size_t, ftvec::pre_roots> ftvec::eval_args = {}; |
127 | 190 | }
|
128 | 191 | #endif // CP_ALGO_MATH_CVECTOR_HPP
|
0 commit comments