diff --git a/packages/pytest-taskgraph/src/pytest_taskgraph/fixtures/gen.py b/packages/pytest-taskgraph/src/pytest_taskgraph/fixtures/gen.py index 8318f8c91..718f34a50 100644 --- a/packages/pytest-taskgraph/src/pytest_taskgraph/fixtures/gen.py +++ b/packages/pytest-taskgraph/src/pytest_taskgraph/fixtures/gen.py @@ -17,7 +17,7 @@ here = Path(__file__).parent -def fake_loader(kind, path, config, parameters, loaded_tasks): +def fake_loader(kind, path, config, parameters, loaded_tasks, write_artifacts): for i in range(3): dependencies = {} if i >= 1: diff --git a/src/taskgraph/generator.py b/src/taskgraph/generator.py index 9b794ee48..f1a427306 100644 --- a/src/taskgraph/generator.py +++ b/src/taskgraph/generator.py @@ -3,6 +3,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. import copy +import inspect import logging import multiprocessing import os @@ -60,12 +61,17 @@ def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts): loader = self._get_loader() config = copy.deepcopy(self.config) + if "write_artifacts" in inspect.signature(loader).parameters: + extra_args = (write_artifacts,) + else: + extra_args = () inputs = loader( self.name, self.path, config, parameters, list(kind_dependencies_tasks.values()), + *extra_args, ) transforms = TransformSequence() diff --git a/src/taskgraph/loader/default.py b/src/taskgraph/loader/default.py index f060a1d92..3fe74f621 100644 --- a/src/taskgraph/loader/default.py +++ b/src/taskgraph/loader/default.py @@ -16,7 +16,7 @@ ] -def loader(kind, path, config, params, loaded_tasks): +def loader(kind, path, config, params, loaded_tasks, write_artifacts): """ This default loader builds on the `transform` loader by providing sensible default transforms that the majority of simple tasks will need. @@ -30,4 +30,4 @@ def loader(kind, path, config, params, loaded_tasks): f"Transform {t} is already present in the loader's default transforms; it must not be defined in the kind" ) transform_refs.extend(DEFAULT_TRANSFORMS) - return transform_loader(kind, path, config, params, loaded_tasks) + return transform_loader(kind, path, config, params, loaded_tasks, write_artifacts) diff --git a/src/taskgraph/loader/transform.py b/src/taskgraph/loader/transform.py index a134ffd12..fe5a99ef9 100644 --- a/src/taskgraph/loader/transform.py +++ b/src/taskgraph/loader/transform.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def loader(kind, path, config, params, loaded_tasks): +def loader(kind, path, config, params, loaded_tasks, write_artifacts): """ Get the input elements that will be transformed into tasks in a generic way. The elements themselves are free-form, and become the input to the diff --git a/test/test_generator.py b/test/test_generator.py index 701e5e02f..39a587c91 100644 --- a/test/test_generator.py +++ b/test/test_generator.py @@ -216,6 +216,18 @@ def test_load_tasks_for_kind(monkeypatch): ) +def test_loader_backwards_compat_interface(graph_config): + """Ensure loaders can be called even if they don't support a + `write_artifacts` argument.""" + + class OldLoaderKind(Kind): + def _get_loader(self): + return lambda kind, path, config, params, tasks: [] + + kind = OldLoaderKind("", "", {"transforms": []}, graph_config) + kind.load_tasks({}, {}, False) + + @pytest.mark.parametrize( "config,expected_transforms", ( @@ -243,7 +255,7 @@ def test_default_loader(config, expected_transforms): assert loader is default_loader, ( "Default Kind loader should be taskgraph.loader.default.loader" ) - loader("", "", config, {}, []) + loader("", "", config, {}, [], False) assert config["transforms"] == expected_transforms @@ -273,7 +285,7 @@ def test_default_loader(config, expected_transforms): def test_default_loader_errors(config): loader = Kind("", "", config, {})._get_loader() try: - loader("", "", config, {}, []) + loader("", "", config, {}, [], False) except KeyError: return