diff --git a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc index bc528493a8..e514e6a753 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc @@ -11,6 +11,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "utils/containers/filtrans.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/digraph.h" @@ -63,7 +64,7 @@ milliseconds_t task_simulator_estimate_forward_pass_time( std::unordered_set devices_occupied = set_union(transform(in_progress_tasks, get_devices)); std::unordered_set required_devices = get_devices(task); - return intersection(devices_occupied, required_devices).empty(); + return set_intersection(devices_occupied, required_devices).empty(); }; TaskExecutionConstraint constraint = diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 4c1b9d4609..79bd091c0a 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -13,9 +13,9 @@ #include "utils/bidict/algorithms/unordered_set_of.h" #include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" -#include "utils/containers/intersection.h" #include "utils/containers/map_values2.h" #include "utils/containers/set_difference.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/transform.h" #include "utils/optional.h" @@ -67,7 +67,7 @@ static std::pair // those will not result in actual copies once shard expansion is performed std::unordered_set< std::pair> - remove = intersection(input_mapping, output_mapping); + remove = set_intersection(input_mapping, output_mapping); DynamicValueAttrs filtered_input = input; filtered_input.mapping = diff --git a/lib/utils/include/utils/containers/are_disjoint.h b/lib/utils/include/utils/containers/are_disjoint.h index 4b5c51fb12..0d5882889c 100644 --- a/lib/utils/include/utils/containers/are_disjoint.h +++ b/lib/utils/include/utils/containers/are_disjoint.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_DISJOINT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_DISJOINT_H -#include "utils/containers/intersection.h" +#include "utils/containers/set_intersection.h" namespace FlexFlow { template bool are_disjoint(std::unordered_set const &l, std::unordered_set const &r) { - return intersection(l, r).empty(); + return set_intersection(l, r).empty(); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h index 06a42327e1..38db9b34d3 100644 --- a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h @@ -14,7 +14,7 @@ std::unordered_map std::unordered_set lhs_keys = keys(lhs); std::unordered_set rhs_keys = keys(rhs); - std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); + std::unordered_set shared_keys = set_intersection(lhs_keys, rhs_keys); ASSERT(shared_keys.empty()); return binary_merge_maps_with( diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with.h b/lib/utils/include/utils/containers/binary_merge_maps_with.h index a7c196d061..3ba95342e6 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with.h @@ -2,10 +2,10 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H #include "utils/containers/generate_map.h" -#include "utils/containers/intersection.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps_with_right_dominating.h" #include "utils/containers/restrict_keys.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/set_minus.h" #include @@ -22,7 +22,7 @@ std::unordered_map std::unordered_set l_only_keys = set_minus(l_keys, r_keys); std::unordered_set r_only_keys = set_minus(r_keys, l_keys); - std::unordered_set both_keys = intersection(r_keys, l_keys); + std::unordered_set both_keys = set_intersection(r_keys, l_keys); std::unordered_map l_only = restrict_keys(lhs, l_only_keys); std::unordered_map r_only = restrict_keys(rhs, r_only_keys); diff --git a/lib/utils/include/utils/containers/intersection.h b/lib/utils/include/utils/containers/set_intersection.h similarity index 56% rename from lib/utils/include/utils/containers/intersection.h rename to lib/utils/include/utils/containers/set_intersection.h index 55e6c7a5f8..7240da6f55 100644 --- a/lib/utils/include/utils/containers/intersection.h +++ b/lib/utils/include/utils/containers/set_intersection.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INTERSECTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INTERSECTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_INTERSECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_INTERSECTION_H #include "utils/containers/contains.h" #include @@ -9,8 +9,8 @@ namespace FlexFlow { template -std::unordered_set intersection(std::unordered_set const &l, - std::unordered_set const &r) { +std::unordered_set set_intersection(std::unordered_set const &l, + std::unordered_set const &r) { std::unordered_set result; for (T const &ll : l) { if (contains(r, ll)) { @@ -21,7 +21,7 @@ std::unordered_set intersection(std::unordered_set const &l, } template -std::set intersection(std::set const &l, std::set const &r) { +std::set set_intersection(std::set const &l, std::set const &r) { std::set result; for (T const &ll : l) { if (contains(r, ll)) { @@ -32,10 +32,10 @@ std::set intersection(std::set const &l, std::set const &r) { } template -std::optional intersection(C const &c) { +std::optional set_intersection(C const &c) { std::optional result; for (T const &t : c) { - result = intersection(result.value_or(t), t); + result = set_intersection(result.value_or(t), t); } return result; diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index c4234888cb..c59268d6eb 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -5,7 +5,7 @@ #include "utils/containers/contains.h" #include "utils/containers/filter.h" #include "utils/containers/filter_keys.h" -#include "utils/containers/intersection.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/set_of.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" @@ -146,7 +146,7 @@ query_set query_intersection(query_set const &lhs, return lhs; } else { return query_set::match_values_in( - set_of(intersection(allowed_values(lhs), allowed_values(rhs)))); + set_of(set_intersection(allowed_values(lhs), allowed_values(rhs)))); } } diff --git a/lib/utils/src/utils/containers/intersection.cc b/lib/utils/src/utils/containers/intersection.cc deleted file mode 100644 index 7b89acf69e..0000000000 --- a/lib/utils/src/utils/containers/intersection.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "utils/containers/intersection.h" -#include "utils/archetypes/ordered_value_type.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T = value_type<0>; - -template std::unordered_set intersection(std::unordered_set const &, - std::unordered_set const &); -template std::optional> - intersection(std::vector> const &); - -using T2 = ordered_value_type<0>; - -template std::set intersection(std::set const &, std::set const &); -template std::optional> - intersection(std::vector> const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/set_intersection.cc b/lib/utils/src/utils/containers/set_intersection.cc new file mode 100644 index 0000000000..2548120859 --- /dev/null +++ b/lib/utils/src/utils/containers/set_intersection.cc @@ -0,0 +1,21 @@ +#include "utils/containers/set_intersection.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template std::unordered_set set_intersection(std::unordered_set const &, + std::unordered_set const &); +template std::optional> + set_intersection(std::vector> const &); + +using T2 = ordered_value_type<0>; + +template std::set set_intersection(std::set const &, + std::set const &); +template std::optional> + set_intersection(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index f09db77282..1206c56375 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -1,7 +1,6 @@ #include "utils/graph/algorithms.h" #include "utils/containers/flatmap.h" #include "utils/containers/get_only.h" -#include "utils/containers/intersection.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_difference.h" #include "utils/containers/set_of.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators.cc index 7c5a863359..3837c93af2 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -1,8 +1,10 @@ #include "utils/graph/digraph/algorithms/get_dominators.h" #include "utils/containers/restrict_keys.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/get_dominators_map.h" #include "utils/hash/unordered_set.h" +#include "utils/optional.h" #include namespace FlexFlow { @@ -13,14 +15,12 @@ std::unordered_set get_dominators(DiGraphView const &g, Node const &n) { std::unordered_set get_dominators(DiGraphView const &g, std::unordered_set const &n) { - if (n.empty()) { - throw mk_runtime_error("Cannot find dominators of no nodes"); - } + ASSERT(n.size() > 0, "Cannot find dominators of no nodes"); + std::optional> result = - intersection(values(restrict_keys(get_dominators_map(g), n))); - assert(result.has_value()); + set_intersection(values(restrict_keys(get_dominators_map(g), n))); - return result.value(); + return assert_unwrap(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index 2da1b208f4..53b8027e07 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -1,6 +1,7 @@ #include "utils/graph/digraph/algorithms/get_dominators_map.h" #include "utils/containers/generate_map.h" #include "utils/containers/restrict_keys.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/get_initial_nodes.h" @@ -33,7 +34,7 @@ std::unordered_map> std::unordered_set old_result_entry = result.at(n); result.at(n) = - intersection(transform(get_predecessors(g, n), [&](Node const &n) { + set_intersection(transform(get_predecessors(g, n), [&](Node const &n) { return result.at(n); })).value_or(std::unordered_set{}); result.at(n).insert(n); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc index 39523f2ec1..53deeb823c 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc @@ -2,7 +2,6 @@ #include "utils/containers/generate_map.h" #include "utils/containers/get_one_of.h" #include "utils/containers/get_only.h" -#include "utils/containers/intersection.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/apply_contraction.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc index 0d5705854d..dbcd07b0ff 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc @@ -1,7 +1,7 @@ #include "utils/containers/filter.h" -#include "utils/containers/intersection.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/maximum.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms/get_ancestors.h" @@ -26,7 +26,8 @@ std::optional> transform(nodes, [&](Node const &n) { return set_union(get_ancestors(g, n), {n}); }); - std::unordered_set common_ancestors = intersection(ancestors).value(); + std::unordered_set common_ancestors = + set_intersection(ancestors).value(); if (common_ancestors.empty()) { return std::unordered_set{}; diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index 1e5d0b0ae7..b7c304dcd1 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -2,6 +2,7 @@ #include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/bidict/algorithms/transform_keys.h" #include "utils/containers/is_subseteq_of.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms/get_edges.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" @@ -18,7 +19,7 @@ DirectedEdgeMaskView::DirectedEdgeMaskView( std::unordered_set DirectedEdgeMaskView::query_edges(DirectedEdgeQuery const &q) const { - return intersection(g.query_edges(q), this->edge_mask); + return set_intersection(g.query_edges(q), this->edge_mask); } std::unordered_set diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc index 36b8e7294b..f6249c7559 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc @@ -2,10 +2,10 @@ #include "utils/containers/filter_keys.h" #include "utils/containers/get_only.h" #include "utils/containers/group_by.h" -#include "utils/containers/intersection.h" #include "utils/containers/map_values.h" #include "utils/containers/maximum.h" #include "utils/containers/range.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" @@ -137,7 +137,7 @@ static std::unordered_set auto subtrees_overlapping_with_component = filter(subtrees, [&](std::unordered_set subtree) { - return intersection(subtree, component).size() > 0; + return set_intersection(subtree, component).size() > 0; }); std::unordered_set forest = diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc index 7206ec5cda..60e99dbea4 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc @@ -5,11 +5,11 @@ #include "utils/containers/filter.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" -#include "utils/containers/intersection.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/maximum.h" #include "utils/containers/set_difference.h" +#include "utils/containers/set_intersection.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" @@ -66,7 +66,7 @@ static std::unordered_set std::unordered_set> overlapping_subtrees = filter(subtrees, [&](std::unordered_set const &subtree) { - return !intersection(subtree, component).empty(); + return !set_intersection(subtree, component).empty(); }); std::unordered_set forest = set_union(overlapping_subtrees); @@ -89,7 +89,7 @@ static UpDownPartition node_roles); std::unordered_set base_down = nodes; - std::unordered_set base_up = intersection( + std::unordered_set base_up = set_intersection( set_union(transform( nodes, [&](Node const &n) { return get_ancestors(sp_pure, n); })), forest); diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/set_intersection.cc similarity index 68% rename from lib/utils/test/src/utils/containers/intersection.cc rename to lib/utils/test/src/utils/containers/set_intersection.cc index d34d08a3f2..d206985792 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/set_intersection.cc @@ -1,4 +1,4 @@ -#include "utils/containers/intersection.h" +#include "utils/containers/set_intersection.h" #include "test/utils/doctest/fmt/optional.h" #include "test/utils/doctest/fmt/set.h" #include "test/utils/doctest/fmt/unordered_set.h" @@ -8,22 +8,22 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE( - "intersection(S, S)", S, std::unordered_set, std::set) { + "set_intersection(S, S)", S, std::unordered_set, std::set) { S input_l = {1, 2, 3}; S input_r = {2, 3, 5}; - S result = intersection(input_l, input_r); + S result = set_intersection(input_l, input_r); S correct = {2, 3}; CHECK(result == correct); } TEST_CASE_TEMPLATE( - "intersection(C)", S, std::unordered_set, std::set) { + "set_intersection(C)", S, std::unordered_set, std::set) { SUBCASE("input is empty container") { std::vector input = {}; - std::optional result = intersection(input); + std::optional result = set_intersection(input); std::optional correct = std::nullopt; CHECK(result == correct); @@ -32,7 +32,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is has only one set") { std::vector input = {{1, 2, 3}}; - std::optional result = intersection(input); + std::optional result = set_intersection(input); std::optional correct = {{1, 2, 3}}; CHECK(result == correct); @@ -41,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input has multiple sets") { std::vector input = {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}; - std::optional result = intersection(input); + std::optional result = set_intersection(input); std::optional correct = {{3}}; CHECK(result == correct);