diff --git a/pyproject.toml b/pyproject.toml index 002c4e1..52ac32c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ version_provider = "pep621" branch = true source = ["src/mozilla_taskgraph/", "mozilla_taskgraph"] -[tool.ruff] +[tool.ruff.lint] select = [ "E", "W", # pycodestyle "F", # pyflakes @@ -77,5 +77,5 @@ ignore = [ "E741", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["mozilla_taskgraph"] diff --git a/src/mozilla_taskgraph/transforms/replicate.py b/src/mozilla_taskgraph/transforms/replicate.py new file mode 100644 index 0000000..c25b55d --- /dev/null +++ b/src/mozilla_taskgraph/transforms/replicate.py @@ -0,0 +1,241 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from __future__ import annotations + +import os +import re +from textwrap import dedent + +from requests.exceptions import HTTPError +from taskgraph.transforms.base import TransformSequence +from taskgraph.util.attributes import attrmatch +from taskgraph.util.schema import Schema +from taskgraph.util.taskcluster import ( + find_task_id, + get_artifact, + get_task_definition, +) +from voluptuous import ALLOW_EXTRA, Any, Optional, Required + +REPLICATE_SCHEMA = Schema( + { + Required( + "replicate", + description=dedent( + """ + Configuration for the replicate transforms. + """.lstrip(), + ), + ): { + Required( + "target", + description=dedent( + """ + Define which tasks to target for replication. + + Each item in the list can be either: + + 1. A taskId + 2. An index path that points to a single task + + If any of the resolved tasks are a Decision task, targeted + tasks will be derived from the `task-graph.json` artifact. + """.lstrip() + ), + ): [str], + Optional( + "include-attrs", + description=dedent( + """ + A dict of attribute key/value pairs that targeted tasks will be + filtered on. Targeted tasks must *match all* of the given + attributes or they will be ignored. + + Matching is performed by the :func:`~taskgraph.util.attrmatch` + utility function. + """.lstrip(), + ), + ): {str: Any(str, [str])}, + Optional( + "exclude-attrs", + description=dedent( + """ + A dict of attribute key/value pairs that targeted tasks will be + filtered on. Targeted tasks must *not match any* of the given + attributes or they will be ignored. + + Matching is performed by the :func:`~taskgraph.util.attrmatch` + utility function. + """.lstrip(), + ), + ): {str: Any(str, [str])}, + }, + }, + extra=ALLOW_EXTRA, +) + +TASK_ID_RE = re.compile( + r"^[A-Za-z0-9_-]{8}[Q-T][A-Za-z0-9_-][CGKOSWaeimquy26-][A-Za-z0-9_-]{10}[AQgw]$" +) + +transforms = TransformSequence() +transforms.add_validate(REPLICATE_SCHEMA) + + +@transforms.add +def resolve_targets(config, tasks): + for task in tasks: + config = task.pop("replicate") + + task_defs = [] + for target in config["target"]: + if TASK_ID_RE.match(target): + # target is a task id + task_id = target + else: + # target is an index path + task_id = find_task_id(target) + + try: + # we have a decision task, add all tasks from task-graph.json + result = get_artifact(task_id, "public/task-graph.json").values() + task_defs.extend(result) + except HTTPError as e: + if e.response.status_code != 404: + raise + + # we have a regular task, just yield its definition and move on + task_defs.append(get_task_definition(target)) + + for task_def in task_defs: + attributes = task_def.get("attributes", {}) + + # filter out some unsupported / undesired cases implicitly + if task_def["task"]["provisionerId"] == "releng-hardware": + continue + + if ( + task_def["task"]["payload"] + .get("features", {}) + .get("runAsAdministrator") + ): + continue + + # filter out tasks that don't satisfy include-attrs + if not attrmatch(attributes, **config.get("include-attrs", {})): + continue + + # filter out tasks that satisfy exclude-attrs + if exclude_attrs := config.get("exclude-attrs"): + excludes = { + key: lambda attr: any([attr.startswith(v) for v in values]) + for key, values in exclude_attrs.items() + } + if attrmatch(attributes, **excludes): + continue + + task_def["name-prefix"] = task["name"] + yield task_def + + +def _rewrite_datestamps(task_def): + """Rewrite absolute datestamps from a concrete task definition into + relative ones that can then be used to schedule a new task.""" + # Arguably, we should try to figure out what these values should be from + # the repo that created them originally. In practice it probably doesn't + # matter. + task_def["created"] = {"relative-datestamp": "0 seconds"} + task_def["deadline"] = {"relative-datestamp": "1 day"} + task_def["expires"] = {"relative-datestamp": "1 month"} + + if artifacts := task_def.get("payload", {}).get("artifacts"): + artifacts = artifacts.values() if isinstance(artifacts, dict) else artifacts + for artifact in artifacts: + if "expires" in artifact: + artifact["expires"] = {"relative-datestamp": "1 month"} + + +def _remove_revisions(task_def): + """Rewrite revisions in task payloads to ensure that tasks do not refer to + out of date revisions.""" + to_remove = set() + for k in task_def.get("payload", {}).get("env", {}): + if k.endswith("_REV"): + to_remove.add(k) + + for k in to_remove: + del task_def["payload"]["env"][k] + + +@transforms.add +def rewrite_task(config, task_defs): + assert "TASK_ID" in os.environ + + trust_domain = config.graph_config["trust-domain"] + level = config.params["level"] + + # Replace strings like `gecko-level-3` with the active trust domain and + # level. + pattern = re.compile(r"[a-z]+-level-[1-3]") + repl = f"{trust_domain}-level-{level}" + + for task_def in task_defs: + task = task_def["task"] + + task.update( + { + "schedulerId": repl, + "taskGroupId": os.environ["TASK_ID"], + "priority": "low", + "routes": ["checks"], + } + ) + + # Remove treeherder config + if "treeherder" in task["extra"]: + del task["extra"]["treeherder"] + + cache = task["payload"].get("cache", {}) + for name, value in cache.copy().items(): + del cache[name] + name = pattern.sub(repl, name) + cache[name] = value + + for mount in task["payload"].get("mounts", []): + if "cacheName" in mount: + mount["cacheName"] = pattern.sub( + repl, + mount["cacheName"], + ) + + for i, scope in enumerate(task.get("scopes", [])): + task["scopes"][i] = pattern.sub(repl, scope) + + # Drop down to level 1 to match the current context. + for key in ("taskQueueId", "provisionerId", "worker-type"): + if key in task: + task_def[key] = task[key].replace("3", level) + + # All datestamps come in as absolute ones, many of which + # will be in the past. We need to rewrite these to relative + # ones to make the task reschedulable. + _rewrite_datestamps(task) + + # We also need to remove absolute revisions from payloads + # to avoid issues with revisions not matching the refs + # that are given. + _remove_revisions(task) + + name_prefix = task_def.pop("name-prefix") + task["metadata"]["name"] = f"{name_prefix}-{task['metadata']['name']}" + taskdesc = { + "label": task["metadata"]["name"], + "dependencies": {}, + "description": task["metadata"]["description"], + "task": task, + "attributes": {"replicate": name_prefix}, + } + + yield taskdesc diff --git a/test/conftest.py b/test/conftest.py index f521212..a8e59e9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -26,7 +26,7 @@ def set_taskcluster_url(session_mocker): @pytest.fixture def responses(): - with RequestsMock() as rsps: + with RequestsMock(assert_all_requests_are_fired=True) as rsps: yield rsps diff --git a/test/transforms/test_replicate.py b/test/transforms/test_replicate.py new file mode 100644 index 0000000..abb5731 --- /dev/null +++ b/test/transforms/test_replicate.py @@ -0,0 +1,275 @@ +from itertools import count +from pprint import pprint +from unittest.mock import Mock + +import pytest +from requests import HTTPError +from taskgraph.util.templates import merge + +from mozilla_taskgraph.transforms.replicate import transforms as replicate_transforms + +TC_ROOT_URL = "https://tc-tests.example.com" + + +def get_target_defs(*task_defs): + default = { + "task": { + "extra": { + "treeherder": "1", + }, + "metadata": {"name": "task-b", "description": "description"}, + "payload": { + "artifacts": { + "foo": { + "expires": "some datestamp", + }, + }, + "cache": { + "foo-level-3": "1", + }, + "env": { + "SHOULD_NOT_BE_REMOVED": "1", + "SHOULD_BE_REMOVED_REV": "1", + }, + "mounts": [ + { + "cacheName": "cache-foo-level-3-name", + } + ], + }, + "provisionerId": "foo", + "scopes": [ + "test:foo-level-3:scope", + ], + } + } + + task_defs = task_defs or [{}] + return [merge(default, task_def) for task_def in task_defs] + + +def get_expected(prefix, *task_defs): + expected = [] + for task_def in task_defs: + expected_task = merge( + task_def, + { + "attributes": { + "replicate": prefix, + }, + "dependencies": {}, + "description": "description", + "label": f"{prefix}-{task_def['task']['metadata']['name']}", + "task": { + "created": {"relative-datestamp": "0 seconds"}, + "deadline": {"relative-datestamp": "1 day"}, + "expires": {"relative-datestamp": "1 month"}, + "metadata": { + "name": f"{prefix}-{task_def['task']['metadata']['name']}", + }, + "payload": { + "artifacts": { + "foo": { + "expires": { + "relative-datestamp": "1 month", + } + } + }, + "cache": { + "test-level-1": "1", + }, + "mounts": [{"cacheName": "cache-test-level-1-name"}], + }, + "priority": "low", + "routes": [ + "checks", + ], + "schedulerId": "test-level-1", + "scopes": [ + "test:test-level-1:scope", + ], + "taskGroupId": "abc", + }, + }, + ) + del expected_task["task"]["extra"]["treeherder"] + del expected_task["task"]["payload"]["cache"]["foo-level-3"] + del expected_task["task"]["payload"]["env"]["SHOULD_BE_REMOVED_REV"] + del expected_task["task"]["payload"]["mounts"][0] + del expected_task["task"]["scopes"][0] + expected.append(expected_task) + + return expected + + +@pytest.fixture +def run_replicate(monkeypatch, run_transform): + task_id = "abc" + monkeypatch.setenv("TASK_ID", task_id) + + def inner(task): + result = run_transform(replicate_transforms, task) + pprint(result, indent=2) + return result + + return inner + + +def test_missing_config(run_replicate): + task = {} + with pytest.raises(Exception): + run_replicate(task) + + task["replicate"] = {} + with pytest.raises(Exception): + run_replicate(task) + + task["replicate"]["target"] = [] + assert run_replicate(task) == [] + + +def test_requests_error(responses, run_replicate): + task_id = "fwp41cUkRmara7CD6l2U3A" + task = { + "name": "foo", + "replicate": { + "target": [ + task_id, + ] + }, + } + responses.get( + f"{TC_ROOT_URL}/api/queue/v1/task/{task_id}/artifacts/public/task-graph.json", + body=HTTPError("Artifact not found!", response=Mock(status_code=403)), + ) + + with pytest.raises(HTTPError): + run_replicate(task) + + +def test_task_id(responses, run_replicate): + task_id = "fwp41cUkRmara7CD6l2U3A" + prefix = "kind-a" + task = { + "name": prefix, + "replicate": { + "target": [ + task_id, + ] + }, + } + task_def = get_target_defs()[0] + expected = get_expected(prefix, task_def)[0] + + responses.get( + f"{TC_ROOT_URL}/api/queue/v1/task/{task_id}/artifacts/public/task-graph.json", + body=HTTPError("Artifact not found!", response=Mock(status_code=404)), + ) + responses.get(f"{TC_ROOT_URL}/api/queue/v1/task/{task_id}", json=task_def) + + result = run_replicate(task) + assert len(result) == 1 + assert result[0] == expected + + +def test_index_path(responses, run_replicate): + prefix = "kind-a" + task_id = "def" + index_path = "foo.bar" + task = { + "name": prefix, + "replicate": {"target": [index_path]}, + } + task_def = get_target_defs()[0] + expected = get_expected(prefix, task_def)[0] + + responses.get( + f"{TC_ROOT_URL}/api/index/v1/task/{index_path}", json={"taskId": task_id} + ) + responses.get( + f"{TC_ROOT_URL}/api/queue/v1/task/{task_id}/artifacts/public/task-graph.json", + body=HTTPError("Artifact not found!", response=Mock(status_code=404)), + ) + responses.get(f"{TC_ROOT_URL}/api/queue/v1/task/{index_path}", json=task_def) + + result = run_replicate(task) + assert len(result) == 1 + assert result[0] == expected + + +def test_decision_task(responses, run_replicate): + prefix = "kind-a" + task_id = "fwp41cUkRmara7CD6l2U3A" + task = { + "name": prefix, + "replicate": { + "target": [ + task_id, + ] + }, + } + task_defs = get_target_defs({}, {"task": {"metadata": {"name": "task-c"}}}) + expected = get_expected(prefix, *task_defs) + + counter = count() + responses.get( + f"{TC_ROOT_URL}/api/queue/v1/task/{task_id}/artifacts/public/task-graph.json", + json={next(counter): task_def for task_def in task_defs}, + ) + result = run_replicate(task) + assert result == expected + + +@pytest.mark.parametrize( + "target_def", + ( + pytest.param( + { + "attributes": {"foo": "bar"}, + "task": {"provisionerId": "releng-hardware"}, + }, + id="releng-hardware", + ), + pytest.param( + { + "attributes": {"foo": "bar"}, + "task": {"payload": {"features": {"runAsAdministrator": True}}}, + }, + id="runAsAdministrator", + ), + pytest.param( + {}, # doesn't match 'include-attrs' + id="include-attrs", + ), + pytest.param( + {"attributes": {"foo": "bar", "baz": "1"}}, + id="exclude-attrs", + ), + ), +) +def test_filtered_out(responses, run_replicate, target_def): + prefix = "kind-a" + task_id = "fwp41cUkRmara7CD6l2U3A" + task = { + "name": prefix, + "replicate": { + "target": [ + task_id, + ], + "include-attrs": { + "foo": "bar", + }, + "exclude-attrs": { + "baz": "1", + }, + }, + } + task_defs = get_target_defs(target_def) + + counter = count() + responses.get( + f"{TC_ROOT_URL}/api/queue/v1/task/{task_id}/artifacts/public/task-graph.json", + json={next(counter): task_def for task_def in task_defs}, + ) + result = run_replicate(task) + assert len(result) == 0