Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions src/taskgraph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@


import collections
import functools
from dataclasses import dataclass


@dataclass(frozen=True)
class Graph:
class _Graph:
nodes: frozenset
edges: frozenset


class Graph(_Graph):
"""Generic representation of a directed acyclic graph with labeled edges
connecting the nodes. Graph operations are implemented in a functional
manner, so the data structure is immutable.
Expand All @@ -23,8 +29,8 @@ class Graph:
node `left` to node `right`..
"""

nodes: frozenset
edges: frozenset
def __init__(self, nodes, edges):
super().__init__(frozenset(nodes), frozenset(edges))

def transitive_closure(self, nodes, reverse=False):
"""Return the transitive closure of <nodes>: the graph containing all
Expand Down Expand Up @@ -67,23 +73,28 @@ def transitive_closure(self, nodes, reverse=False):
add_nodes = {(left if reverse else right) for (left, right, _) in add_edges}
new_nodes = nodes | add_nodes
new_edges = edges | add_edges
return Graph(new_nodes, new_edges) # type: ignore
return Graph(new_nodes, new_edges)

def _visit(self, reverse):
queue = collections.deque(sorted(self.nodes))
links_by_node = self.reverse_links_dict() if reverse else self.links_dict()
seen = set()
forward_links, reverse_links = self.links_and_reverse_links_dict()

dependencies = reverse_links if reverse else forward_links
dependents = forward_links if reverse else reverse_links

indegree = {node: len(dependencies[node]) for node in self.nodes}

queue = collections.deque(
node for node, degree in indegree.items() if degree == 0
)

while queue:
node = queue.popleft()
if node in seen:
continue
links = links_by_node[node]
if all((n in seen) for n in links):
seen.add(node)
yield node
else:
queue.extend(n for n in links if n not in seen)
queue.append(node)
yield node

for dependent in dependents[node]:
indegree[dependent] -= 1
if indegree[dependent] == 0:
queue.append(dependent)

def visit_postorder(self):
"""
Expand All @@ -102,6 +113,21 @@ def visit_preorder(self):
"""
return self._visit(True)

@functools.cache
def links_and_reverse_links_dict(self):
"""
Return both links and reverse_links dictionaries.
Returns a (forward_links, reverse_links) tuple where forward_links maps
each node to the set of nodes it links to, and reverse_links maps each
node to the set of nodes linking to it.
"""
forward = collections.defaultdict(set)
reverse = collections.defaultdict(set)
for left, right, _ in self.edges:
forward[left].add(right)
reverse[right].add(left)
return (forward, reverse)

def links_dict(self):
"""
Return a dictionary mapping each node to a set of the nodes it links to
Expand Down
2 changes: 1 addition & 1 deletion src/taskgraph/morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def amend_taskgraph(taskgraph, label_to_taskid, to_add):
for depname, dep in task.dependencies.items():
new_edges.add((task.task_id, dep, depname))

taskgraph = TaskGraph(new_tasks, Graph(set(new_tasks), new_edges)) # type: ignore
taskgraph = TaskGraph(new_tasks, Graph(frozenset(new_tasks), new_edges))
return taskgraph, label_to_taskid


Expand Down
4 changes: 3 additions & 1 deletion src/taskgraph/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def get_subgraph(
if left in tasks_by_taskid and right in tasks_by_taskid
}

return TaskGraph(tasks_by_taskid, Graph(set(tasks_by_taskid), edges_by_taskid)) # type: ignore
return TaskGraph(
tasks_by_taskid, Graph(frozenset(tasks_by_taskid), edges_by_taskid)
)


@register_strategy("never")
Expand Down
2 changes: 1 addition & 1 deletion src/taskgraph/taskgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ def from_json(cls, tasks_dict):
tasks[key].task_id = value["task_id"]
for depname, dep in value["dependencies"].items():
edges.add((key, dep, depname))
task_graph = cls(tasks, Graph(set(tasks), edges)) # type: ignore
task_graph = cls(tasks, Graph(frozenset(tasks), edges))
return tasks, task_graph
Loading