Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/taskgraph/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def _run(self):

# Initial verifications that don't depend on any generation state.
self.verify("initial")
self.verify("graph_config", graph_config)

if callable(self._parameters):
parameters = self._parameters(graph_config)
Expand Down
7 changes: 7 additions & 0 deletions src/taskgraph/util/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def verify(self):
self.func()


@dataclass(frozen=True)
class GraphConfigVerification(Verification):
def verify(self, graph_config: GraphConfig):
self.func(graph_config)


@dataclass(frozen=True)
class GraphVerification(Verification):
"""Verification for a TaskGraph object."""
Expand Down Expand Up @@ -95,6 +101,7 @@ class VerificationSequence:
_verifications: Dict = field(default_factory=dict)
_verification_types = {
"graph": GraphVerification,
"graph_config": GraphConfigVerification,
"initial": InitialVerification,
"kinds": KindsVerification,
"parameters": ParametersVerification,
Expand Down
2 changes: 1 addition & 1 deletion test/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_verifications(mocker, maketgg):
m = mocker.patch.object(generator, "verifications")
tgg = maketgg(["_fake-t-2"], enable_verifications=True)
tgg.morphed_task_graph
assert m.call_count == 9
assert m.call_count == 10

m = mocker.patch.object(generator, "verifications")
tgg = maketgg(["_fake-t-2"], enable_verifications=False)
Expand Down
10 changes: 10 additions & 0 deletions test/test_util_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from taskgraph.task import Task
from taskgraph.util.treeherder import split_symbol
from taskgraph.util.verify import (
GraphConfigVerification,
GraphVerification,
ParametersVerification,
VerificationSequence,
Expand Down Expand Up @@ -38,6 +39,8 @@ def inner(name, **kwargs):

if isinstance(v, GraphVerification):
assert "graph" in kwargs

if isinstance(v, (GraphVerification, GraphConfigVerification)):
kwargs.setdefault("graph_config", graph_config)

if isinstance(v, (GraphVerification, ParametersVerification)):
Expand Down Expand Up @@ -94,6 +97,13 @@ def assert_simple_verification(arg):
1,
id="ParametersVerification",
),
pytest.param(
"graph_config",
("passed-thru",),
assert_simple_verification,
1,
id="GraphConfigVerification",
),
),
)
def test_verification_types(name, input, run_assertions, expected_called):
Expand Down
Loading