Skip to content

Commit c6a5e57

Browse files
committed
Factor out cvector
1 parent f63041a commit c6a5e57

File tree

3 files changed

+135
-135
lines changed

3 files changed

+135
-135
lines changed

cp-algo/math/cvector.hpp

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#ifndef CP_ALGO_MATH_CVECTOR_HPP
2+
#define CP_ALGO_MATH_CVECTOR_HPP
3+
#include "../util/complex.hpp"
4+
#include <experimental/simd>
5+
namespace cp_algo::math::fft {
6+
using ftype = double;
7+
using point = complex<ftype>;
8+
using vftype = std::experimental::native_simd<ftype>;
9+
using vpoint = complex<vftype>;
10+
static constexpr size_t flen = vftype::size();
11+
12+
struct cvector {
13+
static constexpr size_t pre_roots = 1 << 18;
14+
std::vector<vftype> x, y;
15+
cvector(size_t n) {
16+
n = std::max(flen, std::bit_ceil(n));
17+
x.resize(n / flen);
18+
y.resize(n / flen);
19+
}
20+
template<class pt = point>
21+
void set(size_t k, pt t) {
22+
if constexpr(std::is_same_v<pt, point>) {
23+
x[k / flen][k % flen] = real(t);
24+
y[k / flen][k % flen] = imag(t);
25+
} else {
26+
x[k / flen] = real(t);
27+
y[k / flen] = imag(t);
28+
}
29+
}
30+
template<class pt = point>
31+
pt get(size_t k) const {
32+
if constexpr(std::is_same_v<pt, point>) {
33+
return {x[k / flen][k % flen], y[k / flen][k % flen]};
34+
} else {
35+
return {x[k / flen], y[k / flen]};
36+
}
37+
}
38+
vpoint vget(size_t k) const {
39+
return get<vpoint>(k);
40+
}
41+
42+
size_t size() const {
43+
return flen * std::size(x);
44+
}
45+
void dot(cvector const& t) {
46+
size_t n = size();
47+
for(size_t k = 0; k < n; k += flen) {
48+
set(k, get<vpoint>(k) * t.get<vpoint>(k));
49+
}
50+
}
51+
static const cvector roots;
52+
template< bool precalc = false, class ft = point>
53+
static auto root(size_t n, size_t k, ft &&arg) {
54+
if(n < pre_roots && !precalc) {
55+
return roots.get<complex<ft>>(n + k);
56+
} else {
57+
return complex<ft>::polar(1., arg);
58+
}
59+
}
60+
template<class pt = point, bool precalc = false>
61+
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
62+
ftype arg = std::numbers::pi / (ftype)n;
63+
size_t step = sizeof(pt) / sizeof(point);
64+
using ft = pt::value_type;
65+
auto k = [&]() {
66+
if constexpr(std::is_same_v<pt, point>) {
67+
return ft{};
68+
} else {
69+
return ft{[](auto i) {return i;}};
70+
}
71+
}();
72+
for(size_t i = 0; i < m; i += step, k += (ftype)step) {
73+
callback(i, root<precalc>(n, i, arg * k));
74+
}
75+
}
76+
77+
void ifft() {
78+
size_t n = size();
79+
for(size_t i = 1; i < n; i *= 2) {
80+
for(size_t j = 0; j < n; j += 2 * i) {
81+
auto butterfly = [&]<class pt>(size_t k, pt rt) {
82+
k += j;
83+
auto t = get<pt>(k + i) * conj(rt);
84+
set(k + i, get<pt>(k) - t);
85+
set(k, get<pt>(k) + t);
86+
};
87+
if(i < flen) {
88+
exec_on_roots<point>(i, i, butterfly);
89+
} else {
90+
exec_on_roots<vpoint>(i, i, butterfly);
91+
}
92+
}
93+
}
94+
for(size_t k = 0; k < n; k += flen) {
95+
set(k, get<vpoint>(k) /= (ftype)n);
96+
}
97+
}
98+
void fft() {
99+
size_t n = size();
100+
for(size_t i = n / 2; i >= 1; i /= 2) {
101+
for(size_t j = 0; j < n; j += 2 * i) {
102+
auto butterfly = [&]<class pt>(size_t k, pt rt) {
103+
k += j;
104+
auto A = get<pt>(k) + get<pt>(k + i);
105+
auto B = get<pt>(k) - get<pt>(k + i);
106+
set(k, A);
107+
set(k + i, B * rt);
108+
};
109+
if(i < flen) {
110+
exec_on_roots<point>(i, i, butterfly);
111+
} else {
112+
exec_on_roots<vpoint>(i, i, butterfly);
113+
}
114+
}
115+
}
116+
}
117+
};
118+
const cvector cvector::roots = []() {
119+
cvector res(pre_roots);
120+
for(size_t n = 1; n < res.size(); n *= 2) {
121+
auto propagate = [&](size_t k, auto rt) {
122+
res.set(n + k, rt);
123+
};
124+
if(n < flen) {
125+
res.exec_on_roots<point, true>(n, n, propagate);
126+
} else {
127+
res.exec_on_roots<vpoint, true>(n, n, propagate);
128+
}
129+
}
130+
return res;
131+
}();
132+
}
133+
#endif // CP_ALGO_MATH_CVECTOR_HPP

cp-algo/math/fft.hpp

