Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})

response = await self._agent.run(
self._cache,
Expand All @@ -357,7 +357,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}

updates: list[AgentResponseUpdate] = []
user_input_requests: list[Content] = []
Expand Down
12 changes: 7 additions & 5 deletions python/packages/core/agent_framework/_workflows/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ class WorkflowCheckpoint:
workflow_id: Identifier of the workflow this checkpoint belongs to
timestamp: ISO 8601 timestamp when checkpoint was created
messages: Messages exchanged between executors
shared_state: Complete shared state including user data and executor states.
Executor states are stored under the reserved key '_executor_state'.
state: Committed workflow state including user data and executor states.
This contains only committed state; pending state changes are not
included in checkpoints. Executor states are stored under the
reserved key '_executor_state'.
iteration_count: Current iteration number when checkpoint was created
metadata: Additional metadata (e.g., superstep info, graph signature)
version: Checkpoint format version

Note:
The shared_state dict may contain reserved keys managed by the framework.
See SharedState class documentation for details on reserved keys.
The state dict may contain reserved keys managed by the framework.
See State class documentation for details on reserved keys.
"""

checkpoint_id: str = field(default_factory=lambda: str(uuid.uuid4()))
Expand All @@ -43,7 +45,7 @@ class WorkflowCheckpoint:

# Core workflow state
messages: dict[str, list[dict[str, Any]]] = field(default_factory=dict) # type: ignore[misc]
shared_state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc]
state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc]
pending_request_info_events: dict[str, dict[str, Any]] = field(default_factory=dict) # type: ignore[misc]

# Runtime state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class WorkflowCheckpointSummary:

def get_checkpoint_summary(checkpoint: WorkflowCheckpoint) -> WorkflowCheckpointSummary:
targets = sorted(checkpoint.messages.keys())
executor_ids = sorted(checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {}).keys())
executor_ids = sorted(checkpoint.state.get(EXECUTOR_STATE_KEY, {}).keys())
pending_request_info_events = [
RequestInfoEvent.from_dict(request) for request in checkpoint.pending_request_info_events.values()
]
Expand Down
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_workflows/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# Default maximum iterations for workflow execution.
DEFAULT_MAX_ITERATIONS = 100

# Key used to store executor state in shared state.
# Key used to store executor state in state.
EXECUTOR_STATE_KEY = "_executor_state"

# Source identifier for internal workflow messages.
INTERNAL_SOURCE_PREFIX = "internal"

# SharedState key for storing run kwargs that should be passed to agent invocations.
# State key for storing run kwargs that should be passed to agent invocations.
# Used by all orchestration patterns (Sequential, Concurrent, GroupChat, Handoff, Magentic)
# to pass kwargs from workflow.run_stream() through to agent.run_stream() and @tool functions.
WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs"
Expand Down
24 changes: 12 additions & 12 deletions python/packages/core/agent_framework/_workflows/_edge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from ._executor import Executor
from ._runner_context import Message, RunnerContext
from ._shared_state import SharedState
from ._state import State

logger = logging.getLogger(__name__)

Expand All @@ -38,12 +38,12 @@ def __init__(self, edge_group: EdgeGroup, executors: dict[str, Executor]) -> Non
self._executors = executors

@abstractmethod
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
"""Send a message through the edge group.

Args:
message: The message to send.
shared_state: The shared state to use for holding data.
state: The workflow state.
ctx: The context for the runner.

