Skip to content

Commit c361ad8

Browse files
authored
Python: Add checkpoint save and restore hooks to executor (#2097)
* Add checkpoint hooks * Deprecate get_executor_state and set_executor_state * Fix tests and samples * Add doc strings * Add sample * Fix import * Address comments and fix tests * Address comments * conditional import
1 parent 1325979 commit c361ad8

22 files changed

Lines changed: 511 additions & 726 deletions

python/packages/core/agent_framework/_workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
ExecutorFailedEvent,
3838
ExecutorInvokedEvent,
3939
RequestInfoEvent,
40+
SuperStepCompletedEvent,
41+
SuperStepStartedEvent,
4042
WorkflowErrorDetails,
4143
WorkflowEvent,
4244
WorkflowEventSource,
@@ -152,6 +154,8 @@
152154
"StandardMagenticManager",
153155
"SubWorkflowRequestMessage",
154156
"SubWorkflowResponseMessage",
157+
"SuperStepCompletedEvent",
158+
"SuperStepStartedEvent",
155159
"SwitchCaseEdgeGroup",
156160
"SwitchCaseEdgeGroupCase",
157161
"SwitchCaseEdgeGroupDefault",

python/packages/core/agent_framework/_workflows/__init__.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ from ._events import (
3535
ExecutorFailedEvent,
3636
ExecutorInvokedEvent,
3737
RequestInfoEvent,
38+
SuperStepCompletedEvent,
39+
SuperStepStartedEvent,
3840
WorkflowErrorDetails,
3941
WorkflowEvent,
4042
WorkflowEventSource,
@@ -148,6 +150,8 @@ __all__ = [
148150
"StandardMagenticManager",
149151
"SubWorkflowRequestMessage",
150152
"SubWorkflowResponseMessage",
153+
"SuperStepCompletedEvent",
154+
"SuperStepStartedEvent",
151155
"SwitchCaseEdgeGroup",
152156
"SwitchCaseEdgeGroupCase",
153157
"SwitchCaseEdgeGroupDefault",

python/packages/core/agent_framework/_workflows/_agent_executor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import logging
4+
import sys
45
from dataclasses import dataclass
56
from typing import Any, cast
67

@@ -20,6 +21,11 @@
2021
from ._request_info_mixin import response_handler
2122
from ._workflow_context import WorkflowContext
2223

24+
if sys.version_info >= (3, 12):
25+
from typing import override
26+
else:
27+
from typing_extensions import override
28+
2329
logger = logging.getLogger(__name__)
2430

2531

@@ -179,7 +185,8 @@ async def handle_user_input_response(
179185
self._pending_responses_to_agent.clear()
180186
await self._run_agent_and_emit(ctx)
181187

182-
async def snapshot_state(self) -> dict[str, Any]:
188+
@override
189+
async def on_checkpoint_save(self) -> dict[str, Any]:
183190
"""Capture current executor state for checkpointing.
184191
185192
NOTE: if the thread storage is on the server side, the full thread state
@@ -196,9 +203,6 @@ async def snapshot_state(self) -> dict[str, Any]:
196203
client_module = self._agent.chat_client.__class__.__module__
197204

198205
if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
199-
# TODO(TaoChenOSU): update this warning when we surface the hooks for
200-
# custom executor checkpointing.
201-
# https://github.com/microsoft/agent-framework/issues/1816
202206
logger.warning(
203207
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
204208
"Currently, checkpointing does not capture messages from server-side threads "
@@ -217,7 +221,8 @@ async def snapshot_state(self) -> dict[str, Any]:
217221
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
218222
}
219223

220-
async def restore_state(self, state: dict[str, Any]) -> None:
224+
@override
225+
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
221226
"""Restore executor state from checkpoint.
222227
223228
Args:

python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import inspect
66
import logging
7+
import sys
78
from abc import ABC, abstractmethod
89
from collections.abc import Awaitable, Callable, Sequence
910
from typing import Any
@@ -13,6 +14,12 @@
1314
from ._orchestrator_helpers import ParticipantRegistry
1415
from ._workflow_context import WorkflowContext
1516

17+
if sys.version_info >= (3, 12):
18+
from typing import override
19+
else:
20+
from typing_extensions import override
21+
22+
1623
logger = logging.getLogger(__name__)
1724

1825

@@ -210,11 +217,12 @@ def _check_round_limit(self) -> bool:
210217

211218
# State persistence (shared across all patterns)
212219

213-
def snapshot_state(self) -> dict[str, Any]:
220+
@override
221+
async def on_checkpoint_save(self) -> dict[str, Any]:
214222
"""Capture current orchestrator state for checkpointing.
215223
216224
Default implementation uses OrchestrationState to serialize common state.
217-
Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data.
225+
Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.
218226
219227
Returns:
220228
Serialized state dict
@@ -238,11 +246,12 @@ def _snapshot_pattern_metadata(self) -> dict[str, Any]:
238246
"""
239247
return {}
240248

241-
def restore_state(self, state: dict[str, Any]) -> None:
249+
@override
250+
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
242251
"""Restore orchestrator state from checkpoint.
243252
244253
Default implementation uses OrchestrationState to deserialize common state.
245-
Subclasses should override _restore_pattern_metadata() to restore pattern-specific data.
254+
Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.
246255
247256
Args:
248257
state: Serialized state dict

python/packages/core/agent_framework/_workflows/_conversation_history.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
dictionary snapshots so orchestrators can share logic without new mixins.
77
"""
88

9-
import json
10-
from collections.abc import Mapping, Sequence
11-
from typing import Any
9+
from collections.abc import Sequence
1210

1311
from .._types import ChatMessage
1412

@@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage:
2624
"""Attach `fallback` author if message is missing `author_name`."""
2725
message.author_name = message.author_name or fallback
2826
return message
29-
30-
31-
def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]:
32-
"""Build an immutable snapshot for checkpoint storage."""
33-
if hasattr(conversation, "to_dict"):
34-
result = conversation.to_dict() # type: ignore[attr-defined]
35-
if isinstance(result, dict):
36-
return result # type: ignore[return-value]
37-
if isinstance(result, Mapping):
38-
return dict(result) # type: ignore[arg-type]
39-
serialisable: list[dict[str, Any]] = []
40-
for message in conversation:
41-
if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined]
42-
msg_dict = message.to_dict() # type: ignore[attr-defined]
43-
serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type]
44-
elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined]
45-
json_payload = message.to_json() # type: ignore[attr-defined]
46-
parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload
47-
serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type]
48-
else:
49-
serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type]
50-
return {"messages": serialisable}

