diff --git a/src/taskgraph/graph.py b/src/taskgraph/graph.py index c0c089035..9951c696f 100644 --- a/src/taskgraph/graph.py +++ b/src/taskgraph/graph.py @@ -4,11 +4,17 @@ import collections +import functools from dataclasses import dataclass @dataclass(frozen=True) -class Graph: +class _Graph: + nodes: frozenset + edges: frozenset + + +class Graph(_Graph): """Generic representation of a directed acyclic graph with labeled edges connecting the nodes. Graph operations are implemented in a functional manner, so the data structure is immutable. @@ -23,8 +29,8 @@ class Graph: node `left` to node `right`.. """ - nodes: frozenset - edges: frozenset + def __init__(self, nodes, edges): + super().__init__(frozenset(nodes), frozenset(edges)) def transitive_closure(self, nodes, reverse=False): """Return the transitive closure of : the graph containing all @@ -67,23 +73,28 @@ def transitive_closure(self, nodes, reverse=False): add_nodes = {(left if reverse else right) for (left, right, _) in add_edges} new_nodes = nodes | add_nodes new_edges = edges | add_edges - return Graph(new_nodes, new_edges) # type: ignore + return Graph(new_nodes, new_edges) def _visit(self, reverse): - queue = collections.deque(sorted(self.nodes)) - links_by_node = self.reverse_links_dict() if reverse else self.links_dict() - seen = set() + forward_links, reverse_links = self.links_and_reverse_links_dict() + + dependencies = reverse_links if reverse else forward_links + dependents = forward_links if reverse else reverse_links + + indegree = {node: len(dependencies[node]) for node in self.nodes} + + queue = collections.deque( + node for node, degree in indegree.items() if degree == 0 + ) + while queue: node = queue.popleft() - if node in seen: - continue - links = links_by_node[node] - if all((n in seen) for n in links): - seen.add(node) - yield node - else: - queue.extend(n for n in links if n not in seen) - queue.append(node) + yield node + + for dependent in dependents[node]: + indegree[dependent] -= 1 + if indegree[dependent] == 0: + queue.append(dependent) def visit_postorder(self): """ @@ -102,6 +113,21 @@ def visit_preorder(self): """ return self._visit(True) + @functools.cache + def links_and_reverse_links_dict(self): + """ + Return both links and reverse_links dictionaries. + Returns a (forward_links, reverse_links) tuple where forward_links maps + each node to the set of nodes it links to, and reverse_links maps each + node to the set of nodes linking to it. + """ + forward = collections.defaultdict(set) + reverse = collections.defaultdict(set) + for left, right, _ in self.edges: + forward[left].add(right) + reverse[right].add(left) + return (forward, reverse) + def links_dict(self): """ Return a dictionary mapping each node to a set of the nodes it links to diff --git a/src/taskgraph/morph.py b/src/taskgraph/morph.py index 01af7778f..357741f24 100644 --- a/src/taskgraph/morph.py +++ b/src/taskgraph/morph.py @@ -51,7 +51,7 @@ def amend_taskgraph(taskgraph, label_to_taskid, to_add): for depname, dep in task.dependencies.items(): new_edges.add((task.task_id, dep, depname)) - taskgraph = TaskGraph(new_tasks, Graph(set(new_tasks), new_edges)) # type: ignore + taskgraph = TaskGraph(new_tasks, Graph(frozenset(new_tasks), new_edges)) return taskgraph, label_to_taskid diff --git a/src/taskgraph/optimize/base.py b/src/taskgraph/optimize/base.py index 34d54f503..3d2776a87 100644 --- a/src/taskgraph/optimize/base.py +++ b/src/taskgraph/optimize/base.py @@ -451,7 +451,9 @@ def get_subgraph( if left in tasks_by_taskid and right in tasks_by_taskid } - return TaskGraph(tasks_by_taskid, Graph(set(tasks_by_taskid), edges_by_taskid)) # type: ignore + return TaskGraph( + tasks_by_taskid, Graph(frozenset(tasks_by_taskid), edges_by_taskid) + ) @register_strategy("never") diff --git a/src/taskgraph/taskgraph.py b/src/taskgraph/taskgraph.py index a94f615a2..839fcb385 100644 --- a/src/taskgraph/taskgraph.py +++ b/src/taskgraph/taskgraph.py @@ -67,5 +67,5 @@ def from_json(cls, tasks_dict): tasks[key].task_id = value["task_id"] for depname, dep in value["dependencies"].items(): edges.add((key, dep, depname)) - task_graph = cls(tasks, Graph(set(tasks), edges)) # type: ignore + task_graph = cls(tasks, Graph(frozenset(tasks), edges)) return tasks, task_graph