Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/packages/core/agent_framework/_workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
WorkflowStartedEvent,
WorkflowStatusEvent,
)
from ._exceptions import (
WorkflowCheckpointException,
WorkflowConvergenceException,
WorkflowException,
WorkflowRunnerException,
)
from ._executor import (
Executor,
handler,
Expand Down Expand Up @@ -109,7 +115,11 @@
from ._workflow import Workflow, WorkflowRunResult
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
from ._workflow_executor import SubWorkflowRequestMessage, SubWorkflowResponseMessage, WorkflowExecutor
from ._workflow_executor import (
SubWorkflowRequestMessage,
SubWorkflowResponseMessage,
WorkflowExecutor,
)

__all__ = [
"DEFAULT_MAX_ITERATIONS",
Expand Down Expand Up @@ -185,17 +195,21 @@
"WorkflowAgent",
"WorkflowBuilder",
"WorkflowCheckpoint",
"WorkflowCheckpointException",
"WorkflowCheckpointSummary",
"WorkflowContext",
"WorkflowConvergenceException",
"WorkflowErrorDetails",
"WorkflowEvent",
"WorkflowEventSource",
"WorkflowException",
"WorkflowExecutor",
"WorkflowFailedEvent",
"WorkflowLifecycleEvent",
"WorkflowOutputEvent",
"WorkflowRunResult",
"WorkflowRunState",
"WorkflowRunnerException",
"WorkflowStartedEvent",
"WorkflowStatusEvent",
"WorkflowValidationError",
Expand Down
27 changes: 27 additions & 0 deletions python/packages/core/agent_framework/_workflows/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Microsoft. All rights reserved.

from ..exceptions import AgentFrameworkException


class WorkflowException(AgentFrameworkException):
"""Base exception for workflow errors."""

pass


class WorkflowRunnerException(WorkflowException):
"""Base exception for workflow runner errors."""

pass


class WorkflowConvergenceException(WorkflowRunnerException):
"""Exception raised when a workflow runner fails to converge within the maximum iterations."""

pass


class WorkflowCheckpointException(WorkflowRunnerException):
"""Exception raised for errors related to workflow checkpoints."""

pass
57 changes: 34 additions & 23 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
from typing import Any

from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpoint_value
from ._checkpoint_encoding import (
DATACLASS_MARKER,
MODEL_MARKER,
decode_checkpoint_value,
)
from ._const import EXECUTOR_STATE_KEY
from ._edge import EdgeGroup
from ._edge_runner import EdgeRunner, create_edge_runner
from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent
from ._exceptions import (
WorkflowCheckpointException,
WorkflowConvergenceException,
WorkflowRunnerException,
)
from ._executor import Executor
from ._runner_context import (
Message,
Expand Down Expand Up @@ -72,7 +81,7 @@ def reset_iteration_count(self) -> None:
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
if self._running:
raise RuntimeError("Runner is already running.")
raise WorkflowRunnerException("Runner is already running.")

self._running = True
try:
Expand Down Expand Up @@ -134,12 +143,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
break

if self._iteration >= self._max_iterations and await self._ctx.has_messages():
raise RuntimeError(f"Runner did not converge after {self._max_iterations} iterations.")
raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.")

logger.info(f"Workflow completed after {self._iteration} supersteps")
# TODO(@taochen): iteration is reset to zero, even in the event of a request info event.
# Should iteration be preserved in the event of a request info event?
self._iteration = 0
self._resumed_from_checkpoint = False # Reset resume flag for next run
finally:
self._running = False
Expand Down Expand Up @@ -212,7 +218,7 @@ async def restore_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
) -> bool:
) -> None:
"""Restore workflow state from a checkpoint.

Args:
Expand All @@ -221,7 +227,10 @@ async def restore_from_checkpoint(
runner context itself is not configured with checkpointing.

Returns:
True if restoration was successful, False otherwise
None on success.

Raises:
WorkflowCheckpointException on failure.
"""
try:
# Load the checkpoint
Expand All @@ -231,18 +240,19 @@ async def restore_from_checkpoint(
elif checkpoint_storage is not None:
checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id)
else:
logger.warning("Context does not support checkpointing and no external storage was provided")
return False
raise WorkflowCheckpointException(
"Cannot load checkpoint: no checkpointing configured in context or external storage provided."
)

if not checkpoint:
logger.error(f"Checkpoint {checkpoint_id} not found")
return False
raise WorkflowCheckpointException(f"Checkpoint {checkpoint_id} not found")

# Validate the loaded checkpoint against the workflow
graph_hash = getattr(self, "graph_signature_hash", None)
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
raise ValueError(
raise WorkflowCheckpointException(
"Workflow graph has changed since the checkpoint was created. "
"Please rebuild the original workflow before resuming."
)
Expand All @@ -263,12 +273,11 @@ async def restore_from_checkpoint(
self._mark_resumed(checkpoint.iteration_count)

logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
return True
except ValueError:
except WorkflowCheckpointException:
raise
except Exception as e:
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
return False
raise WorkflowCheckpointException(f"Failed to restore from checkpoint {checkpoint_id}") from e

async def _save_executor_states(self) -> None:
"""Populate executor state by calling checkpoint hooks on executors.
Expand Down Expand Up @@ -309,7 +318,7 @@ async def _save_executor_states(self) -> None:
try:
state_dict = await executor.on_checkpoint_save()
except Exception as ex: # pragma: no cover
raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex
raise WorkflowCheckpointException(f"Executor {exec_id} on_checkpoint_save failed") from ex

try:
await self._set_executor_state(exec_id, state_dict)
Expand All @@ -335,17 +344,19 @@ async def _restore_executor_states(self) -> None:

executor_states = await self._shared_state.get(EXECUTOR_STATE_KEY)
if not isinstance(executor_states, dict):
raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.")
raise WorkflowCheckpointException("Executor states in shared state is not a dictionary. Unable to restore.")

for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType]
if not isinstance(executor_id, str):
raise ValueError("Executor ID in executor states is not a string. Unable to restore.")
raise WorkflowCheckpointException("Executor ID in executor states is not a string. Unable to restore.")
if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType]
raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.")
raise WorkflowCheckpointException(
f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore."
)

executor = self._executors.get(executor_id)
if not executor:
raise ValueError(f"Executor {executor_id} not found during state restoration.")
raise WorkflowCheckpointException(f"Executor {executor_id} not found during state restoration.")

# Try backward compatibility behavior first
# TODO(@taochen): Remove backward compatibility
Expand All @@ -358,15 +369,15 @@ async def _restore_executor_states(self) -> None:
await maybe # type: ignore[arg-type]
restored = True
except Exception as ex: # pragma: no cover - defensive
raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex
raise WorkflowCheckpointException(f"Executor {executor_id} restore_state failed") from ex

if not restored:
# Try the updated behavior only if backward compatibility did not restore
try:
await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType]
restored = True
except Exception as ex: # pragma: no cover - defensive
raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex
raise WorkflowCheckpointException(f"Executor {executor_id} on_checkpoint_restore failed") from ex

if not restored:
logger.debug(f"Executor {executor_id} does not support state restoration; skipping.")
Expand Down Expand Up @@ -409,7 +420,7 @@ async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) ->
existing_states = {}

if not isinstance(existing_states, dict):
raise ValueError("Existing executor states in shared state is not a dictionary.")
raise WorkflowCheckpointException("Existing executor states in shared state is not a dictionary.")

existing_states[executor_id] = state
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,7 @@ async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None:
self._messages.clear()
messages_data = checkpoint.messages
for source_id, message_list in messages_data.items():
self._messages[source_id] = [
Message(
data=decode_checkpoint_value(msg.get("data")),
source_id=msg.get("source_id", ""),
target_id=msg.get("target_id"),
trace_contexts=msg.get("trace_contexts"),
source_span_ids=msg.get("source_span_ids"),
)
for msg in message_list
]
self._messages[source_id] = [Message.from_dict(msg) for msg in message_list]

# Restore pending request info events
self._pending_request_info_events.clear()
Expand Down
5 changes: 1 addition & 4 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,7 @@ async def _execute_with_message_or_checkpoint(
"or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)."
)

restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)

if not restored:
raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}")
await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)

# Handle initial message
elif message is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import pytest
from typing_extensions import Never

from agent_framework import WorkflowBuilder, WorkflowContext, WorkflowRunState, WorkflowStatusEvent, handler
from agent_framework import (
WorkflowBuilder,
WorkflowCheckpointException,
WorkflowContext,
WorkflowRunState,
WorkflowStatusEvent,
handler,
)
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
from agent_framework._workflows._executor import Executor

Expand Down Expand Up @@ -43,7 +50,7 @@ async def test_resume_fails_when_graph_mismatch() -> None:
# Build a structurally different workflow (different finish executor id)
mismatched_workflow = build_workflow(storage, finish_id="finish_alt")

with pytest.raises(ValueError, match="Workflow graph has changed"):
with pytest.raises(WorkflowCheckpointException, match="Workflow graph has changed"):
_ = [
event
async for event in mismatched_workflow.run_stream(
Expand Down
22 changes: 16 additions & 6 deletions python/packages/core/tests/workflow/test_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TextContent,
Workflow,
WorkflowCheckpoint,
WorkflowCheckpointException,
WorkflowContext,
WorkflowEvent,
WorkflowOutputEvent,
Expand Down Expand Up @@ -341,7 +342,8 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result():
events.append(ev)

idle_status = next(
(e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), None
(e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE),
None,
)
assert idle_status is not None
# Check that we got workflow output via WorkflowOutputEvent
Expand Down Expand Up @@ -586,7 +588,9 @@ async def _collect_agent_responses_setup(participant: AgentProtocol) -> list[Cha
if isinstance(ev, AgentRunUpdateEvent):
captured.append(
ChatMessage(
role=ev.data.role or Role.ASSISTANT, text=ev.data.text or "", author_name=ev.data.author_name
role=ev.data.role or Role.ASSISTANT,
text=ev.data.text or "",
author_name=ev.data.author_name,
)
)

Expand All @@ -606,7 +610,9 @@ async def test_agent_executor_invoke_with_assistants_client_messages():
assert any((m.author_name == agent.name and "ok" in (m.text or "")) for m in captured)


async def _collect_checkpoints(storage: InMemoryCheckpointStorage) -> list[WorkflowCheckpoint]:
async def _collect_checkpoints(
storage: InMemoryCheckpointStorage,
) -> list[WorkflowCheckpoint]:
checkpoints = await storage.list_checkpoints()
assert checkpoints
checkpoints.sort(key=lambda cp: cp.timestamp)
Expand Down Expand Up @@ -721,7 +727,7 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames():
.build()
)

with pytest.raises(ValueError, match="Workflow graph has changed"):
with pytest.raises(WorkflowCheckpointException, match="Workflow graph has changed"):
async for _ in renamed_workflow.run_stream(
checkpoint_id=target_checkpoint.checkpoint_id, # type: ignore[reportUnknownMemberType]
):
Expand Down Expand Up @@ -762,7 +768,8 @@ async def test_magentic_stall_and_reset_reach_limits():
events.append(ev)

idle_status = next(
(e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), None
(e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE),
None,
)
assert idle_status is not None
output_event = next((e for e in events if isinstance(e, WorkflowOutputEvent)), None)
Expand Down Expand Up @@ -801,7 +808,10 @@ async def test_magentic_checkpoint_runtime_overrides_buildtime() -> None:
"""Test that runtime checkpoint storage overrides build-time configuration."""
import tempfile

with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2:
with (
tempfile.TemporaryDirectory() as temp_dir1,
tempfile.TemporaryDirectory() as temp_dir2,
):
from agent_framework._workflows._checkpoint import FileCheckpointStorage

buildtime_storage = FileCheckpointStorage(temp_dir1)
Expand Down
Loading
Loading