Skip to content

Commit f8c0434

Browse files
committed
Better sieve
1 parent df5c9c2 commit f8c0434

File tree

2 files changed

+138
-95
lines changed

2 files changed

+138
-95
lines changed

cp-algo/math/sieve.hpp

Lines changed: 131 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,80 +16,136 @@ namespace cp_algo::math {
1616
using cp_algo::structures::dynamic_bit_array;
1717
using cp_algo::structures::bit_array;
1818

19-
constexpr size_t base_threshold = 1 << 21;
20-
21-
constexpr auto to_ord(auto x) {
22-
return x / 2;
19+
constexpr uint32_t period = 210;
20+
constexpr uint32_t coprime = 48;
21+
constexpr auto coprime210 = [](auto x) {
22+
return x % 2 && x % 3 && x % 5 && x % 7;
23+
};
24+
25+
// Residues coprime to 210
26+
constexpr auto res210 = []() {
27+
std::array<uint8_t, coprime> res;
28+
int idx = 0;
29+
for(uint8_t i = 1; i < period; i += 2) {
30+
if (coprime210(i)) {
31+
res[idx++] = i;
32+
}
33+
}
34+
return res;
35+
}();
36+
37+
// Maps residue mod 210 to pre-upper_bound index in res210
38+
constexpr auto state210 = []() {
39+
std::array<uint8_t, period> state;
40+
uint8_t idx = 0;
41+
for(uint8_t i = 0; i < period; i++) {
42+
state[i] = idx;
43+
idx += coprime210(i);
44+
}
45+
return state;
46+
}();
47+
48+
// Add to reach next coprime residue
49+
constexpr auto add210 = []() {
50+
std::array<uint8_t, period> add;
51+
for(uint8_t i = 0; i < period; i++) {
52+
add[i] = 1;
53+
while (!coprime210(i + add[i])) {
54+
add[i]++;
55+
}
56+
}
57+
return add;
58+
}();
59+
60+
constexpr auto gap210 = []() {
61+
std::array<uint8_t, coprime> gap;
62+
for(uint8_t i = 0; i < coprime; i++) {
63+
gap[i] = add210[res210[i]];
64+
}
65+
return gap;
66+
}();
67+
68+
// Convert value to ordinal (index in compressed bit array)
69+
constexpr uint32_t to_ord(uint32_t x) {
70+
return (x / period) * coprime + state210[x % period];
2371
}
24-
constexpr auto to_val(auto x) {
25-
return 2 * x + 1;
72+
73+
// Convert ordinal to value
74+
constexpr uint32_t to_val(uint32_t x) {
75+
return (x / coprime) * period + res210[x % coprime];
2676
}
2777

28-
const auto base_prime_bits = []() {
29-
dynamic_bit_array prime(base_threshold);
78+
constexpr size_t sqrt_threshold = 1 << 15;
79+
constexpr auto sqrt_prime_bits = []() {
80+
const int size = sqrt_threshold / 4;
81+
bit_array<size> prime;
3082
prime.set_all();
3183
prime.reset(to_ord(1));
32-
for(size_t i = 3; to_ord(i * i) < base_threshold; i += 2) {
84+
for(uint32_t i = 11; to_ord(i * i) < size; i += add210[i % period]) {
3385
if (prime[to_ord(i)]) {
34-
for(size_t j = i * i; to_ord(j) < base_threshold; j += 2 * i) {
35-
prime.reset(to_ord(j));
86+
for(uint32_t k = i; to_ord(i * k) < size; k += add210[k % period]) {
87+
prime.reset(to_ord(i * k));
3688
}
3789
}
3890
}
3991
return prime;
4092
}();
4193

42-
const auto base_primes = []() {
43-
big_vector<uint32_t> primes;
44-
for(uint32_t i = 3; to_ord(i) < base_threshold; i += 2) {
45-
if (base_prime_bits[to_ord(i)]) {
46-
primes.push_back(i);
94+
constexpr size_t num_primes = []() {
95+
size_t cnt = 0;
96+
for(uint32_t i = 11; i < sqrt_threshold; i += add210[i % period]) {
97+
cnt += sqrt_prime_bits[to_ord(i)];
98+
}
99+
return cnt;
100+
}();
101+
constexpr auto sqrt_primes = []() {
102+
std::array<uint32_t, num_primes> primes;
103+
size_t j = 0;
104+
for(uint32_t i = 11; i < sqrt_threshold; i += add210[i % period]) {
105+
if (sqrt_prime_bits[to_ord(i)]) {
106+
primes[j++] = i;
47107
}
48108
}
49109
return primes;
50110
}();
51111

52-
constexpr size_t sqrt_threshold = 1 << 16;
53-
const auto sqrt_primes = std::span(
54-
base_primes.begin(),
55-
std::ranges::upper_bound(base_primes, sqrt_threshold)
56-
);
57-
58-
constexpr size_t max_wheel_size = std::min<size_t>(base_threshold, 1 << 21);
112+
constexpr size_t max_wheel_size = 1 << 21;
59113
struct wheel_t {
60114
dynamic_bit_array mask;
61115
uint32_t product;
62116
};
63117

64118
auto make_wheel(big_vector<uint32_t> primes, uint32_t product) {
65-
assert(product % (2 * dynamic_bit_array::width) == 0);
119+
assert(product % (period * dynamic_bit_array::width) == 0);
66120
wheel_t wheel;
67121
wheel.product = product;
68-
wheel.mask.resize(product / 2);
122+
wheel.mask.resize(product / period * coprime);
69123
wheel.mask.set_all();
70124
for(auto p: primes) {
71-
for(size_t j = to_ord(p); j < wheel.mask.size(); j += p) {
72-
wheel.mask.reset(j);
125+
for (uint32_t k = 1; p * k < product; k += add210[k % period]) {
126+
wheel.mask.reset(to_ord(p * k));
73127
}
74128
}
75129
return wheel;
76130
}
77131

78-
auto medium_primes = sqrt_primes;
132+
constexpr uint32_t wheel_threshold = 400;
133+
size_t medium_primes_begin;
79134
const auto wheels = []() {
80-
uint32_t product = 2 * dynamic_bit_array::width;
135+
uint32_t product = period * dynamic_bit_array::width;
81136
big_vector<uint32_t> current;
82137
big_vector<wheel_t> wheels;
83138
for(size_t i = 0; i < size(sqrt_primes); i++) {
84139
uint32_t p = sqrt_primes[i];
85140
if (product * p > max_wheel_size) {
86-
if (size(current) == 1) {
87-
medium_primes = sqrt_primes.subspan(i - 1);
88-
return wheels;
89-
}
90141
wheels.push_back(make_wheel(current, product));
91142
current = {p};
92-
product = 2 * dynamic_bit_array::width * p;
143+
product = period * dynamic_bit_array::width * p;
144+
if (product > max_wheel_size || p > wheel_threshold) {
145+
medium_primes_begin = i;
146+
checkpoint("make wheels");
147+
return wheels;
148+
}
93149
} else {
94150
current.push_back(p);
95151
product *= p;
@@ -98,6 +154,22 @@ namespace cp_algo::math {
98154
assert(false);
99155
}();
100156

157+
const auto [ord_step, step_sum] = []() {
158+
big_vector<std::array<uint32_t, 2 * coprime>> ord_steps(num_primes);
159+
big_vector<uint32_t> sums(num_primes);
160+
for (uint32_t i = 0; i < size(sqrt_primes); i++) {
161+
auto p = sqrt_primes[i];
162+
for(uint32_t j = 0; j < coprime; j++) {
163+
ord_steps[i][j] = to_ord(p * (res210[j] + gap210[j])) - to_ord(p * res210[j]);
164+
}
165+
sums[i] = std::ranges::fold_left(ord_steps[i], 0u, std::plus{});
166+
for(uint32_t j = 0; j < coprime; j++) {
167+
ord_steps[i][j + coprime] = ord_steps[i][j];
168+
}
169+
}
170+
return std::pair{ord_steps, sums};
171+
}();
172+
101173
void sieve_dense(auto &prime, uint32_t l, uint32_t r, wheel_t const& wheel) {
102174
if (l >= r) return;
103175
const auto width = (uint32_t)dynamic_bit_array::width;
@@ -112,81 +184,51 @@ namespace cp_algo::math {
112184
}
113185
}
114186

115-
constexpr auto add210 = []() {
116-
std::array<uint8_t, 210> add;
117-
auto good = [&](int x) {
118-
return x % 2 && x % 3 && x % 5 && x % 7;
119-
};
120-
for(int i = 0; i < 210; i++) {
121-
add[i] = 1;
122-
while (!good(i + add[i])) {
123-
add[i]++;
124-
}
125-
}
126-
return add;
127-
}();
128-
129-
constexpr uint8_t gap210[] = {
130-
5, 1, 2, 1, 2, 3, 1, 3,
131-
2, 1, 2, 3, 3, 1, 3, 2,
132-
1, 3, 2, 3, 4, 2, 1, 2,
133-
1, 2, 4, 3, 2, 3, 1, 2,
134-
3, 1, 3, 3, 2, 1, 2, 3,
135-
1, 3, 2, 1, 2, 1, 5, 1
136-
};
137-
138-
constexpr auto state210 = []() {
139-
std::array<uint8_t, 210> state;
140-
int idx = 0;
141-
for(int i = 0; i < 210; i++) {
142-
if (i % 2 && i % 3 && i % 5 && i % 7) {
143-
state[i] = uint8_t(idx++);
144-
} else {
145-
state[i] = -1;
187+
template <class BitArray>
188+
void sieve210(BitArray& prime, uint32_t l, uint32_t r, size_t i, int state) {
189+
while (l + step_sum[i] <= r) {
190+
for (size_t j = 0; j < coprime; j++) {
191+
prime.reset(l);
192+
l += ord_step[i][state++];
146193
}
194+
state -= coprime;
147195
}
148-
return state;
149-
}();
150-
151-
template <class BitArray>
152-
void sieve210(BitArray &prime, uint32_t l, uint32_t r, uint32_t p, uint8_t state) {
153-
if (l >= r) return;
154196
while (l < r) {
155197
prime.reset(l);
156-
l += p * gap210[state];
157-
state = state == 47 ? 0 : state + 1;
198+
l += ord_step[i][state++];
199+
state = state == coprime ? 0 : state;
158200
}
159201
}
160202

161203
// Primes smaller or equal than N
162-
dynamic_bit_array odd_sieve(uint32_t N) {
204+
dynamic_bit_array sieve210(uint32_t N) {
163205
N++;
164-
dynamic_bit_array prime(N / 2);
206+
dynamic_bit_array prime(to_ord(N));
165207
prime.set_all();
166-
for(size_t i = 0; i < std::min(prime.words, base_prime_bits.words); i++) {
167-
prime.word(i) = base_prime_bits.word(i);
168-
}
169-
cp_algo::checkpoint("init");
170-
static constexpr uint32_t dense_block = 1 << 24;
171-
for(uint32_t start = base_threshold; start < N; start += dense_block) {
172-
uint32_t r = std::min(start + dense_block, N);
208+
static constexpr uint32_t dense_block = 1 << 25;
209+
for(uint32_t start = 0; start < N; start += dense_block) {
210+
uint32_t r = std::min(start + dense_block, N);
173211
for(auto const& wheel: wheels) {
174212
auto l = start / wheel.product * wheel.product;
175213
sieve_dense(prime, to_ord(l), to_ord(r), wheel);
176214
}
177215
}
178-
cp_algo::checkpoint("dense sieve");
179-
static constexpr uint32_t sparse_block = 1 << 22;
180-
for(uint32_t start = base_threshold; start < N; start += sparse_block) {
216+
checkpoint("dense sieve");
217+
static constexpr uint32_t sparse_block = 1 << 24;
218+
for(uint32_t start = 0; start < N; start += sparse_block) {
181219
uint32_t r = std::min(start + sparse_block, N);
182-
for(auto p: medium_primes) {
220+
for(size_t i = medium_primes_begin; i < size(sqrt_primes); i++) {
221+
auto p = sqrt_primes[i];
183222
if(p * p >= r) break;
184223
auto k = std::max(start / p, p);
185-
if (state210[k % 210] == 0xFF) {k += add210[k % 210];}
186-
sieve210(prime, to_ord(k * p), to_ord(r), p, state210[k % 210]);
224+
if (!coprime210(k)) {k += add210[k % 210];}
225+
sieve210(prime, to_ord(p * k), to_ord(r), i, state210[k % 210]);
187226
}
188227
}
189-
cp_algo::checkpoint("sparse sieve");
228+
checkpoint("sparse sieve");
229+
for(size_t i = 0; i < std::min(prime.words, sqrt_prime_bits.words); i++) {
230+
prime.word(i) = sqrt_prime_bits.word(i);
231+
}
190232
return prime;
191233
}
192234
}

verify/number_theory/enumerate_primes.test.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@ using namespace cp_algo::math;
1515
void solve() {
1616
uint32_t N, A, B;
1717
cin >> N >> A >> B;
18-
auto primes = odd_sieve(N);
19-
auto cnt = count(primes) + (N >= 2);
18+
auto primes = sieve210(N);
19+
auto cnt = count(primes) + (N >= 2) + (N >= 3) + (N >= 5) + (N >= 7);
2020
cp_algo::checkpoint("count");
2121
auto X = cnt < B ? 0 : (cnt - B + A - 1) / A;
2222
cout << cnt << ' ' << X << endl;
23-
if (X) {
24-
if (B == 0) {
25-
cout << 2 << ' ';
23+
for(uint32_t p: {2u, 3u, 5u, 7u}) {
24+
if (B == 0 && X && p <= N) {
25+
cout << p << ' ';
26+
X--;
2627
}
2728
B = (B - 1 + A) % A;
2829
}
2930
for(size_t i = skip(primes, 0, B); i < primes.size(); i = skip(primes, i, A)) {
30-
cout << to_val(i) << ' ';
31+
cout << to_val(uint32_t(i)) << ' ';
3132
}
3233
cout << "\n";
3334
cp_algo::checkpoint("print");

0 commit comments

Comments
 (0)