Skip to content
4 changes: 4 additions & 0 deletions python/packages/core/agent_framework/_workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
SuperStepCompletedEvent,
SuperStepStartedEvent,
WorkflowErrorDetails,
WorkflowEvent,
WorkflowEventSource,
Expand Down Expand Up @@ -152,6 +154,8 @@
"StandardMagenticManager",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SuperStepCompletedEvent",
"SuperStepStartedEvent",
"SwitchCaseEdgeGroup",
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
Expand Down
4 changes: 4 additions & 0 deletions python/packages/core/agent_framework/_workflows/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ from ._events import (
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
SuperStepCompletedEvent,
SuperStepStartedEvent,
WorkflowErrorDetails,
WorkflowEvent,
WorkflowEventSource,
Expand Down Expand Up @@ -148,6 +150,8 @@ __all__ = [
"StandardMagenticManager",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SuperStepCompletedEvent",
"SuperStepStartedEvent",
"SwitchCaseEdgeGroup",
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
Expand Down
15 changes: 10 additions & 5 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
import sys
from dataclasses import dataclass
from typing import Any, cast

Expand All @@ -20,6 +21,11 @@
from ._request_info_mixin import response_handler
from ._workflow_context import WorkflowContext

if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -179,7 +185,8 @@ async def handle_user_input_response(
self._pending_responses_to_agent.clear()
await self._run_agent_and_emit(ctx)

async def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.

NOTE: if the thread storage is on the server side, the full thread state
Expand All @@ -196,9 +203,6 @@ async def snapshot_state(self) -> dict[str, Any]:
client_module = self._agent.chat_client.__class__.__module__

if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
# TODO(TaoChenOSU): update this warning when we surface the hooks for
# custom executor checkpointing.
# https://github.com/microsoft/agent-framework/issues/1816
logger.warning(
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
"Currently, checkpointing does not capture messages from server-side threads "
Expand All @@ -217,7 +221,8 @@ async def snapshot_state(self) -> dict[str, Any]:
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
}

async def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore executor state from checkpoint.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import inspect
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
Expand All @@ -13,6 +14,12 @@
from ._orchestrator_helpers import ParticipantRegistry
from ._workflow_context import WorkflowContext

if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -210,11 +217,12 @@ def _check_round_limit(self) -> bool:

# State persistence (shared across all patterns)

def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing.

Default implementation uses OrchestrationState to serialize common state.
Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data.
Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.

Returns:
Serialized state dict
Expand All @@ -238,11 +246,12 @@ def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""
return {}

def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore orchestrator state from checkpoint.

Default implementation uses OrchestrationState to deserialize common state.
Subclasses should override _restore_pattern_metadata() to restore pattern-specific data.
Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.

Args:
state: Serialized state dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
dictionary snapshots so orchestrators can share logic without new mixins.
"""

import json
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Sequence

from .._types import ChatMessage

Expand All @@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage:
"""Attach `fallback` author if message is missing `author_name`."""
message.author_name = message.author_name or fallback
return message


def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]:
"""Build an immutable snapshot for checkpoint storage."""
if hasattr(conversation, "to_dict"):
result = conversation.to_dict() # type: ignore[attr-defined]
if isinstance(result, dict):
return result # type: ignore[return-value]
if isinstance(result, Mapping):
return dict(result) # type: ignore[arg-type]
serialisable: list[dict[str, Any]] = []
for message in conversation:
if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined]
msg_dict = message.to_dict() # type: ignore[attr-defined]
serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type]
elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined]
json_payload = message.to_json() # type: ignore[attr-defined]
parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload
serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type]
else:
serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type]
return {"messages": serialisable}
38 changes: 32 additions & 6 deletions python/packages/core/agent_framework/_workflows/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,36 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})"


class SuperStepEvent(WorkflowEvent):
"""Event triggered when a superstep starts or ends."""

def __init__(self, iteration: int, data: Any | None = None):
"""Initialize the superstep event.

Args:
iteration: The number of the superstep (1-based index).
data: Optional data associated with the superstep event.
"""
super().__init__(data)
self.iteration = iteration

def __repr__(self) -> str:
"""Return a string representation of the superstep event."""
return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})"


class SuperStepStartedEvent(SuperStepEvent):
"""Event triggered when a superstep starts."""

...


class SuperStepCompletedEvent(SuperStepEvent):
"""Event triggered when a superstep ends."""

...


class ExecutorEvent(WorkflowEvent):
"""Base class for executor events."""

Expand All @@ -310,17 +340,13 @@ def __repr__(self) -> str:
class ExecutorInvokedEvent(ExecutorEvent):
"""Event triggered when an executor handler is invoked."""

def __repr__(self) -> str:
"""Return a string representation of the executor handler invoke event."""
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
...


class ExecutorCompletedEvent(ExecutorEvent):
"""Event triggered when an executor handler is completed."""

def __repr__(self) -> str:
"""Return a string representation of the executor handler complete event."""
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
...


class ExecutorFailedEvent(ExecutorEvent):
Expand Down
31 changes: 31 additions & 0 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def sync_process(text: str, ctx: WorkflowContext[str]) -> None:
that parent workflows can intercept. See WorkflowExecutor documentation for details on
workflow composition patterns and request/response handling.

## State Management
Executors can contain states that persist across workflow runs and checkpoints. Override the
`on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state
serialization and restoration logic.

## Implementation Notes
- Do not call `execute()` directly - it's invoked by the workflow engine
- Do not override `execute()` - define handlers using decorators instead
Expand Down Expand Up @@ -460,6 +465,32 @@ def _find_handler(self, message: Any) -> Callable[[Any, WorkflowContext[Any, Any
return self._handlers[message_type]
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")

async def on_checkpoint_save(self) -> dict[str, Any]:
"""Hook called when the workflow is being saved to a checkpoint.

Override this method in subclasses to implement custom logic that should
return state to be saved in the checkpoint.

The returned state dictionary will be passed to `on_checkpoint_restore`
when the workflow is restored from the checkpoint. The dictionary should
only contain JSON-serializable data.

Returns:
A state dictionary to be saved during checkpointing.
"""
return {}

async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Hook called when the workflow is restored from a checkpoint.

Override this method in subclasses to implement custom logic that should
run when the workflow is restored from a checkpoint.

Args:
state: The state dictionary that was saved during checkpointing.
"""
...


# endregion: Executor

Expand Down
38 changes: 10 additions & 28 deletions python/packages/core/agent_framework/_workflows/_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import re
import sys
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any
Expand Down Expand Up @@ -50,6 +51,12 @@
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext

if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -307,15 +314,6 @@ async def handle_agent_response(
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
) -> None:
"""Process an agent's response and determine whether to route, request input, or terminate."""
# Hydrate coordinator state (and detect new run) using checkpointable executor state
state = await ctx.get_executor_state()
if not state:
self._clear_conversation()
elif not self._get_conversation():
restored = self._restore_conversation_from_state(state)
if restored:
self._conversation = list(restored)

source = ctx.get_source_executor_id()
is_starting_agent = source == self._starting_agent_id

Expand Down Expand Up @@ -343,7 +341,7 @@ async def handle_agent_response(
# Update current agent when handoff occurs
self._current_agent_id = target
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
await self._persist_state(ctx)

# Clean tool-related content before sending to next agent
cleaned = clean_conversation_for_handoff(conversation)
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
Expand All @@ -360,7 +358,6 @@ async def handle_agent_response(
f"Agent '{source}' responded without handoff. "
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
)
await self._persist_state(ctx)

if await self._check_termination():
# Clean the output conversation for display
Expand Down Expand Up @@ -388,7 +385,6 @@ async def handle_user_input(
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
# Update authoritative conversation
self._conversation = list(message.full_conversation)
await self._persist_state(ctx)

# Check termination before sending to agent
if await self._check_termination():
Expand Down Expand Up @@ -473,11 +469,7 @@ def _conversation_from_response(self, response: AgentExecutorResponse) -> list[C
)
return list(conversation)

async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
"""Store authoritative conversation snapshot without losing rich metadata."""
state_payload = self.snapshot_state()
await ctx.set_executor_state(state_payload)

@override
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.

Expand All @@ -492,6 +484,7 @@ def _snapshot_pattern_metadata(self) -> dict[str, Any]:
}
return {}

@override
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.

Expand All @@ -503,17 +496,6 @@ def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
if self._return_to_previous and "current_agent_id" in metadata:
self._current_agent_id = metadata["current_agent_id"]

def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
"""Rehydrate the coordinator's conversation history from checkpointed state.

DEPRECATED: Use restore_state() instead. Kept for backward compatibility.
"""
from ._orchestration_state import OrchestrationState

orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))}
temp_state = OrchestrationState.from_dict(orch_state_dict)
return list(temp_state.conversation)

def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None:
"""Merge top-level response metadata into the latest assistant message."""
if not agent_response.additional_properties:
Expand Down
Loading