Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,10 +744,19 @@ def _compute_graph_signature(self) -> dict[str, Any]:
ignoring data/state changes. Used to verify that a workflow's structure hasn't
changed when resuming from checkpoints.
"""
executors_signature = {
executor_id: f"{executor.__class__.__module__}.{executor.__class__.__name__}"
for executor_id, executor in self.executors.items()
}
from ._workflow_executor import WorkflowExecutor

executors_signature = {}
for executor_id, executor in self.executors.items():
executor_sig: Any = f"{executor.__class__.__module__}.{executor.__class__.__name__}"

if isinstance(executor, WorkflowExecutor):
executor_sig = {
"type": executor_sig,
"sub_workflow": executor.workflow._graph_signature,
}

executors_signature[executor_id] = executor_sig

edge_groups_signature: list[dict[str, Any]] = []
for group in self.edge_groups:
Expand Down
85 changes: 85 additions & 0 deletions python/packages/core/tests/workflow/test_checkpoint_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
WorkflowBuilder,
WorkflowCheckpointException,
WorkflowContext,
WorkflowExecutor,
WorkflowRunState,
handler,
)
Expand Down Expand Up @@ -81,3 +82,87 @@ async def test_resume_succeeds_when_graph_matches() -> None:
]

assert any(event.type == "status" and event.state == WorkflowRunState.IDLE for event in events)


# -- Sub-workflow checkpoint validation tests --


class SubStartExecutor(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message)


class SubFinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(message)


def build_sub_workflow(sub_finish_id: str = "sub_finish"):
sub_start = SubStartExecutor(id="sub_start")
sub_finish = SubFinishExecutor(id=sub_finish_id)
return WorkflowBuilder(start_executor=sub_start).add_edge(sub_start, sub_finish).build()


def build_parent_workflow(storage: InMemoryCheckpointStorage, sub_finish_id: str = "sub_finish"):
sub_workflow = build_sub_workflow(sub_finish_id=sub_finish_id)
sub_executor = WorkflowExecutor(sub_workflow, id="sub_wf", allow_direct_output=True)

start = StartExecutor(id="start")
finish = FinishExecutor(id="finish")

builder = (
WorkflowBuilder(max_iterations=3, start_executor=start, checkpoint_storage=storage)
.add_edge(start, sub_executor)
.add_edge(sub_executor, finish)
)
return builder.build()


async def test_resume_succeeds_when_sub_workflow_matches() -> None:
storage = InMemoryCheckpointStorage()
workflow = build_parent_workflow(storage, sub_finish_id="sub_finish")

_ = [event async for event in workflow.run("hello", stream=True)]

checkpoints = await storage.list_checkpoints()
assert checkpoints, "expected at least one checkpoint to be created"
target_checkpoint = checkpoints[-1]

resumed_workflow = build_parent_workflow(storage, sub_finish_id="sub_finish")

events = [
event
async for event in resumed_workflow.run(
checkpoint_id=target_checkpoint.checkpoint_id,
checkpoint_storage=storage,
stream=True,
)
]

assert any(event.type == "status" and event.state == WorkflowRunState.IDLE for event in events)


async def test_resume_fails_when_sub_workflow_changes() -> None:
storage = InMemoryCheckpointStorage()
workflow = build_parent_workflow(storage, sub_finish_id="sub_finish")

_ = [event async for event in workflow.run("hello", stream=True)]

checkpoints = await storage.list_checkpoints()
assert checkpoints, "expected at least one checkpoint to be created"
target_checkpoint = checkpoints[-1]

# Build parent with a structurally different sub-workflow (different executor id inside)
mismatched_workflow = build_parent_workflow(storage, sub_finish_id="sub_finish_alt")

with pytest.raises(WorkflowCheckpointException, match="Workflow graph has changed"):
_ = [
event
async for event in mismatched_workflow.run(
checkpoint_id=target_checkpoint.checkpoint_id,
checkpoint_storage=storage,
stream=True,
)
]
Loading