Skip to content

Commit 2868c3a

Browse files
committed
Increase FFT precision, get rid of manual vectorization, use partial fft
1 parent ac2a9a0 commit 2868c3a

File tree

3 files changed

+131
-136
lines changed

3 files changed

+131
-136
lines changed

cp-algo/math/cvector.hpp

+96-101
Original file line numberDiff line numberDiff line change
@@ -1,133 +1,128 @@
11
#ifndef CP_ALGO_MATH_CVECTOR_HPP
22
#define CP_ALGO_MATH_CVECTOR_HPP
3-
#include "../util/complex.hpp"
4-
#include <experimental/simd>
3+
#include <algorithm>
4+
#include <complex>
5+
#include <vector>
6+
#include <ranges>
57
namespace cp_algo::math::fft {
68
using ftype = double;
7-
using point = complex<ftype>;
8-
using vftype = std::experimental::native_simd<ftype>;
9-
using vpoint = complex<vftype>;
10-
static constexpr size_t flen = vftype::size();
9+
using point = std::complex<ftype>;
1110

12-
struct cvector {
13-
static constexpr size_t pre_roots = 1 << 18;
14-
std::vector<vftype> x, y;
15-
cvector(size_t n) {
16-
n = std::max(flen, std::bit_ceil(n));
17-
x.resize(n / flen);
18-
y.resize(n / flen);
11+
struct ftvec: std::vector<point> {
12+
static constexpr size_t pre_roots = 1 << 16;
13+
static constexpr size_t threshold = 32;
14+
ftvec(size_t n) {
15+
this->resize(std::max(threshold, std::bit_ceil(n)));
1916
}
20-
template<class pt = point>
21-
void set(size_t k, pt t) {
22-
if constexpr(std::is_same_v<pt, point>) {
23-
x[k / flen][k % flen] = real(t);
24-
y[k / flen][k % flen] = imag(t);
25-
} else {
26-
x[k / flen] = real(t);
27-
y[k / flen] = imag(t);
17+
static auto dot_block(size_t k, ftvec const& A, ftvec const& B) {
18+
static std::array<point, 2 * threshold> r;
19+
std::ranges::fill(r, point(0));
20+
for(size_t i = 0; i < threshold; i++) {
21+
for(size_t j = 0; j < threshold; j++) {
22+
r[i + j] += A[k + i] * B[k + j];
23+
}
2824
}
29-
}
30-
template<class pt = point>
31-
pt get(size_t k) const {
32-
if constexpr(std::is_same_v<pt, point>) {
33-
return {x[k / flen][k % flen], y[k / flen][k % flen]};
34-
} else {
35-
return {x[k / flen], y[k / flen]};
25+
auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2);
26+
static std::array<point, threshold> res;
27+
for(size_t i = 0; i < threshold; i++) {
28+
res[i] = r[i] + r[i + threshold] * rt;
3629
}
37-
}
38-
vpoint vget(size_t k) const {
39-
return get<vpoint>(k);
30+
return res;
4031
}
4132

42-
size_t size() const {
43-
return flen * std::size(x);
33+
void dot(ftvec const& t) {
34+
size_t n = this->size();
35+
for(size_t k = 0; k < n; k += threshold) {
36+
std::ranges::copy(dot_block(k, *this, t), this->begin() + k);
37+
}
4438
}
45-
void dot(cvector const& t) {
46-
size_t n = size();
47-
for(size_t k = 0; k < n; k += flen) {
48-
set(k, get<vpoint>(k) * t.get<vpoint>(k));
39+
static std::array<point, pre_roots> roots, evalp;
40+
static std::array<size_t, pre_roots> eval_args;
41+
static point root(size_t n, size_t k) {
42+
if(n + k < pre_roots && roots[n + k] != point{}) {
43+
return roots[n + k];
4944
}
45+
auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n));
46+
if(n + k < pre_roots) {
47+
roots[n + k] = res;
48+
}
49+
return res;
5050
}
51-
static const cvector roots;
52-
template< bool precalc = false, class ft = point>
53-
static auto root(size_t n, size_t k, ft &&arg) {
54-
if(n < pre_roots && !precalc) {
55-
return roots.get<complex<ft>>(n + k);
56-
} else {
57-
return complex<ft>::polar(1., arg);
51+
static size_t eval_arg(size_t n) {
52+
if(n < pre_roots && eval_args[n]) {
53+
return eval_args[n];
54+
} else if(n == 0) {
55+
return 0;
56+
}
57+
auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
58+
if(n < pre_roots) {
59+
eval_args[n] = res;
5860
}
61+
return res;
62+
}
63+
static point eval_point(size_t n) {
64+
if(n < pre_roots && evalp[n] != point{}) {
65+
return evalp[n];
66+
} else if(n == 0) {
67+
return point(1);
68+
}
69+
auto res = root(2 * std::bit_floor(n), eval_arg(n));
70+
if(n < pre_roots) {
71+
evalp[n] = res;
72+
}
73+
return res;
5974
}
60-
template<class pt = point, bool precalc = false>
6175
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
62-
ftype arg = std::numbers::pi / (ftype)n;
63-
size_t step = sizeof(pt) / sizeof(point);
64-
using ft = pt::value_type;
65-
auto k = [&]() {
66-
if constexpr(std::is_same_v<pt, point>) {
67-
return ft{};
68-
} else {
69-
return ft{[](auto i) {return i;}};
76+
auto step = root(n, 1);
77+
auto rt = point(1);
78+
for(size_t i = 0; i < m; i++) {
79+
if(i % threshold == 0) {
80+
rt = root(n / threshold, i / threshold);
7081
}
71-
}();
72-
for(size_t i = 0; i < m; i += step, k += (ftype)step) {
73-
callback(i, root<precalc>(n, i, arg * k));
82+
callback(i, rt);
83+
rt *= step;
84+
}
85+
}
86+
static void exec_on_evals(size_t n, auto &&callback) {
87+
for(size_t i = 0; i < n; i++) {
88+
callback(i, eval_point(i));
7489
}
7590
}
7691

7792
void ifft() {
78-
size_t n = size();
79-
for(size_t i = 1; i < n; i *= 2) {
80-
for(size_t j = 0; j < n; j += 2 * i) {
81-
auto butterfly = [&]<class pt>(size_t k, pt rt) {
82-
k += j;
83-
auto t = get<pt>(k + i) * conj(rt);
84-
set(k + i, get<pt>(k) - t);
85-
set(k, get<pt>(k) + t);
86-
};
87-
if(i < flen) {
88-
exec_on_roots<point>(i, i, butterfly);
89-
} else {
90-
exec_on_roots<vpoint>(i, i, butterfly);
93+
size_t n = this->size();
94+
for(size_t half = threshold; half <= n / 2; half *= 2) {
95+
exec_on_evals(n / (2 * half), [&](size_t k, point rt) {
96+
k *= 2 * half;
97+
for(size_t j = k; j < k + half; j++) {
98+
auto A = this->at(j) + this->at(j + half);
99+
auto B = this->at(j) - this->at(j + half);
100+
this->at(j) = A;
101+
this->at(j + half) = B * conj(rt);
91102
}
92-
}
103+
});
93104
}
94-
for(size_t k = 0; k < n; k += flen) {
95-
set(k, get<vpoint>(k) /= (ftype)n);
105+
point ni = point(int(threshold)) / point(int(n));
106+
for(auto &it: *this) {
107+
it *= ni;
96108
}
97109
}
98110
void fft() {
99-
size_t n = size();
100-
for(size_t i = n / 2; i >= 1; i /= 2) {
101-
for(size_t j = 0; j < n; j += 2 * i) {
102-
auto butterfly = [&]<class pt>(size_t k, pt rt) {
103-
k += j;
104-
auto A = get<pt>(k) + get<pt>(k + i);
105-
auto B = get<pt>(k) - get<pt>(k + i);
106-
set(k, A);
107-
set(k + i, B * rt);
108-
};
109-
if(i < flen) {
110-
exec_on_roots<point>(i, i, butterfly);
111-
} else {
112-
exec_on_roots<vpoint>(i, i, butterfly);
111+
size_t n = this->size();
112+
for(size_t half = n / 2; half >= threshold; half /= 2) {
113+
exec_on_evals(n / (2 * half), [&](size_t k, point rt) {
114+
k *= 2 * half;
115+
for(size_t j = k; j < k + half; j++) {
116+
auto t = this->at(j + half) * rt;
117+
this->at(j + half) = this->at(j) - t;
118+
this->at(j) += t;
113119
}
114-
}
120+
});
115121
}
116122
}
117123
};
118-
const cvector cvector::roots = []() {
119-
cvector res(pre_roots);
120-
for(size_t n = 1; n < res.size(); n *= 2) {
121-
auto propagate = [&](size_t k, auto rt) {
122-
res.set(n + k, rt);
123-
};
124-
if(n < flen) {
125-
res.exec_on_roots<point, true>(n, n, propagate);
126-
} else {
127-
res.exec_on_roots<vpoint, true>(n, n, propagate);
128-
}
129-
}
130-
return res;
131-
}();
124+
std::array<point, ftvec::pre_roots> ftvec::roots = {};
125+
std::array<point, ftvec::pre_roots> ftvec::evalp = {};
126+
std::array<size_t, ftvec::pre_roots> ftvec::eval_args = {};
132127
}
133128
#endif // CP_ALGO_MATH_CVECTOR_HPP

cp-algo/math/fft.hpp

+29-29
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
#define CP_ALGO_MATH_FFT_HPP
33
#include "../number_theory/modint.hpp"
44
#include "cvector.hpp"
5-
#include <ranges>
65
namespace cp_algo::math::fft {
76
template<typename base>
87
struct dft {
9-
cvector A;
8+
ftvec A;
109

1110
dft(std::vector<base> const& a, size_t n): A(n) {
1211
for(size_t i = 0; i < std::min(n, a.size()); i++) {
13-
A.set(i, a[i]);
12+
A[i] = a[i];
1413
}
1514
if(n) {
1615
A.fft();
@@ -27,30 +26,30 @@ namespace cp_algo::math::fft {
2726
A.ifft();
2827
std::vector<base> res(n);
2928
for(size_t k = 0; k < n; k++) {
30-
res[k] = A.get(k);
29+
res[k] = A[k];
3130
}
3231
return res;
3332
}
3433

3534
auto operator * (dft const& B) const {
3635
return dft(*this) *= B;
3736
}
38-
39-
point operator [](int i) const {return A.get(i);}
4037
};
4138

4239
template<modint_type base>
4340
struct dft<base> {
4441
int split;
45-
cvector A, B;
42+
ftvec A, B;
4643

4744
dft(auto const& a, size_t n): A(n), B(n) {
48-
split = int(std::sqrt(base::mod()));
49-
cvector::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
45+
n = size(A);
46+
split = int(std::sqrt(base::mod())) + 1;
47+
ftvec::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
5048
size_t ti = std::min(i, i - n);
51-
A.set(ti, A.get(ti) + ftype(a[i].rem() % split) * rt);
52-
B.set(ti, B.get(ti) + ftype(a[i].rem() / split) * rt);
53-
49+
auto rem = std::remainder(a[i].rem(), split);
50+
auto quo = (a[i].rem() - rem) / split;
51+
A[ti] += rem * rt;
52+
B[ti] += quo * rt;
5453
});
5554
if(n) {
5655
A.fft();
@@ -65,21 +64,26 @@ namespace cp_algo::math::fft {
6564
res = {};
6665
return;
6766
}
68-
for(size_t i = 0; i < n; i += flen) {
69-
auto tmp = A.vget(i) * D.vget(i) + B.vget(i) * C.vget(i);
70-
A.set(i, A.vget(i) * C.vget(i));
71-
B.set(i, B.vget(i) * D.vget(i));
72-
C.set(i, tmp);
67+
for(size_t i = 0; i < n; i += ftvec::threshold) {
68+
auto AC = ftvec::dot_block(i, A, C);
69+
auto AD = ftvec::dot_block(i, A, D);
70+
auto BC = ftvec::dot_block(i, B, C);
71+
auto BD = ftvec::dot_block(i, B, D);
72+
for(size_t j = 0; j < ftvec::threshold; j++) {
73+
A[i + j] = AC[j];
74+
C[i + j] = AD[j] + BC[j];
75+
B[i + j] = BD[j];
76+
}
7377
}
7478
A.ifft();
7579
B.ifft();
7680
C.ifft();
7781
auto splitsplit = (base(split) * split).rem();
78-
cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
82+
ftvec::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
7983
rt = conj(rt);
80-
auto Ai = A.get(i) * rt;
81-
auto Bi = B.get(i) * rt;
82-
auto Ci = C.get(i) * rt;
84+
auto Ai = A[i] * rt;
85+
auto Bi = B[i] * rt;
86+
auto Ci = C[i] * rt;
8387
int64_t A0 = llround(real(Ai));
8488
int64_t A1 = llround(real(Ci));
8589
int64_t A2 = llround(real(Bi));
@@ -97,7 +101,7 @@ namespace cp_algo::math::fft {
97101
mul(B.A, B.B, res, k);
98102
}
99103
void mul(auto const& B, auto& res, size_t k) {
100-
mul(cvector(B.A), B.B, res, k);
104+
mul(ftvec(B.A), B.B, res, k);
101105
}
102106
std::vector<base> operator *= (dft &B) {
103107
std::vector<base> res(2 * A.size());
@@ -112,8 +116,6 @@ namespace cp_algo::math::fft {
112116
auto operator * (dft const& B) const {
113117
return dft(*this) *= B;
114118
}
115-
116-
point operator [](int i) const {return A.get(i);}
117119
};
118120

119121
void mul_slow(auto &a, auto const& b, size_t k) {
@@ -135,17 +137,15 @@ namespace cp_algo::math::fft {
135137
if(!as || !bs) {
136138
return 0;
137139
}
138-
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
140+
return std::bit_ceil(as + bs - 1) / 2;
139141
}
140142
void mul_truncate(auto &a, auto const& b, size_t k) {
141143
using base = std::decay_t<decltype(a[0])>;
142-
if(std::min({k, size(a), size(b)}) < magic) {
144+
if(std::min({k, size(a), size(b)}) < 1) {
143145
mul_slow(a, b, k);
144146
return;
145147
}
146-
auto n = std::max(flen, std::bit_ceil(
147-
std::min(k, size(a)) + std::min(k, size(b)) - 1
148-
) / 2);
148+
auto n = std::bit_ceil(std::min(k, size(a)) + std::min(k, size(b)) - 1) / 2;
149149
a.resize(k);
150150
auto A = dft<base>(a, n);
151151
if(&a == &b) {

0 commit comments

Comments
 (0)