From 1b7121e0045f61ef95bf6c7a24b7cd6e56b82d59 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 10 Nov 2025 16:25:41 -0800 Subject: [PATCH 1/9] Add checkpoint hooks --- .../_workflows/_agent_executor.py | 11 +- .../_base_group_chat_orchestrator.py | 10 +- .../_workflows/_conversation_history.py | 26 +---- .../agent_framework/_workflows/_executor.py | 26 +++++ .../agent_framework/_workflows/_handoff.py | 9 +- .../agent_framework/_workflows/_magentic.py | 109 ++++-------------- .../agent_framework/_workflows/_runner.py | 105 ++++++++++------- 7 files changed, 125 insertions(+), 171 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 0c05abbb69..a00623a258 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Any +from typing import Any, override from .._agents import AgentProtocol, ChatAgent from .._threads import AgentThread @@ -192,7 +192,8 @@ async def from_messages( self._cache = normalize_messages_input(messages) await self._run_agent_and_emit(ctx) - async def snapshot_state(self) -> dict[str, Any]: + @override + async def on_checkpoint_save(self): """Capture current executor state for checkpointing. NOTE: if the thread storage is on the server side, the full thread state @@ -209,9 +210,6 @@ async def snapshot_state(self) -> dict[str, Any]: client_module = self._agent.chat_client.__class__.__module__ if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module: - # TODO(TaoChenOSU): update this warning when we surface the hooks for - # custom executor checkpointing. - # https://github.com/microsoft/agent-framework/issues/1816 logger.warning( "Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. " "Currently, checkpointing does not capture messages from server-side threads " @@ -228,7 +226,8 @@ async def snapshot_state(self) -> dict[str, Any]: "agent_thread": serialized_thread, } - async def restore_state(self, state: dict[str, Any]) -> None: + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore executor state from checkpoint. Args: diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index 5752febab5..dde104cce8 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence -from typing import Any +from typing import Any, override from .._types import ChatMessage from ._executor import Executor @@ -210,11 +210,12 @@ def _check_round_limit(self) -> bool: # State persistence (shared across all patterns) - def snapshot_state(self) -> dict[str, Any]: + @override + async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current orchestrator state for checkpointing. Default implementation uses OrchestrationState to serialize common state. - Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data. + Subclasses should override `on_checkpoint_save()` to add pattern-specific data. Returns: Serialized state dict @@ -238,7 +239,8 @@ def _snapshot_pattern_metadata(self) -> dict[str, Any]: """ return {} - def restore_state(self, state: dict[str, Any]) -> None: + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore orchestrator state from checkpoint. Default implementation uses OrchestrationState to deserialize common state. diff --git a/python/packages/core/agent_framework/_workflows/_conversation_history.py b/python/packages/core/agent_framework/_workflows/_conversation_history.py index 7e19671b27..52d7d99c74 100644 --- a/python/packages/core/agent_framework/_workflows/_conversation_history.py +++ b/python/packages/core/agent_framework/_workflows/_conversation_history.py @@ -6,9 +6,7 @@ dictionary snapshots so orchestrators can share logic without new mixins. """ -import json -from collections.abc import Mapping, Sequence -from typing import Any +from collections.abc import Sequence from .._types import ChatMessage @@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage: """Attach `fallback` author if message is missing `author_name`.""" message.author_name = message.author_name or fallback return message - - -def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]: - """Build an immutable snapshot for checkpoint storage.""" - if hasattr(conversation, "to_dict"): - result = conversation.to_dict() # type: ignore[attr-defined] - if isinstance(result, dict): - return result # type: ignore[return-value] - if isinstance(result, Mapping): - return dict(result) # type: ignore[arg-type] - serialisable: list[dict[str, Any]] = [] - for message in conversation: - if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined] - msg_dict = message.to_dict() # type: ignore[attr-defined] - serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type] - elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined] - json_payload = message.to_json() # type: ignore[attr-defined] - parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload - serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type] - else: - serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type] - return {"messages": serialisable} diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 1563dd7c53..45fb831711 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -460,6 +460,32 @@ def _find_handler(self, message: Any) -> Callable[[Any, WorkflowContext[Any, Any return self._handlers[message_type] raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.") + async def on_checkpoint_save(self) -> dict[str, Any]: + """Hook called when the workflow is being saved to a checkpoint. + + Override this method in subclasses to implement custom logic that should + return state to be saved in the checkpoint. + + The returned state dictionary will be passed to `on_checkpoint_restore` + when the workflow is restored from the checkpoint. The dictionary should + only contain JSON-serializable data. + + Returns: + A state dictionary to be saved during checkpointing. + """ + return {} + + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Hook called when the workflow is restored from a checkpoint. + + Override this method in subclasses to implement custom logic that should + run when the workflow is restored from a checkpoint. + + Args: + state: The state dictionary that was saved during checkpointing. + """ + ... + # endregion: Executor diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index c29e3f55ad..4d35b30444 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -343,7 +343,7 @@ async def handle_agent_response( # Update current agent when handoff occurs self._current_agent_id = target logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.") - await self._persist_state(ctx) + # Clean tool-related content before sending to next agent cleaned = clean_conversation_for_handoff(conversation) request = AgentExecutorRequest(messages=cleaned, should_respond=True) @@ -360,7 +360,6 @@ async def handle_agent_response( f"Agent '{source}' responded without handoff. " f"Requesting user input. Return-to-previous: {self._return_to_previous}" ) - await self._persist_state(ctx) if await self._check_termination(): # Clean the output conversation for display @@ -388,7 +387,6 @@ async def handle_user_input( """Receive full conversation with new user input from gateway, update history, trim for agent.""" # Update authoritative conversation self._conversation = list(message.full_conversation) - await self._persist_state(ctx) # Check termination before sending to agent if await self._check_termination(): @@ -473,11 +471,6 @@ def _conversation_from_response(self, response: AgentExecutorResponse) -> list[C ) return list(conversation) - async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None: - """Store authoritative conversation snapshot without losing rich metadata.""" - state_payload = self.snapshot_state() - await ctx.set_executor_state(state_payload) - def _snapshot_pattern_metadata(self) -> dict[str, Any]: """Serialize pattern-specific state. diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 9d21391ad8..3c4c36b757 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -10,7 +10,7 @@ from collections.abc import AsyncIterable, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, Protocol, TypeVar, Union, cast +from typing import Any, Protocol, TypeVar, Union, cast, override from uuid import uuid4 from agent_framework import ( @@ -673,11 +673,11 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatM """Prepare the final answer.""" ... - def snapshot_state(self) -> dict[str, Any]: + def on_checkpoint_save(self) -> dict[str, Any]: """Serialize runtime state for checkpointing.""" return {} - def restore_state(self, state: dict[str, Any]) -> None: + def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore runtime state from checkpoint data.""" return @@ -695,15 +695,13 @@ class StandardMagenticManager(MagenticManagerBase): task_ledger: _MagenticTaskLedger | None - def snapshot_state(self) -> dict[str, Any]: - state = super().snapshot_state() + def on_checkpoint_save(self) -> dict[str, Any]: + state: dict[str, Any] = {} if self.task_ledger is not None: - state = dict(state) state["task_ledger"] = self.task_ledger.to_dict() return state - def restore_state(self, state: dict[str, Any]) -> None: - super().restore_state(state) + def on_checkpoint_restore(self, state: dict[str, Any]) -> None: ledger = state.get("task_ledger") if ledger is not None: try: @@ -997,7 +995,6 @@ def __init__( # Terminal state marker to stop further processing after completion/limits self._terminated = False # Tracks whether checkpoint state has been applied for this run - self._state_restored = False def _get_author_name(self) -> str: """Get the magentic manager name for orchestrator-generated messages.""" @@ -1036,7 +1033,8 @@ async def _emit_orchestrator_message( ) await ctx.add_event(event) - def snapshot_state(self) -> dict[str, Any]: + @override + async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current orchestrator state for checkpointing. Uses OrchestrationState for structure but maintains Magentic's complex metadata @@ -1055,14 +1053,14 @@ def snapshot_state(self) -> dict[str, Any]: state["magentic_context"] = self._context.to_dict() if self._task_ledger is not None: state["task_ledger"] = _message_to_payload(self._task_ledger) - manager_state: dict[str, Any] | None = None + with contextlib.suppress(Exception): - manager_state = self._manager.snapshot_state() - if manager_state: - state["manager_state"] = manager_state + state["manager_state"] = self._manager.on_checkpoint_save() + return state - def restore_state(self, state: dict[str, Any]) -> None: + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore orchestrator state from checkpoint. Maintains backward compatibility with existing Magentic checkpoints @@ -1112,7 +1110,7 @@ def restore_state(self, state: dict[str, Any]) -> None: manager_state = state.get("manager_state") if manager_state is not None: try: - self._manager.restore_state(manager_state) + self._manager.on_checkpoint_restore(manager_state) except Exception as exc: # pragma: no cover logger.warning("Failed to restore manager state: %s", exc) @@ -1142,49 +1140,6 @@ def _reconcile_restored_participants(self) -> None: for name, description in expected.items(): restored[name] = description - def _snapshot_pattern_metadata(self) -> dict[str, Any]: - """Serialize pattern-specific state. - - Magentic uses custom snapshot_state() instead of base class hooks. - This method exists to satisfy the base class contract. - - Returns: - Empty dict (Magentic manages its own state) - """ - return {} - - def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: - """Restore pattern-specific state. - - Magentic uses custom restore_state() instead of base class hooks. - This method exists to satisfy the base class contract. - - Args: - metadata: Pattern-specific state dict (ignored) - """ - pass - - async def _ensure_state_restored( - self, - context: WorkflowContext[Any, Any], - ) -> None: - if self._state_restored and self._context is not None: - return - state = await context.get_executor_state() - if not state: - self._state_restored = True - return - if not isinstance(state, dict): - self._state_restored = True - return - try: - self.restore_state(state) - except Exception as exc: # pragma: no cover - logger.warning("Magentic Orchestrator: Failed to apply checkpoint state: %s", exc, exc_info=True) - raise - else: - self._state_restored = True - @handler async def handle_start_message( self, @@ -1204,7 +1159,7 @@ async def handle_start_message( ) if message.messages: self._context.chat_history.extend(message.messages) - self._state_restored = True + # Non-streaming callback for the orchestrator receipt of the task await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK) @@ -1269,7 +1224,7 @@ async def handle_response_message( """Handle responses from agents.""" if getattr(self, "_terminated", False): return - await self._ensure_state_restored(context) + if self._context is None: raise RuntimeError("Magentic Orchestrator: Received response but not initialized") @@ -1301,7 +1256,7 @@ async def handle_plan_review_response( ) -> None: if getattr(self, "_terminated", False): return - await self._ensure_state_restored(context) + if self._context is None: return @@ -1636,9 +1591,9 @@ def __init__( self._agent = agent self._agent_id = agent_id self._chat_history: list[ChatMessage] = [] - self._state_restored = False - def snapshot_state(self) -> dict[str, Any]: + @override + async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current executor state for checkpointing. Returns: @@ -1650,7 +1605,8 @@ def snapshot_state(self) -> dict[str, Any]: "chat_history": encode_chat_messages(self._chat_history), } - def restore_state(self, state: dict[str, Any]) -> None: + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore executor state from checkpoint. Args: @@ -1668,24 +1624,6 @@ def restore_state(self, state: dict[str, Any]) -> None: else: self._chat_history = [] - async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None: - if self._state_restored and self._chat_history: - return - state = await context.get_executor_state() - if not state: - self._state_restored = True - return - if not isinstance(state, dict): - self._state_restored = True - return - try: - self.restore_state(state) - except Exception as exc: # pragma: no cover - logger.warning("Agent %s: Failed to apply checkpoint state: %s", self._agent_id, exc, exc_info=True) - raise - else: - self._state_restored = True - @handler async def handle_response_message( self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage] @@ -1693,8 +1631,6 @@ async def handle_response_message( """Handle response message (task ledger broadcast).""" logger.debug("Agent %s: Received response message", self._agent_id) - await self._ensure_state_restored(context) - # Check if this message is intended for this agent if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast: # Message is targeted to a different agent, ignore it @@ -1735,8 +1671,6 @@ async def handle_request_message( logger.info("Agent %s: Received request to respond", self._agent_id) - await self._ensure_state_restored(context) - # Add persona adoption message with appropriate role persona_role = self._get_persona_adoption_role() persona_msg = ChatMessage( @@ -1783,7 +1717,6 @@ def reset(self) -> None: """Reset the internal chat history of the agent (internal operation).""" logger.debug("Agent %s: Resetting chat history", self._agent_id) self._chat_history.clear() - self._state_restored = True async def _emit_agent_delta_event( self, diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 51ff79a864..0c8cc7f6ee 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -183,8 +183,8 @@ async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | Non return None try: - # Auto-snapshot executor states - await self._auto_snapshot_executor_states() + # Snapshot executor states + await self._save_executor_states() checkpoint_category = "initial" if checkpoint_type == "after_initial_execution" else "superstep" metadata = { "superstep": self._iteration, @@ -203,41 +203,6 @@ async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | Non logger.warning(f"Failed to create {checkpoint_type} checkpoint: {e}") return None - async def _auto_snapshot_executor_states(self) -> None: - """Populate executor state by calling snapshot hooks on executors if available. - - TODO(@taochen#1614): this method is potentially problematic if executors also call - set_executor_state on the context directly. We should clarify the intended usage - pattern for executor state management. - - Convention: - - If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it. - - Else if it has a plain attribute `state` that is a dict, use that. - Only JSON-serializable dicts should be provided by executors. - """ - for exec_id, executor in self._executors.items(): - state_dict: dict[str, Any] | None = None - snapshot = getattr(executor, "snapshot_state", None) - try: - if callable(snapshot): - maybe = snapshot() - if asyncio.iscoroutine(maybe): # type: ignore[arg-type] - maybe = await maybe # type: ignore[assignment] - if isinstance(maybe, dict): - state_dict = maybe # type: ignore[assignment] - else: - state_attr = getattr(executor, "state", None) - if isinstance(state_attr, dict): - state_dict = state_attr # type: ignore[assignment] - except Exception as ex: # pragma: no cover - logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}") - - if state_dict is not None: - try: - await self._set_executor_state(exec_id, state_dict) - except Exception as ex: # pragma: no cover - logger.debug(f"Failed to persist state for executor {exec_id}: {ex}") - async def restore_from_checkpoint( self, checkpoint_id: str, @@ -300,7 +265,59 @@ async def restore_from_checkpoint( logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}") return False + async def _save_executor_states(self) -> None: + """Populate executor state by calling checkpoint hooks on executors. + + TODO(@taochen#1614): this method is potentially problematic if executors also call + set_executor_state on the context directly. We should clarify the intended usage + pattern for executor state management. + + Backward compatibility behavior: + - If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it. + - Else if it has a plain attribute `state` that is a dict, use that. + + Updated behavior: + - Executors should implement `on_checkpoint_save(self) -> dict` to provide state. + + Only JSON-serializable dicts should be provided by executors. + """ + for exec_id, executor in self._executors.items(): + # TODO(@taochen) Remove backward compatibility + state_dict: dict[str, Any] | None = None + snapshot = getattr(executor, "snapshot_state", None) + try: + if callable(snapshot): + maybe = snapshot() + if asyncio.iscoroutine(maybe): # type: ignore[arg-type] + maybe = await maybe # type: ignore[assignment] + if isinstance(maybe, dict): + state_dict = maybe # type: ignore[assignment] + else: + state_attr = getattr(executor, "state", None) + if isinstance(state_attr, dict): + state_dict = state_attr # type: ignore[assignment] + except Exception as ex: # pragma: no cover + logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}") + + # Updated behavior: use on_checkpoint_save() method + state_dict = state_dict or {} + state_dict |= await executor.on_checkpoint_save() + + try: + await self._set_executor_state(exec_id, state_dict) + except Exception as ex: # pragma: no cover + logger.debug(f"Failed to persist state for executor {exec_id}: {ex}") + async def _restore_executor_states(self) -> None: + """Restore executor state by calling restore hooks on executors. + + Backward compatibility behavior: + - If an executor defines an async or sync method `restore_state(self, state: dict)`, use it. + - Else, skip restoration for that executor. + + Updated behavior: + - Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state. + """ has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY) if not has_executor_states: return @@ -309,16 +326,17 @@ async def _restore_executor_states(self) -> None: if not isinstance(executor_states, dict): raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.") - for executor_id, state in executor_states.items(): + for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType] if not isinstance(executor_id, str): raise ValueError("Executor ID in executor states is not a string. Unable to restore.") - if not isinstance(state, dict): - raise ValueError(f"Executor state for {executor_id} is not a dictionary. Unable to restore.") + if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType] + raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.") executor = self._executors.get(executor_id) if not executor: raise ValueError(f"Executor {executor_id} not found during state restoration.") + # TODO(@taochen) Remove backward compatibility restored = False restore_method = getattr(executor, "restore_state", None) try: @@ -330,6 +348,13 @@ async def _restore_executor_states(self) -> None: except Exception as ex: # pragma: no cover - defensive raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex + # Updated behavior: use on_checkpoint_restore() method + try: + await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType] + restored = True + except Exception as ex: # pragma: no cover - defensive + raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex + if not restored: logger.debug(f"Executor {executor_id} does not support state restoration; skipping.") From 66e0f4fd9abaf9ea5dfd41bd29f02ace777e3b80 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 11 Nov 2025 11:36:45 -0800 Subject: [PATCH 2/9] Deprecate get_executor_state and set_executor_state --- .../_workflows/_agent_executor.py | 2 +- .../agent_framework/_workflows/_runner.py | 4 +- .../_workflows/_workflow_context.py | 8 +- .../_workflows/_workflow_executor.py | 112 ++++++------------ .../tests/workflow/test_agent_executor.py | 8 +- .../core/tests/workflow/test_handoff.py | 4 +- .../core/tests/workflow/test_magentic.py | 13 +- .../packages/devui/tests/test_checkpoints.py | 9 +- .../devui/spam_workflow/workflow.py | 70 ++++------- .../checkpoint_with_human_in_the_loop.py | 38 +++--- .../checkpoint/checkpoint_with_resume.py | 46 ++----- .../checkpoint/sub_workflow_checkpoint.py | 28 +++-- 12 files changed, 135 insertions(+), 207 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index a00623a258..7302d3d911 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -193,7 +193,7 @@ async def from_messages( await self._run_agent_and_emit(ctx) @override - async def on_checkpoint_save(self): + async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current executor state for checkpointing. NOTE: if the thread storage is on the server side, the full thread state diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 0c8cc7f6ee..3a36c85782 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -282,7 +282,7 @@ async def _save_executor_states(self) -> None: Only JSON-serializable dicts should be provided by executors. """ for exec_id, executor in self._executors.items(): - # TODO(@taochen) Remove backward compatibility + # TODO(@taochen): Remove backward compatibility state_dict: dict[str, Any] | None = None snapshot = getattr(executor, "snapshot_state", None) try: @@ -336,7 +336,7 @@ async def _restore_executor_states(self) -> None: if not executor: raise ValueError(f"Executor {executor_id} not found during state restoration.") - # TODO(@taochen) Remove backward compatibility + # TODO(@taochen): Remove backward compatibility restored = False restore_method = getattr(executor, "restore_state", None) try: diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index d2a3648298..1d1c8e8c49 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -8,7 +8,7 @@ from opentelemetry.propagate import inject from opentelemetry.trace import SpanKind -from typing_extensions import Never, TypeVar +from typing_extensions import Never, TypeVar, deprecated from ..observability import OtelAttr, create_workflow_span from ._const import EXECUTOR_STATE_KEY @@ -410,6 +410,9 @@ def shared_state(self) -> SharedState: """Get the shared state.""" return self._shared_state + @deprecated( + "Use executor member properties directly instead. For state sharing between executors, use `set_shared_state`." + ) async def set_executor_state(self, state: dict[str, Any]) -> None: """Store executor state in shared state under a reserved key. @@ -428,6 +431,9 @@ async def set_executor_state(self, state: dict[str, Any]) -> None: existing_states[self._executor_id] = state await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states) + @deprecated( + "Use executor member properties directly instead. For state sharing between executors, use `get_shared_state`." + ) async def get_executor_state(self) -> dict[str, Any] | None: """Retrieve previously persisted state for this executor, if any.""" has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 77acbc5a58..636687428a 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -1,11 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import contextlib import logging import uuid from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override if TYPE_CHECKING: from ._workflow import Workflow @@ -181,8 +180,7 @@ class WorkflowExecutor(Executor): # Includes all sub-workflow output types # Plus SubWorkflowRequestMessage if sub-workflow can make requests - output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable - ``` + output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable ## Error Handling WorkflowExecutor propagates sub-workflow failures: @@ -221,23 +219,10 @@ class WorkflowExecutor(Executor): ### Important Considerations **Shared Workflow Instance**: All concurrent executions use the same underlying workflow instance. - For proper isolation, ensure that: - - The wrapped workflow and its executors are stateless - - Executors use WorkflowContext state management instead of instance variables - - Any shared state is managed through WorkflowContext.get_shared_state/set_shared_state + For proper isolation, ensure that the wrapped workflow and its executors are stateless. .. code-block:: python - # Good: Stateless executor using context state - class StatelessExecutor(Executor): - @handler - async def process(self, data: str, ctx: WorkflowContext[str]) -> None: - # Use context state instead of instance variables - state = await ctx.get_executor_state() or {} - state["processed"] = data - await ctx.set_executor_state(state) - - # Avoid: Stateful executor with instance variables class StatefulExecutor(Executor): def __init__(self): @@ -246,23 +231,23 @@ def __init__(self): ## Integration with Parent Workflows Parent workflows can intercept sub-workflow requests: - ```python - class ParentExecutor(Executor): - @handler - async def handle_subworkflow_request( - self, - request: SubWorkflowRequestMessage, - ctx: WorkflowContext[SubWorkflowResponseMessage], - ) -> None: - # Handle request locally or forward to external source - if self.can_handle_locally(request): - # Send response back to sub-workflow - response = request.create_response(data="local response data") - await ctx.send_message(response, target_id=request.source_executor_id) - else: - # Forward to external handler - await ctx.request_info(request.source_event, response_type=request.source_event.response_type) - ``` + + .. code-block:: python + class ParentExecutor(Executor): + @handler + async def handle_subworkflow_request( + self, + request: SubWorkflowRequestMessage, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + # Handle request locally or forward to external source + if self.can_handle_locally(request): + # Send response back to sub-workflow + response = request.create_response(data="local response data") + await ctx.send_message(response, target_id=request.source_executor_id) + else: + # Forward to external handler + await ctx.request_info(request.source_event, response_type=request.source_event.response_type) ## Implementation Notes - Sub-workflows run to completion before processing their results @@ -296,7 +281,6 @@ def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = Fa self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext # Map request_id to execution_id for response routing self._request_to_execution: dict[str, str] = {} # request_id -> execution_id - self._state_loaded: bool = False @property def input_types(self) -> list[type[Any]]: @@ -362,8 +346,6 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) input_data: The input data to send to the sub-workflow. ctx: The workflow context from the parent. """ - await self._ensure_state_loaded(ctx) - # Create execution context for this sub-workflow run execution_id = str(uuid.uuid4()) execution_context = ExecutionContext( @@ -405,8 +387,6 @@ async def handle_response(self, response: SubWorkflowResponseMessage, ctx: Workf response: The response to a previous request. ctx: The workflow context. """ - await self._ensure_state_loaded(ctx) - # Find the execution context for this request original_request = response.source_event execution_id = self._request_to_execution.get(original_request.request_id) @@ -434,8 +414,6 @@ async def handle_response(self, response: SubWorkflowResponseMessage, ctx: Workf # Accumulate the response in this execution's context execution_context.collected_responses[original_request.request_id] = response.data - await self._persist_execution_state(ctx) - # Check if we have all expected responses for this execution if len(execution_context.collected_responses) < execution_context.expected_response_count: logger.debug( @@ -459,25 +437,20 @@ async def handle_response(self, response: SubWorkflowResponseMessage, ctx: Workf if not execution_context.pending_requests: del self._execution_contexts[execution_id] - async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None: - if self._state_loaded: - return - - state: dict[str, Any] | None = None - try: - state = await ctx.get_executor_state() - except Exception: - state = None - - if isinstance(state, dict) and state: - with contextlib.suppress(Exception): - await self.restore_state(state) - self._state_loaded = True - else: - self._state_loaded = True + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Get the current state of the WorkflowExecutor for checkpointing purposes.""" + return { + "execution_contexts": { + execution_id: encode_checkpoint_value(execution_context) + for execution_id, execution_context in self._execution_contexts.items() + }, + "request_to_execution": dict(self._request_to_execution), + } - async def restore_state(self, state: dict[str, Any]) -> None: - """Restore pending request bookkeeping from a checkpoint snapshot.""" + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore the WorkflowExecutor state from a checkpoint snapshot.""" # Validate the state contains the right keys if "execution_contexts" not in state: raise KeyError("Missing 'execution_contexts' in WorkflowExecutor state.") @@ -529,23 +502,6 @@ async def restore_state(self, state: dict[str, Any]) -> None: for event in request_info_events ]) - self._state_loaded = True - - async def _persist_execution_state(self, ctx: WorkflowContext) -> None: - """Persist the state of the WorkflowExecutor for checkpointing purposes.""" - state = { - "execution_contexts": { - execution_id: encode_checkpoint_value(execution_context) - for execution_id, execution_context in self._execution_contexts.items() - }, - "request_to_execution": dict(self._request_to_execution), - } - - try: - await ctx.set_executor_state(state) - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}") - async def _process_workflow_result( self, result: WorkflowRunResult, @@ -635,5 +591,3 @@ async def _process_workflow_result( ) else: raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}") - - await self._persist_execution_state(ctx) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 3bda2fcaad..f37b3e0e2b 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -154,8 +154,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert thread_messages[1].text == "Initial response 1" -async def test_agent_executor_snapshot_and_restore_state_directly() -> None: - """Test AgentExecutor's snapshot_state and restore_state methods directly.""" +async def test_agent_executor_save_and_restore_state_directly() -> None: + """Test AgentExecutor's on_checkpoint_save and on_checkpoint_restore methods directly.""" # Create agent with thread containing messages agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent") thread = AgentThread(message_store=ChatMessageStore()) @@ -178,7 +178,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None: executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage] # Snapshot the state - state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType] + state = await executor.on_checkpoint_save() # Verify snapshot contains both cache and thread assert "cache" in state @@ -202,7 +202,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None: assert len(initial_thread_msgs) == 0 # Restore state - await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType] + await new_executor.on_checkpoint_restore(state) # Verify cache is restored restored_cache = new_executor._cache # type: ignore[reportPrivateUsage] diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index a799fb6f73..4cf33856aa 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -585,7 +585,7 @@ async def test_return_to_previous_state_serialization(): coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage] # Snapshot the state - state = coordinator.snapshot_state() + state = await coordinator.on_checkpoint_save() # Verify pattern metadata includes current_agent_id assert "metadata" in state @@ -603,7 +603,7 @@ async def test_return_to_previous_state_serialization(): ) # Restore state - coordinator2.restore_state(state) + await coordinator2.on_checkpoint_restore(state) # Verify current_agent_id was restored assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage] diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index eda0675361..66e816bf10 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -185,7 +185,6 @@ async def test_standard_manager_progress_ledger_and_fallback(): assert ledger2.is_request_satisfied.answer is False -@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()") async def test_magentic_workflow_plan_review_approval_to_completion(): manager = FakeManager(max_round_count=10) wf = ( @@ -204,7 +203,7 @@ async def test_magentic_workflow_plan_review_approval_to_completion(): completed = False output: ChatMessage | None = None - async for ev in wf.run_stream( + async for ev in wf.send_responses_streaming( responses={req_event.request_id: MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE)} ): if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: @@ -218,7 +217,6 @@ async def test_magentic_workflow_plan_review_approval_to_completion(): assert isinstance(output, ChatMessage) -@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()") async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds(): class CountingManager(FakeManager): # Declare as a model field so assignment is allowed under Pydantic @@ -250,7 +248,7 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ # Reply APPROVE with comments (no edited text). Expect one replan and no second review round. saw_second_review = False completed = False - async for ev in wf.run_stream( + async for ev in wf.send_responses_streaming( responses={ req_event.request_id: MagenticPlanReviewReply( decision=MagenticPlanReviewDecision.APPROVE, @@ -298,7 +296,6 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result(): assert data.role == Role.ASSISTANT -@pytest.mark.skip(reason="Response handling refactored - send_responses_streaming no longer exists") async def test_magentic_checkpoint_resume_round_trip(): storage = InMemoryCheckpointStorage() @@ -369,7 +366,7 @@ async def _noop(self, message: object, ctx: WorkflowContext[object]) -> None: # pass -def test_magentic_agent_executor_snapshot_roundtrip(): +async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip(): backing_executor = _DummyExec("backing") agent_exec = MagenticAgentExecutor(backing_executor, "agentA") agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage] @@ -377,10 +374,10 @@ def test_magentic_agent_executor_snapshot_roundtrip(): ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"), ]) - state = agent_exec.snapshot_state() + state = await agent_exec.on_checkpoint_save() restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA") - restored_executor.restore_state(state) + await restored_executor.on_checkpoint_restore(state) assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage] assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage] diff --git a/python/packages/devui/tests/test_checkpoints.py b/python/packages/devui/tests/test_checkpoints.py index 473b4faac0..838195fdd3 100644 --- a/python/packages/devui/tests/test_checkpoints.py +++ b/python/packages/devui/tests/test_checkpoints.py @@ -37,10 +37,14 @@ class WorkflowHILRequest: class WorkflowTestExecutor(Executor): """Test executor with HIL.""" + def __init__(self, id: str) -> None: + super().__init__(id=id) + self._data_value: str | None = None + @handler async def process(self, data: WorkflowTestData, ctx: WorkflowContext) -> None: """Process data and request approval.""" - await ctx.set_executor_state({"data_value": data.value}) + self._data_value = data.value # Request HIL (checkpoint created here) await ctx.request_info(request_data=WorkflowHILRequest(question=f"Approve {data.value}?"), response_type=str) @@ -50,8 +54,7 @@ async def handle_response( self, original_request: WorkflowHILRequest, response: str, ctx: WorkflowContext[str] ) -> None: """Handle HIL response.""" - state = await ctx.get_executor_state() or {} - value = state.get("data_value", "") + value = self._data_value or "" await ctx.send_message(f"{value}_approved" if response.lower() == "yes" else f"{value}_rejected") diff --git a/python/samples/getting_started/devui/spam_workflow/workflow.py b/python/samples/getting_started/devui/spam_workflow/workflow.py index 020d22a4c7..73be349cc6 100644 --- a/python/samples/getting_started/devui/spam_workflow/workflow.py +++ b/python/samples/getting_started/devui/spam_workflow/workflow.py @@ -17,7 +17,7 @@ import asyncio import logging from dataclasses import dataclass -from typing import Literal, Annotated +from typing import Literal from agent_framework import ( Case, @@ -31,9 +31,11 @@ from pydantic import BaseModel, Field from typing_extensions import Never + # Define response model with clear user guidance class SpamDecision(BaseModel): """User's decision on whether the email is spam.""" + decision: Literal["spam", "not spam"] = Field( description="Enter 'spam' to mark as spam, or 'not spam' to mark as legitimate" ) @@ -71,10 +73,11 @@ def __post_init__(self): class SpamApprovalRequest: """Human-in-the-loop approval request for spam classification.""" - email_message: str = "" - detected_as_spam: bool = False - confidence: float = 0.0 - reasons: str = "" + email_message: str + detected_as_spam: bool + confidence: float + reasons: list[str] + full_email_content: EmailContent @dataclass @@ -128,8 +131,6 @@ async def handle_email(self, email: EmailRequest, ctx: WorkflowContext[EmailCont await ctx.send_message(result) - - class SpamDetector(Executor): """Step 2: An executor that analyzes content and determines if a message is spam.""" @@ -139,7 +140,9 @@ def __init__(self, spam_keywords: list[str], id: str): self._spam_keywords = spam_keywords @handler - async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]) -> None: + async def handle_email_content( + self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest] + ) -> None: """Analyze email content and determine if the message is spam, then request human approval.""" await asyncio.sleep(2.0) # Simulate analysis and detection time @@ -186,25 +189,13 @@ async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowC is_spam = spam_score >= 0.5 - # Store detection result in executor state for later use - # Store minimal data needed (not complex objects that don't serialize well) - await ctx.set_executor_state({ - "original_message": email_content.original_message, - "cleaned_message": email_content.cleaned_message, - "word_count": email_content.word_count, - "has_suspicious_patterns": email_content.has_suspicious_patterns, - "is_spam": is_spam, - "ai_original_classification": is_spam, # Store original AI decision - "confidence_score": spam_score, - "spam_reasons": spam_reasons - }) - # Request human approval before proceeding using new API approval_request = SpamApprovalRequest( email_message=email_text[:200], # First 200 chars detected_as_spam=is_spam, confidence=spam_score, - reasons=", ".join(spam_reasons) if spam_reasons else "no specific reasons" + reasons=spam_reasons, + full_email_content=email_content, ) await ctx.request_info( @@ -214,20 +205,15 @@ async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowC @response_handler async def handle_human_response( - self, - original_request: SpamApprovalRequest, - response: SpamDecision, - ctx: WorkflowContext[SpamDetectorResponse] + self, original_request: SpamApprovalRequest, response: SpamDecision, ctx: WorkflowContext[SpamDetectorResponse] ) -> None: """Process human approval response and continue workflow.""" print(f"[SpamDetector] handle_human_response called with response: {response}") # Get stored detection result - state = await ctx.get_executor_state() or {} - print(f"[SpamDetector] Retrieved state: {state}") - ai_original = state.get("ai_original_classification", False) - confidence_score = state.get("confidence_score", 0.0) - spam_reasons = state.get("spam_reasons", []) + ai_original = original_request.detected_as_spam + confidence_score = original_request.confidence + spam_reasons = original_request.reasons # Parse human decision from the response model human_decision = response.decision.strip().lower() @@ -241,27 +227,21 @@ async def handle_human_response( # Default to AI decision if unclear is_spam = ai_original - # Reconstruct EmailContent from stored primitives - email_content = EmailContent( - original_message=state.get("original_message", ""), - cleaned_message=state.get("cleaned_message", ""), - word_count=state.get("word_count", 0), - has_suspicious_patterns=state.get("has_suspicious_patterns", False) - ) - result = SpamDetectorResponse( - email_content=email_content, + email_content=original_request.full_email_content, is_spam=is_spam, confidence_score=confidence_score, spam_reasons=spam_reasons, human_reviewed=True, human_decision=response.decision, - ai_original_classification=ai_original + ai_original_classification=ai_original, ) - print(f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True") + print( + f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True" + ) await ctx.send_message(result) - print(f"[SpamDetector] Message sent successfully") + print("[SpamDetector] Message sent successfully") class SpamHandler(Executor): @@ -427,7 +407,9 @@ async def handle_processing_result( spam_detector, [ Case(condition=lambda x: isinstance(x, SpamDetectorResponse) and x.is_spam, target=spam_handler), - Default(target=legitimate_message_handler), # Default handles non-spam and non-SpamDetectorResponse messages + Default( + target=legitimate_message_handler + ), # Default handles non-spam and non-SpamDetectorResponse messages ], ) .add_edge(spam_handler, final_processor) diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py index 9fb870bf01..7effcdd28e 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py @@ -3,6 +3,7 @@ import asyncio from dataclasses import dataclass from pathlib import Path +from typing import Any, override # NOTE: the Azure client imports above are real dependencies. When running this # sample outside of Azure-enabled environments you may wish to swap in the @@ -116,19 +117,19 @@ class ReviewGateway(Executor): def __init__(self, id: str, writer_id: str) -> None: super().__init__(id=id) self._writer_id = writer_id + self._iteration = 0 @handler async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None: # Capture the agent output so we can surface it to the reviewer and persist iterations. - draft = response.agent_run_response.text or "" - iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1 - await ctx.set_executor_state({"iteration": iteration, "last_draft": draft}) + self._iteration += 1 + # Emit a human approval request. await ctx.request_info( request_data=HumanApprovalRequest( prompt="Review the draft. Reply 'approve' or provide edit instructions.", - draft=draft, - iteration=iteration, + draft=response.agent_run_response.text, + iteration=self._iteration, ), response_type=str, ) @@ -142,28 +143,33 @@ async def on_human_feedback( ) -> None: # The `original_request` is the request we sent earlier that is now being answered. reply = feedback.strip() - state = await ctx.get_executor_state() or {} - draft = state.get("last_draft") or (original_request.draft or "") - if reply.lower() == "approve": + if len(reply) == 0 or reply.lower() == "approve": # Workflow is completed when the human approves. - await ctx.yield_output(draft) + await ctx.yield_output(original_request.draft) return # Any other response loops us back to the writer with fresh guidance. - guidance = reply or "Tighten the copy and emphasise customer benefit." - iteration = int(state.get("iteration", 1)) + 1 - await ctx.set_executor_state({"iteration": iteration, "last_draft": draft}) prompt = ( "Revise the launch note. Respond with the new copy only.\n\n" - f"Previous draft:\n{draft}\n\n" - f"Human guidance: {guidance}" + f"Previous draft:\n{original_request.draft}\n\n" + f"Human guidance: {reply}" ) await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True), target_id=self._writer_id, ) + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + # Save the current iteration count in executor state for checkpointing. + return {"iteration": self._iteration} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + # Restore the iteration count from executor state during checkpoint recovery. + self._iteration = state.get("iteration", 0) + def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow: """Assemble the workflow graph used by both the initial run and resume.""" @@ -247,10 +253,10 @@ async def run_interactive_session( else: if initial_message: print(f"\nStarting workflow with brief: {initial_message}\n") - event_stream = workflow.run_stream(initial_message) + event_stream = workflow.run_stream(message=initial_message) elif checkpoint_id: print("\nStarting workflow from checkpoint...\n") - event_stream = workflow.run_stream(checkpoint_id) + event_stream = workflow.run_stream(checkpoint_id=checkpoint_id) else: raise ValueError("Either initial_message or checkpoint_id must be provided") diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py index cb0c7705c5..d0848ca9e7 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py @@ -29,20 +29,19 @@ Sample: Checkpointing and Resuming a Workflow (with an Agent stage) Purpose: -This sample shows how to enable checkpointing at superstep boundaries, persist both -executor-local state and shared workflow state, and then resume execution from a specific -checkpoint. The workflow demonstrates a simple text-processing pipeline that includes -an LLM-backed AgentExecutor stage. +This sample shows how to enable checkpointing at superstep boundaries, persist shared +workflow state, and then resume execution from a specific checkpoint. The workflow +demonstrates a simple text-processing pipeline that includes an LLM-backed AgentExecutor +stage. Pipeline: -1) UpperCaseExecutor converts input to uppercase and records state. +1) UpperCaseExecutor converts input to uppercase. 2) ReverseTextExecutor reverses the string. 3) SubmitToLowerAgent prepares an AgentExecutorRequest for the lowercasing agent. 4) lower_agent (AgentExecutor) converts text to lowercase via Azure OpenAI. 5) FinalizeFromAgent yields the final result. What you learn: -- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state. - How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility. - How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder. - How to list and inspect checkpoints programmatically. @@ -66,21 +65,14 @@ class UpperCaseExecutor(Executor): """Uppercases the input text and persists both local and shared state.""" + def __init__(self, id: str): + super().__init__(id=id) + @handler async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None: result = text.upper() print(f"UpperCaseExecutor: '{text}' -> '{result}'") - # Persist executor-local state so it is captured in checkpoints - # and available after resume for observability or logic. - prev = await ctx.get_executor_state() or {} - count = int(prev.get("count", 0)) + 1 - await ctx.set_executor_state({ - "count": count, - "last_input": text, - "last_output": result, - }) - # Write to shared_state so downstream executors and any resumed runs can read it. await ctx.set_shared_state("original_input", text) await ctx.set_shared_state("upper_output", result) @@ -119,19 +111,8 @@ class FinalizeFromAgent(Executor): @handler async def finalize(self, response: AgentExecutorResponse, ctx: WorkflowContext[Any, str]) -> None: - result = response.agent_run_response.text or "" - - # Persist executor-local state for auditability when inspecting checkpoints. - prev = await ctx.get_executor_state() or {} - count = int(prev.get("count", 0)) + 1 - await ctx.set_executor_state({ - "count": count, - "last_output": result, - "final": True, - }) - # Yield the final result so external consumers see the final value. - await ctx.yield_output(result) + await ctx.yield_output(response.agent_run_response.text) class ReverseTextExecutor(Executor): @@ -142,15 +123,6 @@ async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None: result = text[::-1] print(f"ReverseTextExecutor: '{text}' -> '{result}'") - # Persist executor-local state so checkpoint inspection can reveal progress. - prev = await ctx.get_executor_state() or {} - count = int(prev.get("count", 0)) + 1 - await ctx.set_executor_state({ - "count": count, - "last_input": text, - "last_output": result, - }) - # Forward the reversed string to the next stage. await ctx.send_message(result) diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index 62c88bf49f..12b82d3b9d 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field, replace from datetime import datetime, timedelta from pathlib import Path +from typing import Any, override from agent_framework import ( Executor, @@ -205,6 +206,7 @@ class LaunchCoordinator(Executor): def __init__(self) -> None: super().__init__(id="launch_coordinator") + self._pending_requests: dict[str, SubWorkflowRequestMessage] = {} @handler async def kick_off(self, topic: str, ctx: WorkflowContext[DraftTask]) -> None: @@ -244,11 +246,9 @@ async def handler_sub_workflow_request( if not isinstance(request.source_event.data, ReviewRequest): raise TypeError(f"Expected 'ReviewRequest', got {type(request.source_event.data)}") - # Record the request to response matching + # Record the request for response matching review_request = request.source_event.data - executor_state = await ctx.get_executor_state() or {} - executor_state[review_request.id] = request - await ctx.set_executor_state(executor_state) + self._pending_requests[review_request.id] = request # Send the request without modification await ctx.request_info(request_data=review_request, response_type=str) @@ -265,17 +265,25 @@ async def handle_request_response( Note that the response must be sent back using SubWorkflowResponseMessage to route the response back to the sub-workflow. """ - executor_state = await ctx.get_executor_state() or {} - request_message = executor_state.pop(original_request.id, None) - - # Save the executor state back to the context - await ctx.set_executor_state(executor_state) + request_message = self._pending_requests.pop(original_request.id, None) if request_message is None: raise ValueError("No matching pending request found for the resource response") await ctx.send_message(request_message.create_response(response)) + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Capture any additional state needed for checkpointing.""" + return { + "pending_requests": self._pending_requests, + } + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore any additional state needed from checkpointing.""" + self._pending_requests = state.get("pending_requests", {}) + # --------------------------------------------------------------------------- # Workflow construction helpers @@ -357,7 +365,7 @@ async def main() -> None: request_info_event: RequestInfoEvent | None = None async for event in workflow2.run_stream( - resume_checkpoint.checkpoint_id, + checkpoint_id=resume_checkpoint.checkpoint_id, ): if isinstance(event, RequestInfoEvent): request_info_event = event From 27e5589480a06635ff4c570270fed1e68d419cff Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 11 Nov 2025 13:25:28 -0800 Subject: [PATCH 3/9] Fix tests and samples --- .../agent_framework/_workflows/_handoff.py | 24 ++------- .../agent_framework/_workflows/_magentic.py | 30 ++++++----- .../agent_framework/_workflows/_runner.py | 4 -- .../core/tests/workflow/test_handoff.py | 51 ------------------- .../core/tests/workflow/test_magentic.py | 12 +++-- .../checkpoint/sub_workflow_checkpoint.py | 4 +- 6 files changed, 27 insertions(+), 98 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 4d35b30444..78e938f217 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -18,7 +18,7 @@ import re from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field -from typing import Any +from typing import Any, override from agent_framework import ( AgentProtocol, @@ -307,15 +307,6 @@ async def handle_agent_response( ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput], ) -> None: """Process an agent's response and determine whether to route, request input, or terminate.""" - # Hydrate coordinator state (and detect new run) using checkpointable executor state - state = await ctx.get_executor_state() - if not state: - self._clear_conversation() - elif not self._get_conversation(): - restored = self._restore_conversation_from_state(state) - if restored: - self._conversation = list(restored) - source = ctx.get_source_executor_id() is_starting_agent = source == self._starting_agent_id @@ -471,6 +462,7 @@ def _conversation_from_response(self, response: AgentExecutorResponse) -> list[C ) return list(conversation) + @override def _snapshot_pattern_metadata(self) -> dict[str, Any]: """Serialize pattern-specific state. @@ -485,6 +477,7 @@ def _snapshot_pattern_metadata(self) -> dict[str, Any]: } return {} + @override def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: """Restore pattern-specific state. @@ -496,17 +489,6 @@ def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: if self._return_to_previous and "current_agent_id" in metadata: self._current_agent_id = metadata["current_agent_id"] - def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]: - """Rehydrate the coordinator's conversation history from checkpointed state. - - DEPRECATED: Use restore_state() instead. Kept for backward compatibility. - """ - from ._orchestration_state import OrchestrationState - - orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))} - temp_state = OrchestrationState.from_dict(orch_state_dict) - return list(temp_state.conversation) - def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None: """Merge top-level response metadata into the latest assistant message.""" if not agent_response.additional_properties: diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 3c4c36b757..e02d45ea46 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -695,20 +695,6 @@ class StandardMagenticManager(MagenticManagerBase): task_ledger: _MagenticTaskLedger | None - def on_checkpoint_save(self) -> dict[str, Any]: - state: dict[str, Any] = {} - if self.task_ledger is not None: - state["task_ledger"] = self.task_ledger.to_dict() - return state - - def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - ledger = state.get("task_ledger") - if ledger is not None: - try: - self.task_ledger = _MagenticTaskLedger.from_dict(ledger) - except Exception: # pragma: no cover - defensive - logger.warning("Failed to restore manager task ledger from checkpoint state") - def __init__( self, chat_client: ChatClientProtocol, @@ -938,6 +924,22 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatM author_name=response.author_name or MAGENTIC_MANAGER_NAME, ) + @override + def on_checkpoint_save(self) -> dict[str, Any]: + state: dict[str, Any] = {} + if self.task_ledger is not None: + state["task_ledger"] = self.task_ledger.to_dict() + return state + + @override + def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + ledger = state.get("task_ledger") + if ledger is not None: + try: + self.task_ledger = _MagenticTaskLedger.from_dict(ledger) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore manager task ledger from checkpoint state") + # endregion Magentic Manager diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 3a36c85782..9c90e8f949 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -268,10 +268,6 @@ async def restore_from_checkpoint( async def _save_executor_states(self) -> None: """Populate executor state by calling checkpoint hooks on executors. - TODO(@taochen#1614): this method is potentially problematic if executors also call - set_executor_state on the context directly. We should clarify the intended usage - pattern for executor state management. - Backward compatibility behavior: - If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it. - Else if it has a plain attribute `state` that is a dict, use that. diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index 4cf33856aa..5dfd7522df 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -288,57 +288,6 @@ def test_build_fails_without_participants(): HandoffBuilder().build() -async def test_multiple_runs_dont_leak_conversation(): - """Verify that running the same workflow multiple times doesn't leak conversation history.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist") - - workflow = ( - HandoffBuilder(participants=[triage, specialist]) - .set_coordinator("triage") - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) - .build() - ) - - # First run - events = await _drain(workflow.run_stream("First run message")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second message"})) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "First run should emit output" - - first_run_conversation = outputs[-1].data - assert isinstance(first_run_conversation, list) - first_run_conv_list = cast(list[ChatMessage], first_run_conversation) - first_run_user_messages = [msg for msg in first_run_conv_list if msg.role == Role.USER] - assert len(first_run_user_messages) == 2 - assert any("First run message" in msg.text for msg in first_run_user_messages if msg.text) - - # Second run - should start fresh, not include first run's messages - triage.calls.clear() - specialist.calls.clear() - - events = await _drain(workflow.run_stream("Second run different message")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Another message"})) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Second run should emit output" - - second_run_conversation = outputs[-1].data - assert isinstance(second_run_conversation, list) - second_run_conv_list = cast(list[ChatMessage], second_run_conversation) - second_run_user_messages = [msg for msg in second_run_conv_list if msg.role == Role.USER] - assert len(second_run_user_messages) == 2, ( - "Second run should have exactly 2 user messages, not accumulate first run" - ) - assert any("Second run different message" in msg.text for msg in second_run_user_messages if msg.text) - assert not any("First run message" in msg.text for msg in second_run_user_messages if msg.text), ( - "Second run should NOT contain first run's messages" - ) - - async def test_handoff_async_termination_condition() -> None: """Test that async termination conditions work correctly.""" termination_call_count = 0 diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 66e816bf10..037bca94c7 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -2,7 +2,7 @@ from collections.abc import AsyncIterable from dataclasses import dataclass -from typing import Any, cast +from typing import Any, cast, override import pytest @@ -101,8 +101,9 @@ class FakeManager(MagenticManagerBase): next_speaker_name: str = "agentA" instruction_text: str = "Proceed with step 1" - def snapshot_state(self) -> dict[str, Any]: - state = super().snapshot_state() + @override + def on_checkpoint_save(self) -> dict[str, Any]: + state = super().on_checkpoint_save() if self.task_ledger is not None: state = dict(state) state["task_ledger"] = { @@ -111,8 +112,9 @@ def snapshot_state(self) -> dict[str, Any]: } return state - def restore_state(self, state: dict[str, Any]) -> None: - super().restore_state(state) + @override + def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + super().on_checkpoint_restore(state) ledger_state = state.get("task_ledger") if isinstance(ledger_state, dict): ledger_dict = cast(dict[str, Any], ledger_state) diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index 12b82d3b9d..8fb8226933 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -364,9 +364,7 @@ async def main() -> None: workflow2 = build_parent_workflow(storage) request_info_event: RequestInfoEvent | None = None - async for event in workflow2.run_stream( - checkpoint_id=resume_checkpoint.checkpoint_id, - ): + async for event in workflow2.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id): if isinstance(event, RequestInfoEvent): request_info_event = event From 41f9c624c2c48f8a69d86a01ebca0d84d6a4f160 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 11 Nov 2025 13:38:56 -0800 Subject: [PATCH 4/9] Add doc strings --- .../core/agent_framework/_workflows/_executor.py | 5 +++++ .../core/agent_framework/_workflows/_workflow.py | 10 +++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 45fb831711..80df16592b 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -155,6 +155,11 @@ def sync_process(text: str, ctx: WorkflowContext[str]) -> None: that parent workflows can intercept. See WorkflowExecutor documentation for details on workflow composition patterns and request/response handling. + ## State Management + Executors can contain states that persist across workflow runs and checkpoints. Override the + `on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state + serialization and restoration logic. + ## Implementation Notes - Do not call `execute()` directly - it's invoked by the workflow engine - Do not override `execute()` - define handlers using decorators instead diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index e5fd02a611..a14542b2a6 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -109,9 +109,9 @@ class Workflow(DictConvertible): """A graph-based execution engine that orchestrates connected executors. ## Overview - A workflow executes a directed graph of executors connected via edge groups using a Pregel-like model, - running in supersteps until the graph becomes idle. Workflows are created using the - WorkflowBuilder class - do not instantiate this class directly. + A workflow executes a directed graph of executors connected via edge groups using a + Pregel-like model, running in supersteps until the graph becomes idle. Workflows + are created using the WorkflowBuilder class - do not instantiate this class directly. ## Execution Model Executors run in synchronized supersteps where each executor: @@ -142,6 +142,10 @@ class Workflow(DictConvertible): - HIL continuation: Provide `responses` to continue after RequestInfoExecutor requests - Runtime checkpointing: Provide `checkpoint_storage` to enable/override checkpointing for this run + ## State Management + Workflow instances contain states and states are preserved across calls to `run` and `run_stream`. + To execute multiple independent runs, create separate Workflow instances via WorkflowBuilder. + ## External Input Requests Executors within a workflow can request external input using `ctx.request_info()`: 1. Executor calls `ctx.request_info()` to request input From bd10dda244d8fc87661f1c92bd65bf6841b05c57 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 11 Nov 2025 17:02:21 -0800 Subject: [PATCH 5/9] Add sample --- .../agent_framework/_workflows/__init__.py | 6 + .../agent_framework/_workflows/__init__.pyi | 4 + .../agent_framework/_workflows/_events.py | 38 +- .../agent_framework/_workflows/_runner.py | 6 +- .../checkpoint/checkpoint_with_resume.py | 378 ++++++------------ 5 files changed, 166 insertions(+), 266 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 6c4948f0ea..cbb6de84e4 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -37,6 +37,8 @@ ExecutorFailedEvent, ExecutorInvokedEvent, RequestInfoEvent, + SuperStepCompletedEvent, + SuperStepStartedEvent, WorkflowErrorDetails, WorkflowEvent, WorkflowEventSource, @@ -152,6 +154,10 @@ "StandardMagenticManager", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", + "SuperStepCompletedEvent", + "SuperStepCompletedEvent", + "SuperStepStartedEvent", + "SuperStepStartedEvent", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", diff --git a/python/packages/core/agent_framework/_workflows/__init__.pyi b/python/packages/core/agent_framework/_workflows/__init__.pyi index 44247d685c..c9f8c6cb62 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.pyi +++ b/python/packages/core/agent_framework/_workflows/__init__.pyi @@ -35,6 +35,8 @@ from ._events import ( ExecutorFailedEvent, ExecutorInvokedEvent, RequestInfoEvent, + SuperStepCompletedEvent, + SuperStepStartedEvent, WorkflowErrorDetails, WorkflowEvent, WorkflowEventSource, @@ -148,6 +150,8 @@ __all__ = [ "StandardMagenticManager", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", + "SuperStepCompletedEvent", + "SuperStepStartedEvent", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 76ae7f8a4f..b681544876 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -294,6 +294,36 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})" +class SuperStepEvent(WorkflowEvent): + """Event triggered when a superstep starts or ends.""" + + def __init__(self, iteration: int, data: Any | None = None): + """Initialize the superstep event. + + Args: + iteration: The number of the superstep (1-based index). + data: Optional data associated with the superstep event. + """ + super().__init__(data) + self.iteration = iteration + + def __repr__(self) -> str: + """Return a string representation of the superstep event.""" + return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})" + + +class SuperStepStartedEvent(SuperStepEvent): + """Event triggered when a superstep starts.""" + + ... + + +class SuperStepCompletedEvent(SuperStepEvent): + """Event triggered when a superstep ends.""" + + ... + + class ExecutorEvent(WorkflowEvent): """Base class for executor events.""" @@ -310,17 +340,13 @@ def __repr__(self) -> str: class ExecutorInvokedEvent(ExecutorEvent): """Event triggered when an executor handler is invoked.""" - def __repr__(self) -> str: - """Return a string representation of the executor handler invoke event.""" - return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})" + ... class ExecutorCompletedEvent(ExecutorEvent): """Event triggered when an executor handler is completed.""" - def __repr__(self) -> str: - """Return a string representation of the executor handler complete event.""" - return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})" + ... class ExecutorFailedEvent(ExecutorEvent): diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 9c90e8f949..0aaa7af17b 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -11,7 +11,7 @@ from ._const import EXECUTOR_STATE_KEY from ._edge import EdgeGroup from ._edge_runner import EdgeRunner, create_edge_runner -from ._events import WorkflowEvent +from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent from ._executor import Executor from ._runner_context import ( Message, @@ -92,6 +92,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: while self._iteration < self._max_iterations: logger.info(f"Starting superstep {self._iteration + 1}") + yield SuperStepStartedEvent(iteration=self._iteration + 1) # Run iteration concurrently with live event streaming: we poll # for new events while the iteration coroutine progresses. @@ -126,6 +127,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: # Create checkpoint after each superstep iteration await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}") + yield SuperStepCompletedEvent(iteration=self._iteration) + + # Check for convergence: no more messages to process if not await self._ctx.has_messages(): break diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py index d0848ca9e7..0ea09e2f48 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py @@ -1,294 +1,154 @@ # Copyright (c) Microsoft. All rights reserved. -import asyncio -import os -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from agent_framework import ( - AgentExecutor, - AgentExecutorRequest, - AgentExecutorResponse, - ChatMessage, - Executor, - FileCheckpointStorage, - Role, - WorkflowBuilder, - WorkflowContext, - get_checkpoint_summary, - handler, -) -from agent_framework.azure import AzureOpenAIChatClient -from azure.identity import AzureCliCredential - -if TYPE_CHECKING: - from agent_framework import Workflow - from agent_framework._workflows._checkpoint import WorkflowCheckpoint - """ -Sample: Checkpointing and Resuming a Workflow (with an Agent stage) +Sample: Checkpointing and Resuming a Workflow Purpose: -This sample shows how to enable checkpointing at superstep boundaries, persist shared -workflow state, and then resume execution from a specific checkpoint. The workflow -demonstrates a simple text-processing pipeline that includes an LLM-backed AgentExecutor -stage. - -Pipeline: -1) UpperCaseExecutor converts input to uppercase. -2) ReverseTextExecutor reverses the string. -3) SubmitToLowerAgent prepares an AgentExecutorRequest for the lowercasing agent. -4) lower_agent (AgentExecutor) converts text to lowercase via Azure OpenAI. -5) FinalizeFromAgent yields the final result. +This sample shows how to enable checkpointing for a long-running workflow +that can be paused and resumed. What you learn: -- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility. -- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder. -- How to list and inspect checkpoints programmatically. -- How to interactively choose a checkpoint to resume from (instead of always resuming - from the most recent or a hard-coded one) using run_stream. -- How workflows complete by yielding outputs when idle, not via explicit completion events. +- How to configure checkpointing storage (InMemoryCheckpointStorage for testing) +- How to resume a workflow from a checkpoint after interruption +- How to implement executor state management with checkpoint hooks +- How to handle workflow interruptions and automatic recovery + +Pipeline: +This sample shows a workflow that computes factor pairs for numbers up to a given limit: +1) A start executor that receives the upper limit and creates the initial task +2) A worker executor that processes each number to find its factor pairs +3) The worker uses checkpoint hooks to save/restore its internal state Prerequisites: -- Azure AI or Azure OpenAI available for AzureOpenAIChatClient. -- Authentication with azure-identity via AzureCliCredential. Run az login locally. -- Filesystem access for writing JSON checkpoint files in a temp directory. +- Basic understanding of workflow concepts, including executors, edges, events, etc. """ -# Define the temporary directory for storing checkpoints. -# These files allow the workflow to be resumed later. -DIR = os.path.dirname(__file__) -TEMP_DIR = os.path.join(DIR, "tmp", "checkpoints") -os.makedirs(TEMP_DIR, exist_ok=True) - - -class UpperCaseExecutor(Executor): - """Uppercases the input text and persists both local and shared state.""" - - def __init__(self, id: str): - super().__init__(id=id) +import asyncio +from dataclasses import dataclass +from random import random +from typing import Any, override - @handler - async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None: - result = text.upper() - print(f"UpperCaseExecutor: '{text}' -> '{result}'") +from agent_framework import ( + Executor, + InMemoryCheckpointStorage, + SuperStepCompletedEvent, + WorkflowBuilder, + WorkflowCheckpoint, + WorkflowContext, + WorkflowOutputEvent, + handler, +) - # Write to shared_state so downstream executors and any resumed runs can read it. - await ctx.set_shared_state("original_input", text) - await ctx.set_shared_state("upper_output", result) - # Send transformed text to the next executor. - await ctx.send_message(result) +@dataclass +class ComputeTask: + """Task containing the list of numbers remaining to be processed.""" + remaining_numbers: list[int] -class SubmitToLowerAgent(Executor): - """Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility.""" - def __init__(self, id: str, agent_id: str): - super().__init__(id=id) - self._agent_id = agent_id +class StartExecutor(Executor): + """Initiates the workflow by providing the upper limit for factor pair computation.""" @handler - async def submit(self, text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: - # Demonstrate reading shared_state written by UpperCaseExecutor. - # Shared state survives across checkpoints and is visible to all executors. - orig = await ctx.get_shared_state("original_input") - upper = await ctx.get_shared_state("upper_output") - print(f"LowerAgent (shared_state): original_input='{orig}', upper_output='{upper}'") - - # Build a minimal, deterministic prompt for the AgentExecutor. - prompt = f"Convert the following text to lowercase. Return ONLY the transformed text.\n\nText: {text}" - - # Send to the AgentExecutor. should_respond=True instructs the agent to produce a reply. - await ctx.send_message( - AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True), - target_id=self._agent_id, - ) + async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None: + """Start the workflow with a list of numbers to process.""" + print(f"StartExecutor: Starting factor pair computation up to {upper_limit}") + await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1)))) -class FinalizeFromAgent(Executor): - """Consumes the AgentExecutorResponse and yields the final result.""" +class WorkerExecutor(Executor): + """Processes numbers to compute their factor pairs and manages executor state for checkpointing.""" - @handler - async def finalize(self, response: AgentExecutorResponse, ctx: WorkflowContext[Any, str]) -> None: - # Yield the final result so external consumers see the final value. - await ctx.yield_output(response.agent_run_response.text) - - -class ReverseTextExecutor(Executor): - """Reverses the input text and persists local state.""" + def __init__(self, id: str) -> None: + super().__init__(id=id) + self._composite_number_pairs: dict[int, list[tuple[int, int]]] = {} @handler - async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None: - result = text[::-1] - print(f"ReverseTextExecutor: '{text}' -> '{result}'") + async def compute( + self, + task: ComputeTask, + ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]], + ) -> None: + """Process the next number in the task, computing its factor pairs.""" + next_number = task.remaining_numbers.pop(0) + + print(f"WorkerExecutor: Computing factor pairs for {next_number}") + pairs: list[tuple[int, int]] = [] + for i in range(1, next_number): + if next_number % i == 0: + pairs.append((i, next_number // i)) + self._composite_number_pairs[next_number] = pairs + + if not task.remaining_numbers: + # All numbers processed - output the results + await ctx.yield_output(self._composite_number_pairs) + else: + # More numbers to process - continue with remaining task + await ctx.send_message(task) + + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Save the executor's internal state for checkpointing.""" + return {"composite_number_pairs": self._composite_number_pairs} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore the executor's internal state from a checkpoint.""" + self._composite_number_pairs = state.get("composite_number_pairs", {}) - # Forward the reversed string to the next stage. - await ctx.send_message(result) - -def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow": - # Instantiate the pipeline executors. - upper_case_executor = UpperCaseExecutor(id="upper-case") - reverse_text_executor = ReverseTextExecutor(id="reverse-text") - - # Configure the agent stage that lowercases the text. - chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - lower_agent = AgentExecutor( - chat_client.create_agent( - instructions=("You transform text to lowercase. Reply with ONLY the transformed text.") - ), - id="lower_agent", - ) - - # Bridge to the agent and terminalization stage. - submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id) - finalize = FinalizeFromAgent(id="finalize") - - # Build the workflow with checkpointing enabled. - return ( - WorkflowBuilder(max_iterations=5) - .add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse - .add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request - .add_edge(submit_lower, lower_agent) # Submit to AgentExecutor - .add_edge(lower_agent, finalize) # Agent output -> Finalize - .set_start_executor(upper_case_executor) # Entry point - .with_checkpointing(checkpoint_storage=checkpoint_storage) # Enable persistence - .build() +async def main(): + # Create workflow executors + start_executor = StartExecutor(id="start") + worker_executor = WorkerExecutor(id="worker") + + # Build workflow with checkpointing enabled + workflow_builder = ( + WorkflowBuilder() + .set_start_executor(start_executor) + .add_edge(start_executor, worker_executor) + .add_edge(worker_executor, worker_executor) # Self-loop for iterative processing ) - - -def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: - """Display human-friendly checkpoint metadata using framework summaries.""" - - if not checkpoints: - return - - print("\nCheckpoint summary:") - for cp in sorted(checkpoints, key=lambda c: c.timestamp): - summary = get_checkpoint_summary(cp) - msg_count = sum(len(v) for v in cp.messages.values()) - state_keys = sorted(summary.executor_ids) - orig = cp.shared_state.get("original_input") - upper = cp.shared_state.get("upper_output") - - line = ( - f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}" + checkpoint_storage = InMemoryCheckpointStorage() + workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage) + + # Run workflow with automatic checkpoint recovery + latest_checkpoint: WorkflowCheckpoint | None = None + while True: + workflow = workflow_builder.build() + + # Start from checkpoint or fresh execution + print(f"\n** Workflow {workflow.id} started **") + event_stream = ( + workflow.run_stream(message=10) + if latest_checkpoint is None + else workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id) ) - if summary.status: - line += f" | status={summary.status}" - line += f" | shared_state: original_input='{orig}', upper_output='{upper}'" - print(line) - - -async def main(): - # Clear existing checkpoints in this sample directory for a clean run. - checkpoint_dir = Path(TEMP_DIR) - for file in checkpoint_dir.glob("*.json"): # noqa: ASYNC240 - file.unlink() - - # Backing store for checkpoints written by with_checkpointing. - checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR) - - workflow = create_workflow(checkpoint_storage=checkpoint_storage) - - # Run the full workflow once and observe events as they stream. - print("Running workflow with initial message...") - async for event in workflow.run_stream(message="hello world"): - print(f"Event: {event}") - # Inspect checkpoints written during the run. - all_checkpoints = await checkpoint_storage.list_checkpoints() - if not all_checkpoints: - print("No checkpoints found!") - return - - # All checkpoints created by this run share the same workflow_id. - workflow_id = all_checkpoints[0].workflow_id - - _render_checkpoint_summary(all_checkpoints) - - # Offer an interactive selection of checkpoints to resume from. - sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp) - - print("\nAvailable checkpoints to resume from:") - for idx, cp in enumerate(sorted_cps): - summary = get_checkpoint_summary(cp) - line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}" - if summary.status: - line += f" status={summary.status}" - msg_count = sum(len(v) for v in cp.messages.values()) - line += f" messages={msg_count}" - print(line) - - user_input = input( # noqa: ASYNC250 - "\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: " - ).strip() - - if not user_input: - print("No checkpoint selected. Exiting without resuming.") - return - - chosen_cp_id: str | None = None - - # Try as index first - if user_input.isdigit(): - idx = int(user_input) - if 0 <= idx < len(sorted_cps): - chosen_cp_id = sorted_cps[idx].checkpoint_id - # Fall back to direct id match - if chosen_cp_id is None: - for cp in sorted_cps: - if cp.checkpoint_id.startswith(user_input): # allow prefix match for convenience - chosen_cp_id = cp.checkpoint_id + output: str | None = None + async for event in event_stream: + if isinstance(event, WorkflowOutputEvent): + output = event.data + break + if isinstance(event, SuperStepCompletedEvent) and random() < 0.5: + # Randomly simulate system interruptions + print("\n** Simulating workflow interruption. Stopping execution. **") break - if chosen_cp_id is None: - print("Input did not match any checkpoint. Exiting without resuming.") - return - - # You can reuse the same workflow graph definition and resume from a prior checkpoint. - # This second workflow instance does not enable checkpointing to show that resumption - # reads from stored state but need not write new checkpoints. - new_workflow = create_workflow(checkpoint_storage=checkpoint_storage) - - print(f"\nResuming from checkpoint: {chosen_cp_id}") - async for event in new_workflow.run_stream(checkpoint_id=chosen_cp_id, checkpoint_storage=checkpoint_storage): - print(f"Resumed Event: {event}") - - """ - Sample Output: - - Running workflow with initial message... - UpperCaseExecutor: 'hello world' -> 'HELLO WORLD' - Event: ExecutorInvokeEvent(executor_id=upper_case_executor) - Event: ExecutorCompletedEvent(executor_id=upper_case_executor) - ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH' - Event: ExecutorInvokeEvent(executor_id=reverse_text_executor) - Event: ExecutorCompletedEvent(executor_id=reverse_text_executor) - LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD' - Event: ExecutorInvokeEvent(executor_id=submit_lower) - Event: ExecutorInvokeEvent(executor_id=lower_agent) - Event: ExecutorInvokeEvent(executor_id=finalize) - - Checkpoint summary: - - dfc63e72-8e8d-454f-9b6d-0d740b9062e6 | label='after_initial_execution' | iter=0 | messages=1 | states=['upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD' - - a78c345a-e5d9-45ba-82c0-cb725452d91b | label='superstep_1' | iter=1 | messages=1 | states=['reverse_text_executor', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD' - - 637c1dbd-a525-4404-9583-da03980537a2 | label='superstep_2' | iter=2 | messages=0 | states=['finalize', 'lower_agent', 'reverse_text_executor', 'submit_lower', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD' - - Available checkpoints to resume from: - [0] id=dfc63e72-... iter=0 messages=1 label='after_initial_execution' - [1] id=a78c345a-... iter=1 messages=1 label='superstep_1' - [2] id=637c1dbd-... iter=2 messages=0 label='superstep_2' - - Enter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: 1 + # Find the latest checkpoint to resume from + all_checkpoints = await checkpoint_storage.list_checkpoints() + if not all_checkpoints: + raise RuntimeError("No checkpoints available to resume from.") + latest_checkpoint = all_checkpoints[-1] + print( + f"Checkpoint {latest_checkpoint.checkpoint_id}: " + f"(iter={latest_checkpoint.iteration_count}, messages={latest_checkpoint.messages})" + ) - Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b - LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD' - Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower) - Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent) - Resumed Event: ExecutorInvokeEvent(executor_id=finalize) - """ # noqa: E501 + if output is not None: + print(f"\nWorkflow completed successfully with output: {output}") + break if __name__ == "__main__": From 3344cdc6e5cc398c6715cad295dc40e81bb1e86f Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 11 Nov 2025 17:17:40 -0800 Subject: [PATCH 6/9] Fix import --- .../_workflows/_agent_executor.py | 4 +- .../_base_group_chat_orchestrator.py | 4 +- .../agent_framework/_workflows/_handoff.py | 4 +- .../agent_framework/_workflows/_magentic.py | 4 +- .../_workflows/_workflow_executor.py | 4 +- .../core/tests/workflow/test_magentic.py | 3 +- .../checkpoint_with_resume_from_failure.py | 217 ++++++++++++++++++ 7 files changed, 234 insertions(+), 6 deletions(-) create mode 100644 python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 729979e4fd..2c5b1b1c06 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -2,7 +2,9 @@ import logging from dataclasses import dataclass -from typing import Any, cast, override +from typing import Any, cast + +from typing_extensions import override from agent_framework import FunctionApprovalRequestContent, FunctionApprovalResponseContent diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index dde104cce8..779456136c 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -6,7 +6,9 @@ import logging from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence -from typing import Any, override +from typing import Any + +from typing_extensions import override from .._types import ChatMessage from ._executor import Executor diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 78e938f217..1a1acbb0c3 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -18,7 +18,9 @@ import re from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field -from typing import Any, override +from typing import Any + +from typing_extensions import override from agent_framework import ( AgentProtocol, diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index e02d45ea46..aad514c838 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -10,9 +10,11 @@ from collections.abc import AsyncIterable, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, Protocol, TypeVar, Union, cast, override +from typing import Any, Protocol, TypeVar, Union, cast from uuid import uuid4 +from typing_extensions import override + from agent_framework import ( AgentProtocol, AgentRunResponse, diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 636687428a..089ba299fc 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -4,7 +4,9 @@ import logging import uuid from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, override +from typing import TYPE_CHECKING, Any + +from typing_extensions import override if TYPE_CHECKING: from ._workflow import Workflow diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 037bca94c7..5e64898a53 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -2,9 +2,10 @@ from collections.abc import AsyncIterable from dataclasses import dataclass -from typing import Any, cast, override +from typing import Any, cast import pytest +from typing_extensions import override from agent_framework import ( AgentRunResponse, diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py new file mode 100644 index 0000000000..9ff5e292a8 --- /dev/null +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft. All rights reserved. + + +""" +Sample: Checkpointing and Resuming a Workflow + +Purpose: +This sample shows how to enable checkpointing for a long-running workflow +that may result in intermittent failures. If a failure occurs, the workflow +can be resumed from the last successful checkpoint rather than starting +over from the beginning. + +What you learn: +- How to configure different checkpointing storages: + - FilesystemCheckpointStorage for local filesystem storage of checkpoints + - InMemoryCheckpointStorage for ephemeral in-memory checkpointing (useful for testing) +- How to resume a workflow from a checkpoint +- How to inspect checkpoints programmatically + +Pipeline: +This sample shows a workflow that will run the samples in one of the directories under `samples/getting_started/`. +Note: For demonstration purposes, the workflow will not actually run all the samples, but will simulate +running them by printing their names. +1) A start executor that will read the list of sample files to run. +2) A distributor executor that will distribute the sample files to multiple worker executors. +3) Multiple worker executors that will simulate running the sample files. These executors will + randomly fail to demonstrate checkpointing and resuming. +4) A collector executor that will collect the results from the worker executors. + +Prerequisites: +- Basic understanding of workflow concepts, including executors, edges, events, etc. +""" + +import asyncio +import random +from typing import Any, Never, override + +from agent_framework import ( + Executor, + InMemoryCheckpointStorage, + WorkflowBuilder, + WorkflowContext, + WorkflowOutputEvent, + handler, +) +from anyio import Path + + +class StartExecutor(Executor): + """Executor that starts the workflow by providing a list of sample files to run.""" + + @handler + async def start(self, directory: str, ctx: WorkflowContext[list[str]]) -> None: + """Start the workflow by listing sample files in the given directory.""" + # Validate the input directory + directory_path = Path(directory) + if not await directory_path.is_dir(): + raise ValueError(f"Directory '{directory}' does not exist or is not a directory.") + # Validate that the directory is a subdirectory of samples/getting_started + expected_parent = Path(__file__).parent.parent.parent + if expected_parent not in directory_path.parents: + raise ValueError(f"Directory '{directory}' is not a subdirectory of '{expected_parent}'.") + + sample_files = await self._find_all_sample_files(directory_path) + print(f"StartExecutor: Found {len(sample_files)} sample files to run.") + + # Save the expected files in shared state for the collector to validate + await ctx.set_shared_state("expected_files", sample_files) + # Send the list of sample files to the next executor + await ctx.send_message(sample_files) + + async def _find_all_sample_files(self, directory: Path) -> list[str]: + """Recursively find all Python sample files in the given directory.""" + sample_files: list[str] = [] + async for file_path in directory.rglob("*.py"): + sample_files.append(str(file_path)) + return sample_files + + +class Distributor(Executor): + """Executor that distributes sample files to worker executors.""" + + def __init__(self, id: str, worker_ids: list[str]): + super().__init__(id=id) + self._worker_ids = worker_ids + + @handler + async def distribute(self, sample_files: list[str], ctx: WorkflowContext[list[str]]) -> None: + """Distribute sample files to worker executors.""" + distribution_lists: list[list[str]] = [[] for _ in range(len(self._worker_ids))] + for index, sample_file in enumerate(sample_files): + distribution_lists[index % len(self._worker_ids)].append(sample_file) + + for worker_id, files in zip(self._worker_ids, distribution_lists, strict=True): + print(f"Distributor: Distributing {len(files)} files to worker '{worker_id}'.") + await ctx.send_message(files, target_id=worker_id) + + +class WorkerExecutor(Executor): + """Executor that simulates running sample files.""" + + def __init__(self, id: str): + super().__init__(id=id) + self._processed_files: list[str] = [] + + @handler + async def run_samples(self, sample_files: list[str], ctx: WorkflowContext[list[str]]) -> None: + """Simulate running the sample files.""" + for sample_file in sample_files: + if sample_file in self._processed_files: + continue + + # Simulate some processing time + await asyncio.sleep(0.5) + # Simulate random failure + if random.random() < 0.2: # 20% chance to fail + break + + self._processed_files.append(sample_file) + + print(f"WorkerExecutor '{self.id}': Processed {len(self._processed_files)} of {len(sample_files)} files.") + await ctx.send_message(self._processed_files) + + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Save the executor state for checkpointing.""" + return {"processed_files": self._processed_files} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore the executor state from a checkpoint.""" + self._processed_files = state.get("processed_files", []) + + +class Collector(Executor): + """Executor that collects results from worker executors.""" + + def __init__(self, id: str): + super().__init__(id=id) + self._processed_files: list[str] = [] + + @handler + async def collect(self, worker_results: list[str], ctx: WorkflowContext[Never, str]) -> None: + """Collect results from worker executors.""" + expected_files: list[str] = await ctx.get_shared_state("expected_files") + + for file in worker_results: + if file not in expected_files: + raise ValueError(f"Collector: Received unexpected file result '{file}'.") + if file in self._processed_files: + raise ValueError(f"Collector: Duplicate result for file '{file}'.") + + self._processed_files.extend(worker_results) + + if len(self._processed_files) == len(expected_files): + print("Collector: All sample files have been processed.") + await ctx.yield_output("All samples processed successfully.") + + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Save the executor state for checkpointing.""" + return {"processed_files": self._processed_files} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore the executor state from a checkpoint.""" + self._processed_files = state.get("processed_files", []) + + +async def main(): + # Create the executors + start_executor = StartExecutor(id="start") + workers = [WorkerExecutor(id=f"worker_{i}") for i in range(3)] + distributor = Distributor(id="distributor", worker_ids=[w.id for w in workers]) + collector = Collector(id="collector") + + # Create the workflow builder with a start executor + workflow_builder = WorkflowBuilder().set_start_executor(start_executor).add_edge(start_executor, distributor) + # Connect the distributor to each worker and each worker to the collector + for worker in workers: + workflow_builder = workflow_builder.add_edge(distributor, worker) + workflow_builder = workflow_builder.add_edge(worker, collector) + # Add checkpointing with in-memory storage + checkpoint_storage = InMemoryCheckpointStorage() + workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage) + + # Build the workflow + workflow = workflow_builder.build() + + directory_to_run = str(Path(__file__).parent.parent) # samples/getting_started/workflows/ + + event_stream = workflow.run_stream(message=directory_to_run) + + while True: + async for event in event_stream: + if isinstance(event, WorkflowOutputEvent): + print(f"Workflow completed successfully with output: {event.data}") + break + + print( + "Workflow did not produce a final output, attempting to resume from the " + "second checkpoint that is created right after the distributor runs." + ) + + # Attempt to restore from the last checkpoint + all_checkpoints = await checkpoint_storage.list_checkpoints() + if not all_checkpoints: + raise RuntimeError("No checkpoints available to resume from.") + + # Checkpoints are ordered by creation time, so pick the second one + latest_checkpoint = all_checkpoints[1] + print(f"Resuming from checkpoint: {latest_checkpoint.checkpoint_id}") + event_stream = workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id) + + +if __name__ == "__main__": + asyncio.run(main()) From bd8f5f005eddff75662c094ecd1cfbd9ed716c4f Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 11 Nov 2025 21:33:49 -0800 Subject: [PATCH 7/9] Address comments and fix tests --- .../core/agent_framework/_workflows/__init__.py | 2 -- .../_workflows/_base_group_chat_orchestrator.py | 4 ++-- .../core/agent_framework/_workflows/_runner.py | 6 ++++++ python/packages/core/tests/workflow/test_workflow.py | 12 +++++++++--- .../workflows/checkpoint/sub_workflow_checkpoint.py | 1 + 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index cbb6de84e4..18dd674a92 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -155,8 +155,6 @@ "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SuperStepCompletedEvent", - "SuperStepCompletedEvent", - "SuperStepStartedEvent", "SuperStepStartedEvent", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index 779456136c..3209beacfb 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -217,7 +217,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current orchestrator state for checkpointing. Default implementation uses OrchestrationState to serialize common state. - Subclasses should override `on_checkpoint_save()` to add pattern-specific data. + Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data. Returns: Serialized state dict @@ -246,7 +246,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore orchestrator state from checkpoint. Default implementation uses OrchestrationState to deserialize common state. - Subclasses should override _restore_pattern_metadata() to restore pattern-specific data. + Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data. Args: state: Serialized state dict diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 0aaa7af17b..f80340b74a 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -279,6 +279,9 @@ async def _save_executor_states(self) -> None: Updated behavior: - Executors should implement `on_checkpoint_save(self) -> dict` to provide state. + Both behaviors can coexist; the updated behavior takes precedence, i.e., overrides + any state provided by the backward compatibility methods. + Only JSON-serializable dicts should be provided by executors. """ for exec_id, executor in self._executors.items(): @@ -317,6 +320,9 @@ async def _restore_executor_states(self) -> None: Updated behavior: - Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state. + + Both behaviors can coexist; the updated behavior takes precedence, i.e., overrides + any restoration performed by the backward compatibility methods. """ has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY) if not has_executor_states: diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 7f1a7fdce6..059c94803e 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -199,7 +199,10 @@ async def test_fan_out(): # Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent # executor_b will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore) - assert len(events) == 7 + # Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent + # This workflow will converge in 2 supersteps because executor_c will send one more message + # after executor_b completes + assert len(events) == 11 assert events.get_final_state() == WorkflowRunState.IDLE outputs = events.get_outputs() @@ -220,7 +223,9 @@ async def test_fan_out_multiple_completed_events(): # Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent # executor_b and executor_c will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore) - assert len(events) == 8 + # Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent + # This workflow will converge in 1 superstep because executor_a and executor_b will not send further messages + assert len(events) == 10 # Multiple outputs are expected from both executors outputs = events.get_outputs() @@ -246,7 +251,8 @@ async def test_fan_in(): # Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent # aggregator will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore) - assert len(events) == 9 + # Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent + assert len(events) == 13 assert events.get_final_state() == WorkflowRunState.IDLE outputs = events.get_outputs() diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index 8fb8226933..2d90db8c2c 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -206,6 +206,7 @@ class LaunchCoordinator(Executor): def __init__(self) -> None: super().__init__(id="launch_coordinator") + # Track pending requests to match responses self._pending_requests: dict[str, SubWorkflowRequestMessage] = {} @handler From 749b2bfc961d6fffe55a7395549f2a0acd618924 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 12 Nov 2025 10:23:37 -0800 Subject: [PATCH 8/9] Address comments --- .../agent_framework/_workflows/_magentic.py | 4 +- .../agent_framework/_workflows/_runner.py | 34 +-- .../_workflows/_workflow_context.py | 8 +- .../checkpoint/checkpoint_with_resume.py | 3 + .../checkpoint_with_resume_from_failure.py | 217 ------------------ 5 files changed, 32 insertions(+), 234 deletions(-) delete mode 100644 python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index aad514c838..67658adeee 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -1058,8 +1058,10 @@ async def on_checkpoint_save(self) -> dict[str, Any]: if self._task_ledger is not None: state["task_ledger"] = _message_to_payload(self._task_ledger) - with contextlib.suppress(Exception): + try: state["manager_state"] = self._manager.on_checkpoint_save() + except Exception as exc: + logger.warning("Failed to save manager state for checkpoint: %s\nSkipping...", exc) return state diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index f80340b74a..8cc01c23cf 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -279,14 +279,15 @@ async def _save_executor_states(self) -> None: Updated behavior: - Executors should implement `on_checkpoint_save(self) -> dict` to provide state. - Both behaviors can coexist; the updated behavior takes precedence, i.e., overrides - any state provided by the backward compatibility methods. + This method will try the backward compatibility behavior first; if that does not yield state, + it falls back to the updated behavior. Only JSON-serializable dicts should be provided by executors. """ for exec_id, executor in self._executors.items(): - # TODO(@taochen): Remove backward compatibility state_dict: dict[str, Any] | None = None + # Try backward compatibility behavior first + # TODO(@taochen): Remove backward compatibility snapshot = getattr(executor, "snapshot_state", None) try: if callable(snapshot): @@ -302,9 +303,12 @@ async def _save_executor_states(self) -> None: except Exception as ex: # pragma: no cover logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}") - # Updated behavior: use on_checkpoint_save() method - state_dict = state_dict or {} - state_dict |= await executor.on_checkpoint_save() + if state_dict is None: + # Try the updated behavior only if backward compatibility did not yield state + try: + state_dict = await executor.on_checkpoint_save() + except Exception as ex: # pragma: no cover + raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex try: await self._set_executor_state(exec_id, state_dict) @@ -321,8 +325,8 @@ async def _restore_executor_states(self) -> None: Updated behavior: - Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state. - Both behaviors can coexist; the updated behavior takes precedence, i.e., overrides - any restoration performed by the backward compatibility methods. + This method will try the backward compatibility behavior first; if that does not restore state, + it falls back to the updated behavior. """ has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY) if not has_executor_states: @@ -342,6 +346,7 @@ async def _restore_executor_states(self) -> None: if not executor: raise ValueError(f"Executor {executor_id} not found during state restoration.") + # Try backward compatibility behavior first # TODO(@taochen): Remove backward compatibility restored = False restore_method = getattr(executor, "restore_state", None) @@ -354,12 +359,13 @@ async def _restore_executor_states(self) -> None: except Exception as ex: # pragma: no cover - defensive raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex - # Updated behavior: use on_checkpoint_restore() method - try: - await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType] - restored = True - except Exception as ex: # pragma: no cover - defensive - raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex + if not restored: + # Try the updated behavior only if backward compatibility did not restore + try: + await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType] + restored = True + except Exception as ex: # pragma: no cover - defensive + raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex if not restored: logger.debug(f"Executor {executor_id} does not support state restoration; skipping.") diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index 1d1c8e8c49..dcf6715d62 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -411,7 +411,9 @@ def shared_state(self) -> SharedState: return self._shared_state @deprecated( - "Use executor member properties directly instead. For state sharing between executors, use `set_shared_state`." + "Override `on_checkpoint_save()` methods instead. " + "For cross-executor state sharing, use set_shared_state() instead. " + "This API will be removed after 12/01/2025." ) async def set_executor_state(self, state: dict[str, Any]) -> None: """Store executor state in shared state under a reserved key. @@ -432,7 +434,9 @@ async def set_executor_state(self, state: dict[str, Any]) -> None: await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states) @deprecated( - "Use executor member properties directly instead. For state sharing between executors, use `get_shared_state`." + "Override `on_checkpoint_restore()` methods instead. " + "For cross-executor state sharing, use get_shared_state() instead. " + "This API will be removed after 12/01/2025." ) async def get_executor_state(self) -> dict[str, Any] | None: """Retrieve previously persisted state for this executor, if any.""" diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py index 0ea09e2f48..0b7e874a1a 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume.py @@ -133,6 +133,9 @@ async def main(): break if isinstance(event, SuperStepCompletedEvent) and random() < 0.5: # Randomly simulate system interruptions + # The `SuperStepCompletedEvent` ensures we only interrupt after + # the current super-step is fully complete and checkpointed. + # If we interrupt mid-step, the workflow may resume from an earlier point. print("\n** Simulating workflow interruption. Stopping execution. **") break diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py deleted file mode 100644 index 9ff5e292a8..0000000000 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_resume_from_failure.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - - -""" -Sample: Checkpointing and Resuming a Workflow - -Purpose: -This sample shows how to enable checkpointing for a long-running workflow -that may result in intermittent failures. If a failure occurs, the workflow -can be resumed from the last successful checkpoint rather than starting -over from the beginning. - -What you learn: -- How to configure different checkpointing storages: - - FilesystemCheckpointStorage for local filesystem storage of checkpoints - - InMemoryCheckpointStorage for ephemeral in-memory checkpointing (useful for testing) -- How to resume a workflow from a checkpoint -- How to inspect checkpoints programmatically - -Pipeline: -This sample shows a workflow that will run the samples in one of the directories under `samples/getting_started/`. -Note: For demonstration purposes, the workflow will not actually run all the samples, but will simulate -running them by printing their names. -1) A start executor that will read the list of sample files to run. -2) A distributor executor that will distribute the sample files to multiple worker executors. -3) Multiple worker executors that will simulate running the sample files. These executors will - randomly fail to demonstrate checkpointing and resuming. -4) A collector executor that will collect the results from the worker executors. - -Prerequisites: -- Basic understanding of workflow concepts, including executors, edges, events, etc. -""" - -import asyncio -import random -from typing import Any, Never, override - -from agent_framework import ( - Executor, - InMemoryCheckpointStorage, - WorkflowBuilder, - WorkflowContext, - WorkflowOutputEvent, - handler, -) -from anyio import Path - - -class StartExecutor(Executor): - """Executor that starts the workflow by providing a list of sample files to run.""" - - @handler - async def start(self, directory: str, ctx: WorkflowContext[list[str]]) -> None: - """Start the workflow by listing sample files in the given directory.""" - # Validate the input directory - directory_path = Path(directory) - if not await directory_path.is_dir(): - raise ValueError(f"Directory '{directory}' does not exist or is not a directory.") - # Validate that the directory is a subdirectory of samples/getting_started - expected_parent = Path(__file__).parent.parent.parent - if expected_parent not in directory_path.parents: - raise ValueError(f"Directory '{directory}' is not a subdirectory of '{expected_parent}'.") - - sample_files = await self._find_all_sample_files(directory_path) - print(f"StartExecutor: Found {len(sample_files)} sample files to run.") - - # Save the expected files in shared state for the collector to validate - await ctx.set_shared_state("expected_files", sample_files) - # Send the list of sample files to the next executor - await ctx.send_message(sample_files) - - async def _find_all_sample_files(self, directory: Path) -> list[str]: - """Recursively find all Python sample files in the given directory.""" - sample_files: list[str] = [] - async for file_path in directory.rglob("*.py"): - sample_files.append(str(file_path)) - return sample_files - - -class Distributor(Executor): - """Executor that distributes sample files to worker executors.""" - - def __init__(self, id: str, worker_ids: list[str]): - super().__init__(id=id) - self._worker_ids = worker_ids - - @handler - async def distribute(self, sample_files: list[str], ctx: WorkflowContext[list[str]]) -> None: - """Distribute sample files to worker executors.""" - distribution_lists: list[list[str]] = [[] for _ in range(len(self._worker_ids))] - for index, sample_file in enumerate(sample_files): - distribution_lists[index % len(self._worker_ids)].append(sample_file) - - for worker_id, files in zip(self._worker_ids, distribution_lists, strict=True): - print(f"Distributor: Distributing {len(files)} files to worker '{worker_id}'.") - await ctx.send_message(files, target_id=worker_id) - - -class WorkerExecutor(Executor): - """Executor that simulates running sample files.""" - - def __init__(self, id: str): - super().__init__(id=id) - self._processed_files: list[str] = [] - - @handler - async def run_samples(self, sample_files: list[str], ctx: WorkflowContext[list[str]]) -> None: - """Simulate running the sample files.""" - for sample_file in sample_files: - if sample_file in self._processed_files: - continue - - # Simulate some processing time - await asyncio.sleep(0.5) - # Simulate random failure - if random.random() < 0.2: # 20% chance to fail - break - - self._processed_files.append(sample_file) - - print(f"WorkerExecutor '{self.id}': Processed {len(self._processed_files)} of {len(sample_files)} files.") - await ctx.send_message(self._processed_files) - - @override - async def on_checkpoint_save(self) -> dict[str, Any]: - """Save the executor state for checkpointing.""" - return {"processed_files": self._processed_files} - - @override - async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - """Restore the executor state from a checkpoint.""" - self._processed_files = state.get("processed_files", []) - - -class Collector(Executor): - """Executor that collects results from worker executors.""" - - def __init__(self, id: str): - super().__init__(id=id) - self._processed_files: list[str] = [] - - @handler - async def collect(self, worker_results: list[str], ctx: WorkflowContext[Never, str]) -> None: - """Collect results from worker executors.""" - expected_files: list[str] = await ctx.get_shared_state("expected_files") - - for file in worker_results: - if file not in expected_files: - raise ValueError(f"Collector: Received unexpected file result '{file}'.") - if file in self._processed_files: - raise ValueError(f"Collector: Duplicate result for file '{file}'.") - - self._processed_files.extend(worker_results) - - if len(self._processed_files) == len(expected_files): - print("Collector: All sample files have been processed.") - await ctx.yield_output("All samples processed successfully.") - - @override - async def on_checkpoint_save(self) -> dict[str, Any]: - """Save the executor state for checkpointing.""" - return {"processed_files": self._processed_files} - - @override - async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - """Restore the executor state from a checkpoint.""" - self._processed_files = state.get("processed_files", []) - - -async def main(): - # Create the executors - start_executor = StartExecutor(id="start") - workers = [WorkerExecutor(id=f"worker_{i}") for i in range(3)] - distributor = Distributor(id="distributor", worker_ids=[w.id for w in workers]) - collector = Collector(id="collector") - - # Create the workflow builder with a start executor - workflow_builder = WorkflowBuilder().set_start_executor(start_executor).add_edge(start_executor, distributor) - # Connect the distributor to each worker and each worker to the collector - for worker in workers: - workflow_builder = workflow_builder.add_edge(distributor, worker) - workflow_builder = workflow_builder.add_edge(worker, collector) - # Add checkpointing with in-memory storage - checkpoint_storage = InMemoryCheckpointStorage() - workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage) - - # Build the workflow - workflow = workflow_builder.build() - - directory_to_run = str(Path(__file__).parent.parent) # samples/getting_started/workflows/ - - event_stream = workflow.run_stream(message=directory_to_run) - - while True: - async for event in event_stream: - if isinstance(event, WorkflowOutputEvent): - print(f"Workflow completed successfully with output: {event.data}") - break - - print( - "Workflow did not produce a final output, attempting to resume from the " - "second checkpoint that is created right after the distributor runs." - ) - - # Attempt to restore from the last checkpoint - all_checkpoints = await checkpoint_storage.list_checkpoints() - if not all_checkpoints: - raise RuntimeError("No checkpoints available to resume from.") - - # Checkpoints are ordered by creation time, so pick the second one - latest_checkpoint = all_checkpoints[1] - print(f"Resuming from checkpoint: {latest_checkpoint.checkpoint_id}") - event_stream = workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id) - - -if __name__ == "__main__": - asyncio.run(main()) From 8dc82d60289bdd6a40b50e3a3d384068db866395 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 13 Nov 2025 09:37:55 -0800 Subject: [PATCH 9/9] conditional import --- .../agent_framework/_workflows/_agent_executor.py | 8 ++++++-- .../_workflows/_base_group_chat_orchestrator.py | 9 +++++++-- .../core/agent_framework/_workflows/_handoff.py | 9 +++++++-- .../core/agent_framework/_workflows/_magentic.py | 12 ++++++++---- .../agent_framework/_workflows/_workflow_executor.py | 9 +++++++-- python/packages/core/tests/workflow/test_magentic.py | 7 ++++++- 6 files changed, 41 insertions(+), 13 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 2c5b1b1c06..358cee94dd 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -1,11 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. import logging +import sys from dataclasses import dataclass from typing import Any, cast -from typing_extensions import override - from agent_framework import FunctionApprovalRequestContent, FunctionApprovalResponseContent from .._agents import AgentProtocol, ChatAgent @@ -22,6 +21,11 @@ from ._request_info_mixin import response_handler from ._workflow_context import WorkflowContext +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index 3209beacfb..8b49de740c 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -4,17 +4,22 @@ import inspect import logging +import sys from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence from typing import Any -from typing_extensions import override - from .._types import ChatMessage from ._executor import Executor from ._orchestrator_helpers import ParticipantRegistry from ._workflow_context import WorkflowContext +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 1a1acbb0c3..d18bc59562 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -16,12 +16,11 @@ import logging import re +import sys from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from typing import Any -from typing_extensions import override - from agent_framework import ( AgentProtocol, AgentRunResponse, @@ -52,6 +51,12 @@ from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 67658adeee..ea6fb259a6 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -13,8 +13,6 @@ from typing import Any, Protocol, TypeVar, Union, cast from uuid import uuid4 -from typing_extensions import override - from agent_framework import ( AgentProtocol, AgentRunResponse, @@ -47,9 +45,15 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self +else: + from typing_extensions import Self + +if sys.version_info >= (3, 12): + from typing import override else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import override + logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 089ba299fc..cc028f337c 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -2,12 +2,11 @@ import asyncio import logging +import sys import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from typing_extensions import override - if TYPE_CHECKING: from ._workflow import Workflow @@ -27,6 +26,12 @@ from ._workflow import WorkflowRunResult from ._workflow_context import WorkflowContext +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + logger = logging.getLogger(__name__) diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 5e64898a53..cc1e8ad132 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import AsyncIterable from dataclasses import dataclass from typing import Any, cast import pytest -from typing_extensions import override from agent_framework import ( AgentRunResponse, @@ -43,6 +43,11 @@ _MagenticStartMessage, # type: ignore ) +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + def test_magentic_start_message_from_string(): msg = _MagenticStartMessage.from_string("Do the thing")