Skip to content

Commit 7e8bd88

Browse files
committed
conv_simple for multiplication
1 parent 1ec2ba2 commit 7e8bd88

File tree

4 files changed

+69
-18
lines changed

4 files changed

+69
-18
lines changed

cp-algo/math/bigint.hpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
#ifndef CP_ALGO_MATH_BIGINT_HPP
22
#define CP_ALGO_MATH_BIGINT_HPP
33
#include "../util/big_alloc.hpp"
4-
#include "../math/fft64.hpp"
4+
#include "../math/fft_simple.hpp"
55
#include <bits/stdc++.h>
66

77
namespace cp_algo::math {
88
enum base_v {
9-
x10 = uint64_t(1e18),
9+
x10 = uint64_t(1e16),
1010
x16 = uint64_t(1ull << 60)
1111
};
1212
template<base_v base = x10>
1313
struct bigint {
14-
static constexpr uint16_t digit_length = base == x10 ? 18 : 15;
14+
static constexpr uint16_t digit_length = base == x10 ? 16 : 15;
1515
static constexpr uint16_t sub_base = base == x10 ? 10 : 16;
16-
static constexpr uint32_t meta_base = base == x10 ? uint32_t(1e6) : uint32_t(1 << 20);
16+
static constexpr uint32_t meta_base = base == x10 ? uint32_t(1e4) : uint32_t(1 << 15);
1717
big_vector<uint64_t> digits;
1818
bool negative;
1919

@@ -122,24 +122,27 @@ namespace cp_algo::math {
122122
}
123123
void to_metabase() {
124124
auto N = ssize(digits);
125-
digits.resize(3 * N);
125+
digits.resize(4 * N);
126126
for (auto i = N - 1; i >= 0; i--) {
127127
uint64_t val = digits[i];
128-
digits[3 * i] = val % meta_base;
128+
digits[4 * i] = val % meta_base;
129129
val /= meta_base;
130-
digits[3 * i + 1] = val % meta_base;
130+
digits[4 * i + 1] = val % meta_base;
131131
val /= meta_base;
132-
digits[3 * i + 2] = val;
132+
digits[4 * i + 2] = val % meta_base;
133+
val /= meta_base;
134+
digits[4 * i + 3] = val;
133135
}
134136
}
135137
void from_metabase() {
136-
auto N = (ssize(digits) + 2) / 3;
137-
digits.resize(3 * N);
138+
auto N = (ssize(digits) + 3) / 4;
139+
digits.resize(4 * N);
138140
uint64_t carry = 0;
139141
for (int i = 0; i < N; i++) {
140-
__uint128_t val = digits[3 * i + 2];
141-
val = val * meta_base + digits[3 * i + 1];
142-
val = val * meta_base + digits[3 * i];
142+
__uint128_t val = digits[4 * i + 3];
143+
val = val * meta_base + digits[4 * i + 2];
144+
val = val * meta_base + digits[4 * i + 1];
145+
val = val * meta_base + digits[4 * i];
143146
val += carry;
144147
digits[i] = uint64_t(val % base);
145148
carry = uint64_t(val / base);
@@ -172,7 +175,7 @@ namespace cp_algo::math {
172175
}
173176
to_metabase();
174177
other.to_metabase();
175-
fft::conv64(digits, other.digits);
178+
fft::conv_simple(digits, other.digits);
176179
from_metabase();
177180
return normalize();
178181
}

cp-algo/math/fft_simple.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifndef CP_ALGO_MATH_FFT_SIMPLE_HPP
2+
#define CP_ALGO_MATH_FFT_SIMPLE_HPP
3+
#include "../random/rng.hpp"
4+
#include "../math/common.hpp"
5+
#include "../math/cvector.hpp"
6+
CP_ALGO_SIMD_PRAGMA_PUSH
7+
namespace cp_algo::math::fft {
8+
struct dft_simple {
9+
cp_algo::math::fft::cvector cv;
10+
11+
dft_simple(auto const& a, size_t n): cv(n) {
12+
for(size_t i = 0; i < std::min(std::size(a), n); i++) {
13+
real(cv.at(i))[i % 4] = ftype(a[i]);
14+
imag(cv.at(i))[i % 4] = ftype(i + n < std::size(a) ? a[i + n] : 0);
15+
}
16+
checkpoint("dft64 init");
17+
cv.fft();
18+
}
19+
20+
void dot(dft_simple const& t) {
21+
cv.dot(t.cv);
22+
}
23+
24+
void recover_mod(auto &res, size_t k) {
25+
cv.ifft();
26+
size_t n = cv.size();
27+
for(size_t i = 0; i < std::min(k, n); i++) {
28+
res[i] = llround(real(cv.get(i)));
29+
}
30+
for(size_t i = n; i < k; i++) {
31+
res[i] = llround(imag(cv.get(i - n)));
32+
}
33+
cp_algo::checkpoint("recover mod");
34+
}
35+
};
36+
37+
// Multiplies a and b, assuming perfect precision and no overflow
38+
void conv_simple(auto& a, auto const& b) {
39+
size_t n = a.size(), m = b.size();
40+
size_t N = std::max(flen, std::bit_ceil(n + m - 1) / 2);
41+
dft_simple A(a, N), B(b, N);
42+
A.dot(B);
43+
a.resize(n + m - 1);
44+
A.recover_mod(a, n + m - 1);
45+
}
46+
}
47+
#pragma GCC pop_options
48+
#endif // CP_ALGO_MATH_FFT_SIMPLE_HPP

verify/bigint/hex_multiplication.test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#pragma GCC optimize("O3,unroll-loops")
44
#include <bits/allocator.h>
55
#pragma GCC target("avx2")
6-
#define CP_ALGO_CHECKPOINT
6+
//#define CP_ALGO_CHECKPOINT
77
#include <iostream>
88
#include "blazingio/blazingio.min.hpp"
99
#include "cp-algo/math/bigint.hpp"
@@ -29,5 +29,5 @@ signed main() {
2929
while(t--) {
3030
solve();
3131
}
32-
cp_algo::checkpoint<1>();
32+
//cp_algo::checkpoint<1>();
3333
}

verify/bigint/multiplication.test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#pragma GCC optimize("O3,unroll-loops")
44
#include <bits/allocator.h>
55
#pragma GCC target("avx2")
6-
#define CP_ALGO_CHECKPOINT
6+
//#define CP_ALGO_CHECKPOINT
77
#include <iostream>
88
#include "blazingio/blazingio.min.hpp"
99
#include "cp-algo/math/bigint.hpp"
@@ -29,5 +29,5 @@ signed main() {
2929
while(t--) {
3030
solve();
3131
}
32-
cp_algo::checkpoint<1>();
32+
//cp_algo::checkpoint<1>();
3333
}

0 commit comments

Comments
 (0)