Skip to content

Commit 5786660

Browse files
committed
add convolution mod 2^64
1 parent 629da39 commit 5786660

File tree

6 files changed

+168
-13
lines changed

6 files changed

+168
-13
lines changed

cp-algo/graph/base.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ namespace cp_algo::graph {
4444
return std::views::iota(0, n());
4545
}
4646
auto edges_view() const {
47-
return std::views::iota(0, 2 * m()) | std::views::filter(
48-
[](edge_index e) {return !(e % 2);}
49-
);
47+
return std::views::iota(0, 2 * m()) | std::views::stride(2);
5048
}
5149
auto const& incidence_lists() const {return adj;}
5250
edge_t const& edge(edge_index e) const {return edges[e];}

cp-algo/math/common.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,13 @@ namespace cp_algo::math {
2929
T bpow(T const& x, auto n) {
3030
return bpow(x, n, T(1));
3131
}
32+
inline constexpr auto inv2(auto x) {
33+
assert(x % 2);
34+
std::make_unsigned_t<decltype(x)> y = 1;
35+
while(y * x != 1) {
36+
y *= 2 - x * y;
37+
}
38+
return y;
39+
}
3240
}
3341
#endif // CP_ALGO_MATH_COMMON_HPP

cp-algo/math/fft64.hpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#ifndef CP_ALGO_MATH_FFT64_HPP
2+
#define CP_ALGO_MATH_FFT64_HPP
3+
#include "../random/rng.hpp"
4+
#include "../math/common.hpp"
5+
#include "../math/cvector.hpp"
6+
7+
namespace cp_algo::math::fft {
8+
struct dft64 {
9+
std::vector<cp_algo::math::fft::cvector> cv;
10+
11+
static uint64_t factor, ifactor;
12+
static bool _init;
13+
14+
static void init() {
15+
if(_init) return;
16+
_init = true;
17+
factor = random::rng();
18+
if(factor % 2 == 0) {factor++;}
19+
ifactor = inv2(factor);
20+
}
21+
22+
dft64(auto const& a, size_t n): cv(4, n) {
23+
init();
24+
uint64_t cur = 1, step = bpow(factor, n);
25+
for(size_t i = 0; i < std::min(std::size(a), n); i++) {
26+
auto split = [&](size_t i, uint64_t mul) -> std::array<int16_t, 4> {
27+
uint64_t x = i < std::size(a) ? a[i] * mul : 0;
28+
std::array<int16_t, 4> res;
29+
for(int z = 0; z < 4; z++) {
30+
res[z] = int16_t(x);
31+
x = (x >> 16) + (res[z] < 0);
32+
}
33+
return res;
34+
};
35+
auto re = split(i, cur);
36+
auto im = split(n + i, cur * step);
37+
for(int z = 0; z < 4; z++) {
38+
real(cv[z].at(i))[i % 4] = re[z];
39+
imag(cv[z].at(i))[i % 4] = im[z];
40+
}
41+
cur *= factor;
42+
}
43+
checkpoint("dft64 init");
44+
for(auto &x: cv) {
45+
x.fft();
46+
}
47+
}
48+
49+
void dot(dft64 const& t) {
50+
size_t N = cv[0].size();
51+
cvector::exec_on_evals<1>(N / flen, [&](size_t k, point rt) {
52+
k *= flen;
53+
auto [A0x, A0y] = cv[0].at(k);
54+
auto [A1x, A1y] = cv[1].at(k);
55+
auto [A2x, A2y] = cv[2].at(k);
56+
auto [A3x, A3y] = cv[3].at(k);
57+
std::array B = {
58+
t.cv[0].at(k),
59+
t.cv[1].at(k),
60+
t.cv[2].at(k),
61+
t.cv[3].at(k)
62+
};
63+
64+
std::array<vpoint, 4> C = {vz, vz, vz, vz};
65+
for (size_t i = 0; i < flen; i++) {
66+
std::array A = {
67+
vpoint{vz + A0x[i], vz + A0y[i]},
68+
vpoint{vz + A1x[i], vz + A1y[i]},
69+
vpoint{vz + A2x[i], vz + A2y[i]},
70+
vpoint{vz + A3x[i], vz + A3y[i]}
71+
};
72+
for(size_t k = 0; k < 4; k++) {
73+
for(size_t i = 0; i <= k; i++) {
74+
C[k] += A[i] * B[k - i];
75+
}
76+
}
77+
for(size_t k = 0; k < 4; k++) {
78+
real(B[k]) = rotate_right(real(B[k]));
79+
imag(B[k]) = rotate_right(imag(B[k]));
80+
auto bx = real(B[k])[0], by = imag(B[k])[0];
81+
real(B[k])[0] = bx * real(rt) - by * imag(rt);
82+
imag(B[k])[0] = bx * imag(rt) + by * real(rt);
83+
}
84+
}
85+
cv[0].at(k) = C[0];
86+
cv[1].at(k) = C[1];
87+
cv[2].at(k) = C[2];
88+
cv[3].at(k) = C[3];
89+
});
90+
checkpoint("dot");
91+
for(auto &x: cv) {
92+
x.ifft();
93+
}
94+
}
95+
96+
void recover_mod(auto &res, size_t k) {
97+
size_t n = cv[0].size();
98+
uint64_t cur = 1, step = bpow(ifactor, n);
99+
for(size_t i = 0; i < std::min(k, n); i++) {
100+
std::array re = {real(cv[0].get(i)), real(cv[1].get(i)), real(cv[2].get(i)), real(cv[3].get(i))};
101+
std::array im = {imag(cv[0].get(i)), imag(cv[1].get(i)), imag(cv[2].get(i)), imag(cv[3].get(i))};
102+
auto set_i = [&](size_t i, auto &x, auto mul) {
103+
if (i >= k) return;
104+
res[i] = llround(x[0]) + (llround(x[1]) << 16) + (llround(x[2]) << 32) + (llround(x[3]) << 48);
105+
res[i] *= mul;
106+
};
107+
set_i(i, re, cur);
108+
set_i(n + i, im, cur * step);
109+
cur *= ifactor;
110+
}
111+
cp_algo::checkpoint("recover mod");
112+
}
113+
};
114+
uint64_t dft64::factor = 1, dft64::ifactor = 1;
115+
bool dft64::_init = false;
116+
117+
void conv64(auto& a, auto const& b) {
118+
size_t n = a.size(), m = b.size();
119+
size_t N = std::max(flen, std::bit_ceil(n + m - 1) / 2);
120+
dft64 A(a, N), B(b, N);
121+
A.dot(B);
122+
a.resize(n + m - 1);
123+
A.recover_mod(a, n + m - 1);
124+
}
125+
}
126+
#endif // CP_ALGO_MATH_FFT64_HPP

