Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "utils/containers/transform.h"
#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h"
#include "utils/nonempty_unordered_set/nonempty_unordered_set.h"
#include "utils/one_to_many/one_to_many_transform_values.h"

namespace FlexFlow {
Expand All @@ -26,30 +25,40 @@ DataflowGraphData dataflow_graph_data_from_kwarg_dataflow_graph_data(
std::unordered_set<KwargDataflowOutput<SlotName>> all_outputs =
kwarg_data.outputs;

OneToMany<Node, SlotName> incoming_slots_by_node =
one_to_many_transform_values(
std::unordered_map<Node, std::unordered_set<SlotName>>
incoming_slots_by_node = map_values(
group_by(all_inputs,
[](KwargDataflowInput<SlotName> const &i) -> Node {
return i.node;
}),
[](KwargDataflowInput<SlotName> const &i) -> SlotName {
return i.slot_name;
})
.l_to_r(),
[](nonempty_unordered_set<KwargDataflowInput<SlotName>> const &is)
-> std::unordered_set<SlotName> {
return transform(is.unwrap_as_unordered_set(),
[](KwargDataflowInput<SlotName> const &i) {
return i.slot_name;
});
});

OneToMany<Node, SlotName> outgoing_slots_by_node =
one_to_many_transform_values(
std::unordered_map<Node, std::unordered_set<SlotName>>
outgoing_slots_by_node = map_values(
group_by(all_outputs,
[](KwargDataflowOutput<SlotName> const &o) -> Node {
return o.node;
}),
[](KwargDataflowOutput<SlotName> const &o) -> SlotName {
return o.slot_name;
})
.l_to_r(),
[](nonempty_unordered_set<KwargDataflowOutput<SlotName>> const &os)
-> std::unordered_set<SlotName> {
return transform(os.unwrap_as_unordered_set(),
[](KwargDataflowOutput<SlotName> const &o) {
return o.slot_name;
});
});

auto dataflow_input_from_kwarg_input =
[&](KwargDataflowInput<SlotName> const &i) -> DataflowInput {
std::vector<SlotName> slot_ordering = order_slots(
incoming_slots_by_node.at_l(i.node).unwrap_as_unordered_set());
std::vector<SlotName> slot_ordering =
order_slots(incoming_slots_by_node.at(i.node));

return DataflowInput{
/*node=*/i.node,
Expand All @@ -62,8 +71,8 @@ DataflowGraphData dataflow_graph_data_from_kwarg_dataflow_graph_data(

auto dataflow_output_from_kwarg_output =
[&](KwargDataflowOutput<SlotName> const &o) -> DataflowOutput {
std::vector<SlotName> slot_ordering = order_slots(
outgoing_slots_by_node.at_l(o.node).unwrap_as_unordered_set());
std::vector<SlotName> slot_ordering =
order_slots(outgoing_slots_by_node.at(o.node));

return DataflowOutput{
/*node=*/o.node,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H

#include "utils/containers/set_intersection.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h"
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h"

namespace FlexFlow {

template <typename SlotName>
KwargDataflowGraphView<SlotName> get_kwarg_dataflow_graph_subgraph(
KwargDataflowGraphView<SlotName> const &g,
std::unordered_set<Node> const &subgraph_nodes) {
KwargDataflowGraphData<SlotName> g_data = get_kwarg_dataflow_graph_data(g);

std::unordered_set<Node> nodes =
set_intersection(g_data.nodes, subgraph_nodes);

std::unordered_set<KwargDataflowEdge<SlotName>> edges =
filter(g_data.edges, [&](KwargDataflowEdge<SlotName> const &e) -> bool {
return contains(subgraph_nodes, e.src.node) &&
contains(subgraph_nodes, e.dst.node);
});

std::unordered_set<KwargDataflowOutput<SlotName>> outputs = filter(
g_data.outputs, [&](KwargDataflowOutput<SlotName> const &o) -> bool {
return contains(subgraph_nodes, o.node);
});

KwargDataflowGraphData<SlotName> subgraph_data =
KwargDataflowGraphData<SlotName>{
/*nodes=*/nodes,
/*edges=*/edges,
/*outputs=*/outputs,
};

return view_from_kwarg_dataflow_graph_data(subgraph_data);
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_DATA_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_DATA_H

#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_data.dtg.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"

namespace FlexFlow {

template <typename NodeLabel, typename ValueLabel, typename SlotName>
LabelledKwargDataflowGraphData<NodeLabel, ValueLabel, SlotName>
get_labelled_kwarg_dataflow_graph_data(
LabelledKwargDataflowGraphView<NodeLabel, ValueLabel, SlotName> const
&g) {
return LabelledKwargDataflowGraphData<NodeLabel, ValueLabel, SlotName>{
/*node_data=*/get_labelled_kwarg_dataflow_graph_node_label_map(g),
/*edges=*/get_all_kwarg_dataflow_edges(g),
/*output_data=*/get_labelled_kwarg_dataflow_graph_output_label_map(g),
};
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_NODE_LABEL_MAP_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_NODE_LABEL_MAP_H

#include "utils/containers/generate_map.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"
#include "utils/graph/node/algorithms.h"

namespace FlexFlow {

template <typename NodeLabel, typename OutputLabel, typename SlotName>
std::unordered_map<Node, NodeLabel>
get_labelled_kwarg_dataflow_graph_node_label_map(
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName> const
&g) {
return generate_map(get_nodes(g),
[&](Node const &n) -> NodeLabel { return g.at(n); });
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_OUTPUT_LABEL_MAP_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_OUTPUT_LABEL_MAP_H

#include "utils/containers/generate_map.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"

namespace FlexFlow {

template <typename NodeLabel, typename OutputLabel, typename SlotName>
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel>
get_labelled_kwarg_dataflow_graph_output_label_map(
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName> const
&g) {
return generate_map(
get_all_kwarg_dataflow_outputs(g),
[&](KwargDataflowOutput<SlotName> const &o) -> OutputLabel {
return g.at(o);
});
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H

#include "utils/containers/contains.h"
#include "utils/containers/filter_keys.h"
#include "utils/containers/restrict_keys.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_node_label_map.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/get_labelled_kwarg_dataflow_graph_output_label_map.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h"
#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h"

namespace FlexFlow {

template <typename NodeLabel, typename OutputLabel, typename SlotName>
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName>
get_labelled_kwarg_dataflow_graph_subgraph(
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName> const
&g,
std::unordered_set<Node> const &subgraph_nodes) {
KwargDataflowGraphView<SlotName> unlabelled_subgraph =
get_kwarg_dataflow_graph_subgraph(g, subgraph_nodes);

std::unordered_map<Node, NodeLabel> g_node_labelling =
get_labelled_kwarg_dataflow_graph_node_label_map(g);

std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel>
g_output_labelling =
get_labelled_kwarg_dataflow_graph_output_label_map(g);

return kwarg_dataflow_graph_view_with_labelling(
unlabelled_subgraph,
restrict_keys(g_node_labelling, subgraph_nodes),
filter_keys(g_output_labelling,
[&](KwargDataflowOutput<SlotName> const &o) -> bool {
return contains(subgraph_nodes, o.node);
}));
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H

#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"

namespace FlexFlow {

template <typename NodeLabel, typename OutputLabel, typename SlotName>
struct KwargDataflowGraphLabellingWrapper final
: public ILabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName> {
public:
KwargDataflowGraphLabellingWrapper() = delete;
KwargDataflowGraphLabellingWrapper(
KwargDataflowGraphView<SlotName> const &unlabelled,
std::unordered_map<Node, NodeLabel> const &node_labels,
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel> const
&output_labels)
: unlabelled(unlabelled), node_labels(node_labels),
output_labels(output_labels) {}

std::unordered_set<Node> query_nodes(NodeQuery const &q) const override {
return this->unlabelled.query_nodes(q);
}

std::unordered_set<KwargDataflowEdge<SlotName>>
query_edges(KwargDataflowEdgeQuery<SlotName> const &q) const override {
return this->unlabelled.query_edges(q);
}

std::unordered_set<KwargDataflowOutput<SlotName>> query_outputs(
KwargDataflowOutputQuery<SlotName> const &q) const override {
return this->unlabelled.query_outputs(q);
}

NodeLabel at(Node const &n) const override {
return this->node_labels.at(n);
}

OutputLabel at(KwargDataflowOutput<SlotName> const &v) const override {
return this->output_labels.at(v);
}

KwargDataflowGraphLabellingWrapper *clone() const override {
return new KwargDataflowGraphLabellingWrapper{
this->unlabelled,
this->node_labels,
this->output_labels,
};
}

private:
KwargDataflowGraphView<SlotName> unlabelled;
std::unordered_map<Node, NodeLabel> node_labels;
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel> output_labels;
};

template <typename NodeLabel, typename OutputLabel, typename SlotName>
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName>
kwarg_dataflow_graph_view_with_labelling(
KwargDataflowGraphView<SlotName> const &g,
std::unordered_map<Node, NodeLabel> const &node_labels,
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel> const
&value_labels) {
return LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName>::
template create<
KwargDataflowGraphLabellingWrapper<NodeLabel, OutputLabel, SlotName>>(
g, node_labels, value_labels);
}
} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
namespace = "FlexFlow"
name = "LabelledKwargDataflowGraphData"
type = "struct"
features = [
"eq",
"hash",
"fmt",
]

template_params = [
"NodeLabel",
"ValueLabel",
"SlotName",
]

includes = [
"utils/graph/node/node.dtg.h",
"utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h",
"utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h",
"<unordered_map>",
"<unordered_set>",
]

src_includes = [
"utils/hash/unordered_map.h",
"utils/hash/unordered_set.h",
"utils/fmt/unordered_map.h",
"utils/fmt/unordered_set.h",
]

[[fields]]
name = "node_data"
type = "std::unordered_map<::FlexFlow::Node, NodeLabel>"

[[fields]]
name = "edges"
type = "std::unordered_set<::FlexFlow::KwargDataflowEdge<SlotName>>"

[[fields]]
name = "output_data"
type = "std::unordered_map<::FlexFlow::KwargDataflowOutput<SlotName>, ValueLabel>"
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
namespace FlexFlow {

template <typename GraphInputName, typename SlotName>
struct ViewFromOpenKwargDataflowGraphView final
struct ViewFromOpenKwargDataflowGraphData final
: virtual public IOpenKwargDataflowGraphView<GraphInputName, SlotName> {
ViewFromOpenKwargDataflowGraphView(
ViewFromOpenKwargDataflowGraphData(
OpenKwargDataflowGraphData<GraphInputName, SlotName> const &data)
: data(data) {}

Expand Down Expand Up @@ -44,9 +44,9 @@ struct ViewFromOpenKwargDataflowGraphView final
});
}

ViewFromOpenKwargDataflowGraphView<GraphInputName, SlotName> *
ViewFromOpenKwargDataflowGraphData<GraphInputName, SlotName> *
clone() const override {
return new ViewFromOpenKwargDataflowGraphView<GraphInputName, SlotName>{
return new ViewFromOpenKwargDataflowGraphData<GraphInputName, SlotName>{
this->data};
}

Expand All @@ -61,7 +61,7 @@ OpenKwargDataflowGraphView<GraphInputName, SlotName>
require_open_kwarg_dataflow_graph_data_is_valid(data);

return OpenKwargDataflowGraphView<GraphInputName, SlotName>::template create<
ViewFromOpenKwargDataflowGraphView<GraphInputName, SlotName>>(data);
ViewFromOpenKwargDataflowGraphData<GraphInputName, SlotName>>(data);
}

} // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_subgraph.h"
#include "utils/archetypes/ordered_value_type.h"

namespace FlexFlow {

using SlotName = ordered_value_type<0>;

template KwargDataflowGraphView<SlotName>
get_kwarg_dataflow_graph_subgraph(KwargDataflowGraphView<SlotName> const &,
std::unordered_set<Node> const &);

} // namespace FlexFlow
Loading
Loading