Returns:
Expand All @@ -63,7 +63,7 @@ async def _execute_on_target(
target_id: str,
source_ids: list[str],
message: Message,
shared_state: SharedState,
state: State,
ctx: RunnerContext,
) -> None:
"""Execute a message on a target executor with trace context."""
Expand All @@ -76,7 +76,7 @@ async def _execute_on_target(
await target_executor.execute(
message,
source_ids, # source_executor_ids
shared_state, # shared_state
state, # state
ctx, # runner_context
trace_contexts=message.trace_contexts, # Pass trace contexts
source_span_ids=message.source_span_ids, # Pass source span IDs for linking
Expand All @@ -90,7 +90,7 @@ def __init__(self, edge_group: SingleEdgeGroup | InternalEdgeGroup, executors: d
super().__init__(edge_group, executors)
self._edge = edge_group.edges[0]

async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
"""Send a message through the single edge."""
should_execute = False
target_id: str | None = None
Expand Down Expand Up @@ -144,7 +144,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R

# Execute outside the span
if should_execute and target_id and source_id:
await self._execute_on_target(target_id, [source_id], message, shared_state, ctx)
await self._execute_on_target(target_id, [source_id], message, state, ctx)
return True

return False
Expand All @@ -162,7 +162,7 @@ def __init__(self, edge_group: FanOutEdgeGroup, executors: dict[str, Executor])
Callable[[Any, list[str]], list[str]] | None, getattr(edge_group, "selection_func", None)
)

async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
"""Send a message through all edges in the fan-out edge group."""
deliverable_edges: list[Edge] = []
single_target_edge: Edge | None = None
Expand Down Expand Up @@ -253,14 +253,14 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R
# Execute outside the span
if single_target_edge:
await self._execute_on_target(
single_target_edge.target_id, [single_target_edge.source_id], message, shared_state, ctx
single_target_edge.target_id, [single_target_edge.source_id], message, state, ctx
)
return True

if deliverable_edges:

async def send_to_edge(edge: Edge) -> bool:
await self._execute_on_target(edge.target_id, [edge.source_id], message, shared_state, ctx)
await self._execute_on_target(edge.target_id, [edge.source_id], message, state, ctx)
return True

tasks = [send_to_edge(edge) for edge in deliverable_edges]
Expand All @@ -285,7 +285,7 @@ def __init__(self, edge_group: FanInEdgeGroup, executors: dict[str, Executor]) -
# Key is the source executor ID, value is a list of messages
self._buffer: dict[str, list[Message]] = defaultdict(list)

async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
"""Send a message through all edges in the fan-in edge group."""
execution_data: dict[str, Any] | None = None
with create_edge_group_processing_span(
Expand Down Expand Up @@ -362,7 +362,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R
# Execute outside the span if needed
if execution_data:
await self._execute_on_target(
execution_data["target_id"], execution_data["source_ids"], execution_data["message"], shared_state, ctx
execution_data["target_id"], execution_data["source_ids"], execution_data["message"], state, ctx
)
return True

Expand Down
14 changes: 7 additions & 7 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._model_utils import DictConvertible
from ._request_info_mixin import RequestInfoMixin
from ._runner_context import Message, MessageType, RunnerContext
from ._shared_state import SharedState
from ._state import State
from ._typing_utils import is_instance_of, normalize_type_to_list, resolve_type_annotation
from ._workflow_context import WorkflowContext, validate_workflow_context_annotation

Expand Down Expand Up @@ -221,7 +221,7 @@ async def execute(
self,
message: Any,
source_executor_ids: list[str],
shared_state: SharedState,
state: State,
runner_context: RunnerContext,
trace_contexts: list[dict[str, str]] | None = None,
source_span_ids: list[str] | None = None,
Expand All @@ -234,7 +234,7 @@ async def execute(
Args:
message: The message to be processed by the executor.
source_executor_ids: The IDs of the source executors that sent messages to this executor.
shared_state: The shared state for the workflow.
state: The state for the workflow.
runner_context: The runner context that provides methods to send messages and events.
trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation.
source_span_ids: Optional source span IDs from multiple sources for linking.
Expand Down Expand Up @@ -262,7 +262,7 @@ async def execute(
# Create the appropriate WorkflowContext based on handler specs
context = self._create_context_for_handler(
source_executor_ids=source_executor_ids,
shared_state=shared_state,
state=state,
runner_context=runner_context,
trace_contexts=trace_contexts,
source_span_ids=source_span_ids,
Expand Down Expand Up @@ -295,7 +295,7 @@ async def execute(
def _create_context_for_handler(
self,
source_executor_ids: list[str],
shared_state: SharedState,
state: State,
runner_context: RunnerContext,
trace_contexts: list[dict[str, str]] | None = None,
source_span_ids: list[str] | None = None,
Expand All @@ -305,7 +305,7 @@ def _create_context_for_handler(

Args:
source_executor_ids: The IDs of the source executors that sent messages to this executor.
shared_state: The shared state for the workflow.
state: The state for the workflow.
runner_context: The runner context that provides methods to send messages and events.
trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation.
source_span_ids: Optional source span IDs from multiple sources for linking.
Expand All @@ -318,7 +318,7 @@ def _create_context_for_handler(
return WorkflowContext(
executor=self,
source_executor_ids=source_executor_ids,
shared_state=shared_state,
state=state,
runner_context=runner_context,
trace_contexts=trace_contexts,
source_span_ids=source_span_ids,
Expand Down
39 changes: 19 additions & 20 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
Message,
RunnerContext,
)
from ._shared_state import SharedState
from ._state import State

logger = logging.getLogger(__name__)

Expand All @@ -39,17 +39,17 @@ def __init__(
self,
edge_groups: Sequence[EdgeGroup],
executors: dict[str, Executor],
shared_state: SharedState,
state: State,
ctx: RunnerContext,
max_iterations: int = 100,
workflow_id: str | None = None,
) -> None:
"""Initialize the runner with edges, shared state, and context.
"""Initialize the runner with edges, state, and context.

Args:
edge_groups: The edge groups of the workflow.
executors: Map of executor IDs to executor instances.
shared_state: The shared state for the workflow.
state: The state for the workflow.
ctx: The runner context for the workflow.
max_iterations: The maximum number of iterations to run.
workflow_id: The workflow ID for checkpointing.
Expand All @@ -60,7 +60,7 @@ def __init__(
self._ctx = ctx
self._iteration = 0
self._max_iterations = max_iterations
self._shared_state = shared_state
self._state = state
self._workflow_id = workflow_id
self._running = False
self._resumed_from_checkpoint = False # Track whether we resumed
Expand Down Expand Up @@ -141,6 +141,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:

logger.info(f"Completed superstep {self._iteration}")

# Commit pending state changes at superstep boundary
self._state.commit()

# Create checkpoint after each superstep iteration
await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}")

Expand All @@ -164,7 +167,7 @@ async def _deliver_messages(source_executor_id: str, messages: list[Message]) ->

async def _deliver_message_inner(edge_runner: EdgeRunner, message: Message) -> bool:
"""Inner loop to deliver a single message through an edge runner."""
return await edge_runner.send_message(message, self._shared_state, self._ctx)
return await edge_runner.send_message(message, self._state, self._ctx)

def _normalize_message_payload(message: Message) -> None:
data = message.data
Expand Down Expand Up @@ -212,7 +215,7 @@ async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | Non
if self.graph_signature_hash:
metadata["graph_signature"] = self.graph_signature_hash
checkpoint_id = await self._ctx.create_checkpoint(
self._shared_state,
self._state,
self._iteration,
metadata=metadata,
)
Expand Down Expand Up @@ -271,9 +274,9 @@ async def restore_from_checkpoint(
)

self._workflow_id = checkpoint.workflow_id
# Restore shared state
await self._shared_state.import_state(decode_checkpoint_value(checkpoint.shared_state))
# Restore executor states using the restored shared state
# Restore state
self._state.import_state(decode_checkpoint_value(checkpoint.state))
# Restore executor states using the restored state
await self._restore_executor_states()
# Apply the checkpoint to the context
await self._ctx.apply_checkpoint(checkpoint)
Expand Down Expand Up @@ -346,11 +349,11 @@ async def _restore_executor_states(self) -> None:
This method will try the backward compatibility behavior first; if that does not restore state,
it falls back to the updated behavior.
"""
has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
has_executor_states = self._state.has(EXECUTOR_STATE_KEY)
if not has_executor_states:
return

executor_states = await self._shared_state.get(EXECUTOR_STATE_KEY)
executor_states = self._state.get(EXECUTOR_STATE_KEY)
if not isinstance(executor_states, dict):
raise WorkflowCheckpointException("Executor states in shared state is not a dictionary. Unable to restore.")

Expand Down Expand Up @@ -416,19 +419,15 @@ def _mark_resumed(self, iteration: int) -> None:
self._iteration = iteration

async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Store executor state in shared state under a reserved key.
"""Store executor state in state under a reserved key.

Executors call this with a JSON-serializable dict capturing the minimal
state needed to resume. It replaces any previously stored state.
"""
has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
if has_existing_states:
existing_states = await self._shared_state.get(EXECUTOR_STATE_KEY)
else:
existing_states = {}
existing_states = self._state.get(EXECUTOR_STATE_KEY, {})

if not isinstance(existing_states, dict):
raise WorkflowCheckpointException("Existing executor states in shared state is not a dictionary.")
raise WorkflowCheckpointException("Existing executor states in state is not a dictionary.")

existing_states[executor_id] = state
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
self._state.set(EXECUTOR_STATE_KEY, existing_states)
Loading
Loading