Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h"
#include "utils/containers/filtrans.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/set_union.h"
#include "utils/containers/transform.h"
#include "utils/graph/digraph/digraph.h"
Expand Down Expand Up @@ -63,7 +64,7 @@ milliseconds_t task_simulator_estimate_forward_pass_time(
std::unordered_set<device_id_t> devices_occupied =
set_union(transform(in_progress_tasks, get_devices));
std::unordered_set<device_id_t> required_devices = get_devices(task);
return intersection(devices_occupied, required_devices).empty();
return set_intersection(devices_occupied, required_devices).empty();
};

TaskExecutionConstraint constraint =
Expand Down
4 changes: 2 additions & 2 deletions lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#include "utils/bidict/algorithms/unordered_set_of.h"
#include "utils/containers/contains_key.h"
#include "utils/containers/flatmap.h"
#include "utils/containers/intersection.h"
#include "utils/containers/map_values2.h"
#include "utils/containers/set_difference.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/transform.h"
#include "utils/optional.h"

Expand Down Expand Up @@ -67,7 +67,7 @@ static std::pair<DynamicValueAttrs, DynamicValueAttrs>
// those will not result in actual copies once shard expansion is performed
std::unordered_set<
std::pair<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>>
remove = intersection(input_mapping, output_mapping);
remove = set_intersection(input_mapping, output_mapping);

DynamicValueAttrs filtered_input = input;
filtered_input.mapping =
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/containers/are_disjoint.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_DISJOINT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_DISJOINT_H

#include "utils/containers/intersection.h"
#include "utils/containers/set_intersection.h"

namespace FlexFlow {

template <typename T>
bool are_disjoint(std::unordered_set<T> const &l,
std::unordered_set<T> const &r) {
return intersection<T>(l, r).empty();
return set_intersection<T>(l, r).empty();
}

} // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ std::unordered_map<K, V>
std::unordered_set<K> lhs_keys = keys(lhs);
std::unordered_set<K> rhs_keys = keys(rhs);

std::unordered_set<K> shared_keys = intersection(lhs_keys, rhs_keys);
std::unordered_set<K> shared_keys = set_intersection(lhs_keys, rhs_keys);
ASSERT(shared_keys.empty());

