diff --git a/lib/utils/include/utils/bidict/algorithms/filter_bidict.h b/lib/utils/include/utils/bidict/algorithms/filter_bidict.h new file mode 100644 index 0000000000..910293c28a --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filter_bidict.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_BIDICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_BIDICT_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict filter_bidict(bidict const &b, F &&f) { + bidict result; + + for (std::pair const &p : b) { + if (f(p.first, p.second)) { + result.equate_strict(p.first, p.second); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/bidict/algorithms/filter_bidict.cc b/lib/utils/src/utils/bidict/algorithms/filter_bidict.cc new file mode 100644 index 0000000000..53cee0548e --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/filter_bidict.cc @@ -0,0 +1,12 @@ +#include "utils/bidict/algorithms/filter_bidict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; +using F = std::function; + +template bidict filter_bidict(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_bidict.cc b/lib/utils/test/src/utils/bidict/algorithms/filter_bidict.cc new file mode 100644 index 0000000000..c310f82459 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/filter_bidict.cc @@ -0,0 +1,28 @@ +#include "utils/bidict/algorithms/filter_bidict.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_bidict") { + bidict b = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + {4, "four"}, + }; + + auto filter_func = [](int k, std::string const &v) -> bool { + return (k % 2) == 0 || v.size() == 3; + }; + + bidict result = filter_bidict(b, filter_func); + bidict correct = { + {1, "one"}, + {2, "two"}, + {4, "four"}, + }; + + CHECK(result == correct); + } +}