+1-134
Original file line numberDiff line numberDiff line change
@@ -1,142 +1,9 @@
11
#ifndef CP_ALGO_MATH_FFT_HPP
22
#define CP_ALGO_MATH_FFT_HPP
3-
#include "common.hpp"
43
#include "../number_theory/modint.hpp"
5-
#include "../util/complex.hpp"
6-
#include <algorithm>
7-
#include <cassert>
4+
#include "cvector.hpp"
85
#include <ranges>
9-
#include <vector>
10-
#include <bit>
11-
#include <experimental/simd>
126
namespace cp_algo::math::fft {
13-
using ftype = double;
14-
using point = complex<ftype>;
15-
using vftype = std::experimental::native_simd<ftype>;
16-
using vpoint = complex<vftype>;
17-
static constexpr size_t flen = vftype::size();
18-
19-
struct cvector {
20-
static constexpr size_t pre_roots = 1 << 18;
21-
std::vector<vftype> x, y;
22-
cvector(size_t n) {
23-
n = std::max(flen, std::bit_ceil(n));
24-
x.resize(n / flen);
25-
y.resize(n / flen);
26-
}
27-
template<class pt = point>
28-
void set(size_t k, pt t) {
29-
if constexpr(std::is_same_v<pt, point>) {
30-
x[k / flen][k % flen] = real(t);
31-
y[k / flen][k % flen] = imag(t);
32-
} else {
33-
x[k / flen] = real(t);
34-
y[k / flen] = imag(t);
35-
}
36-
}
37-
template<class pt = point>
38-
pt get(size_t k) const {
39-
if constexpr(std::is_same_v<pt, point>) {
40-
return {x[k / flen][k % flen], y[k / flen][k % flen]};
41-
} else {
42-
return {x[k / flen], y[k / flen]};
43-
}
44-
}
45-
vpoint vget(size_t k) const {
46-
return get<vpoint>(k);
47-
}
48-
49-
size_t size() const {
50-
return flen * std::size(x);
51-
}
52-
void dot(cvector const& t) {
53-
size_t n = size();
54-
for(size_t k = 0; k < n; k += flen) {
55-
set(k, get<vpoint>(k) * t.get<vpoint>(k));
56-
}
57-
}
58-
static const cvector roots;
59-
template< bool precalc = false, class ft = point>
60-
static auto root(size_t n, size_t k, ft &&arg) {
61-
if(n < pre_roots && !precalc) {
62-
return roots.get<complex<ft>>(n + k);
63-
} else {
64-
return complex<ft>::polar(1., arg);
65-
}
66-
}
67-
template<class pt = point, bool precalc = false>
68-
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
69-
ftype arg = std::numbers::pi / (ftype)n;
70-
size_t step = sizeof(pt) / sizeof(point);
71-
using ft = pt::value_type;
72-
auto k = [&]() {
73-
if constexpr(std::is_same_v<pt, point>) {
74-
return ft{};
75-
} else {
76-
return ft{[](auto i) {return i;}};
77-
}
78-
}();
79-
for(size_t i = 0; i < m; i += step, k += (ftype)step) {
80-
callback(i, root<precalc>(n, i, arg * k));
81-
}
82-
}
83-
84-
void ifft() {
85-
size_t n = size();
86-
for(size_t i = 1; i < n; i *= 2) {
87-
for(size_t j = 0; j < n; j += 2 * i) {
88-
auto butterfly = [&]<class pt>(size_t k, pt rt) {
89-
k += j;
90-
auto t = get<pt>(k + i) * conj(rt);
91-
set(k + i, get<pt>(k) - t);
92-
set(k, get<pt>(k) + t);
93-
};
94-
if(i < flen) {
95-
exec_on_roots<point>(i, i, butterfly);
96-
} else {
97-
exec_on_roots<vpoint>(i, i, butterfly);
98-
}
99-
}
100-
}
101-
for(size_t k = 0; k < n; k += flen) {
102-
set(k, get<vpoint>(k) /= (ftype)n);
103-
}
104-
}
105-
void fft() {
106-
size_t n = size();
107-
for(size_t i = n / 2; i >= 1; i /= 2) {
108-
for(size_t j = 0; j < n; j += 2 * i) {
109-
auto butterfly = [&]<class pt>(size_t k, pt rt) {
110-
k += j;
111-
auto A = get<pt>(k) + get<pt>(k + i);
112-
auto B = get<pt>(k) - get<pt>(k + i);
113-
set(k, A);
114-
set(k + i, B * rt);
115-
};
116-
if(i < flen) {
117-
exec_on_roots<point>(i, i, butterfly);
118-
} else {
119-
exec_on_roots<vpoint>(i, i, butterfly);
120-
}
121-
}
122-
}
123-
}
124-
};
125-
const cvector cvector::roots = []() {
126-
cvector res(pre_roots);
127-
for(size_t n = 1; n < res.size(); n *= 2) {
128-
auto propagate = [&](size_t k, auto rt) {
129-
res.set(n + k, rt);
130-
};
131-
if(n < flen) {
132-
res.exec_on_roots<point, true>(n, n, propagate);
133-
} else {
134-
res.exec_on_roots<vpoint, true>(n, n, propagate);
135-
}
136-
}
137-
return res;
138-
}();
139-
1407
template<typename base>
1418
struct dft {
1429
cvector A;

verify/poly/wildcard.test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// @brief Wildcard Pattern Matching
22
#define PROBLEM "https://judge.yosupo.jp/problem/wildcard_pattern_matching"
33
#pragma GCC optimize("Ofast,unroll-loops")
4-
#include "cp-algo/math/fft.hpp"
4+
#include "cp-algo/math/cvector.hpp"
55
#include "cp-algo/random/rng.hpp"
66
#include <bits/stdc++.h>
77

0 commit comments

Comments
 (0)