python/packages/core/agent_framework/_workflows/_events.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,36 @@ def __repr__(self) -> str:
294294
return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})"
295295

296296

297+
class SuperStepEvent(WorkflowEvent):
298+
"""Event triggered when a superstep starts or ends."""
299+
300+
def __init__(self, iteration: int, data: Any | None = None):
301+
"""Initialize the superstep event.
302+
303+
Args:
304+
iteration: The number of the superstep (1-based index).
305+
data: Optional data associated with the superstep event.
306+
"""
307+
super().__init__(data)
308+
self.iteration = iteration
309+
310+
def __repr__(self) -> str:
311+
"""Return a string representation of the superstep event."""
312+
return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})"
313+
314+
315+
class SuperStepStartedEvent(SuperStepEvent):
316+
"""Event triggered when a superstep starts."""
317+
318+
...
319+
320+
321+
class SuperStepCompletedEvent(SuperStepEvent):
322+
"""Event triggered when a superstep ends."""
323+
324+
...
325+
326+
297327
class ExecutorEvent(WorkflowEvent):
298328
"""Base class for executor events."""
299329

@@ -310,17 +340,13 @@ def __repr__(self) -> str:
310340
class ExecutorInvokedEvent(ExecutorEvent):
311341
"""Event triggered when an executor handler is invoked."""
312342

313-
def __repr__(self) -> str:
314-
"""Return a string representation of the executor handler invoke event."""
315-
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
343+
...
316344

317345

318346
class ExecutorCompletedEvent(ExecutorEvent):
319347
"""Event triggered when an executor handler is completed."""
320348

321-
def __repr__(self) -> str:
322-
"""Return a string representation of the executor handler complete event."""
323-
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
349+
...
324350

325351

326352
class ExecutorFailedEvent(ExecutorEvent):

