Skip to content

Commit ffd4bdd

Browse files
committedNov 20, 2024·
Use native_simd instead of GCC vectors, fix -Wconversion warnings
1 parent 6cda7ab commit ffd4bdd

File tree

5 files changed

+25
-29
lines changed

5 files changed

+25
-29
lines changed
 

‎.verify-helper/config.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[[languages.cpp.environments]]
22
CXX = "g++"
3-
CXXFLAGS = ["-std=c++23", "-Wall", "-Wextra", "-pedantic", "-Werror", "-O2", "-march=native"]
3+
CXXFLAGS = ["-std=c++23", "-Wall", "-Wextra", "-Wconversion", "-Werror", "-pedantic", "-O2", "-march=native"]

‎cp-algo/math/fft.hpp

+15-19
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,15 @@
88
#include <ranges>
99
#include <vector>
1010
#include <bit>
11+
#include <experimental/simd>
1112

1213
namespace cp_algo::math::fft {
1314
using ftype = double;
14-
static constexpr size_t bytes = 32;
15-
static constexpr size_t flen = bytes / sizeof(ftype);
1615
using point = std::complex<ftype>;
17-
using vftype [[gnu::vector_size(bytes)]] = ftype;
16+
using vftype = std::experimental::native_simd<ftype>;
1817
using vpoint = std::complex<vftype>;
18+
static constexpr size_t flen = vftype::size();
1919

20-
#define WITH_IV(...) \
21-
[&]<size_t ... i>(std::index_sequence<i...>) { \
22-
return __VA_ARGS__; \
23-
}(std::make_index_sequence<flen>());
2420

2521
template<typename ft>
2622
constexpr ft to_ft(auto x) {
@@ -76,12 +72,12 @@ namespace cp_algo::math::fft {
7672
if(n < pre_roots) {
7773
return roots.get<pt>(n + k);
7874
} else {
79-
auto arg = std::numbers::pi / n;
75+
auto arg = std::numbers::pi / (ftype)n;
8076
if constexpr(std::is_same_v<pt, point>) {
81-
return {cos(k * arg), sin(k * arg)};
77+
return {(ftype)cos(k * arg), (ftype)sin(k * arg)};
8278
} else {
83-
return WITH_IV(pt{vftype{cos((k + i) * arg)...},
84-
vftype{sin((k + i) * arg)...}});
79+
return pt{vftype{[&](auto i) {return cos(ftype(k + i) * arg);}},
80+
vftype{[&](auto i) {return sin(ftype(k + i) * arg);}}};
8581
}
8682
}
8783
}
@@ -118,7 +114,7 @@ namespace cp_algo::math::fft {
118114
}
119115
}
120116
for(size_t k = 0; k < n; k += flen) {
121-
set(k, get<vpoint>(k) /= to_pt<vpoint>(n));
117+
set(k, get<vpoint>(k) /= to_pt<vpoint>((ftype)n));
122118
}
123119
}
124120
void fft() {
@@ -144,11 +140,11 @@ namespace cp_algo::math::fft {
144140
const cvector cvector::roots = []() {
145141
cvector res(pre_roots);
146142
for(size_t n = 1; n < res.size(); n *= 2) {
147-
auto base = std::polar(1., std::numbers::pi / n);
143+
auto base = std::polar(1., std::numbers::pi / (ftype)n);
148144
point cur = 1;
149145
for(size_t k = 0; k < n; k++) {
150146
if((k & 15) == 0) {
151-
cur = std::polar(1., std::numbers::pi * k / n);
147+
cur = std::polar(1., std::numbers::pi * (ftype)k / (ftype)n);
152148
}
153149
res.set(n + k, cur);
154150
cur *= base;
@@ -198,7 +194,7 @@ namespace cp_algo::math::fft {
198194
cvector A, B;
199195

200196
dft(auto const& a, size_t n): A(n), B(n) {
201-
split = std::sqrt(base::mod());
197+
split = int(std::sqrt(base::mod()));
202198
cvector::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
203199
size_t ti = std::min(i, i - n);
204200
A.set(ti, A.get(ti) + ftype(a[i].rem() % split) * rt);
@@ -273,12 +269,12 @@ namespace cp_algo::math::fft {
273269
if(empty(a) || empty(b)) {
274270
a.clear();
275271
} else {
276-
int n = std::min(k, size(a));
277-
int m = std::min(k, size(b));
272+
size_t n = std::min(k, size(a));
273+
size_t m = std::min(k, size(b));
278274
a.resize(k);
279-
for(int j = k - 1; j >= 0; j--) {
275+
for(int j = int(k - 1); j >= 0; j--) {
280276
a[j] *= b[0];
281-
for(int i = std::max(j - n, 0) + 1; i < std::min(j + 1, m); i++) {
277+
for(size_t i = std::max<size_t>(j - n, 0) + 1; i < std::min<size_t>(j + 1, m); i++) {
282278
a[j] += a[j - i] * b[i];
283279
}
284280
}

‎cp-algo/math/poly.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ namespace cp_algo::math {
299299
if(is_zero()) {
300300
return k ? *this : poly_t(1);
301301
}
302-
int i = trailing_xk();
302+
size_t i = trailing_xk();
303303
if(i > 0) {
304-
return k >= int64_t(n + i - 1) / i ? poly_t(T(0)) : div_xk(i).pow(k, n - i * k).mul_xk(i * k);
304+
return k >= int64_t(n + i - 1) / (int64_t)i ? poly_t(T(0)) : div_xk(i).pow(k, n - i * k).mul_xk(i * k);
305305
}
306306
if(std::min(deg(), (int)n) <= magic) {
307307
return pow_dn(k, n);
@@ -319,7 +319,7 @@ namespace cp_algo::math {
319319
if(is_zero()) {
320320
return *this;
321321
}
322-
int i = trailing_xk();
322+
size_t i = trailing_xk();
323323
if(i % 2) {
324324
return std::nullopt;
325325
} else if(i > 0) {

‎cp-algo/math/poly/impl/div.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ namespace cp_algo::math::poly::impl {
8282
auto [q0, q1] = q.bisect();
8383
auto qq = q0 * q0 - (q1 * q1).mul_xk_inplace(1);
8484
inv_inplace(qq, k / 2 - q.deg() / 2, (n + 1) / 2 + q.deg() / 2);
85-
int N = fft::com_size(size(q0.a), size(qq.a));
85+
size_t N = fft::com_size(size(q0.a), size(qq.a));
8686
auto q0f = fft::dft<base>(q0.a, N);
8787
auto q1f = fft::dft<base>(q1.a, N);
8888
auto qqf = fft::dft<base>(qq.a, N);
@@ -109,7 +109,7 @@ namespace cp_algo::math::poly::impl {
109109
// Q(-x) = P0(x^2) + xP1(x^2)
110110
auto [q0, q1] = p.bisect(n);
111111

112-
int N = fft::com_size(size(q0.a), (n + 1) / 2);
112+
size_t N = fft::com_size(size(q0.a), (n + 1) / 2);
113113

114114
auto q0f = fft::dft<base>(q0.a, N);
115115
auto q1f = fft::dft<base>(q1.a, N);

‎cp-algo/number_theory/modint.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ namespace cp_algo::math {
3131
}
3232
static UInt m_reduce(UInt2 ab) {
3333
if(mod() % 2 == 0) [[unlikely]] {
34-
return ab % mod();
34+
return UInt(ab % mod());
3535
} else {
36-
UInt m = ab * imod();
37-
return (ab + (UInt2)m * mod()) >> bits;
36+
UInt2 m = (UInt)ab * imod();
37+
return UInt((ab + m * mod()) >> bits);
3838
}
3939
}
4040
static UInt m_transform(UInt a) {
@@ -45,7 +45,7 @@ namespace cp_algo::math {
4545
}
4646
}
4747
modint_base(): r(0) {}
48-
modint_base(Int2 rr): r(rr % mod()) {
48+
modint_base(Int2 rr): r(UInt(rr % mod())) {
4949
r = std::min(r, r + mod());
5050
r = m_transform(r);
5151
}

0 commit comments

Comments
 (0)
Please sign in to comment.