From 1ff30c284314e3d98e1d9e03d50e46095dbb5ced Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 29 May 2026 17:10:23 -0700 Subject: [PATCH] Fill out some utility functions for non-open kwarg graphs --- ...raph_data_from_kwarg_dataflow_graph_data.h | 38 ++-- .../get_kwarg_dataflow_graph_subgraph.h | 43 ++++ .../get_labelled_kwarg_dataflow_graph_data.h | 26 +++ ...lled_kwarg_dataflow_graph_node_label_map.h | 21 ++ ...ed_kwarg_dataflow_graph_output_label_map.h | 24 +++ ...t_labelled_kwarg_dataflow_graph_subgraph.h | 42 ++++ ...kwarg_dataflow_graph_view_with_labelling.h | 73 +++++++ ...abelled_kwarg_dataflow_graph_data.dtg.toml | 41 ++++ ...view_from_open_kwarg_dataflow_graph_data.h | 10 +- .../get_kwarg_dataflow_graph_subgraph.cc | 12 ++ .../get_labelled_kwarg_dataflow_graph_data.cc | 16 ++ ...led_kwarg_dataflow_graph_node_label_map.cc | 16 ++ ...d_kwarg_dataflow_graph_output_label_map.cc | 16 ++ ..._labelled_kwarg_dataflow_graph_subgraph.cc | 17 ++ ...warg_dataflow_graph_view_with_labelling.cc | 17 ++ .../get_kwarg_dataflow_graph_subgraph.cc | 103 ++++++++++ .../get_labelled_kwarg_dataflow_graph_data.cc | 133 +++++++++++++ ...led_kwarg_dataflow_graph_node_label_map.cc | 88 +++++++++ ...d_kwarg_dataflow_graph_output_label_map.cc | 90 +++++++++ ..._labelled_kwarg_dataflow_graph_subgraph.cc | 185 ++++++++++++++++++ ...warg_dataflow_graph_view_with_labelling.cc | 112 +++++++++++ ...iew_from_open_kwarg_dataflow_graph_data.cc | 122 ++++++++++++ 22 files changed, 1226 insertions(+), 19 deletions(-) create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.toml create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc create mode 100644 lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc create mode 100644 lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.cc diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h index c739bf0e82..154072eebf 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h @@ -26,30 +26,40 @@ DataflowGraphData dataflow_graph_data_from_kwarg_dataflow_graph_data( std::unordered_set> all_outputs = kwarg_data.outputs; - OneToMany incoming_slots_by_node = - one_to_many_transform_values( + std::unordered_map> + incoming_slots_by_node = map_values( group_by(all_inputs, [](KwargDataflowInput const &i) -> Node { return i.node; - }), - [](KwargDataflowInput const &i) -> SlotName { - return i.slot_name; + }) + .l_to_r(), + [](nonempty_unordered_set> const &is) + -> std::unordered_set { + return transform(is.unwrap_as_unordered_set(), + [](KwargDataflowInput const &i) { + return i.slot_name; + }); }); - OneToMany outgoing_slots_by_node = - one_to_many_transform_values( + std::unordered_map> + outgoing_slots_by_node = map_values( group_by(all_outputs, [](KwargDataflowOutput const &o) -> Node { return o.node; - }), - [](KwargDataflowOutput const &o) -> SlotName { - return o.slot_name; + }) + .l_to_r(), + [](nonempty_unordered_set> const &os) + -> std::unordered_set { + return transform(os.unwrap_as_unordered_set(), + [](KwargDataflowOutput const &o) { + return o.slot_name; + }); }); auto dataflow_input_from_kwarg_input = [&](KwargDataflowInput const &i) -> DataflowInput { - std::vector slot_ordering = order_slots( - incoming_slots_by_node.at_l(i.node).unwrap_as_unordered_set()); + std::vector slot_ordering = + order_slots(incoming_slots_by_node.at(i.node)); return DataflowInput{ /*node=*/i.node, @@ -62,8 +72,8 @@ DataflowGraphData dataflow_graph_data_from_kwarg_dataflow_graph_data( auto dataflow_output_from_kwarg_output = [&](KwargDataflowOutput const &o) -> DataflowOutput { - std::vector slot_ordering = order_slots( - outgoing_slots_by_node.at_l(o.node).unwrap_as_unordered_set()); + std::vector slot_ordering = + order_slots(outgoing_slots_by_node.at(o.node)); return DataflowOutput{ /*node=*/o.node, diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h new file mode 100644 index 0000000000..dba05a6170 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H + +#include "utils/containers/set_intersection.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h" + +namespace FlexFlow { + +template +KwargDataflowGraphView get_kwarg_dataflow_graph_subgraph( + KwargDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes) { + KwargDataflowGraphData g_data = get_kwarg_dataflow_graph_data(g); + + std::unordered_set nodes = + set_intersection(g_data.nodes, subgraph_nodes); + + std::unordered_set> edges = + filter(g_data.edges, [&](KwargDataflowEdge const &e) -> bool { + return contains(subgraph_nodes, e.src.node) && + contains(subgraph_nodes, e.dst.node); + }); + + std::unordered_set> outputs = filter( + g_data.outputs, [&](KwargDataflowOutput const &o) -> bool { + return contains(subgraph_nodes, o.node); + }); + + KwargDataflowGraphData subgraph_data = + KwargDataflowGraphData{ + /*nodes=*/nodes, + /*edges=*/edges, + /*outputs=*/outputs, + }; + + return view_from_kwarg_dataflow_graph_data(subgraph_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..f2c8d34963 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +LabelledKwargDataflowGraphData + get_labelled_kwarg_dataflow_graph_data( + LabelledKwargDataflowGraphView const + &g) { + return LabelledKwargDataflowGraphData{ + /*node_data=*/get_labelled_kwarg_dataflow_graph_node_label_map(g), + /*edges=*/get_all_kwarg_dataflow_edges(g), + /*output_data=*/get_labelled_kwarg_dataflow_graph_output_label_map(g), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h new file mode 100644 index 0000000000..fabb47a5ab --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_NODE_LABEL_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_NODE_LABEL_MAP_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +template +std::unordered_map + get_labelled_kwarg_dataflow_graph_node_label_map( + LabelledKwargDataflowGraphView const + &g) { + return generate_map(get_nodes(g), + [&](Node const &n) -> NodeLabel { return g.at(n); }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h new file mode 100644 index 0000000000..45366d4d96 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_OUTPUT_LABEL_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_OUTPUT_LABEL_MAP_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_map, OutputLabel> + get_labelled_kwarg_dataflow_graph_output_label_map( + LabelledKwargDataflowGraphView const + &g) { + return generate_map( + get_all_kwarg_dataflow_outputs(g), + [&](KwargDataflowOutput const &o) -> OutputLabel { + return g.at(o); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.h new file mode 100644 index 0000000000..3aa7953446 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H + +#include "utils/containers/contains.h" +#include "utils/containers/filter_keys.h" +#include "utils/containers/restrict_keys.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +LabelledKwargDataflowGraphView + get_labelled_kwarg_dataflow_graph_subgraph( + LabelledKwargDataflowGraphView const + &g, + std::unordered_set const &subgraph_nodes) { + KwargDataflowGraphView unlabelled_subgraph = + get_kwarg_dataflow_graph_subgraph(g, subgraph_nodes); + + std::unordered_map g_node_labelling = + get_labelled_kwarg_dataflow_graph_node_label_map(g); + + std::unordered_map, OutputLabel> + g_output_labelling = + get_labelled_kwarg_dataflow_graph_output_label_map(g); + + return kwarg_dataflow_graph_view_with_labelling( + unlabelled_subgraph, + restrict_keys(g_node_labelling, subgraph_nodes), + filter_keys(g_output_labelling, + [&](KwargDataflowOutput const &o) -> bool { + return contains(subgraph_nodes, o.node); + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h new file mode 100644 index 0000000000..782e63889b --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct KwargDataflowGraphLabellingWrapper final + : public ILabelledKwargDataflowGraphView { +public: + KwargDataflowGraphLabellingWrapper() = delete; + KwargDataflowGraphLabellingWrapper( + KwargDataflowGraphView const &unlabelled, + std::unordered_map const &node_labels, + std::unordered_map, OutputLabel> const + &output_labels) + : unlabelled(unlabelled), node_labels(node_labels), + output_labels(output_labels) {} + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->unlabelled.query_nodes(q); + } + + std::unordered_set> + query_edges(KwargDataflowEdgeQuery const &q) const override { + return this->unlabelled.query_edges(q); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return this->unlabelled.query_outputs(q); + } + + NodeLabel at(Node const &n) const override { + return this->node_labels.at(n); + } + + OutputLabel at(KwargDataflowOutput const &v) const override { + return this->output_labels.at(v); + } + + KwargDataflowGraphLabellingWrapper *clone() const override { + return new KwargDataflowGraphLabellingWrapper{ + this->unlabelled, + this->node_labels, + this->output_labels, + }; + } + +private: + KwargDataflowGraphView unlabelled; + std::unordered_map node_labels; + std::unordered_map, OutputLabel> output_labels; +}; + +template +LabelledKwargDataflowGraphView + kwarg_dataflow_graph_view_with_labelling( + KwargDataflowGraphView const &g, + std::unordered_map const &node_labels, + std::unordered_map, OutputLabel> const + &value_labels) { + return LabelledKwargDataflowGraphView:: + template create< + KwargDataflowGraphLabellingWrapper>( + g, node_labels, value_labels); +} +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..470f2712f8 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "LabelledKwargDataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "NodeLabel", + "ValueLabel", + "SlotName", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::Node, NodeLabel>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::KwargDataflowEdge>" + +[[fields]] +name = "output_data" +type = "std::unordered_map<::FlexFlow::KwargDataflowOutput, ValueLabel>" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h index 4dcde44f4d..178d64933a 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h @@ -11,9 +11,9 @@ namespace FlexFlow { template -struct ViewFromOpenKwargDataflowGraphView final +struct ViewFromOpenKwargDataflowGraphData final : virtual public IOpenKwargDataflowGraphView { - ViewFromOpenKwargDataflowGraphView( + ViewFromOpenKwargDataflowGraphData( OpenKwargDataflowGraphData const &data) : data(data) {} @@ -44,9 +44,9 @@ struct ViewFromOpenKwargDataflowGraphView final }); } - ViewFromOpenKwargDataflowGraphView * + ViewFromOpenKwargDataflowGraphData * clone() const override { - return new ViewFromOpenKwargDataflowGraphView{ + return new ViewFromOpenKwargDataflowGraphData{ this->data}; } @@ -61,7 +61,7 @@ OpenKwargDataflowGraphView require_open_kwarg_dataflow_graph_data_is_valid(data); return OpenKwargDataflowGraphView::template create< - ViewFromOpenKwargDataflowGraphView>(data); + ViewFromOpenKwargDataflowGraphData>(data); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc new file mode 100644 index 0000000000..a17f751e1e --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc @@ -0,0 +1,12 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template KwargDataflowGraphView + get_kwarg_dataflow_graph_subgraph(KwargDataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..60e550f180 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc @@ -0,0 +1,16 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using ValueLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template LabelledKwargDataflowGraphData + get_labelled_kwarg_dataflow_graph_data( + LabelledKwargDataflowGraphView const + &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc new file mode 100644 index 0000000000..dfeeab0aa0 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc @@ -0,0 +1,16 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using OutputLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template std::unordered_map + get_labelled_kwarg_dataflow_graph_node_label_map( + LabelledKwargDataflowGraphView const + &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc new file mode 100644 index 0000000000..bd287b5342 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc @@ -0,0 +1,16 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using OutputLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template std::unordered_map, OutputLabel> + get_labelled_kwarg_dataflow_graph_output_label_map( + LabelledKwargDataflowGraphView const + &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc new file mode 100644 index 0000000000..03d19cd52e --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc @@ -0,0 +1,17 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using OutputLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template LabelledKwargDataflowGraphView + get_labelled_kwarg_dataflow_graph_subgraph( + LabelledKwargDataflowGraphView const + &, + std::unordered_set const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc new file mode 100644 index 0000000000..7901cdef2f --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc @@ -0,0 +1,17 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using OutputLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template LabelledKwargDataflowGraphView + kwarg_dataflow_graph_view_with_labelling( + KwargDataflowGraphView const &, + std::unordered_map const &, + std::unordered_map, OutputLabel> const &); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc new file mode 100644 index 0000000000..8a752e9887 --- /dev/null +++ b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.cc @@ -0,0 +1,103 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_kwarg_dataflow_graph_subgraph") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + + auto mk_edge = [](Node src, + int src_slot, + Node dst, + int dst_slot) -> KwargDataflowEdge { + return KwargDataflowEdge{ + /*src=*/KwargDataflowOutput{ + /*node=*/src, + /*slot_name=*/src_slot, + }, + /*dst=*/ + KwargDataflowInput{ + /*node=*/dst, + /*slot_name=*/dst_slot, + }, + }; + }; + + KwargDataflowGraphData g_data = + KwargDataflowGraphData{/*nodes=*/{n1, n2, n3, n4, n5}, + /*edges=*/ + { + mk_edge(n1, 0, n2, 1), + mk_edge(n1, 0, n3, 0), + mk_edge(n2, 3, n4, 1), + mk_edge(n3, 1, n5, 1), + }, + /*outputs=*/ + { + KwargDataflowOutput{n1, 0}, + KwargDataflowOutput{n2, 3}, + KwargDataflowOutput{n3, 1}, + KwargDataflowOutput{n5, 0}, + }}; + + KwargDataflowGraphView g = view_from_kwarg_dataflow_graph_data(g_data); + + SUBCASE("node set is contains all graph nodes") { + KwargDataflowGraphView result = get_kwarg_dataflow_graph_subgraph( + g, std::unordered_set{n1, n2, n3, n4, n5}); + KwargDataflowGraphData result_data = + get_kwarg_dataflow_graph_data(result); + + KwargDataflowGraphData correct_data = g_data; + + CHECK(result_data == correct_data); + } + + SUBCASE("node set is overlapping") { + KwargDataflowGraphView result = + get_kwarg_dataflow_graph_subgraph(g, std::unordered_set{n2, n3, n5}); + KwargDataflowGraphData result_data = + get_kwarg_dataflow_graph_data(result); + + KwargDataflowGraphData correct_data = KwargDataflowGraphData{ + /*nodes=*/{n2, n3, n5}, + /*edges=*/ + { + mk_edge(n3, 1, n5, 1), + }, + /*outputs=*/ + { + KwargDataflowOutput{n2, 3}, + KwargDataflowOutput{n3, 1}, + KwargDataflowOutput{n5, 0}, + }, + }; + + CHECK(result_data == correct_data); + } + + SUBCASE("node set is non-overlapping") { + KwargDataflowGraphView result = + get_kwarg_dataflow_graph_subgraph(g, std::unordered_set{}); + KwargDataflowGraphData result_data = + get_kwarg_dataflow_graph_data(result); + + KwargDataflowGraphData correct_data = KwargDataflowGraphData{ + /*nodes=*/std::unordered_set{}, + /*edges=*/std::unordered_set>{}, + /*outputs=*/std::unordered_set>{}, + }; + + CHECK(result_data == correct_data); + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..d9da80772e --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.cc @@ -0,0 +1,133 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_labelled_kwarg_dataflow_graph_data") { + LabelledKwargDataflowGraph g = + LabelledKwargDataflowGraph::template create< + UnorderedSetLabelledOpenKwargDataflowGraph>(); + + SUBCASE("graph is empty") { + LabelledKwargDataflowGraphView input = g; + + LabelledKwargDataflowGraphData result = + get_labelled_kwarg_dataflow_graph_data(input); + + LabelledKwargDataflowGraphData correct = + LabelledKwargDataflowGraphData{ + /*node_data=*/std::unordered_map{}, + /*edges=*/std::unordered_set>{}, + /*output_data=*/ + std::unordered_map, float>{}, + }; + + ASSERT(result == correct); + } + + SUBCASE("graph is nonempty") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + std::string n3_label = "n1"; + + float n1_t1_label = 5.3; + float n2_t1_label = 12.1; + float n2_t2_label = 5.3; + float n3_t1_label = 1.7; + + KwargNodeAddedResult n1_added = g.add_node( + /*node_label=*/n1_label, + /*inputs=*/{}, + /*output_labels=*/ + std::unordered_map{ + {2, n1_t1_label}, + }); + Node n1 = n1_added.node; + KwargDataflowOutput n1_t1 = require_only_key(n1_added.outputs, 2); + + KwargNodeAddedResult n2_added = g.add_node( + /*node_label=*/n2_label, + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + }, + /*output_labels=*/ + std::unordered_map{ + {0, n2_t1_label}, + {1, n2_t2_label}, + }); + Node n2 = n2_added.node; + KwargDataflowOutput n2_t1 = n2_added.outputs.at(0); + KwargDataflowOutput n2_t2 = n2_added.outputs.at(1); + + KwargNodeAddedResult n3_added = g.add_node( + /*node_label=*/n3_label, + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + {1, n1_t1}, + {2, n2_t2}, + }, + /*output_labels=*/ + std::unordered_map{ + {4, n3_t1_label}, + }); + Node n3 = n3_added.node; + KwargDataflowOutput n3_t1 = require_only_key(n3_added.outputs, 4); + + LabelledKwargDataflowGraphView input = g; + + LabelledKwargDataflowGraphData result = + get_labelled_kwarg_dataflow_graph_data(input); + + auto mk_edge = [](Node src, + int src_slot, + Node dst, + int dst_slot) -> KwargDataflowEdge { + return KwargDataflowEdge{ + /*src=*/KwargDataflowOutput{ + /*node=*/src, + /*slot_name=*/src_slot, + }, + /*dst=*/ + KwargDataflowInput{ + /*node=*/dst, + /*slot_name=*/dst_slot, + }, + }; + }; + + LabelledKwargDataflowGraphData correct = + LabelledKwargDataflowGraphData{ + /*node_data=*/std::unordered_map{ + {n1, n1_label}, + {n2, n2_label}, + {n3, n3_label}, + }, + /*edges=*/ + std::unordered_set>{ + mk_edge(n1, 2, n2, 3), + mk_edge(n1, 2, n3, 3), + mk_edge(n1, 2, n3, 1), + mk_edge(n2, 1, n3, 2), + }, + /*output_data=*/ + std::unordered_map, float>{ + {n1_t1, n1_t1_label}, + {n2_t1, n2_t1_label}, + {n2_t2, n2_t2_label}, + {n3_t1, n3_t1_label}, + }, + }; + + ASSERT(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc new file mode 100644 index 0000000000..5df26765b0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.cc @@ -0,0 +1,88 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_labelled_kwarg_dataflow_graph_node_label_map") { + LabelledKwargDataflowGraph g = + LabelledKwargDataflowGraph::template create< + UnorderedSetLabelledOpenKwargDataflowGraph>(); + + SUBCASE("graph is empty") { + std::unordered_map result = + get_labelled_kwarg_dataflow_graph_node_label_map( + static_cast< + LabelledKwargDataflowGraphView>(g)); + + std::unordered_map correct = {}; + + CHECK(result == correct); + } + + SUBCASE("graph is non-empty") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + std::string n3_label = "n1"; + + KwargNodeAddedResult n1_added = g.add_node( + /*node_label=*/n1_label, + /*inputs=*/{}, + /*output_labels=*/ + std::unordered_map{ + {2, 5.3}, + }); + Node n1 = n1_added.node; + KwargDataflowOutput n1_t1 = require_only_key(n1_added.outputs, 2); + + KwargNodeAddedResult n2_added = g.add_node( + /*node_label=*/n2_label, + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + }, + /*output_labels=*/ + std::unordered_map{ + {0, 12.1}, + {1, 3.2}, + }); + Node n2 = n2_added.node; + KwargDataflowOutput n2_t1 = n2_added.outputs.at(0); + KwargDataflowOutput n2_t2 = n2_added.outputs.at(1); + + KwargNodeAddedResult n3_added = g.add_node( + /*node_label=*/n3_label, + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + {1, n1_t1}, + {2, n2_t2}, + }, + /*output_labels=*/ + std::unordered_map{ + {4, 1.7}, + }); + Node n3 = n3_added.node; + KwargDataflowOutput n3_t1 = require_only_key(n3_added.outputs, 4); + + std::unordered_map result = + get_labelled_kwarg_dataflow_graph_node_label_map( + static_cast< + LabelledKwargDataflowGraphView>(g)); + + std::unordered_map correct = { + {n1, n1_label}, + {n2, n2_label}, + {n3, n3_label}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc new file mode 100644 index 0000000000..f4a63c2056 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.cc @@ -0,0 +1,90 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_labelled_kwarg_dataflow_graph_output_label_map") { + LabelledKwargDataflowGraph g = + LabelledKwargDataflowGraph::template create< + UnorderedSetLabelledOpenKwargDataflowGraph>(); + + SUBCASE("graph is empty") { + std::unordered_map, float> result = + get_labelled_kwarg_dataflow_graph_output_label_map( + static_cast< + LabelledKwargDataflowGraphView>(g)); + + std::unordered_map, float> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("graph is non-empty") { + float n1_t1_label = 5.3; + float n2_t1_label = 12.1; + float n2_t2_label = 5.3; + float n3_t1_label = 1.7; + + KwargNodeAddedResult n1_added = g.add_node( + /*node_label=*/"n1", + /*inputs=*/{}, + /*output_labels=*/ + std::unordered_map{ + {2, n1_t1_label}, + }); + Node n1 = n1_added.node; + KwargDataflowOutput n1_t1 = require_only_key(n1_added.outputs, 2); + + KwargNodeAddedResult n2_added = g.add_node( + /*node_label=*/"n2", + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + }, + /*output_labels=*/ + std::unordered_map{ + {0, n2_t1_label}, + {1, n2_t2_label}, + }); + Node n2 = n2_added.node; + KwargDataflowOutput n2_t1 = n2_added.outputs.at(0); + KwargDataflowOutput n2_t2 = n2_added.outputs.at(1); + + KwargNodeAddedResult n3_added = g.add_node( + /*node_label=*/"n1", + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + {1, n1_t1}, + {2, n2_t2}, + }, + /*output_labels=*/ + std::unordered_map{ + {4, n3_t1_label}, + }); + Node n3 = n3_added.node; + KwargDataflowOutput n3_t1 = require_only_key(n3_added.outputs, 4); + + std::unordered_map, float> result = + get_labelled_kwarg_dataflow_graph_output_label_map( + static_cast< + LabelledKwargDataflowGraphView>(g)); + + std::unordered_map, float> correct = { + {n1_t1, n1_t1_label}, + {n2_t1, n2_t1_label}, + {n2_t2, n2_t2_label}, + {n3_t1, n3_t1_label}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc new file mode 100644 index 0000000000..0489d36312 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.cc @@ -0,0 +1,185 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_subgraph.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_labelled_kwarg_dataflow_graph_subgraph") { + LabelledKwargDataflowGraph g = + LabelledKwargDataflowGraph::template create< + UnorderedSetLabelledOpenKwargDataflowGraph>(); + + std::string n1_label = "apple"; + std::string n2_label = "banana"; + std::string n3_label = "apple"; + + float n1_t1_label = 5.3; + float n2_t1_label = 12.1; + float n2_t2_label = 5.3; + float n3_t1_label = 1.7; + + KwargNodeAddedResult n1_added = g.add_node( + /*node_label=*/n1_label, + /*inputs=*/{}, + /*output_labels=*/ + std::unordered_map{ + {2, n1_t1_label}, + }); + Node n1 = n1_added.node; + KwargDataflowOutput n1_t1 = require_only_key(n1_added.outputs, 2); + + KwargNodeAddedResult n2_added = g.add_node( + /*node_label=*/n2_label, + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + }, + /*output_labels=*/ + std::unordered_map{ + {0, n2_t1_label}, + {1, n2_t2_label}, + }); + Node n2 = n2_added.node; + KwargDataflowOutput n2_t1 = n2_added.outputs.at(0); + KwargDataflowOutput n2_t2 = n2_added.outputs.at(1); + + KwargNodeAddedResult n3_added = g.add_node( + /*node_label=*/n3_label, + /*inputs=*/ + std::unordered_map>{ + {3, n1_t1}, + {1, n1_t1}, + {2, n2_t2}, + }, + /*output_labels=*/ + std::unordered_map{ + {4, n3_t1_label}, + }); + Node n3 = n3_added.node; + KwargDataflowOutput n3_t1 = require_only_key(n3_added.outputs, 4); + + LabelledKwargDataflowGraphView input = g; + + auto mk_edge = [](Node src, + int src_slot, + Node dst, + int dst_slot) -> KwargDataflowEdge { + return KwargDataflowEdge{ + /*src=*/KwargDataflowOutput{ + /*node=*/src, + /*slot_name=*/src_slot, + }, + /*dst=*/ + KwargDataflowInput{ + /*node=*/dst, + /*slot_name=*/dst_slot, + }, + }; + }; + + SUBCASE("node set includes all graph nodes") { + LabelledKwargDataflowGraphView result = + get_labelled_kwarg_dataflow_graph_subgraph( + input, std::unordered_set{n1, n2, n3}); + LabelledKwargDataflowGraphData result_data = + get_labelled_kwarg_dataflow_graph_data(result); + + LabelledKwargDataflowGraphData correct_data = + LabelledKwargDataflowGraphData{ + /*node_data=*/{ + {n1, n1_label}, + {n2, n2_label}, + {n3, n3_label}, + }, + /*edges=*/ + { + mk_edge(n1, 2, n2, 3), + mk_edge(n1, 2, n3, 3), + mk_edge(n1, 2, n3, 1), + mk_edge(n2, 1, n3, 2), + }, + /*output_data=*/ + { + {n1_t1, n1_t1_label}, + {n2_t1, n2_t1_label}, + {n2_t2, n2_t2_label}, + {n3_t1, n3_t1_label}, + }, + }; + + CHECK(result_data == correct_data); + } + + SUBCASE("node set includes only some graph nodes") { + LabelledKwargDataflowGraphView result = + get_labelled_kwarg_dataflow_graph_subgraph( + input, std::unordered_set{n2, n3}); + LabelledKwargDataflowGraphData result_data = + get_labelled_kwarg_dataflow_graph_data(result); + + LabelledKwargDataflowGraphData correct_data = + LabelledKwargDataflowGraphData{ + /*node_data=*/{ + {n2, n2_label}, + {n3, n3_label}, + }, + /*edges=*/ + { + mk_edge(n2, 1, n3, 2), + }, + /*output_data=*/ + { + {n2_t1, n2_t1_label}, + {n2_t2, n2_t2_label}, + {n3_t1, n3_t1_label}, + }, + }; + + CHECK(result_data == correct_data); + } + + SUBCASE("node set includes no graph nodes") { + LabelledKwargDataflowGraphView result = + get_labelled_kwarg_dataflow_graph_subgraph( + input, std::unordered_set{}); + LabelledKwargDataflowGraphData result_data = + get_labelled_kwarg_dataflow_graph_data(result); + + LabelledKwargDataflowGraphData correct_data = + LabelledKwargDataflowGraphData{ + /*node_data=*/std::unordered_map{}, + /*edges=*/std::unordered_set>{}, + /*output_data=*/ + std::unordered_map, float>{}, + }; + + CHECK(result_data == correct_data); + } + + SUBCASE("node set includes nodes not in graph") { + LabelledKwargDataflowGraphView + with_invalid_node = get_labelled_kwarg_dataflow_graph_subgraph( + input, std::unordered_set{n2, n3, Node{100}}); + LabelledKwargDataflowGraphData + with_invalid_node_data = + get_labelled_kwarg_dataflow_graph_data(with_invalid_node); + + LabelledKwargDataflowGraphView + without_invalid_node = get_labelled_kwarg_dataflow_graph_subgraph( + input, std::unordered_set{n2, n3}); + LabelledKwargDataflowGraphData + without_invalid_node_data = + get_labelled_kwarg_dataflow_graph_data(without_invalid_node); + + CHECK(with_invalid_node_data == without_invalid_node_data); + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc new file mode 100644 index 0000000000..0cb75bb5da --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.cc @@ -0,0 +1,112 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("kwarg_dataflow_graph_view_with_labelling") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + + auto mk_edge = [](Node src, + int src_slot, + Node dst, + int dst_slot) -> KwargDataflowEdge { + return KwargDataflowEdge{ + /*src=*/KwargDataflowOutput{ + /*node=*/src, + /*slot_name=*/src_slot, + }, + /*dst=*/ + KwargDataflowInput{ + /*node=*/dst, + /*slot_name=*/dst_slot, + }, + }; + }; + + KwargDataflowOutput n1_0 = KwargDataflowOutput{n1, 0}; + KwargDataflowOutput n2_3 = KwargDataflowOutput{n2, 3}; + KwargDataflowOutput n3_1 = KwargDataflowOutput{n3, 1}; + KwargDataflowOutput n5_0 = KwargDataflowOutput{n5, 0}; + + KwargDataflowEdge e_n1_0_n2_1 = mk_edge(n1, 0, n2, 1); + KwargDataflowEdge e_n1_0_n3_0 = mk_edge(n1, 0, n3, 0); + KwargDataflowEdge e_n2_3_n4_1 = mk_edge(n2, 3, n4, 1); + KwargDataflowEdge e_n3_1_n5_1 = mk_edge(n3, 1, n5, 1); + + KwargDataflowGraphData g_data = + KwargDataflowGraphData{/*nodes=*/{n1, n2, n3, n4, n5}, + /*edges=*/ + { + e_n1_0_n2_1, + e_n1_0_n3_0, + e_n2_3_n4_1, + e_n3_1_n5_1, + }, + /*outputs=*/ + { + n1_0, + n2_3, + n3_1, + n5_0, + }}; + + KwargDataflowGraphView g = view_from_kwarg_dataflow_graph_data(g_data); + + float n1_label = 3.5; + float n2_label = 1.2; + float n3_label = 1.2; + float n4_label = 7.8; + float n5_label = 2.2; + + std::unordered_map node_labelling = { + {n1, 3.5}, + {n2, 1.2}, + {n3, 1.2}, + {n4, 7.8}, + {n5, 2.2}, + }; + + std::string n1_0_label = "a"; + std::string n2_3_label = "b"; + std::string n3_1_label = "c"; + std::string n5_0_label = "d"; + + std::unordered_map, std::string> value_labelling = + { + {n1_0, n1_0_label}, + {n2_3, n2_3_label}, + {n3_1, n3_1_label}, + {n5_0, n5_0_label}, + }; + + LabelledKwargDataflowGraphView result = + kwarg_dataflow_graph_view_with_labelling( + g, node_labelling, value_labelling); + LabelledKwargDataflowGraphData result_data = + get_labelled_kwarg_dataflow_graph_data(result); + + LabelledKwargDataflowGraphData correct_data = + LabelledKwargDataflowGraphData{ + /*node_data=*/node_labelling, + /*edges=*/ + { + e_n1_0_n2_1, + e_n1_0_n3_0, + e_n2_3_n4_1, + e_n3_1_n5_1, + }, + /*output_data=*/value_labelling, + }; + + CHECK(result_data == correct_data); + } +} diff --git a/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.cc b/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..f3d5e83cc2 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.cc @@ -0,0 +1,122 @@ +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("view_from_open_kwarg_dataflow_graph_data") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + + auto mk_input_edge = + [](KwargDataflowGraphInput const &src, + Node dst, + int dst_slot) -> OpenKwargDataflowEdge { + return OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ + /*src=*/src, + /*dst=*/ + KwargDataflowInput{ + /*node=*/dst, + /*slot_name=*/dst_slot, + }, + }, + }; + }; + + auto mk_internal_edge = [](Node src, int src_slot, Node dst, int dst_slot) + -> OpenKwargDataflowEdge { + return OpenKwargDataflowEdge{ + KwargDataflowEdge{ + /*src=*/KwargDataflowOutput{ + /*node=*/src, + /*slot_name=*/src_slot, + }, + /*dst=*/ + KwargDataflowInput{ + /*node=*/dst, + /*slot_name=*/dst_slot, + }, + }, + }; + }; + + KwargDataflowOutput n1_0 = KwargDataflowOutput{n1, 0}; + KwargDataflowOutput n1_3 = KwargDataflowOutput{n1, 3}; + KwargDataflowOutput n1_4 = KwargDataflowOutput{n1, 4}; + KwargDataflowOutput n2_3 = KwargDataflowOutput{n2, 3}; + KwargDataflowOutput n3_1 = KwargDataflowOutput{n3, 1}; + KwargDataflowOutput n5_0 = KwargDataflowOutput{n5, 0}; + KwargDataflowOutput n5_4 = KwargDataflowOutput{n5, 4}; + + OpenKwargDataflowEdge e_n1_0_n2_1 = + mk_internal_edge(n1, 0, n2, 1); + OpenKwargDataflowEdge e_n1_0_n3_0 = + mk_internal_edge(n1, 0, n3, 0); + OpenKwargDataflowEdge e_n2_3_n4_1 = + mk_internal_edge(n2, 3, n4, 1); + OpenKwargDataflowEdge e_n3_1_n5_1 = + mk_internal_edge(n3, 1, n5, 1); + + KwargDataflowGraphInput i1 = + KwargDataflowGraphInput{ + /*name=*/"a", + }; + KwargDataflowGraphInput i2 = + KwargDataflowGraphInput{ + /*name=*/"b", + }; + + OpenKwargDataflowEdge e_i1_n1_3 = + mk_input_edge(i1, n1, 3); + OpenKwargDataflowEdge e_i1_n1_4 = + mk_input_edge(i1, n1, 4); + OpenKwargDataflowEdge e_i2_n5_4 = + mk_input_edge(i2, n5, 4); + + OpenKwargDataflowGraphData input = + OpenKwargDataflowGraphData{ + /*nodes=*/{n1, n2, n3, n4, n5}, + /*edges=*/ + { + e_n1_0_n2_1, + e_n1_0_n3_0, + e_n2_3_n4_1, + e_n3_1_n5_1, + e_i1_n1_3, + e_i1_n1_4, + e_i2_n5_4, + }, + /*inputs=*/ + { + i1, + i2, + }, + /*outputs=*/ + { + n1_0, + n1_3, + n1_4, + n2_3, + n3_1, + n5_0, + n5_4, + }}; + + OpenKwargDataflowGraphView result = + view_from_open_kwarg_dataflow_graph_data(input); + OpenKwargDataflowGraphData result_data = + get_open_kwarg_dataflow_graph_data(result); + + OpenKwargDataflowGraphData correct_data = input; + + CHECK(result_data == correct_data); + } +}