@@ -40,7 +40,7 @@ template <char ...chars> constexpr bool is_in(char x) { return ((x == chars) ||
4040
4141static bool is_in (char c, const char * symbols, size_t num_chars)
4242{
43- for (auto i = 0u ; i < num_chars; i++ )
43+ for (size_t i = 0u ; i < num_chars; ++i )
4444 {
4545 if (c == symbols[i])
4646 {
@@ -66,6 +66,43 @@ inline __m128i mm_is_in(__m128i bytes)
6666 __m128i eq = mm_is_in<s1, tail...>(bytes);
6767 return _mm_or_si128 (eq0, eq);
6868}
69+
70+ inline __m128i mm_is_in (__m128i bytes, const char * symbols, size_t num_chars)
71+ {
72+ __m128i accumulator = _mm_setzero_si128 ();
73+ for (size_t i = 0 ; i < num_chars; ++i)
74+ {
75+ __m128i eq = _mm_cmpeq_epi8 (bytes, _mm_set1_epi8 (symbols[i]));
76+ accumulator = _mm_or_si128 (accumulator, eq);
77+ }
78+
79+ return accumulator;
80+ }
81+
82+ inline std::vector<__m128i> mm_is_in_prepare (const char * symbols, size_t num_chars)
83+ {
84+ std::vector<__m128i> result;
85+ result.reserve (num_chars);
86+
87+ for (size_t i = 0 ; i < num_chars; ++i)
88+ {
89+ result.emplace_back (_mm_set1_epi8 (symbols[i]));
90+ }
91+
92+ return result;
93+ }
94+
95+ inline __m128i mm_is_in_execute (__m128i bytes, const std::vector<__m128i> & needles)
96+ {
97+ __m128i accumulator = _mm_setzero_si128 ();
98+ for (const auto & needle : needles)
99+ {
100+ __m128i eq = _mm_cmpeq_epi8 (bytes, needle);
101+ accumulator = _mm_or_si128 (accumulator, eq);
102+ }
103+
104+ return accumulator;
105+ }
69106#endif
70107
71108template <bool positive>
@@ -112,6 +149,32 @@ inline const char * find_first_symbols_sse2(const char * const begin, const char
112149 return return_mode == ReturnMode::End ? end : nullptr ;
113150}
114151
152+ template <bool positive, ReturnMode return_mode>
153+ inline const char * find_first_symbols_sse2 (const char * const begin, const char * const end, const char * symbols, size_t num_chars)
154+ {
155+ const char * pos = begin;
156+ const auto needles = mm_is_in_prepare (symbols, num_chars);
157+
158+ #if defined(__SSE2__)
159+ for (; pos + 15 < end; pos += 16 )
160+ {
161+ __m128i bytes = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(pos));
162+
163+ __m128i eq = mm_is_in_execute (bytes, needles);
164+
165+ uint16_t bit_mask = maybe_negate<positive>(uint16_t (_mm_movemask_epi8 (eq)));
166+ if (bit_mask)
167+ return pos + __builtin_ctz (bit_mask);
168+ }
169+ #endif
170+
171+ for (; pos < end; ++pos)
172+ if (maybe_negate<positive>(is_in (*pos, symbols, num_chars)))
173+ return pos;
174+
175+ return return_mode == ReturnMode::End ? end : nullptr ;
176+ }
177+
115178
116179template <bool positive, ReturnMode return_mode, char ... symbols>
117180inline const char * find_last_symbols_sse2 (const char * const begin, const char * const end)
@@ -192,21 +255,6 @@ inline const char * find_first_symbols_sse42(const char * const begin, const cha
192255 return return_mode == ReturnMode::End ? end : nullptr ;
193256}
194257
195-
196- // / NOTE No SSE 4.2 implementation for find_last_symbols_or_null. Not worth to do.
197-
198- template <bool positive, ReturnMode return_mode, char ... symbols>
199- inline const char * find_first_symbols_dispatch (const char * begin, const char * end)
200- requires(0 <= sizeof ...(symbols) && sizeof...(symbols) <= 16)
201- {
202- #if defined(__SSE4_2__)
203- if (sizeof ...(symbols) >= 5 )
204- return find_first_symbols_sse42<positive, return_mode, sizeof ...(symbols), symbols...>(begin, end);
205- else
206- #endif
207- return find_first_symbols_sse2<positive, return_mode, symbols...>(begin, end);
208- }
209-
210258template <bool positive, ReturnMode return_mode>
211259inline const char * find_first_symbols_sse42 (const char * const begin, const char * const end, const char * symbols, size_t num_chars)
212260{
@@ -215,7 +263,10 @@ inline const char * find_first_symbols_sse42(const char * const begin, const cha
215263#if defined(__SSE4_2__)
216264 constexpr int mode = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT;
217265
218- const __m128i set = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(symbols));
266+ // This is to avoid read past end of `symbols` if `num_chars < 16`.
267+ char buffer[16 ] = {' \0 ' };
268+ memcpy (buffer, symbols, num_chars);
269+ const __m128i set = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(buffer));
219270
220271 for (; pos + 15 < end; pos += 16 )
221272 {
@@ -241,10 +292,30 @@ inline const char * find_first_symbols_sse42(const char * const begin, const cha
241292 return return_mode == ReturnMode::End ? end : nullptr ;
242293}
243294
295+ // / NOTE No SSE 4.2 implementation for find_last_symbols_or_null. Not worth to do.
296+
297+ template <bool positive, ReturnMode return_mode, char ... symbols>
298+ inline const char * find_first_symbols_dispatch (const char * begin, const char * end)
299+ requires(0 <= sizeof ...(symbols) && sizeof...(symbols) <= 16)
300+ {
301+ #if defined(__SSE4_2__)
302+ if (sizeof ...(symbols) >= 5 )
303+ return find_first_symbols_sse42<positive, return_mode, sizeof ...(symbols), symbols...>(begin, end);
304+ else
305+ #endif
306+ return find_first_symbols_sse2<positive, return_mode, symbols...>(begin, end);
307+ }
308+
244309template <bool positive, ReturnMode return_mode>
245- auto find_first_symbols_sse42 ( std::string_view haystack, std::string_view symbols)
310+ inline const char * find_first_symbols_dispatch ( const std::string_view haystack, const std::string_view symbols)
246311{
247- return find_first_symbols_sse42<positive, return_mode>(haystack.begin (), haystack.end (), symbols.begin (), symbols.size ());
312+ const size_t num_chars = std::min<size_t >(symbols.size (), 16 );
313+ #if defined(__SSE4_2__)
314+ if (num_chars >= 5 )
315+ return find_first_symbols_sse42<positive, return_mode>(haystack.begin (), haystack.end (), symbols.begin (), num_chars);
316+ else
317+ #endif
318+ return find_first_symbols_sse2<positive, return_mode>(haystack.begin (), haystack.end (), symbols.begin (), num_chars);
248319}
249320
250321}
@@ -266,7 +337,7 @@ inline char * find_first_symbols(char * begin, char * end)
266337
267338inline const char * find_first_symbols (std::string_view haystack, std::string_view symbols)
268339{
269- return detail::find_first_symbols_sse42 <true , detail::ReturnMode::End>(haystack, symbols);
340+ return detail::find_first_symbols_dispatch <true , detail::ReturnMode::End>(haystack, symbols);
270341}
271342
272343template <char ... symbols>
@@ -283,7 +354,7 @@ inline char * find_first_not_symbols(char * begin, char * end)
283354
284355inline const char * find_first_not_symbols (std::string_view haystack, std::string_view symbols)
285356{
286- return detail::find_first_symbols_sse42 <false , detail::ReturnMode::End>(haystack, symbols);
357+ return detail::find_first_symbols_dispatch <false , detail::ReturnMode::End>(haystack, symbols);
287358}
288359
289360template <char ... symbols>
@@ -300,7 +371,7 @@ inline char * find_first_symbols_or_null(char * begin, char * end)
300371
301372inline const char * find_first_symbols_or_null (std::string_view haystack, std::string_view symbols)
302373{
303- return detail::find_first_symbols_sse42 <true , detail::ReturnMode::Nullptr>(haystack, symbols);
374+ return detail::find_first_symbols_dispatch <true , detail::ReturnMode::Nullptr>(haystack, symbols);
304375}
305376
306377template <char ... symbols>
@@ -317,7 +388,7 @@ inline char * find_first_not_symbols_or_null(char * begin, char * end)
317388
318389inline const char * find_first_not_symbols_or_null (std::string_view haystack, std::string_view symbols)
319390{
320- return detail::find_first_symbols_sse42 <false , detail::ReturnMode::Nullptr>(haystack, symbols);
391+ return detail::find_first_symbols_dispatch <false , detail::ReturnMode::Nullptr>(haystack, symbols);
321392}
322393
323394template <char ... symbols>
0 commit comments