Skip to content

Commit 25faa7e

Browse files
committed
bigint multiplication tests
1 parent 90035f7 commit 25faa7e

File tree

3 files changed

+142
-32
lines changed

3 files changed

+142
-32
lines changed

cp-algo/math/bigint.hpp

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
#ifndef CP_ALGO_MATH_BIGINT_HPP
22
#define CP_ALGO_MATH_BIGINT_HPP
33
#include "../util/big_alloc.hpp"
4+
#include "../math/fft64.hpp"
45
#include <bits/stdc++.h>
56

67
namespace cp_algo::math {
78
enum base_v {
89
x10 = uint64_t(1e18),
9-
x16 = uint64_t(0)
10+
x16 = uint64_t(1ull << 60)
1011
};
1112
template<base_v base = x10>
1213
struct bigint {
14+
static constexpr uint16_t digit_length = base == x10 ? 18 : 15;
15+
static constexpr uint16_t sub_base = base == x10 ? 10 : 16;
16+
static constexpr uint32_t meta_base = base == x10 ? uint32_t(1e6) : uint32_t(1 << 20);
1317
big_vector<uint64_t> digits;
1418
bool negative;
1519

@@ -42,18 +46,9 @@ namespace cp_algo::math {
4246
size_t N = size(other.digits);
4347
size_t i = 0;
4448
for (; i < N; i++) {
45-
if constexpr (base == x10) {
46-
d_ptr[i] -= o_ptr[i] + carry;
47-
carry = d_ptr[i] >= base;
48-
d_ptr[i] += carry ? uint64_t(base) : 0;
49-
} else if constexpr (base == x16) {
50-
auto sub = o_ptr[i] + carry;
51-
auto new_carry = sub ? (d_ptr[i] < sub) : carry;
52-
d_ptr[i] -= sub;
53-
carry = new_carry;
54-
} else {
55-
static_assert(base == x10 || base == x16, "Unsupported base");
56-
}
49+
d_ptr[i] -= o_ptr[i] + carry;
50+
carry = d_ptr[i] >= base;
51+
d_ptr[i] += carry ? uint64_t(base) : 0;
5752
}
5853
if (carry) {
5954
N = size(digits);
@@ -83,18 +78,9 @@ namespace cp_algo::math {
8378
size_t N = size(other.digits);
8479
size_t i = 0;
8580
for (; i < N; i++) {
86-
if constexpr (base == x10) {
87-
d_ptr[i] += o_ptr[i] + carry;
88-
carry = d_ptr[i] >= base;
89-
d_ptr[i] -= carry ? uint64_t(base) : 0;
90-
} else if constexpr (base == x16) {
91-
auto add = o_ptr[i] + carry;
92-
auto new_carry = add ? (d_ptr[i] >= -add) : carry;
93-
d_ptr[i] += add;
94-
carry = new_carry;
95-
} else {
96-
static_assert(base == x10 || base == x16, "Unsupported base");
97-
}
81+
d_ptr[i] += o_ptr[i] + carry;
82+
carry = d_ptr[i] >= base;
83+
d_ptr[i] -= carry ? uint64_t(base) : 0;
9884
}
9985
if (carry) {
10086
N = size(digits);
@@ -117,8 +103,6 @@ namespace cp_algo::math {
117103
}
118104
size_t len = size(s);
119105
assert(len > 0);
120-
constexpr auto digit_length = base == x10 ? 18 : base == x16 ? 16 : 0;
121-
constexpr auto sub_base = base == x10 ? 10 : base == x16 ? 16 : 0;
122106
size_t num_digits = (len + digit_length - 1) / digit_length;
123107
digits.resize(num_digits);
124108
size_t i = len;
@@ -136,6 +120,68 @@ namespace cp_algo::math {
136120
bigint operator - (const bigint& other) const {
137121
return bigint(*this) -= other;
138122
}
123+
void to_metabase() {
124+
auto N = ssize(digits);
125+
digits.resize(3 * N);
126+
for (auto i = N - 1; i >= 0; i--) {
127+
uint64_t val = digits[i];
128+
digits[3 * i] = val % meta_base;
129+
val /= meta_base;
130+
digits[3 * i + 1] = val % meta_base;
131+
val /= meta_base;
132+
digits[3 * i + 2] = val;
133+
}
134+
}
135+
void from_metabase() {
136+
auto N = (ssize(digits) + 2) / 3;
137+
digits.resize(3 * N);
138+
uint64_t carry = 0;
139+
for (int i = 0; i < N; i++) {
140+
__uint128_t val = digits[3 * i + 2];
141+
val = val * meta_base + digits[3 * i + 1];
142+
val = val * meta_base + digits[3 * i];
143+
val += carry;
144+
digits[i] = uint64_t(val % base);
145+
carry = uint64_t(val / base);
146+
}
147+
digits.resize(N);
148+
while (carry) {
149+
digits.push_back(carry % base);
150+
carry /= base;
151+
}
152+
}
153+
bigint& mul_inplace(auto &&other) {
154+
size_t n = size(digits);
155+
size_t m = size(other.digits);
156+
negative ^= other.negative;
157+
if (std::min(n, m) < 128) {
158+
big_vector<uint64_t> result(n + m);
159+
for (size_t i = 0; i < n; i++) {
160+
uint64_t carry = 0;
161+
for (size_t j = 0; j < m || carry; j++) {
162+
__uint128_t cur = result[i + j] + carry;
163+
if (j < m) {
164+
cur += __uint128_t(digits[i]) * other.digits[j];
165+
}
166+
result[i + j] = uint64_t(cur % base);
167+
carry = uint64_t(cur / base);
168+
}
169+
}
170+
digits = std::move(result);
171+
return normalize();
172+
}
173+
to_metabase();
174+
other.to_metabase();
175+
fft::conv64(digits, other.digits);
176+
from_metabase();
177+
return normalize();
178+
}
179+
bigint& operator *= (bigint const& other) {
180+
return mul_inplace(bigint(other));
181+
}
182+
bigint operator * (const bigint& other) const {
183+
return bigint(*this).mul_inplace(bigint(other));
184+
}
139185
};
140186

141187
template<base_v base>
@@ -154,21 +200,19 @@ namespace cp_algo::math {
154200
if (empty(x.digits)) {
155201
return out << '0';
156202
}
157-
constexpr auto digit_length = base == x10 ? 18 : base == x16 ? 16 : 0;
158-
constexpr auto sub_base = base == x10 ? 10 : base == x16 ? 16 : 0;
159203
char buf[20];
160-
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), x.digits.back(), sub_base);
204+
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), x.digits.back(), bigint<base>::sub_base);
161205
if constexpr (base == x16) {
162206
std::ranges::transform(buf, buf, toupper);
163207
}
164208
out << std::string_view(buf, ptr - buf);
165209
for (auto d: x.digits | std::views::reverse | std::views::drop(1)) {
166-
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), d, sub_base);
210+
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), d, bigint<base>::sub_base);
167211
if constexpr (base == x16) {
168212
std::ranges::transform(buf, buf, toupper);
169213
}
170214
auto len = ptr - buf;
171-
out << std::string(digit_length - len, '0');
215+
out << std::string(bigint<base>::digit_length - len, '0');
172216
out << std::string_view(buf, len);
173217
}
174218
return out;
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// @brief Multiplication of Hex Big Integers
2+
#define PROBLEM "https://judge.yosupo.jp/problem/multiplication_of_hex_big_integers"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
#define CP_ALGO_CHECKPOINT
7+
#include <iostream>
8+
#include "blazingio/blazingio.min.hpp"
9+
#include "cp-algo/math/bigint.hpp"
10+
#include "cp-algo/util/checkpoint.hpp"
11+
#include <bits/stdc++.h>
12+
13+
using namespace std;
14+
using namespace cp_algo::math;
15+
16+
void solve() {
17+
bigint<x16> a, b;
18+
cin >> a >> b;
19+
a.mul_inplace(b);
20+
cout << a << "\n";
21+
}
22+
23+
signed main() {
24+
//freopen("input.txt", "r", stdin);
25+
ios::sync_with_stdio(0);
26+
cin.tie(0);
27+
int t = 1;
28+
cin >> t;
29+
while(t--) {
30+
solve();
31+
}
32+
cp_algo::checkpoint<1>();
33+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// @brief Multiplication of Big Integers
2+
#define PROBLEM "https://judge.yosupo.jp/problem/multiplication_of_big_integers"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
#define CP_ALGO_CHECKPOINT
7+
#include <iostream>
8+
#include "blazingio/blazingio.min.hpp"
9+
#include "cp-algo/math/bigint.hpp"
10+
#include "cp-algo/util/checkpoint.hpp"
11+
#include <bits/stdc++.h>
12+
13+
using namespace std;
14+
using namespace cp_algo::math;
15+
16+
void solve() {
17+
bigint a, b;
18+
cin >> a >> b;
19+
a.mul_inplace(b);
20+
cout << a << "\n";
21+
}
22+
23+
signed main() {
24+
//freopen("input.txt", "r", stdin);
25+
ios::sync_with_stdio(0);
26+
cin.tie(0);
27+
int t = 1;
28+
cin >> t;
29+
while(t--) {
30+
solve();
31+
}
32+
cp_algo::checkpoint<1>();
33+
}

0 commit comments

Comments
 (0)