cp-algo/number_theory/modint.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,6 @@ namespace cp_algo::math {
111111
auto getr() const {return Base::r;}
112112
};
113113

114-
inline constexpr auto inv2(auto x) {
115-
assert(x % 2);
116-
std::make_unsigned_t<decltype(x)> y = 1;
117-
while(y * x != 1) {
118-
y *= 2 - x * y;
119-
}
120-
return y;
121-
}
122-
123114
template<typename Int = int64_t>
124115
struct dynamic_modint: modint_base<dynamic_modint<Int>, Int> {
125116
using Base = modint_base<dynamic_modint<Int>, Int>;

cp-algo/util/simd.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
namespace cp_algo {
88
template<typename T, size_t len>
99
using simd [[gnu::vector_size(len * sizeof(T))]] = T;
10+
using u32x8 = simd<uint32_t, 8>;
1011
using i64x4 = simd<int64_t, 4>;
1112
using u64x4 = simd<uint64_t, 4>;
12-
using u32x8 = simd<uint32_t, 8>;
1313
using i32x4 = simd<int32_t, 4>;
1414
using u32x4 = simd<uint32_t, 4>;
15+
using i16x4 = simd<int16_t, 4>;
1516
using dx4 = simd<double, 4>;
1617

1718
[[gnu::always_inline]] inline dx4 abs(dx4 a) {

verify/poly/convolution64.test.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// @brief Convolution (Mod $2^{64}$)
2+
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_2_64"
3+
#pragma GCC optimize("Ofast,unroll-loops")
4+
#define CP_ALGO_CHECKPOINT
5+
#include <bits/stdc++.h>
6+
#include "cp-algo/math/fft64.hpp"
7+
#include "blazingio/blazingio.min.hpp"
8+
9+
using namespace std;
10+
11+
void solve() {
12+
int n, m;
13+
cin >> n >> m;
14+
vector<uint64_t, cp_algo::big_alloc<uint64_t>> a(n), b(m);
15+
for(auto &x : a) cin >> x;
16+
for(auto &x : b) cin >> x;
17+
cp_algo::checkpoint("read");
18+
cp_algo::math::fft::conv64(a, b);
19+
for(auto x: a) {
20+
cout << uint64_t(x) << " ";
21+
}
22+
cp_algo::checkpoint("write");
23+
cp_algo::checkpoint<1>();
24+
}
25+
26+
signed main() {
27+
//freopen("input.txt", "r", stdin);
28+
ios::sync_with_stdio(0);
29+
cin.tie(0);
30+
solve();
31+
}

0 commit comments

Comments
 (0)