python/packages/core/agent_framework/_workflows/_executor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def sync_process(text: str, ctx: WorkflowContext[str]) -> None:
155155
that parent workflows can intercept. See WorkflowExecutor documentation for details on
156156
workflow composition patterns and request/response handling.
157157
158+
## State Management
159+
Executors can contain states that persist across workflow runs and checkpoints. Override the
160+
`on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state
161+
serialization and restoration logic.
162+
158163
## Implementation Notes
159164
- Do not call `execute()` directly - it's invoked by the workflow engine
160165
- Do not override `execute()` - define handlers using decorators instead
@@ -460,6 +465,32 @@ def _find_handler(self, message: Any) -> Callable[[Any, WorkflowContext[Any, Any
460465
return self._handlers[message_type]
461466
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
462467

468+
async def on_checkpoint_save(self) -> dict[str, Any]:
469+
"""Hook called when the workflow is being saved to a checkpoint.
470+
471+
Override this method in subclasses to implement custom logic that should
472+
return state to be saved in the checkpoint.
473+
474+
The returned state dictionary will be passed to `on_checkpoint_restore`
475+
when the workflow is restored from the checkpoint. The dictionary should
476+
only contain JSON-serializable data.
477+
478+
Returns:
479+
A state dictionary to be saved during checkpointing.
480+
"""
481+
return {}
482+
483+
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
484+
"""Hook called when the workflow is restored from a checkpoint.
485+
486+
Override this method in subclasses to implement custom logic that should
487+
run when the workflow is restored from a checkpoint.
488+
489+
Args:
490+
state: The state dictionary that was saved during checkpointing.
491+
"""
492+
...
493+
463494

464495
# endregion: Executor
465496

python/packages/core/agent_framework/_workflows/_handoff.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import re
19+
import sys
1920
from collections.abc import Awaitable, Callable, Mapping, Sequence
2021
from dataclasses import dataclass, field
2122
from typing import Any
@@ -50,6 +51,12 @@
5051
from ._workflow_builder import WorkflowBuilder
5152
from ._workflow_context import WorkflowContext
5253

54+
if sys.version_info >= (3, 12):
55+
from typing import override
56+
else:
57+
from typing_extensions import override
58+
59+
5360
logger = logging.getLogger(__name__)
5461

5562

@@ -307,15 +314,6 @@ async def handle_agent_response(
307314
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
308315
) -> None:
309316
"""Process an agent's response and determine whether to route, request input, or terminate."""
310-
# Hydrate coordinator state (and detect new run) using checkpointable executor state
311-
state = await ctx.get_executor_state()
312-
if not state:
313-
self._clear_conversation()
314-
elif not self._get_conversation():
315-
restored = self._restore_conversation_from_state(state)
316-
if restored:
317-
self._conversation = list(restored)
318-
319317
source = ctx.get_source_executor_id()
320318
is_starting_agent = source == self._starting_agent_id
321319

@@ -343,7 +341,7 @@ async def handle_agent_response(
343341
# Update current agent when handoff occurs
344342
self._current_agent_id = target
345343
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
346-
await self._persist_state(ctx)
344+
347345
# Clean tool-related content before sending to next agent
348346
cleaned = clean_conversation_for_handoff(conversation)
349347
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
@@ -360,7 +358,6 @@ async def handle_agent_response(
360358
f"Agent '{source}' responded without handoff. "
361359
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
362360
)
363-
await self._persist_state(ctx)
364361

365362
if await self._check_termination():
366363
# Clean the output conversation for display
@@ -388,7 +385,6 @@ async def handle_user_input(
388385
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
389386
# Update authoritative conversation
390387
self._conversation = list(message.full_conversation)
391-
await self._persist_state(ctx)
392388

393389
# Check termination before sending to agent
394390
if await self._check_termination():
@@ -473,11 +469,7 @@ def _conversation_from_response(self, response: AgentExecutorResponse) -> list[C
473469
)
474470
return list(conversation)
475471

476-
async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
477-
"""Store authoritative conversation snapshot without losing rich metadata."""
478-
state_payload = self.snapshot_state()
479-
await ctx.set_executor_state(state_payload)
480-
472+
@override
481473
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
482474
"""Serialize pattern-specific state.
483475
@@ -492,6 +484,7 @@ def _snapshot_pattern_metadata(self) -> dict[str, Any]:
492484
}
493485
return {}
494486

487+
@override
495488
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
496489
"""Restore pattern-specific state.
497490
@@ -503,17 +496,6 @@ def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
503496
if self._return_to_previous and "current_agent_id" in metadata:
504497
self._current_agent_id = metadata["current_agent_id"]
505498

506-
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
507-
"""Rehydrate the coordinator's conversation history from checkpointed state.
508-
509-
DEPRECATED: Use restore_state() instead. Kept for backward compatibility.
510-
"""
511-
from ._orchestration_state import OrchestrationState
512-
513-
orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))}
514-
temp_state = OrchestrationState.from_dict(orch_state_dict)
515-
return list(temp_state.conversation)
516-
517499
def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None:
518500
"""Merge top-level response metadata into the latest assistant message."""
519501
if not agent_response.additional_properties:

0 commit comments

Comments
 (0)