11#ifndef CP_ALGO_MATH_BIGINT_HPP
22#define CP_ALGO_MATH_BIGINT_HPP
33#include " ../util/big_alloc.hpp"
4+ #include " ../math/fft64.hpp"
45#include < bits/stdc++.h>
56
67namespace cp_algo ::math {
78 enum base_v {
89 x10 = uint64_t (1e18 ),
9- x16 = uint64_t (0 )
10+ x16 = uint64_t (1ull << 60 )
1011 };
1112 template <base_v base = x10>
1213 struct bigint {
14+ static constexpr uint16_t digit_length = base == x10 ? 18 : 15 ;
15+ static constexpr uint16_t sub_base = base == x10 ? 10 : 16 ;
16+ static constexpr uint32_t meta_base = base == x10 ? uint32_t (1e6 ) : uint32_t (1 << 20 );
1317 big_vector<uint64_t > digits;
1418 bool negative;
1519
@@ -42,18 +46,9 @@ namespace cp_algo::math {
4246 size_t N = size (other.digits );
4347 size_t i = 0 ;
4448 for (; i < N; i++) {
45- if constexpr (base == x10) {
46- d_ptr[i] -= o_ptr[i] + carry;
47- carry = d_ptr[i] >= base;
48- d_ptr[i] += carry ? uint64_t (base) : 0 ;
49- } else if constexpr (base == x16) {
50- auto sub = o_ptr[i] + carry;
51- auto new_carry = sub ? (d_ptr[i] < sub) : carry;
52- d_ptr[i] -= sub;
53- carry = new_carry;
54- } else {
55- static_assert (base == x10 || base == x16, " Unsupported base" );
56- }
49+ d_ptr[i] -= o_ptr[i] + carry;
50+ carry = d_ptr[i] >= base;
51+ d_ptr[i] += carry ? uint64_t (base) : 0 ;
5752 }
5853 if (carry) {
5954 N = size (digits);
@@ -83,18 +78,9 @@ namespace cp_algo::math {
8378 size_t N = size (other.digits );
8479 size_t i = 0 ;
8580 for (; i < N; i++) {
86- if constexpr (base == x10) {
87- d_ptr[i] += o_ptr[i] + carry;
88- carry = d_ptr[i] >= base;
89- d_ptr[i] -= carry ? uint64_t (base) : 0 ;
90- } else if constexpr (base == x16) {
91- auto add = o_ptr[i] + carry;
92- auto new_carry = add ? (d_ptr[i] >= -add) : carry;
93- d_ptr[i] += add;
94- carry = new_carry;
95- } else {
96- static_assert (base == x10 || base == x16, " Unsupported base" );
97- }
81+ d_ptr[i] += o_ptr[i] + carry;
82+ carry = d_ptr[i] >= base;
83+ d_ptr[i] -= carry ? uint64_t (base) : 0 ;
9884 }
9985 if (carry) {
10086 N = size (digits);
@@ -117,8 +103,6 @@ namespace cp_algo::math {
117103 }
118104 size_t len = size (s);
119105 assert (len > 0 );
120- constexpr auto digit_length = base == x10 ? 18 : base == x16 ? 16 : 0 ;
121- constexpr auto sub_base = base == x10 ? 10 : base == x16 ? 16 : 0 ;
122106 size_t num_digits = (len + digit_length - 1 ) / digit_length;
123107 digits.resize (num_digits);
124108 size_t i = len;
@@ -136,6 +120,68 @@ namespace cp_algo::math {
136120 bigint operator - (const bigint& other) const {
137121 return bigint (*this ) -= other;
138122 }
123+ void to_metabase () {
124+ auto N = ssize (digits);
125+ digits.resize (3 * N);
126+ for (auto i = N - 1 ; i >= 0 ; i--) {
127+ uint64_t val = digits[i];
128+ digits[3 * i] = val % meta_base;
129+ val /= meta_base;
130+ digits[3 * i + 1 ] = val % meta_base;
131+ val /= meta_base;
132+ digits[3 * i + 2 ] = val;
133+ }
134+ }
135+ void from_metabase () {
136+ auto N = (ssize (digits) + 2 ) / 3 ;
137+ digits.resize (3 * N);
138+ uint64_t carry = 0 ;
139+ for (int i = 0 ; i < N; i++) {
140+ __uint128_t val = digits[3 * i + 2 ];
141+ val = val * meta_base + digits[3 * i + 1 ];
142+ val = val * meta_base + digits[3 * i];
143+ val += carry;
144+ digits[i] = uint64_t (val % base);
145+ carry = uint64_t (val / base);
146+ }
147+ digits.resize (N);
148+ while (carry) {
149+ digits.push_back (carry % base);
150+ carry /= base;
151+ }
152+ }
153+ bigint& mul_inplace (auto &&other) {
154+ size_t n = size (digits);
155+ size_t m = size (other.digits );
156+ negative ^= other.negative ;
157+ if (std::min (n, m) < 128 ) {
158+ big_vector<uint64_t > result (n + m);
159+ for (size_t i = 0 ; i < n; i++) {
160+ uint64_t carry = 0 ;
161+ for (size_t j = 0 ; j < m || carry; j++) {
162+ __uint128_t cur = result[i + j] + carry;
163+ if (j < m) {
164+ cur += __uint128_t (digits[i]) * other.digits [j];
165+ }
166+ result[i + j] = uint64_t (cur % base);
167+ carry = uint64_t (cur / base);
168+ }
169+ }
170+ digits = std::move (result);
171+ return normalize ();
172+ }
173+ to_metabase ();
174+ other.to_metabase ();
175+ fft::conv64 (digits, other.digits );
176+ from_metabase ();
177+ return normalize ();
178+ }
179+ bigint& operator *= (bigint const & other) {
180+ return mul_inplace (bigint (other));
181+ }
182+ bigint operator * (const bigint& other) const {
183+ return bigint (*this ).mul_inplace (bigint (other));
184+ }
139185 };
140186
141187 template <base_v base>
@@ -154,21 +200,19 @@ namespace cp_algo::math {
154200 if (empty (x.digits )) {
155201 return out << ' 0' ;
156202 }
157- constexpr auto digit_length = base == x10 ? 18 : base == x16 ? 16 : 0 ;
158- constexpr auto sub_base = base == x10 ? 10 : base == x16 ? 16 : 0 ;
159203 char buf[20 ];
160- auto [ptr, ec] = std::to_chars (buf, buf + sizeof (buf), x.digits .back (), sub_base);
204+ auto [ptr, ec] = std::to_chars (buf, buf + sizeof (buf), x.digits .back (), bigint<base>:: sub_base);
161205 if constexpr (base == x16) {
162206 std::ranges::transform (buf, buf, toupper);
163207 }
164208 out << std::string_view (buf, ptr - buf);
165209 for (auto d: x.digits | std::views::reverse | std::views::drop (1 )) {
166- auto [ptr, ec] = std::to_chars (buf, buf + sizeof (buf), d, sub_base);
210+ auto [ptr, ec] = std::to_chars (buf, buf + sizeof (buf), d, bigint<base>:: sub_base);
167211 if constexpr (base == x16) {
168212 std::ranges::transform (buf, buf, toupper);
169213 }
170214 auto len = ptr - buf;
171- out << std::string (digit_length - len, ' 0' );
215+ out << std::string (bigint<base>:: digit_length - len, ' 0' );
172216 out << std::string_view (buf, len);
173217 }
174218 return out;
0 commit comments