Skip to content

Commit 7a11568

Browse files
committed
Rollback to faster fft...
1 parent 2868c3a commit 7a11568

File tree

3 files changed

+181
-152
lines changed

3 files changed

+181
-152
lines changed

cp-algo/math/cvector.hpp

+153-90
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,191 @@
11
#ifndef CP_ALGO_MATH_CVECTOR_HPP
22
#define CP_ALGO_MATH_CVECTOR_HPP
33
#include <algorithm>
4+
#include <cassert>
45
#include <complex>
56
#include <vector>
67
#include <ranges>
78
namespace cp_algo::math::fft {
89
using ftype = double;
10+
static constexpr size_t bytes = 32;
11+
static constexpr size_t flen = bytes / sizeof(ftype);
912
using point = std::complex<ftype>;
13+
using vftype [[gnu::vector_size(bytes)]] = ftype;
14+
using vpoint = std::complex<vftype>;
1015

11-
struct ftvec: std::vector<point> {
12-
static constexpr size_t pre_roots = 1 << 16;
13-
static constexpr size_t threshold = 32;
14-
ftvec(size_t n) {
15-
this->resize(std::max(threshold, std::bit_ceil(n)));
16+
#define WITH_IV(...) \
17+
[&]<size_t ... i>(std::index_sequence<i...>) { \
18+
return __VA_ARGS__; \
19+
}(std::make_index_sequence<flen>());
20+
21+
template<typename ft>
22+
constexpr ft to_ft(auto x) {
23+
return ft{} + x;
24+
}
25+
template<typename pt>
26+
constexpr pt to_pt(point r) {
27+
using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>;
28+
return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())};
29+
}
30+
struct cvector {
31+
static constexpr size_t pre_roots = 1 << 17;
32+
std::vector<vftype> x, y;
33+
cvector(size_t n) {
34+
n = std::max(flen, std::bit_ceil(n));
35+
x.resize(n / flen);
36+
y.resize(n / flen);
1637
}
17-
static auto dot_block(size_t k, ftvec const& A, ftvec const& B) {
18-
static std::array<point, 2 * threshold> r;
19-
std::ranges::fill(r, point(0));
20-
for(size_t i = 0; i < threshold; i++) {
21-
for(size_t j = 0; j < threshold; j++) {
22-
r[i + j] += A[k + i] * B[k + j];
23-
}
24-
}
25-
auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2);
26-
static std::array<point, threshold> res;
27-
for(size_t i = 0; i < threshold; i++) {
28-
res[i] = r[i] + r[i + threshold] * rt;
38+
template<class pt = point>
39+
void set(size_t k, pt t) {
40+
if constexpr(std::is_same_v<pt, point>) {
41+
x[k / flen][k % flen] = real(t);
42+
y[k / flen][k % flen] = imag(t);
43+
} else {
44+
x[k / flen] = real(t);
45+
y[k / flen] = imag(t);
2946
}
30-
return res;
3147
}
32-
33-
void dot(ftvec const& t) {
34-
size_t n = this->size();
35-
for(size_t k = 0; k < n; k += threshold) {
36-
std::ranges::copy(dot_block(k, *this, t), this->begin() + k);
48+
template<class pt = point>
49+
pt get(size_t k) const {
50+
if constexpr(std::is_same_v<pt, point>) {
51+
return {x[k / flen][k % flen], y[k / flen][k % flen]};
52+
} else {
53+
return {x[k / flen], y[k / flen]};
3754
}
3855
}
39-
static std::array<point, pre_roots> roots, evalp;
40-
static std::array<size_t, pre_roots> eval_args;
41-
static point root(size_t n, size_t k) {
42-
if(n + k < pre_roots && roots[n + k] != point{}) {
43-
return roots[n + k];
44-
}
45-
auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n));
46-
if(n + k < pre_roots) {
47-
roots[n + k] = res;
48-
}
49-
return res;
56+
vpoint vget(size_t k) const {
57+
return get<vpoint>(k);
5058
}
51-
static size_t eval_arg(size_t n) {
52-
if(n < pre_roots && eval_args[n]) {
53-
return eval_args[n];
54-
} else if(n == 0) {
55-
return 0;
56-
}
57-
auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
58-
if(n < pre_roots) {
59-
eval_args[n] = res;
60-
}
61-
return res;
59+
60+
size_t size() const {
61+
return flen * std::size(x);
6262
}
63-
static point eval_point(size_t n) {
64-
if(n < pre_roots && evalp[n] != point{}) {
65-
return evalp[n];
66-
} else if(n == 0) {
67-
return point(1);
63+
void dot(cvector const& t) {
64+
size_t n = size();
65+
for(size_t k = 0; k < n; k += flen) {
66+
set(k, get<vpoint>(k) * t.get<vpoint>(k));
6867
}
69-
auto res = root(2 * std::bit_floor(n), eval_arg(n));
68+
}
69+
static const cvector roots;
70+
template<class pt = point>
71+
static pt root(size_t n, size_t k) {
7072
if(n < pre_roots) {
71-
evalp[n] = res;
73+
return roots.get<pt>(n + k);
74+
} else {
75+
auto arg = std::numbers::pi / ftype(n);
76+
if constexpr(std::is_same_v<pt, point>) {
77+
return {cos(ftype(k) * arg), sin(ftype(k) * arg)};
78+
} else {
79+
return WITH_IV(pt{vftype{cos(ftype(k + i) * arg)...},
80+
vftype{sin(ftype(k + i) * arg)...}});
81+
}
7282
}
73-
return res;
7483
}
84+
template<class pt = point>
7585
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
76-
auto step = root(n, 1);
77-
auto rt = point(1);
78-
for(size_t i = 0; i < m; i++) {
79-
if(i % threshold == 0) {
80-
rt = root(n / threshold, i / threshold);
86+
size_t step = sizeof(pt) / sizeof(point);
87+
pt cur;
88+
pt arg = to_pt<pt>(root<point>(n, step));
89+
for(size_t i = 0; i < m; i += step) {
90+
if(i % 64 == 0 || n < pre_roots) {
91+
cur = root<pt>(n, i);
92+
} else {
93+
cur *= arg;
8194
}
82-
callback(i, rt);
83-
rt *= step;
84-
}
85-
}
86-
static void exec_on_evals(size_t n, auto &&callback) {
87-
for(size_t i = 0; i < n; i++) {
88-
callback(i, eval_point(i));
95+
callback(i, cur);
8996
}
9097
}
9198

9299
void ifft() {
93-
size_t n = this->size();
94-
for(size_t half = threshold; half <= n / 2; half *= 2) {
95-
exec_on_evals(n / (2 * half), [&](size_t k, point rt) {
96-
k *= 2 * half;
97-
for(size_t j = k; j < k + half; j++) {
98-
auto A = this->at(j) + this->at(j + half);
99-
auto B = this->at(j) - this->at(j + half);
100-
this->at(j) = A;
101-
this->at(j + half) = B * conj(rt);
100+
size_t n = size();
101+
for(size_t i = 1; i < n; i *= 2) {
102+
for(size_t j = 0; j < n; j += 2 * i) {
103+
auto butterfly = [&]<class pt>(size_t k, pt rt) {
104+
k += j;
105+
auto t = get<pt>(k + i) * conj(rt);
106+
set(k + i, get<pt>(k) - t);
107+
set(k, get<pt>(k) + t);
108+
};
109+
if(2 * i <= flen) {
110+
exec_on_roots(i, i, butterfly);
111+
} else {
112+
exec_on_roots<vpoint>(i, i, butterfly);
102113
}
103-
});
114+
}
104115
}
105-
point ni = point(int(threshold)) / point(int(n));
106-
for(auto &it: *this) {
107-
it *= ni;
116+
for(size_t k = 0; k < n; k += flen) {
117+
set(k, get<vpoint>(k) /= to_pt<vpoint>(ftype(n)));
108118
}
109119
}
110120
void fft() {
111-
size_t n = this->size();
112-
for(size_t half = n / 2; half >= threshold; half /= 2) {
113-
exec_on_evals(n / (2 * half), [&](size_t k, point rt) {
114-
k *= 2 * half;
115-
for(size_t j = k; j < k + half; j++) {
116-
auto t = this->at(j + half) * rt;
117-
this->at(j + half) = this->at(j) - t;
118-
this->at(j) += t;
121+
size_t n = size();
122+
for(size_t i = n / 2; i >= 1; i /= 2) {
123+
for(size_t j = 0; j < n; j += 2 * i) {
124+
auto butterfly = [&]<class pt>(size_t k, pt rt) {
125+
k += j;
126+
auto A = get<pt>(k) + get<pt>(k + i);
127+
auto B = get<pt>(k) - get<pt>(k + i);
128+
set(k, A);
129+
set(k + i, B * rt);
130+
};
131+
if(2 * i <= flen) {
132+
exec_on_roots(i, i, butterfly);
133+
} else {
134+
exec_on_roots<vpoint>(i, i, butterfly);
119135
}
120-
});
136+
}
137+
}
138+
}
139+
};
140+
const cvector cvector::roots = []() {
141+
cvector res(pre_roots);
142+
for(size_t n = 1; n < res.size(); n *= 2) {
143+
auto base = std::polar(1., std::numbers::pi / ftype(n));
144+
point cur = 1;
145+
for(size_t k = 0; k < n; k++) {
146+
if((k & 15) == 0) {
147+
cur = std::polar(1., std::numbers::pi * ftype(k) / ftype(n));
148+
}
149+
res.set(n + k, cur);
150+
cur *= base;
151+
}
152+
}
153+
return res;
154+
}();
155+
156+
template<typename base>
157+
struct dft {
158+
cvector A;
159+
160+
dft(std::vector<base> const& a, size_t n): A(n) {
161+
for(size_t i = 0; i < std::min(n, a.size()); i++) {
162+
A.set(i, a[i]);
163+
}
164+
if(n) {
165+
A.fft();
121166
}
122167
}
168+
169+
std::vector<base> operator *= (dft const& B) {
170+
assert(A.size() == B.A.size());
171+
size_t n = A.size();
172+
if(!n) {
173+
return std::vector<base>();
174+
}
175+
A.dot(B.A);
176+
A.ifft();
177+
std::vector<base> res(n);
178+
for(size_t k = 0; k < n; k++) {
179+
res[k] = A.get(k);
180+
}
181+
return res;
182+
}
183+
184+
auto operator * (dft const& B) const {
185+
return dft(*this) *= B;
186+
}
187+
188+
point operator [](int i) const {return A.get(i);}
123189
};
124-
std::array<point, ftvec::pre_roots> ftvec::roots = {};
125-
std::array<point, ftvec::pre_roots> ftvec::evalp = {};
126-
std::array<size_t, ftvec::pre_roots> ftvec::eval_args = {};
127190
}
128191
#endif // CP_ALGO_MATH_CVECTOR_HPP

0 commit comments

Comments
 (0)