Skip to content

Commit cb6bde9

Browse files
committed
support generic wheels
1 parent a8f4146 commit cb6bde9

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

cp-algo/math/sieve.hpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,74 +16,75 @@ namespace cp_algo::math {
1616
using cp_algo::structures::dynamic_bit_array;
1717
using cp_algo::structures::bit_array;
1818

19-
constexpr uint32_t period = 210;
20-
constexpr uint32_t coprime = 48;
21-
constexpr auto coprime210 = [](auto x) {
22-
return x % 2 && x % 3 && x % 5 && x % 7;
19+
constexpr auto wheel_primes = std::array{2u, 3u, 5u, 7u};
20+
constexpr uint32_t period = std::ranges::fold_left(wheel_primes, 1u, std::multiplies{});
21+
constexpr uint32_t coprime = std::ranges::fold_left(wheel_primes, 1u, [](auto a, auto b){ return a * (b - 1); });
22+
constexpr auto coprime_wheel = [](auto x) {
23+
return std::ranges::all_of(wheel_primes, [x](auto p){ return x % p; });
2324
};
2425

25-
// Residues coprime to 210
26-
constexpr auto res210 = []() {
26+
// Residues coprime to period
27+
constexpr auto res_wheel = []() {
2728
std::array<uint8_t, coprime> res;
2829
int idx = 0;
2930
for(uint8_t i = 1; i < period; i += 2) {
30-
if (coprime210(i)) {
31+
if (coprime_wheel(i)) {
3132
res[idx++] = i;
3233
}
3334
}
3435
return res;
3536
}();
3637

37-
// Maps residue mod 210 to pre-upper_bound index in res210
38-
constexpr auto state210 = []() {
38+
// Maps residue mod period to pre-upper_bound index in res_wheel
39+
constexpr auto state_wheel = []() {
3940
std::array<uint8_t, period> state;
4041
uint8_t idx = 0;
4142
for(uint8_t i = 0; i < period; i++) {
4243
state[i] = idx;
43-
idx += coprime210(i);
44+
idx += coprime_wheel(i);
4445
}
4546
return state;
4647
}();
4748

4849
// Add to reach next coprime residue
49-
constexpr auto add210 = []() {
50+
constexpr auto add_wheel = []() {
5051
std::array<uint8_t, period> add;
5152
for(uint8_t i = 0; i < period; i++) {
5253
add[i] = 1;
53-
while (!coprime210(i + add[i])) {
54+
while (!coprime_wheel(i + add[i])) {
5455
add[i]++;
5556
}
5657
}
5758
return add;
5859
}();
5960

60-
constexpr auto gap210 = []() {
61+
constexpr auto gap_wheel = []() {
6162
std::array<uint8_t, coprime> gap;
6263
for(uint8_t i = 0; i < coprime; i++) {
63-
gap[i] = add210[res210[i]];
64+
gap[i] = add_wheel[res_wheel[i]];
6465
}
6566
return gap;
6667
}();
6768

6869
// Convert value to ordinal (index in compressed bit array)
6970
constexpr uint32_t to_ord(uint32_t x) {
70-
return (x / period) * coprime + state210[x % period];
71+
return (x / period) * coprime + state_wheel[x % period];
7172
}
7273

7374
// Convert ordinal to value
7475
constexpr uint32_t to_val(uint32_t x) {
75-
return (x / coprime) * period + res210[x % coprime];
76+
return (x / coprime) * period + res_wheel[x % coprime];
7677
}
7778

78-
constexpr size_t sqrt_threshold = 1 << 16;
79+
constexpr size_t sqrt_threshold = 1 << 15;
7980
constexpr auto sqrt_prime_bits = []() {
80-
const int size = sqrt_threshold / 4;
81+
const int size = sqrt_threshold / 2;
8182
bit_array<size> prime;
8283
prime.set_all();
8384
prime.reset(to_ord(1));
84-
for(uint32_t i = 11; to_ord(i * i) < size; i += add210[i % period]) {
85+
for(uint32_t i = res_wheel[1]; to_ord(i * i) < size; i += add_wheel[i % period]) {
8586
if (prime[to_ord(i)]) {
86-
for(uint32_t k = i; to_ord(i * k) < size; k += add210[k % period]) {
87+
for(uint32_t k = i; to_ord(i * k) < size; k += add_wheel[k % period]) {
8788
prime.reset(to_ord(i * k));
8889
}
8990
}
@@ -93,15 +94,15 @@ namespace cp_algo::math {
9394

9495
constexpr size_t num_primes = []() {
9596
size_t cnt = 0;
96-
for(uint32_t i = 11; i < sqrt_threshold; i += add210[i % period]) {
97+
for(uint32_t i = res_wheel[1]; i < sqrt_threshold; i += add_wheel[i % period]) {
9798
cnt += sqrt_prime_bits[to_ord(i)];
9899
}
99100
return cnt;
100101
}();
101102
constexpr auto sqrt_primes = []() {
102103
std::array<uint32_t, num_primes> primes;
103104
size_t j = 0;
104-
for(uint32_t i = 11; i < sqrt_threshold; i += add210[i % period]) {
105+
for(uint32_t i = res_wheel[1]; i < sqrt_threshold; i += add_wheel[i % period]) {
105106
if (sqrt_prime_bits[to_ord(i)]) {
106107
primes[j++] = i;
107108
}
@@ -121,7 +122,7 @@ namespace cp_algo::math {
121122
wheel.mask.resize(product / period * coprime);
122123
wheel.mask.set_all();
123124
for(auto p: primes) {
124-
for (uint32_t k = 1; p * k < product; k += add210[k % period]) {
125+
for (uint32_t k = 1; p * k < product; k += add_wheel[k % period]) {
125126
wheel.mask.reset(to_ord(p * k));
126127
}
127128
}
@@ -150,15 +151,15 @@ namespace cp_algo::math {
150151
}
151152

152153
template <class BitArray>
153-
constexpr void sieve210(BitArray& prime, uint32_t l, uint32_t r, size_t i, int state) {
154+
constexpr void sieve_wheel(BitArray& prime, uint32_t l, uint32_t r, size_t i, int state) {
154155
static const auto ord_step = []() {
155156
big_vector<std::array<uint32_t, 2 * coprime>> ord_steps(num_primes);
156157
for (uint32_t i = 0; i < size(sqrt_primes); i++) {
157158
auto p = sqrt_primes[i];
158159
auto &ords = ord_steps[i];
159160
auto last = to_ord(p);
160161
for(uint32_t j = 0; j < coprime; j++) {
161-
auto next = to_ord(p * (res210[j] + gap210[j]));
162+
auto next = to_ord(p * (res_wheel[j] + gap_wheel[j]));
162163
ords[j] = ords[j + coprime] = next - last;
163164
last = next;
164165
}
@@ -182,21 +183,21 @@ namespace cp_algo::math {
182183
}
183184

184185
// Primes smaller or equal than N
185-
constexpr dynamic_bit_array sieve210(uint32_t N) {
186+
constexpr dynamic_bit_array sieve_wheel(uint32_t N) {
186187
N++;
187188
dynamic_bit_array prime(to_ord(N));
188189
prime.set_all();
189190
static const auto [wheels, medium_primes_begin] = []() {
190191
constexpr size_t max_wheel_size = 1 << 20;
191-
uint32_t product = period * dynamic_bit_array::width / 4;
192+
uint32_t product = period * dynamic_bit_array::width >> (size(wheel_primes) - 2);
192193
big_vector<uint32_t> current;
193194
big_vector<wheel_t> wheels;
194195
for(size_t i = 0; i < size(sqrt_primes); i++) {
195196
uint32_t p = sqrt_primes[i];
196197
if (product * p > max_wheel_size) {
197198
wheels.push_back(make_wheel(current, product));
198199
current = {p};
199-
product = period * dynamic_bit_array::width / 4 * p;
200+
product = (period * dynamic_bit_array::width >> (size(wheel_primes) - 2)) * p;
200201
if (product > max_wheel_size) {
201202
checkpoint("make wheels");
202203
return std::pair{wheels, i};
@@ -224,8 +225,8 @@ namespace cp_algo::math {
224225
auto p = sqrt_primes[i];
225226
if(p * p >= r) break;
226227
auto k = std::max(start / p, p);
227-
if (!coprime210(k)) {k += add210[k % 210];}
228-
sieve210(prime, to_ord(p * k), to_ord(r), i, state210[k % 210]);
228+
if (!coprime_wheel(k)) {k += add_wheel[k % period];}
229+
sieve_wheel(prime, to_ord(p * k), to_ord(r), i, state_wheel[k % period]);
229230
}
230231
}
231232
checkpoint("sparse sieve");

verify/number_theory/enumerate_primes.test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ using namespace cp_algo::math;
1515
void solve() {
1616
uint32_t N, A, B;
1717
cin >> N >> A >> B;
18-
auto primes = sieve210(N);
19-
auto cnt = count(primes) + (N >= 2) + (N >= 3) + (N >= 5) + (N >= 7);
18+
auto primes = sieve_wheel(N);
19+
auto cnt = count(primes) + ranges::fold_left(wheel_primes, 0u,
20+
[N](auto sum, auto p) {return sum + (N >= p); }
21+
);
2022
cp_algo::checkpoint("count");
2123
auto X = cnt < B ? 0 : (cnt - B + A - 1) / A;
2224
cout << cnt << ' ' << X << endl;
23-
for(uint32_t p: {2u, 3u, 5u, 7u}) {
25+
for(uint32_t p: wheel_primes) {
2426
if (B == 0 && X && p <= N) {
2527
cout << p << ' ';
2628
X--;

0 commit comments

Comments
 (0)