diff --git a/src/taskgraph/generator.py b/src/taskgraph/generator.py index 14142eb7b..66c228deb 100644 --- a/src/taskgraph/generator.py +++ b/src/taskgraph/generator.py @@ -2,9 +2,18 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +from collections import defaultdict import copy +from itertools import chain import logging import os +from concurrent.futures import ( + ALL_COMPLETED, + FIRST_COMPLETED, + ProcessPoolExecutor, + as_completed, + wait, +) from dataclasses import dataclass from typing import Callable, Dict, Optional, Union @@ -44,16 +53,20 @@ def _get_loader(self): loader = "taskgraph.loader.default:loader" return find_object(loader) - def load_tasks(self, parameters, loaded_tasks, write_artifacts): + def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts): + logger.debug(f"Loading tasks for kind {self.name}") + + parameters = Parameters(**parameters) loader = self._get_loader() config = copy.deepcopy(self.config) - kind_dependencies = config.get("kind-dependencies", []) - kind_dependencies_tasks = { - task.label: task for task in loaded_tasks if task.kind in kind_dependencies - } - - inputs = loader(self.name, self.path, config, parameters, loaded_tasks) + inputs = loader( + self.name, + self.path, + config, + parameters, + list(kind_dependencies_tasks.values()), + ) transforms = TransformSequence() for xform_path in config["transforms"]: @@ -87,6 +100,7 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts): ) for task_dict in transforms(trans_config, inputs) ] + logger.info(f"Generated {len(tasks)} tasks for kind {self.name}") return tasks @classmethod @@ -249,6 +263,59 @@ def _load_kinds(self, graph_config, target_kinds=None): except KindNotFound: continue + def _load_tasks(self, kinds, kind_graph, parameters): + all_tasks = {} + futures_to_kind = {} + futures = set() + edges = set(kind_graph.edges) + + def add_new_tasks(future): + for task in future.result(): + if task.label in all_tasks: + raise Exception("duplicate tasks with label " + task.label) + all_tasks[task.label] = task + + with ProcessPoolExecutor() as executor: + + def submit_ready_kinds(): + """Create the next batch of tasks for kinds without dependencies.""" + nonlocal kinds, edges, futures + loaded_tasks = all_tasks.copy() + kinds_with_deps = {edge[0] for edge in edges} + ready_kinds = ( + set(kinds) - kinds_with_deps - set(futures_to_kind.values()) + ) + for name in ready_kinds: + kind = kinds[name] + future = executor.submit( + kind.load_tasks, + dict(parameters), + { + k: t + for k, t in loaded_tasks.items() + if t.kind in kind.config.get("kind-dependencies", []) + }, + self._write_artifacts, + ) + future.add_done_callback(add_new_tasks) + futures.add(future) + futures_to_kind[future] = name + + submit_ready_kinds() + while futures: + for future in as_completed(futures): + kind = futures_to_kind.pop(future) + futures.remove(future) + + # Update state for next batch of futures. + del kinds[kind] + edges = {e for e in edges if e[1] != kind} + + # Submit any newly unblocked kinds + submit_ready_kinds() + + return all_tasks + def _run(self): logger.info("Loading graph configuration.") graph_config = load_graph_config(self.root_dir) @@ -303,24 +370,8 @@ def _run(self): ) logger.info("Generating full task set") - all_tasks = {} - for kind_name in kind_graph.visit_postorder(): - logger.debug(f"Loading tasks for kind {kind_name}") - kind = kinds[kind_name] - try: - new_tasks = kind.load_tasks( - parameters, - list(all_tasks.values()), - self._write_artifacts, - ) - except Exception: - logger.exception(f"Error loading tasks for kind {kind_name}:") - raise - for task in new_tasks: - if task.label in all_tasks: - raise Exception("duplicate tasks with label " + task.label) - all_tasks[task.label] = task - logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}") + all_tasks = self._load_tasks(kinds, kind_graph, parameters) + full_task_set = TaskGraph(all_tasks, Graph(frozenset(all_tasks), frozenset())) yield self.verify("full_task_set", full_task_set, graph_config, parameters)