return binary_merge_maps_with(
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/containers/binary_merge_maps_with.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H

#include "utils/containers/generate_map.h"
#include "utils/containers/intersection.h"
#include "utils/containers/keys.h"
#include "utils/containers/merge_maps_with_right_dominating.h"
#include "utils/containers/restrict_keys.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/set_minus.h"
#include <unordered_map>

Expand All @@ -22,7 +22,7 @@ std::unordered_map<K, V>

std::unordered_set<K> l_only_keys = set_minus(l_keys, r_keys);
std::unordered_set<K> r_only_keys = set_minus(r_keys, l_keys);
std::unordered_set<K> both_keys = intersection(r_keys, l_keys);
std::unordered_set<K> both_keys = set_intersection(r_keys, l_keys);

std::unordered_map<K, V> l_only = restrict_keys(lhs, l_only_keys);
std::unordered_map<K, V> r_only = restrict_keys(rhs, r_only_keys);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INTERSECTION_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INTERSECTION_H
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_INTERSECTION_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_INTERSECTION_H

#include "utils/containers/contains.h"
#include <optional>
Expand All @@ -9,8 +9,8 @@
namespace FlexFlow {

template <typename T>
std::unordered_set<T> intersection(std::unordered_set<T> const &l,
std::unordered_set<T> const &r) {
std::unordered_set<T> set_intersection(std::unordered_set<T> const &l,
std::unordered_set<T> const &r) {
std::unordered_set<T> result;
for (T const &ll : l) {
if (contains(r, ll)) {
Expand All @@ -21,7 +21,7 @@ std::unordered_set<T> intersection(std::unordered_set<T> const &l,
}

template <typename T>
std::set<T> intersection(std::set<T> const &l, std::set<T> const &r) {
std::set<T> set_intersection(std::set<T> const &l, std::set<T> const &r) {
std::set<T> result;
for (T const &ll : l) {
if (contains(r, ll)) {
Expand All @@ -32,10 +32,10 @@ std::set<T> intersection(std::set<T> const &l, std::set<T> const &r) {
}

template <typename C, typename T = typename C::value_type>
std::optional<T> intersection(C const &c) {
std::optional<T> set_intersection(C const &c) {
std::optional<T> result;
for (T const &t : c) {
result = intersection(result.value_or(t), t);
result = set_intersection(result.value_or(t), t);
}

return result;
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/query_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "utils/containers/contains.h"
#include "utils/containers/filter.h"
#include "utils/containers/filter_keys.h"
#include "utils/containers/intersection.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/set_of.h"
#include "utils/containers/set_union.h"
#include "utils/containers/transform.h"
Expand Down Expand Up @@ -146,7 +146,7 @@ query_set<T> query_intersection(query_set<T> const &lhs,
return lhs;
} else {
return query_set<T>::match_values_in(
set_of(intersection(allowed_values(lhs), allowed_values(rhs))));
set_of(set_intersection(allowed_values(lhs), allowed_values(rhs))));
}
}

Expand Down
20 changes: 0 additions & 20 deletions lib/utils/src/utils/containers/intersection.cc

This file was deleted.

21 changes: 21 additions & 0 deletions lib/utils/src/utils/containers/set_intersection.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "utils/containers/set_intersection.h"
#include "utils/archetypes/ordered_value_type.h"
#include "utils/archetypes/value_type.h"

namespace FlexFlow {

using T = value_type<0>;

template std::unordered_set<T> set_intersection(std::unordered_set<T> const &,
std::unordered_set<T> const &);
template std::optional<std::unordered_set<T>>
set_intersection(std::vector<std::unordered_set<T>> const &);

using T2 = ordered_value_type<0>;

template std::set<T2> set_intersection(std::set<T2> const &,
std::set<T2> const &);
template std::optional<std::set<T2>>
set_intersection(std::vector<std::set<T2>> const &);

} // namespace FlexFlow
1 change: 0 additions & 1 deletion lib/utils/src/utils/graph/algorithms.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "utils/graph/algorithms.h"
#include "utils/containers/flatmap.h"
#include "utils/containers/get_only.h"
#include "utils/containers/intersection.h"
#include "utils/containers/restrict_keys.h"
#include "utils/containers/set_difference.h"
#include "utils/containers/set_of.h"
Expand Down
12 changes: 6 additions & 6 deletions lib/utils/src/utils/graph/digraph/algorithms/get_dominators.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "utils/graph/digraph/algorithms/get_dominators.h"
#include "utils/containers/restrict_keys.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/values.h"
#include "utils/graph/digraph/algorithms/get_dominators_map.h"
#include "utils/hash/unordered_set.h"
#include "utils/optional.h"
#include <queue>

namespace FlexFlow {
Expand All @@ -13,14 +15,12 @@ std::unordered_set<Node> get_dominators(DiGraphView const &g, Node const &n) {

std::unordered_set<Node> get_dominators(DiGraphView const &g,
std::unordered_set<Node> const &n) {
if (n.empty()) {
throw mk_runtime_error("Cannot find dominators of no nodes");
}
ASSERT(n.size() > 0, "Cannot find dominators of no nodes");

std::optional<std::unordered_set<Node>> result =
intersection(values(restrict_keys(get_dominators_map(g), n)));
assert(result.has_value());
set_intersection(values(restrict_keys(get_dominators_map(g), n)));

return result.value();
return assert_unwrap(result);
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "utils/graph/digraph/algorithms/get_dominators_map.h"
#include "utils/containers/generate_map.h"
#include "utils/containers/restrict_keys.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/transform.h"
#include "utils/containers/values.h"
#include "utils/graph/digraph/algorithms/get_initial_nodes.h"
Expand Down Expand Up @@ -33,7 +34,7 @@ std::unordered_map<Node, std::unordered_set<Node>>
std::unordered_set<Node> old_result_entry = result.at(n);

result.at(n) =
intersection(transform(get_predecessors(g, n), [&](Node const &n) {
set_intersection(transform(get_predecessors(g, n), [&](Node const &n) {
return result.at(n);
})).value_or(std::unordered_set<Node>{});
result.at(n).insert(n);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "utils/containers/generate_map.h"
#include "utils/containers/get_one_of.h"
#include "utils/containers/get_only.h"
#include "utils/containers/intersection.h"
#include "utils/containers/restrict_keys.h"
#include "utils/containers/values.h"
#include "utils/graph/digraph/algorithms/apply_contraction.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "utils/containers/filter.h"
#include "utils/containers/intersection.h"
#include "utils/containers/is_subseteq_of.h"
#include "utils/containers/maximum.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/set_union.h"
#include "utils/containers/transform.h"
#include "utils/graph/digraph/algorithms/get_ancestors.h"
Expand All @@ -26,7 +26,8 @@ std::optional<std::unordered_set<Node>>
transform(nodes, [&](Node const &n) {
return set_union(get_ancestors(g, n), {n});
});
std::unordered_set<Node> common_ancestors = intersection(ancestors).value();
std::unordered_set<Node> common_ancestors =
set_intersection(ancestors).value();

if (common_ancestors.empty()) {
return std::unordered_set<Node>{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "utils/bidict/algorithms/bidict_from_enumerating.h"
#include "utils/bidict/algorithms/transform_keys.h"
#include "utils/containers/is_subseteq_of.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/vector_of.h"
#include "utils/graph/digraph/algorithms/get_edges.h"
#include "utils/graph/digraph/algorithms/materialize_digraph_view.h"
Expand All @@ -18,7 +19,7 @@ DirectedEdgeMaskView::DirectedEdgeMaskView(

std::unordered_set<DirectedEdge>
DirectedEdgeMaskView::query_edges(DirectedEdgeQuery const &q) const {
return intersection(g.query_edges(q), this->edge_mask);
return set_intersection(g.query_edges(q), this->edge_mask);
}

std::unordered_set<Node>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#include "utils/containers/filter_keys.h"
#include "utils/containers/get_only.h"
#include "utils/containers/group_by.h"
#include "utils/containers/intersection.h"
#include "utils/containers/map_values.h"
#include "utils/containers/maximum.h"
#include "utils/containers/range.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/set_union.h"
#include "utils/containers/transform.h"
#include "utils/containers/values.h"
Expand Down Expand Up @@ -137,7 +137,7 @@ static std::unordered_set<Node>

auto subtrees_overlapping_with_component =
filter(subtrees, [&](std::unordered_set<Node> subtree) {
return intersection(subtree, component).size() > 0;
return set_intersection(subtree, component).size() > 0;
});

std::unordered_set<Node> forest =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
#include "utils/containers/filter.h"
#include "utils/containers/generate_map.h"
#include "utils/containers/get_only.h"
#include "utils/containers/intersection.h"
#include "utils/containers/is_subseteq_of.h"
#include "utils/containers/keys.h"
#include "utils/containers/maximum.h"
#include "utils/containers/set_difference.h"
#include "utils/containers/set_intersection.h"
#include "utils/containers/set_union.h"
#include "utils/containers/transform.h"
#include "utils/containers/values.h"
Expand Down Expand Up @@ -66,7 +66,7 @@ static std::unordered_set<Node>

std::unordered_set<std::unordered_set<Node>> overlapping_subtrees =
filter(subtrees, [&](std::unordered_set<Node> const &subtree) {
return !intersection(subtree, component).empty();
return !set_intersection(subtree, component).empty();
});

std::unordered_set<Node> forest = set_union(overlapping_subtrees);
Expand All @@ -89,7 +89,7 @@ static UpDownPartition
node_roles);

std::unordered_set<Node> base_down = nodes;
std::unordered_set<Node> base_up = intersection(
std::unordered_set<Node> base_up = set_intersection(
set_union(transform(
nodes, [&](Node const &n) { return get_ancestors(sp_pure, n); })),
forest);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "utils/containers/intersection.h"
#include "utils/containers/set_intersection.h"
#include "test/utils/doctest/fmt/optional.h"
#include "test/utils/doctest/fmt/set.h"
#include "test/utils/doctest/fmt/unordered_set.h"
Expand All @@ -8,22 +8,22 @@ using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE_TEMPLATE(
"intersection(S, S)", S, std::unordered_set<int>, std::set<int>) {
"set_intersection(S, S)", S, std::unordered_set<int>, std::set<int>) {
S input_l = {1, 2, 3};
S input_r = {2, 3, 5};

S result = intersection(input_l, input_r);
S result = set_intersection(input_l, input_r);
S correct = {2, 3};

CHECK(result == correct);
}

TEST_CASE_TEMPLATE(
"intersection(C<S>)", S, std::unordered_set<int>, std::set<int>) {
"set_intersection(C<S>)", S, std::unordered_set<int>, std::set<int>) {
SUBCASE("input is empty container") {
std::vector<S> input = {};

std::optional<S> result = intersection(input);
std::optional<S> result = set_intersection(input);
std::optional<S> correct = std::nullopt;

CHECK(result == correct);
Expand All @@ -32,7 +32,7 @@ TEST_SUITE(FF_TEST_SUITE) {
SUBCASE("input is has only one set") {
std::vector<S> input = {{1, 2, 3}};

std::optional<S> result = intersection(input);
std::optional<S> result = set_intersection(input);
std::optional<S> correct = {{1, 2, 3}};

CHECK(result == correct);
Expand All @@ -41,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) {
SUBCASE("input has multiple sets") {
std::vector<S> input = {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}};

std::optional<S> result = intersection(input);
std::optional<S> result = set_intersection(input);
std::optional<S> correct = {{3}};

CHECK(result == correct);
Expand Down
Loading