@@ -16,74 +16,75 @@ namespace cp_algo::math {
1616 using cp_algo::structures::dynamic_bit_array;
1717 using cp_algo::structures::bit_array;
1818
19- constexpr uint32_t period = 210 ;
20- constexpr uint32_t coprime = 48 ;
21- constexpr auto coprime210 = [](auto x) {
22- return x % 2 && x % 3 && x % 5 && x % 7 ;
19+ constexpr auto wheel_primes = std::array{2u , 3u , 5u , 7u };
20+ constexpr uint32_t period = std::ranges::fold_left(wheel_primes, 1u , std::multiplies{});
21+ constexpr uint32_t coprime = std::ranges::fold_left(wheel_primes, 1u , [](auto a, auto b){ return a * (b - 1 ); });
22+ constexpr auto coprime_wheel = [](auto x) {
23+ return std::ranges::all_of (wheel_primes, [x](auto p){ return x % p; });
2324 };
2425
25- // Residues coprime to 210
26- constexpr auto res210 = []() {
26+ // Residues coprime to period
27+ constexpr auto res_wheel = []() {
2728 std::array<uint8_t , coprime> res;
2829 int idx = 0 ;
2930 for (uint8_t i = 1 ; i < period; i += 2 ) {
30- if (coprime210 (i)) {
31+ if (coprime_wheel (i)) {
3132 res[idx++] = i;
3233 }
3334 }
3435 return res;
3536 }();
3637
37- // Maps residue mod 210 to pre-upper_bound index in res210
38- constexpr auto state210 = []() {
38+ // Maps residue mod period to pre-upper_bound index in res_wheel
39+ constexpr auto state_wheel = []() {
3940 std::array<uint8_t , period> state;
4041 uint8_t idx = 0 ;
4142 for (uint8_t i = 0 ; i < period; i++) {
4243 state[i] = idx;
43- idx += coprime210 (i);
44+ idx += coprime_wheel (i);
4445 }
4546 return state;
4647 }();
4748
4849 // Add to reach next coprime residue
49- constexpr auto add210 = []() {
50+ constexpr auto add_wheel = []() {
5051 std::array<uint8_t , period> add;
5152 for (uint8_t i = 0 ; i < period; i++) {
5253 add[i] = 1 ;
53- while (!coprime210 (i + add[i])) {
54+ while (!coprime_wheel (i + add[i])) {
5455 add[i]++;
5556 }
5657 }
5758 return add;
5859 }();
5960
60- constexpr auto gap210 = []() {
61+ constexpr auto gap_wheel = []() {
6162 std::array<uint8_t , coprime> gap;
6263 for (uint8_t i = 0 ; i < coprime; i++) {
63- gap[i] = add210[res210 [i]];
64+ gap[i] = add_wheel[res_wheel [i]];
6465 }
6566 return gap;
6667 }();
6768
6869 // Convert value to ordinal (index in compressed bit array)
6970 constexpr uint32_t to_ord (uint32_t x) {
70- return (x / period) * coprime + state210 [x % period];
71+ return (x / period) * coprime + state_wheel [x % period];
7172 }
7273
7374 // Convert ordinal to value
7475 constexpr uint32_t to_val (uint32_t x) {
75- return (x / coprime) * period + res210 [x % coprime];
76+ return (x / coprime) * period + res_wheel [x % coprime];
7677 }
7778
78- constexpr size_t sqrt_threshold = 1 << 16 ;
79+ constexpr size_t sqrt_threshold = 1 << 15 ;
7980 constexpr auto sqrt_prime_bits = []() {
80- const int size = sqrt_threshold / 4 ;
81+ const int size = sqrt_threshold / 2 ;
8182 bit_array<size> prime;
8283 prime.set_all ();
8384 prime.reset (to_ord (1 ));
84- for (uint32_t i = 11 ; to_ord (i * i) < size; i += add210 [i % period]) {
85+ for (uint32_t i = res_wheel[ 1 ] ; to_ord (i * i) < size; i += add_wheel [i % period]) {
8586 if (prime[to_ord (i)]) {
86- for (uint32_t k = i; to_ord (i * k) < size; k += add210 [k % period]) {
87+ for (uint32_t k = i; to_ord (i * k) < size; k += add_wheel [k % period]) {
8788 prime.reset (to_ord (i * k));
8889 }
8990 }
@@ -93,15 +94,15 @@ namespace cp_algo::math {
9394
9495 constexpr size_t num_primes = []() {
9596 size_t cnt = 0 ;
96- for (uint32_t i = 11 ; i < sqrt_threshold; i += add210 [i % period]) {
97+ for (uint32_t i = res_wheel[ 1 ] ; i < sqrt_threshold; i += add_wheel [i % period]) {
9798 cnt += sqrt_prime_bits[to_ord (i)];
9899 }
99100 return cnt;
100101 }();
101102 constexpr auto sqrt_primes = []() {
102103 std::array<uint32_t , num_primes> primes;
103104 size_t j = 0 ;
104- for (uint32_t i = 11 ; i < sqrt_threshold; i += add210 [i % period]) {
105+ for (uint32_t i = res_wheel[ 1 ] ; i < sqrt_threshold; i += add_wheel [i % period]) {
105106 if (sqrt_prime_bits[to_ord (i)]) {
106107 primes[j++] = i;
107108 }
@@ -121,7 +122,7 @@ namespace cp_algo::math {
121122 wheel.mask .resize (product / period * coprime);
122123 wheel.mask .set_all ();
123124 for (auto p: primes) {
124- for (uint32_t k = 1 ; p * k < product; k += add210 [k % period]) {
125+ for (uint32_t k = 1 ; p * k < product; k += add_wheel [k % period]) {
125126 wheel.mask .reset (to_ord (p * k));
126127 }
127128 }
@@ -150,15 +151,15 @@ namespace cp_algo::math {
150151 }
151152
152153 template <class BitArray >
153- constexpr void sieve210 (BitArray& prime, uint32_t l, uint32_t r, size_t i, int state) {
154+ constexpr void sieve_wheel (BitArray& prime, uint32_t l, uint32_t r, size_t i, int state) {
154155 static const auto ord_step = []() {
155156 big_vector<std::array<uint32_t , 2 * coprime>> ord_steps (num_primes);
156157 for (uint32_t i = 0 ; i < size (sqrt_primes); i++) {
157158 auto p = sqrt_primes[i];
158159 auto &ords = ord_steps[i];
159160 auto last = to_ord (p);
160161 for (uint32_t j = 0 ; j < coprime; j++) {
161- auto next = to_ord (p * (res210 [j] + gap210 [j]));
162+ auto next = to_ord (p * (res_wheel [j] + gap_wheel [j]));
162163 ords[j] = ords[j + coprime] = next - last;
163164 last = next;
164165 }
@@ -182,21 +183,21 @@ namespace cp_algo::math {
182183 }
183184
184185 // Primes smaller or equal than N
185- constexpr dynamic_bit_array sieve210 (uint32_t N) {
186+ constexpr dynamic_bit_array sieve_wheel (uint32_t N) {
186187 N++;
187188 dynamic_bit_array prime (to_ord (N));
188189 prime.set_all ();
189190 static const auto [wheels, medium_primes_begin] = []() {
190191 constexpr size_t max_wheel_size = 1 << 20 ;
191- uint32_t product = period * dynamic_bit_array::width / 4 ;
192+ uint32_t product = period * dynamic_bit_array::width >> ( size (wheel_primes) - 2 ) ;
192193 big_vector<uint32_t > current;
193194 big_vector<wheel_t > wheels;
194195 for (size_t i = 0 ; i < size (sqrt_primes); i++) {
195196 uint32_t p = sqrt_primes[i];
196197 if (product * p > max_wheel_size) {
197198 wheels.push_back (make_wheel (current, product));
198199 current = {p};
199- product = period * dynamic_bit_array::width / 4 * p;
200+ product = ( period * dynamic_bit_array::width >> ( size (wheel_primes) - 2 )) * p;
200201 if (product > max_wheel_size) {
201202 checkpoint (" make wheels" );
202203 return std::pair{wheels, i};
@@ -224,8 +225,8 @@ namespace cp_algo::math {
224225 auto p = sqrt_primes[i];
225226 if (p * p >= r) break ;
226227 auto k = std::max (start / p, p);
227- if (!coprime210 (k)) {k += add210 [k % 210 ];}
228- sieve210 (prime, to_ord (p * k), to_ord (r), i, state210 [k % 210 ]);
228+ if (!coprime_wheel (k)) {k += add_wheel [k % period ];}
229+ sieve_wheel (prime, to_ord (p * k), to_ord (r), i, state_wheel [k % period ]);
229230 }
230231 }
231232 checkpoint (" sparse sieve" );
0 commit comments