From 12b1324a542f64250001bca2f261e5c952be9ae9 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 4 Feb 2026 15:31:03 +0900 Subject: [PATCH 1/5] Refactor SharedState to State with sync methods and superstep caching --- .../agent_framework/_workflows/__init__.py | 4 +- .../_workflows/_agent_executor.py | 4 +- .../agent_framework/_workflows/_checkpoint.py | 10 +- .../_workflows/_checkpoint_summary.py | 2 +- .../core/agent_framework/_workflows/_const.py | 4 +- .../_workflows/_edge_runner.py | 24 +- .../agent_framework/_workflows/_executor.py | 14 +- .../agent_framework/_workflows/_runner.py | 39 +- .../_workflows/_runner_context.py | 28 +- .../_workflows/_shared_state.py | 101 -- .../core/agent_framework/_workflows/_state.py | 112 ++ .../agent_framework/_workflows/_workflow.py | 21 +- .../_workflows/_workflow_context.py | 48 +- .../_workflows/_workflow_executor.py | 4 +- .../tests/workflow/test_agent_executor.py | 4 +- .../core/tests/workflow/test_checkpoint.py | 20 +- .../packages/core/tests/workflow/test_edge.py | 152 ++- .../core/tests/workflow/test_magentic.py | 2 +- .../test_request_info_event_rehydrate.py | 12 +- .../core/tests/workflow/test_runner.py | 26 +- .../core/tests/workflow/test_serialization.py | 4 +- .../core/tests/workflow/test_workflow.py | 27 +- .../tests/workflow/test_workflow_context.py | 6 +- .../tests/workflow/test_workflow_kwargs.py | 50 +- .../workflow/test_workflow_observability.py | 14 +- .../tests/workflow/test_workflow_states.py | 6 +- .../_workflows/_declarative_base.py | 127 ++- .../_workflows/_executors_agents.py | 95 +- .../_workflows/_executors_basic.py | 72 +- .../_workflows/_executors_control_flow.py | 40 +- .../_workflows/_executors_external_input.py | 28 +- .../declarative/tests/test_graph_coverage.py | 996 +++++++++--------- .../declarative/tests/test_graph_executors.py | 528 +++++----- .../tests/test_powerfx_yaml_compatibility.py | 548 +++++----- .../workflow/checkpoint-info-modal.tsx | 8 +- .../devui/frontend/src/types/index.ts | 2 +- .../packages/devui/tests/test_checkpoints.py | 20 +- python/packages/devui/tests/test_server.py | 2 +- .../devui/fanout_workflow/workflow.py | 30 +- .../getting_started/workflows/README.md | 14 +- .../agents/workflow_as_agent_kwargs.py | 2 +- .../checkpoint_with_human_in_the_loop.py | 4 +- .../multi_selection_edge_group.py | 22 +- .../control-flow/switch_case_edge_group.py | 18 +- .../map_reduce_and_visualization.py | 28 +- ...es_with_agents.py => state_with_agents.py} | 25 +- .../state-management/workflow_kwargs.py | 2 +- 47 files changed, 1657 insertions(+), 1692 deletions(-) delete mode 100644 python/packages/core/agent_framework/_workflows/_shared_state.py create mode 100644 python/packages/core/agent_framework/_workflows/_state.py rename python/samples/getting_started/workflows/state-management/{shared_states_with_agents.py => state_with_agents.py} (89%) diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 70706ff827..8dc5a71054 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -102,7 +102,7 @@ RunnerContext, ) from ._sequential import SequentialBuilder -from ._shared_state import SharedState +from ._state import State from ._validation import ( EdgeDuplicationError, GraphConnectivityError, @@ -179,9 +179,9 @@ "Runner", "RunnerContext", "SequentialBuilder", - "SharedState", "SingleEdgeGroup", "StandardMagenticManager", + "State", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SuperStepCompletedEvent", diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 80bd4aba43..9777affef5 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -333,7 +333,7 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentResponse | None: Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {} response = await self._agent.run( self._cache, @@ -360,7 +360,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentResponse | No Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {} updates: list[AgentResponseUpdate] = [] user_input_requests: list[Content] = [] diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint.py b/python/packages/core/agent_framework/_workflows/_checkpoint.py index ac04885579..c249c02bdc 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint.py @@ -26,15 +26,15 @@ class WorkflowCheckpoint: workflow_id: Identifier of the workflow this checkpoint belongs to timestamp: ISO 8601 timestamp when checkpoint was created messages: Messages exchanged between executors - shared_state: Complete shared state including user data and executor states. - Executor states are stored under the reserved key '_executor_state'. + state: Complete workflow state including user data and executor states. + Executor states are stored under the reserved key '_executor_state'. iteration_count: Current iteration number when checkpoint was created metadata: Additional metadata (e.g., superstep info, graph signature) version: Checkpoint format version Note: - The shared_state dict may contain reserved keys managed by the framework. - See SharedState class documentation for details on reserved keys. + The state dict may contain reserved keys managed by the framework. + See State class documentation for details on reserved keys. """ checkpoint_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -43,7 +43,7 @@ class WorkflowCheckpoint: # Core workflow state messages: dict[str, list[dict[str, Any]]] = field(default_factory=dict) # type: ignore[misc] - shared_state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc] + state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc] pending_request_info_events: dict[str, dict[str, Any]] = field(default_factory=dict) # type: ignore[misc] # Runtime state diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py b/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py index ebcf2ff83b..b1fd6896ab 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py @@ -25,7 +25,7 @@ class WorkflowCheckpointSummary: def get_checkpoint_summary(checkpoint: WorkflowCheckpoint) -> WorkflowCheckpointSummary: targets = sorted(checkpoint.messages.keys()) - executor_ids = sorted(checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {}).keys()) + executor_ids = sorted(checkpoint.state.get(EXECUTOR_STATE_KEY, {}).keys()) pending_request_info_events = [ RequestInfoEvent.from_dict(request) for request in checkpoint.pending_request_info_events.values() ] diff --git a/python/packages/core/agent_framework/_workflows/_const.py b/python/packages/core/agent_framework/_workflows/_const.py index 4d27c609b1..3a6d24aefe 100644 --- a/python/packages/core/agent_framework/_workflows/_const.py +++ b/python/packages/core/agent_framework/_workflows/_const.py @@ -3,13 +3,13 @@ # Default maximum iterations for workflow execution. DEFAULT_MAX_ITERATIONS = 100 -# Key used to store executor state in shared state. +# Key used to store executor state in state. EXECUTOR_STATE_KEY = "_executor_state" # Source identifier for internal workflow messages. INTERNAL_SOURCE_PREFIX = "internal" -# SharedState key for storing run kwargs that should be passed to agent invocations. +# State key for storing run kwargs that should be passed to agent invocations. # Used by all orchestration patterns (Sequential, Concurrent, GroupChat, Handoff, Magentic) # to pass kwargs from workflow.run_stream() through to agent.run_stream() and @tool functions. WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs" diff --git a/python/packages/core/agent_framework/_workflows/_edge_runner.py b/python/packages/core/agent_framework/_workflows/_edge_runner.py index 8255f8f79c..c87994b4b4 100644 --- a/python/packages/core/agent_framework/_workflows/_edge_runner.py +++ b/python/packages/core/agent_framework/_workflows/_edge_runner.py @@ -19,7 +19,7 @@ ) from ._executor import Executor from ._runner_context import Message, RunnerContext -from ._shared_state import SharedState +from ._state import State logger = logging.getLogger(__name__) @@ -38,12 +38,12 @@ def __init__(self, edge_group: EdgeGroup, executors: dict[str, Executor]) -> Non self._executors = executors @abstractmethod - async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool: """Send a message through the edge group. Args: message: The message to send. - shared_state: The shared state to use for holding data. + state: The workflow state. ctx: The context for the runner. Returns: @@ -63,7 +63,7 @@ async def _execute_on_target( target_id: str, source_ids: list[str], message: Message, - shared_state: SharedState, + state: State, ctx: RunnerContext, ) -> None: """Execute a message on a target executor with trace context.""" @@ -76,7 +76,7 @@ async def _execute_on_target( await target_executor.execute( message, source_ids, # source_executor_ids - shared_state, # shared_state + state, # state ctx, # runner_context trace_contexts=message.trace_contexts, # Pass trace contexts source_span_ids=message.source_span_ids, # Pass source span IDs for linking @@ -90,7 +90,7 @@ def __init__(self, edge_group: SingleEdgeGroup | InternalEdgeGroup, executors: d super().__init__(edge_group, executors) self._edge = edge_group.edges[0] - async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool: """Send a message through the single edge.""" should_execute = False target_id: str | None = None @@ -144,7 +144,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # Execute outside the span if should_execute and target_id and source_id: - await self._execute_on_target(target_id, [source_id], message, shared_state, ctx) + await self._execute_on_target(target_id, [source_id], message, state, ctx) return True return False @@ -162,7 +162,7 @@ def __init__(self, edge_group: FanOutEdgeGroup, executors: dict[str, Executor]) Callable[[Any, list[str]], list[str]] | None, getattr(edge_group, "selection_func", None) ) - async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool: """Send a message through all edges in the fan-out edge group.""" deliverable_edges: list[Edge] = [] single_target_edge: Edge | None = None @@ -253,14 +253,14 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # Execute outside the span if single_target_edge: await self._execute_on_target( - single_target_edge.target_id, [single_target_edge.source_id], message, shared_state, ctx + single_target_edge.target_id, [single_target_edge.source_id], message, state, ctx ) return True if deliverable_edges: async def send_to_edge(edge: Edge) -> bool: - await self._execute_on_target(edge.target_id, [edge.source_id], message, shared_state, ctx) + await self._execute_on_target(edge.target_id, [edge.source_id], message, state, ctx) return True tasks = [send_to_edge(edge) for edge in deliverable_edges] @@ -285,7 +285,7 @@ def __init__(self, edge_group: FanInEdgeGroup, executors: dict[str, Executor]) - # Key is the source executor ID, value is a list of messages self._buffer: dict[str, list[Message]] = defaultdict(list) - async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool: """Send a message through all edges in the fan-in edge group.""" execution_data: dict[str, Any] | None = None with create_edge_group_processing_span( @@ -362,7 +362,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # Execute outside the span if needed if execution_data: await self._execute_on_target( - execution_data["target_id"], execution_data["source_ids"], execution_data["message"], shared_state, ctx + execution_data["target_id"], execution_data["source_ids"], execution_data["message"], state, ctx ) return True diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 18adc4b904..6481203adc 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -20,7 +20,7 @@ from ._model_utils import DictConvertible from ._request_info_mixin import RequestInfoMixin from ._runner_context import Message, MessageType, RunnerContext -from ._shared_state import SharedState +from ._state import State from ._typing_utils import is_instance_of, normalize_type_to_list, resolve_type_annotation from ._workflow_context import WorkflowContext, validate_workflow_context_annotation @@ -221,7 +221,7 @@ async def execute( self, message: Any, source_executor_ids: list[str], - shared_state: SharedState, + state: State, runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, @@ -234,7 +234,7 @@ async def execute( Args: message: The message to be processed by the executor. source_executor_ids: The IDs of the source executors that sent messages to this executor. - shared_state: The shared state for the workflow. + state: The state for the workflow. runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking. @@ -262,7 +262,7 @@ async def execute( # Create the appropriate WorkflowContext based on handler specs context = self._create_context_for_handler( source_executor_ids=source_executor_ids, - shared_state=shared_state, + state=state, runner_context=runner_context, trace_contexts=trace_contexts, source_span_ids=source_span_ids, @@ -295,7 +295,7 @@ async def execute( def _create_context_for_handler( self, source_executor_ids: list[str], - shared_state: SharedState, + state: State, runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, @@ -305,7 +305,7 @@ def _create_context_for_handler( Args: source_executor_ids: The IDs of the source executors that sent messages to this executor. - shared_state: The shared state for the workflow. + state: The state for the workflow. runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking. @@ -318,7 +318,7 @@ def _create_context_for_handler( return WorkflowContext( executor=self, source_executor_ids=source_executor_ids, - shared_state=shared_state, + state=state, runner_context=runner_context, trace_contexts=trace_contexts, source_span_ids=source_span_ids, diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 227f0f7fe7..9ea52932c1 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -26,7 +26,7 @@ Message, RunnerContext, ) -from ._shared_state import SharedState +from ._state import State logger = logging.getLogger(__name__) @@ -38,17 +38,17 @@ def __init__( self, edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], - shared_state: SharedState, + state: State, ctx: RunnerContext, max_iterations: int = 100, workflow_id: str | None = None, ) -> None: - """Initialize the runner with edges, shared state, and context. + """Initialize the runner with edges, state, and context. Args: edge_groups: The edge groups of the workflow. executors: Map of executor IDs to executor instances. - shared_state: The shared state for the workflow. + state: The state for the workflow. ctx: The runner context for the workflow. max_iterations: The maximum number of iterations to run. workflow_id: The workflow ID for checkpointing. @@ -59,7 +59,7 @@ def __init__( self._ctx = ctx self._iteration = 0 self._max_iterations = max_iterations - self._shared_state = shared_state + self._state = state self._workflow_id = workflow_id self._running = False self._resumed_from_checkpoint = False # Track whether we resumed @@ -133,6 +133,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: logger.info(f"Completed superstep {self._iteration}") + # Commit pending state changes at superstep boundary + self._state.commit() + # Create checkpoint after each superstep iteration await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}") @@ -156,7 +159,7 @@ async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> async def _deliver_message_inner(edge_runner: EdgeRunner, message: Message) -> bool: """Inner loop to deliver a single message through an edge runner.""" - return await edge_runner.send_message(message, self._shared_state, self._ctx) + return await edge_runner.send_message(message, self._state, self._ctx) def _normalize_message_payload(message: Message) -> None: data = message.data @@ -204,7 +207,7 @@ async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | Non if self.graph_signature_hash: metadata["graph_signature"] = self.graph_signature_hash checkpoint_id = await self._ctx.create_checkpoint( - self._shared_state, + self._state, self._iteration, metadata=metadata, ) @@ -263,9 +266,9 @@ async def restore_from_checkpoint( ) self._workflow_id = checkpoint.workflow_id - # Restore shared state - await self._shared_state.import_state(decode_checkpoint_value(checkpoint.shared_state)) - # Restore executor states using the restored shared state + # Restore state + self._state.import_state(decode_checkpoint_value(checkpoint.state)) + # Restore executor states using the restored state await self._restore_executor_states() # Apply the checkpoint to the context await self._ctx.apply_checkpoint(checkpoint) @@ -338,11 +341,11 @@ async def _restore_executor_states(self) -> None: This method will try the backward compatibility behavior first; if that does not restore state, it falls back to the updated behavior. """ - has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY) + has_executor_states = self._state.has(EXECUTOR_STATE_KEY) if not has_executor_states: return - executor_states = await self._shared_state.get(EXECUTOR_STATE_KEY) + executor_states = self._state.get(EXECUTOR_STATE_KEY) if not isinstance(executor_states, dict): raise WorkflowCheckpointException("Executor states in shared state is not a dictionary. Unable to restore.") @@ -408,19 +411,15 @@ def _mark_resumed(self, iteration: int) -> None: self._iteration = iteration async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None: - """Store executor state in shared state under a reserved key. + """Store executor state in state under a reserved key. Executors call this with a JSON-serializable dict capturing the minimal state needed to resume. It replaces any previously stored state. """ - has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY) - if has_existing_states: - existing_states = await self._shared_state.get(EXECUTOR_STATE_KEY) - else: - existing_states = {} + existing_states = self._state.get(EXECUTOR_STATE_KEY, {}) if not isinstance(existing_states, dict): - raise WorkflowCheckpointException("Existing executor states in shared state is not a dictionary.") + raise WorkflowCheckpointException("Existing executor states in state is not a dictionary.") existing_states[executor_id] = state - await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states) + self._state.set(EXECUTOR_STATE_KEY, existing_states) diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index ce9fff6617..54a3b9b37e 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -13,7 +13,7 @@ from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._const import INTERNAL_SOURCE_ID from ._events import RequestInfoEvent, WorkflowEvent -from ._shared_state import SharedState +from ._state import State from ._typing_utils import is_instance_of if sys.version_info >= (3, 11): @@ -104,7 +104,7 @@ class _WorkflowState(TypedDict): """ messages: dict[str, list[dict[str, Any]]] - shared_state: dict[str, Any] + state: dict[str, Any] iteration_count: int pending_request_info_events: dict[str, dict[str, Any]] @@ -217,16 +217,16 @@ def is_streaming(self) -> bool: async def create_checkpoint( self, - shared_state: SharedState, + state: State, iteration_count: int, metadata: dict[str, Any] | None = None, ) -> str: """Create a checkpoint of the current workflow state. Args: - shared_state: The shared state to include in the checkpoint. - This is needed to capture the full state of the workflow. - The shared state is not managed by the context itself. + state: The state to include in the checkpoint. + This is needed to capture the full state of the workflow. + The state is not managed by the context itself. iteration_count: The current iteration count of the workflow. metadata: Optional metadata to associate with the checkpoint. @@ -374,7 +374,7 @@ def has_checkpointing(self) -> bool: async def create_checkpoint( self, - shared_state: SharedState, + state: State, iteration_count: int, metadata: dict[str, Any] | None = None, ) -> str: @@ -383,14 +383,14 @@ async def create_checkpoint( raise ValueError("Checkpoint storage not configured") self._workflow_id = self._workflow_id or str(uuid.uuid4()) - state = await self._get_serialized_workflow_state(shared_state, iteration_count) + workflow_state = self._get_serialized_workflow_state(state, iteration_count) checkpoint = WorkflowCheckpoint( workflow_id=self._workflow_id, - messages=state["messages"], - shared_state=state["shared_state"], - pending_request_info_events=state["pending_request_info_events"], - iteration_count=state["iteration_count"], + messages=workflow_state["messages"], + state=workflow_state["state"], + pending_request_info_events=workflow_state["pending_request_info_events"], + iteration_count=workflow_state["iteration_count"], metadata=metadata or {}, ) checkpoint_id = await storage.save_checkpoint(checkpoint) @@ -454,7 +454,7 @@ def is_streaming(self) -> bool: """ return self._streaming - async def _get_serialized_workflow_state(self, shared_state: SharedState, iteration_count: int) -> _WorkflowState: + def _get_serialized_workflow_state(self, state: State, iteration_count: int) -> _WorkflowState: serialized_messages: dict[str, list[dict[str, Any]]] = {} for source_id, message_list in self._messages.items(): serialized_messages[source_id] = [msg.to_dict() for msg in message_list] @@ -465,7 +465,7 @@ async def _get_serialized_workflow_state(self, shared_state: SharedState, iterat return { "messages": serialized_messages, - "shared_state": encode_checkpoint_value(await shared_state.export_state()), + "state": encode_checkpoint_value(state.export_state()), "iteration_count": iteration_count, "pending_request_info_events": serialized_pending_request_info_events, } diff --git a/python/packages/core/agent_framework/_workflows/_shared_state.py b/python/packages/core/agent_framework/_workflows/_shared_state.py deleted file mode 100644 index 93057021fb..0000000000 --- a/python/packages/core/agent_framework/_workflows/_shared_state.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from typing import Any - - -class SharedState: - """A class to manage shared state in a workflow. - - SharedState provides thread-safe access to workflow state data that needs to be - shared across executors during workflow execution. - - Reserved Keys: - The following keys are reserved for internal framework use and should not be - modified by user code: - - - `_executor_state`: Stores executor state for checkpointing (managed by Runner) - - Warning: - Do not use keys starting with underscore (_) as they may be reserved for - internal framework operations. - """ - - def __init__(self) -> None: - """Initialize the shared state.""" - self._state: dict[str, Any] = {} - self._shared_state_lock = asyncio.Lock() - - async def set(self, key: str, value: Any) -> None: - """Set a value in the shared state.""" - async with self._shared_state_lock: - await self.set_within_hold(key, value) - - async def get(self, key: str) -> Any: - """Get a value from the shared state.""" - async with self._shared_state_lock: - return await self.get_within_hold(key) - - async def has(self, key: str) -> bool: - """Check if a key exists in the shared state.""" - async with self._shared_state_lock: - return await self.has_within_hold(key) - - async def delete(self, key: str) -> None: - """Delete a key from the shared state.""" - async with self._shared_state_lock: - await self.delete_within_hold(key) - - async def clear(self) -> None: - """Clear the entire shared state.""" - async with self._shared_state_lock: - self._state.clear() - - async def export_state(self) -> dict[str, Any]: - """Get a serialized copy of the entire shared state.""" - async with self._shared_state_lock: - return dict(self._state) - - async def import_state(self, state: dict[str, Any]) -> None: - """Populate the shared state from a serialized state dictionary. - - This replaces the entire current state with the provided state. - """ - async with self._shared_state_lock: - self._state.update(state) - - @asynccontextmanager - async def hold(self) -> AsyncIterator["SharedState"]: - """Context manager to hold the shared state lock for multiple operations. - - Usage: - async with shared_state.hold(): - await shared_state.set_within_hold("key", value) - value = await shared_state.get_within_hold("key") - """ - async with self._shared_state_lock: - yield self - - # Unsafe methods that don't acquire locks (for use within hold() context) - async def set_within_hold(self, key: str, value: Any) -> None: - """Set a value without acquiring the lock (unsafe - use within hold() context).""" - self._state[key] = value - - async def get_within_hold(self, key: str) -> Any: - """Get a value without acquiring the lock (unsafe - use within hold() context).""" - if key not in self._state: - raise KeyError(f"Key '{key}' not found in shared state.") - return self._state[key] - - async def has_within_hold(self, key: str) -> bool: - """Check if a key exists without acquiring the lock (unsafe - use within hold() context).""" - return key in self._state - - async def delete_within_hold(self, key: str) -> None: - """Delete a key without acquiring the lock (unsafe - use within hold() context).""" - if key in self._state: - del self._state[key] - else: - raise KeyError(f"Key '{key}' not found in shared state.") diff --git a/python/packages/core/agent_framework/_workflows/_state.py b/python/packages/core/agent_framework/_workflows/_state.py new file mode 100644 index 0000000000..8e8a793389 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_state.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + + +class State: + """Manages shared state across executors within a workflow. + + State provides access to workflow state data that is shared across executors + during workflow execution. It implements superstep caching semantics where + writes are staged in a pending buffer and only committed to the actual state + at superstep boundaries. + + Superstep Semantics: + - `set()` writes to a pending buffer, not directly to committed state + - `get()` checks pending buffer first, then committed state + - `commit()` moves all pending changes to committed state (called by Runner at superstep boundary) + - `discard()` clears pending changes without committing + + Reserved Keys: + Keys starting with underscore (_) are reserved for internal framework use. + Do not use these in user code. + """ + + def __init__(self) -> None: + """Initialize the state.""" + self._committed: dict[str, Any] = {} + self._pending: dict[str, Any] = {} + + def set(self, key: str, value: Any) -> None: + """Set a value in the pending state buffer. + + The value will be visible to subsequent `get()` calls but won't be + committed to the actual state until `commit()` is called. + """ + self._pending[key] = value + + def get(self, key: str, default: Any = None) -> Any: + """Get a value from state, checking pending first then committed. + + Args: + key: The key to retrieve. + default: Value to return if key is not found. Defaults to None. + + Returns: + The value if found, otherwise the default value. + """ + if key in self._pending: + return self._pending[key] + return self._committed.get(key, default) + + def has(self, key: str) -> bool: + """Check if a key exists in pending or committed state.""" + return key in self._pending or key in self._committed + + def delete(self, key: str) -> None: + """Mark a key for deletion. + + If the key exists in pending, it is removed from pending. + If the key exists only in committed, a sentinel is stored in pending + to indicate deletion at commit time. + """ + if key in self._pending: + del self._pending[key] + elif key in self._committed: + self._pending[key] = _DeleteSentinel + else: + raise KeyError(f"Key '{key}' not found in state.") + + def clear(self) -> None: + """Clear both committed and pending state.""" + self._committed.clear() + self._pending.clear() + + def commit(self) -> None: + """Commit pending changes to the committed state. + + Called by the Runner at superstep boundaries after successful execution. + """ + for key, value in self._pending.items(): + if value is _DeleteSentinel: + self._committed.pop(key, None) + else: + self._committed[key] = value + self._pending.clear() + + def discard(self) -> None: + """Discard all pending changes without committing.""" + self._pending.clear() + + def export_state(self) -> dict[str, Any]: + """Export a serialized copy of the committed state. + + Note: Does not include pending changes. + """ + return dict(self._committed) + + def import_state(self, state: dict[str, Any]) -> None: + """Import state from a serialized dictionary. + + Merges into committed state. Does not affect pending changes. + """ + self._committed.update(state) + + +class _DeleteSentinelType: + """Sentinel type to mark keys for deletion in pending state.""" + + pass + + +_DeleteSentinel = _DeleteSentinelType() diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index bd14dc6bcc..063d8af796 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -33,7 +33,7 @@ from ._model_utils import DictConvertible from ._runner import Runner from ._runner_context import RunnerContext -from ._shared_state import SharedState +from ._state import State from ._typing_utils import is_instance_of logger = logging.getLogger(__name__) @@ -204,11 +204,11 @@ def __init__( # Store non-serializable runtime objects as private attributes self._runner_context = runner_context - self._shared_state = SharedState() + self._state = State() self._runner: Runner = Runner( self.edge_groups, self.executors, - self._shared_state, + self._state, runner_context, max_iterations=max_iterations, workflow_id=self.id, @@ -297,7 +297,7 @@ async def _run_workflow_with_tracing( initial_executor_fn: Optional function to execute initial executor reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents - run_kwargs: Optional kwargs to store in SharedState for agent invocations + run_kwargs: Optional kwargs to store in State for agent invocations Yields: WorkflowEvent: The events generated during the workflow execution. @@ -330,11 +330,12 @@ async def _run_workflow_with_tracing( if reset_context: self._runner.reset_iteration_count() self._runner.context.reset_for_new_run() - await self._shared_state.clear() + self._state.clear() - # Store run kwargs in SharedState so executors can access them + # Store run kwargs in State so executors can access them # Always store (even empty dict) so retrieval is deterministic - await self._shared_state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {}) + self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {}) + self._state.commit() # Commit immediately so kwargs are available # Set streaming mode after reset self._runner_context.set_streaming(streaming) @@ -428,7 +429,7 @@ async def _execute_with_message_or_checkpoint( await executor.execute( message, [self.__class__.__name__], - self._shared_state, + self._state, self._runner.context, trace_contexts=None, source_span_ids=None, @@ -457,7 +458,7 @@ async def run_stream( - Without checkpoint_id: Enables checkpointing for this run, overriding build-time configuration **kwargs: Additional keyword arguments to pass through to agent invocations. - These are stored in SharedState and accessible in @tool functions + These are stored in State and accessible in @tool functions via the **kwargs parameter. Yields: @@ -591,7 +592,7 @@ async def run( build-time configuration include_status_events: Whether to include WorkflowStatusEvent instances in the result list. **kwargs: Additional keyword arguments to pass through to agent invocations. - These are stored in SharedState and accessible in @tool functions + These are stored in State and accessible in @tool functions via the **kwargs parameter. Returns: diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index 708cdf3c51..f902b25a30 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -26,7 +26,7 @@ _framework_event_origin, # type: ignore ) from ._runner_context import Message, RunnerContext -from ._shared_state import SharedState +from ._state import State if TYPE_CHECKING: from ._executor import Executor @@ -267,7 +267,7 @@ def __init__( self, executor: "Executor", source_executor_ids: list[str], - shared_state: SharedState, + state: State, runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, @@ -280,7 +280,7 @@ def __init__( source_executor_ids: The IDs of the source executors that sent messages to this executor. This is a list to support fan_in scenarios where multiple sources send aggregated messages to the same executor. - shared_state: The shared state for the workflow. + state: The workflow state. runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking (not for nesting). @@ -290,7 +290,7 @@ def __init__( self._executor_id = executor.id self._source_executor_ids = source_executor_ids self._runner_context = runner_context - self._shared_state = shared_state + self._state = state # Track messages sent via send_message() for ExecutorCompletedEvent self._sent_messages: list[Any] = [] @@ -410,13 +410,13 @@ async def request_info(self, request_data: object, response_type: type, *, reque ) await self._runner_context.add_request_info_event(request_info_event) - async def get_shared_state(self, key: str) -> Any: - """Get a value from the shared state.""" - return await self._shared_state.get(key) + def get_state(self, key: str, default: Any = None) -> Any: + """Get a value from the workflow state.""" + return self._state.get(key, default) - async def set_shared_state(self, key: str, value: Any) -> None: - """Set a value in the shared state.""" - await self._shared_state.set(key, value) + def set_state(self, key: str, value: Any) -> None: + """Set a value in the workflow state.""" + self._state.set(key, value) def get_source_executor_id(self) -> str: """Get the ID of the source executor that sent the message to this executor. @@ -437,9 +437,9 @@ def source_executor_ids(self) -> list[str]: return self._source_executor_ids @property - def shared_state(self) -> SharedState: - """Get the shared state.""" - return self._shared_state + def state(self) -> State: + """Get the workflow state.""" + return self._state def get_sent_messages(self) -> list[Any]: """Get all messages sent via send_message() during this handler execution. @@ -459,41 +459,37 @@ def get_yielded_outputs(self) -> list[Any]: @deprecated( "Override `on_checkpoint_save()` methods instead. " - "For cross-executor state sharing, use set_shared_state() instead. " + "For cross-executor state sharing, use set_state() instead. " "This API will be removed after 12/01/2025." ) async def set_executor_state(self, state: dict[str, Any]) -> None: - """Store executor state in shared state under a reserved key. + """Store executor state in state under a reserved key. Executors call this with a JSON-serializable dict capturing the minimal state needed to resume. It replaces any previously stored state. """ - has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY) - if has_existing_states: - existing_states = await self._shared_state.get(EXECUTOR_STATE_KEY) - else: - existing_states = {} + existing_states = self._state.get(EXECUTOR_STATE_KEY, {}) if not isinstance(existing_states, dict): - raise ValueError("Existing executor states in shared state is not a dictionary.") + raise ValueError("Existing executor states in state is not a dictionary.") existing_states[self._executor_id] = state - await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states) + self._state.set(EXECUTOR_STATE_KEY, existing_states) @deprecated( "Override `on_checkpoint_restore()` methods instead. " - "For cross-executor state sharing, use get_shared_state() instead. " + "For cross-executor state sharing, use get_state() instead. " "This API will be removed after 12/01/2025." ) async def get_executor_state(self) -> dict[str, Any] | None: """Retrieve previously persisted state for this executor, if any.""" - has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY) + has_existing_states = self._state.has(EXECUTOR_STATE_KEY) if not has_existing_states: return None - existing_states = await self._shared_state.get(EXECUTOR_STATE_KEY) + existing_states = self._state.get(EXECUTOR_STATE_KEY) if not isinstance(existing_states, dict): - raise ValueError("Existing executor states in shared state is not a dictionary.") + raise ValueError("Existing executor states in state is not a dictionary.") return existing_states.get(self._executor_id) # type: ignore diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 2453620cfd..d04a632352 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -386,8 +386,8 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) logger.debug(f"WorkflowExecutor {self.id} starting sub-workflow {self.workflow.id} execution {execution_id}") try: - # Get kwargs from parent workflow's SharedState to propagate to subworkflow - parent_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) or {} + # Get kwargs from parent workflow's State to propagate to subworkflow + parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {} # Run the sub-workflow and collect all events, passing parent kwargs result = await self.workflow.run(input_data, **parent_kwargs) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 0fa2bfd952..9cb186be31 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -96,8 +96,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: ) # Verify checkpoint contains executor state with both cache and thread - assert "_executor_state" in restore_checkpoint.shared_state - executor_states = restore_checkpoint.shared_state["_executor_state"] + assert "_executor_state" in restore_checkpoint.state + executor_states = restore_checkpoint.state["_executor_state"] assert isinstance(executor_states, dict) assert executor.id in executor_states diff --git a/python/packages/core/tests/workflow/test_checkpoint.py b/python/packages/core/tests/workflow/test_checkpoint.py index 74ac524883..9f6d57b2e1 100644 --- a/python/packages/core/tests/workflow/test_checkpoint.py +++ b/python/packages/core/tests/workflow/test_checkpoint.py @@ -19,7 +19,7 @@ def test_workflow_checkpoint_default_values(): assert checkpoint.workflow_id == "" assert checkpoint.timestamp != "" assert checkpoint.messages == {} - assert checkpoint.shared_state == {} + assert checkpoint.state == {} assert checkpoint.pending_request_info_events == {} assert checkpoint.iteration_count == 0 assert checkpoint.metadata == {} @@ -34,7 +34,7 @@ def test_workflow_checkpoint_custom_values(): timestamp=custom_timestamp, messages={"executor1": [{"data": "test"}]}, pending_request_info_events={"req123": {"data": "test"}}, - shared_state={"key": "value"}, + state={"key": "value"}, iteration_count=5, metadata={"test": True}, version="2.0", @@ -44,7 +44,7 @@ def test_workflow_checkpoint_custom_values(): assert checkpoint.workflow_id == "test-workflow-456" assert checkpoint.timestamp == custom_timestamp assert checkpoint.messages == {"executor1": [{"data": "test"}]} - assert checkpoint.shared_state == {"key": "value"} + assert checkpoint.state == {"key": "value"} assert checkpoint.pending_request_info_events == {"req123": {"data": "test"}} assert checkpoint.iteration_count == 5 assert checkpoint.metadata == {"test": True} @@ -159,7 +159,7 @@ async def test_file_checkpoint_storage_save_and_load(): checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, - shared_state={"key": "value"}, + state={"key": "value"}, pending_request_info_events={"req123": {"data": "test"}}, ) @@ -177,7 +177,7 @@ async def test_file_checkpoint_storage_save_and_load(): assert loaded_checkpoint.checkpoint_id == checkpoint.checkpoint_id assert loaded_checkpoint.workflow_id == checkpoint.workflow_id assert loaded_checkpoint.messages == checkpoint.messages - assert loaded_checkpoint.shared_state == checkpoint.shared_state + assert loaded_checkpoint.state == checkpoint.state assert loaded_checkpoint.pending_request_info_events == checkpoint.pending_request_info_events @@ -293,7 +293,7 @@ async def test_file_checkpoint_storage_json_serialization(): checkpoint = WorkflowCheckpoint( workflow_id="complex-workflow", messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, - shared_state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None}, + state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None}, pending_request_info_events={"req123": {"data": "test"}}, ) @@ -303,7 +303,7 @@ async def test_file_checkpoint_storage_json_serialization(): assert loaded is not None assert loaded.messages == checkpoint.messages - assert loaded.shared_state == checkpoint.shared_state + assert loaded.state == checkpoint.state # Verify the JSON file is properly formatted file_path = Path(temp_dir) / f"{checkpoint.checkpoint_id}.json" @@ -311,9 +311,9 @@ async def test_file_checkpoint_storage_json_serialization(): data = json.load(f) assert data["messages"]["executor1"][0]["data"]["nested"]["value"] == 42 - assert data["shared_state"]["list"] == [1, 2, 3] - assert data["shared_state"]["bool"] is True - assert data["shared_state"]["null"] is None + assert data["state"]["list"] == [1, 2, 3] + assert data["state"]["bool"] is True + assert data["state"]["null"] is None assert data["pending_request_info_events"]["req123"]["data"] == "test" diff --git a/python/packages/core/tests/workflow/test_edge.py b/python/packages/core/tests/workflow/test_edge.py index 42e3893a73..42ff6e5d36 100644 --- a/python/packages/core/tests/workflow/test_edge.py +++ b/python/packages/core/tests/workflow/test_edge.py @@ -10,7 +10,6 @@ Executor, InProcRunnerContext, Message, - SharedState, WorkflowContext, handler, ) @@ -24,10 +23,9 @@ SwitchCaseEdgeGroupDefault, ) from agent_framework._workflows._edge_runner import create_edge_runner +from agent_framework._workflows._state import State from agent_framework.observability import EdgeGroupDeliveryStatus -# Add for test - @dataclass class MockMessage: @@ -191,13 +189,13 @@ async def test_single_edge_group_send_message() -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True @@ -210,13 +208,13 @@ async def test_single_edge_group_send_message_with_target() -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True @@ -229,13 +227,13 @@ async def test_single_edge_group_send_message_with_invalid_target() -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="invalid_target") - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -248,13 +246,13 @@ async def test_single_edge_group_send_message_with_invalid_data() -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -268,13 +266,13 @@ async def test_single_edge_group_send_message_with_condition_pass() -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "test") edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert target.call_count == 1 assert target.last_message.data == "test" @@ -290,13 +288,13 @@ async def test_single_edge_group_send_message_with_condition_fail() -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "test") edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="different") message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) # Should return True because message was processed, but condition failed assert success is True # Target should not be called because condition failed @@ -312,7 +310,7 @@ async def test_single_edge_group_tracing_success(span_exporter) -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() # Create trace context and span IDs to simulate a message with tracing information @@ -325,7 +323,7 @@ async def test_single_edge_group_tracing_success(span_exporter) -> None: # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True spans = span_exporter.get_finished_spans() @@ -361,7 +359,7 @@ async def test_single_edge_group_tracing_condition_failure(span_exporter) -> Non edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "pass") edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="fail") @@ -370,7 +368,7 @@ async def test_single_edge_group_tracing_condition_failure(span_exporter) -> Non # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True # Returns True but condition failed spans = span_exporter.get_finished_spans() @@ -395,7 +393,7 @@ async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() # Send incompatible data type @@ -405,7 +403,7 @@ async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None: # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False spans = span_exporter.get_finished_spans() @@ -430,7 +428,7 @@ async def test_single_edge_group_tracing_target_mismatch(span_exporter) -> None: edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") @@ -439,7 +437,7 @@ async def test_single_edge_group_tracing_target_mismatch(span_exporter) -> None: # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False spans = span_exporter.get_finished_spans() @@ -498,13 +496,13 @@ async def test_source_edge_group_send_message() -> None: edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert target1.call_count == 1 @@ -521,13 +519,13 @@ async def test_source_edge_group_send_message_with_target() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target1.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert target1.call_count == 1 @@ -544,13 +542,13 @@ async def test_source_edge_group_send_message_with_invalid_target() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="invalid_target") - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -564,13 +562,13 @@ async def test_source_edge_group_send_message_with_invalid_data() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -584,13 +582,13 @@ async def test_source_edge_group_send_message_only_one_successful_send() -> None executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert target1.call_count == 1 # target1 can handle MockMessage @@ -633,14 +631,14 @@ async def test_source_edge_group_with_selection_func_send_message() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) with patch("agent_framework._workflows._edge_runner.EdgeRunner._execute_on_target") as mock_send: - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True @@ -661,14 +659,14 @@ async def test_source_edge_group_with_selection_func_send_message_with_invalid_s executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) with pytest.raises(RuntimeError): - await edge_runner.send_message(message, shared_state, ctx) + await edge_runner.send_message(message, state, ctx) async def test_source_edge_group_with_selection_func_send_message_with_target() -> None: @@ -686,14 +684,14 @@ async def test_source_edge_group_with_selection_func_send_message_with_target() executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target1.id) with patch("agent_framework._workflows._edge_runner.EdgeRunner._execute_on_target") as mock_send: - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert mock_send.call_count == 1 @@ -715,13 +713,13 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_no executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target2.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -740,13 +738,13 @@ async def test_source_edge_group_with_selection_func_send_message_with_invalid_d executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -765,13 +763,13 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_in executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id, target_id=target1.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -785,7 +783,7 @@ async def test_fan_out_edge_group_tracing_success(span_exporter) -> None: edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() # Create trace context and span IDs to simulate a message with tracing information @@ -798,7 +796,7 @@ async def test_fan_out_edge_group_tracing_success(span_exporter) -> None: # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True spans = span_exporter.get_finished_spans() @@ -835,7 +833,7 @@ async def test_fan_out_edge_group_tracing_with_target(span_exporter) -> None: edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() # Create trace context and span IDs to simulate a message with tracing information @@ -854,7 +852,7 @@ async def test_fan_out_edge_group_tracing_with_target(span_exporter) -> None: # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True spans = span_exporter.get_finished_spans() @@ -922,7 +920,7 @@ async def test_target_edge_group_send_message_buffer() -> None: executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") @@ -930,7 +928,7 @@ async def test_target_edge_group_send_message_buffer() -> None: with patch("agent_framework._workflows._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message( Message(data=data, source_id=source1.id), - shared_state, + state, ctx, ) @@ -940,7 +938,7 @@ async def test_target_edge_group_send_message_buffer() -> None: success = await edge_runner.send_message( Message(data=data, source_id=source2.id), - shared_state, + state, ctx, ) assert success is True @@ -961,13 +959,13 @@ async def test_target_edge_group_send_message_with_invalid_target() -> None: executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source1.id, target_id="invalid_target") - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -982,13 +980,13 @@ async def test_target_edge_group_send_message_with_invalid_data() -> None: executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source1.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -1002,7 +1000,7 @@ async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None: edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") @@ -1020,7 +1018,7 @@ async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None: # Send first message (should be buffered) success = await edge_runner.send_message( Message(data=data, source_id=source1.id, trace_contexts=trace_contexts1, source_span_ids=source_span_ids1), - shared_state, + state, ctx, ) assert success is True @@ -1052,7 +1050,7 @@ async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None: success = await edge_runner.send_message( Message(data=data, source_id=source2.id, trace_contexts=trace_contexts2, source_span_ids=source_span_ids2), - shared_state, + state, ctx, ) assert success is True @@ -1090,7 +1088,7 @@ async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter) -> None: edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() # Send incompatible data type @@ -1100,7 +1098,7 @@ async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter) -> None: # Clear any build spans span_exporter.clear() - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False spans = span_exporter.get_finished_spans() @@ -1126,14 +1124,14 @@ async def test_fan_in_edge_group_with_multiple_message_types() -> None: executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") success = await edge_runner.send_message( Message(data=data, source_id=source1.id), - shared_state, + state, ctx, ) assert success @@ -1141,7 +1139,7 @@ async def test_fan_in_edge_group_with_multiple_message_types() -> None: data2 = MockMessageSecondary(data="test") success = await edge_runner.send_message( Message(data=data2, source_id=source2.id), - shared_state, + state, ctx, ) assert success @@ -1157,14 +1155,14 @@ async def test_fan_in_edge_group_with_multiple_message_types_failed() -> None: executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data="test") success = await edge_runner.send_message( Message(data=data, source_id=source1.id), - shared_state, + state, ctx, ) assert success @@ -1178,7 +1176,7 @@ async def test_fan_in_edge_group_with_multiple_message_types_failed() -> None: data2 = MockMessageSecondary(data="test") _ = await edge_runner.send_message( Message(data=data2, source_id=source2.id), - shared_state, + state, ctx, ) @@ -1273,14 +1271,14 @@ async def test_switch_case_edge_group_send_message() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data=-1) message = Message(data=data, source_id=source.id) with patch("agent_framework._workflows._edge_runner.EdgeRunner._execute_on_target") as mock_send: - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert mock_send.call_count == 1 @@ -1289,7 +1287,7 @@ async def test_switch_case_edge_group_send_message() -> None: data = MockMessage(data=1) message = Message(data=data, source_id=source.id) with patch("agent_framework._workflows._edge_runner.EdgeRunner._execute_on_target") as mock_send: - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True assert mock_send.call_count == 1 @@ -1312,13 +1310,13 @@ async def test_switch_case_edge_group_send_message_with_invalid_target() -> None executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data=-1) message = Message(data=data, source_id=source.id, target_id="invalid_target") - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False @@ -1339,18 +1337,18 @@ async def test_switch_case_edge_group_send_message_with_valid_target() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = MockMessage(data=1) # Condition will fail message = Message(data=data, source_id=source.id, target_id=target1.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False data = MockMessage(data=-1) # Condition will pass message = Message(data=data, source_id=source.id, target_id=target1.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is True @@ -1371,13 +1369,13 @@ async def test_switch_case_edge_group_send_message_with_invalid_data() -> None: executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) - success = await edge_runner.send_message(message, shared_state, ctx) + success = await edge_runner.send_message(message, state, ctx) assert success is False diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 9c6a2521b1..2ef43e246c 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -897,7 +897,7 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): latest_checkpoint = checkpoints[-1] - # Load checkpoint and verify no duplicates in shared state + # Load checkpoint and verify no duplicates in state checkpoint_data = await storage.load_checkpoint(latest_checkpoint.checkpoint_id) assert checkpoint_data is not None diff --git a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py index c0fd8e198f..8442af9445 100644 --- a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py +++ b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py @@ -10,7 +10,7 @@ from agent_framework._workflows._checkpoint_encoding import DATACLASS_MARKER, encode_checkpoint_value from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary from agent_framework._workflows._events import RequestInfoEvent -from agent_framework._workflows._shared_state import SharedState +from agent_framework._workflows._state import State @dataclass @@ -46,7 +46,7 @@ async def test_rehydrate_request_info_event() -> None: runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) await runner_context.add_request_info_event(request_info_event) - checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint_id = await runner_context.create_checkpoint(State(), iteration_count=1) checkpoint = await runner_context.load_checkpoint(checkpoint_id) assert checkpoint is not None @@ -79,7 +79,7 @@ async def test_rehydrate_fails_when_request_type_missing() -> None: runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) await runner_context.add_request_info_event(request_info_event) - checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint_id = await runner_context.create_checkpoint(State(), iteration_count=1) checkpoint = await runner_context.load_checkpoint(checkpoint_id) assert checkpoint is not None @@ -107,7 +107,7 @@ async def test_rehydrate_fails_when_request_type_mismatch() -> None: runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) await runner_context.add_request_info_event(request_info_event) - checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint_id = await runner_context.create_checkpoint(State(), iteration_count=1) checkpoint = await runner_context.load_checkpoint(checkpoint_id) assert checkpoint is not None @@ -137,7 +137,7 @@ async def test_pending_requests_in_summary() -> None: runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) await runner_context.add_request_info_event(request_info_event) - checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint_id = await runner_context.create_checkpoint(State(), iteration_count=1) checkpoint = await runner_context.load_checkpoint(checkpoint_id) assert checkpoint is not None @@ -175,7 +175,7 @@ async def test_request_info_event_serializes_non_json_payloads() -> None: await runner_context.add_request_info_event(req_1) await runner_context.add_request_info_event(req_2) - checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint_id = await runner_context.create_checkpoint(State(), iteration_count=1) checkpoint = await runner_context.load_checkpoint(checkpoint_id) # Should be JSON serializable despite datetime/slots diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index f6a031e5a3..0800f6f199 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -25,7 +25,7 @@ Message, RunnerContext, ) -from agent_framework._workflows._shared_state import SharedState +from agent_framework._workflows._state import State @dataclass @@ -48,7 +48,7 @@ async def mock_handler(self, message: MockMessage, ctx: WorkflowContext[MockMess def test_create_runner(): - """Test creating a runner with edges and shared state.""" + """Test creating a runner with edges and state.""" executor_a = MockExecutor(id="executor_a") executor_b = MockExecutor(id="executor_b") @@ -63,7 +63,7 @@ def test_create_runner(): executor_b.id: executor_b, } - runner = Runner(edge_groups, executors, shared_state=SharedState(), ctx=InProcRunnerContext()) + runner = Runner(edge_groups, executors, state=State(), ctx=InProcRunnerContext()) assert runner.context is not None and isinstance(runner.context, RunnerContext) @@ -83,16 +83,16 @@ async def test_runner_run_until_convergence(): executor_a.id: executor_a, executor_b.id: executor_b, } - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() - runner = Runner(edges, executors, shared_state, ctx) + runner = Runner(edges, executors, state, ctx) result: int | None = None await executor_a.execute( MockMessage(data=0), ["START"], # source_executor_ids - shared_state, # shared_state + state, # state ctx, # runner_context ) async for event in runner.run_until_convergence(): @@ -121,15 +121,15 @@ async def test_runner_run_until_convergence_not_completed(): executor_a.id: executor_a, executor_b.id: executor_b, } - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() - runner = Runner(edges, executors, shared_state, ctx, max_iterations=5) + runner = Runner(edges, executors, state, ctx, max_iterations=5) await executor_a.execute( MockMessage(data=0), ["START"], # source_executor_ids - shared_state, # shared_state + state, # state ctx, # runner_context ) with pytest.raises( @@ -155,15 +155,15 @@ async def test_runner_already_running(): executor_a.id: executor_a, executor_b.id: executor_b, } - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() - runner = Runner(edges, executors, shared_state, ctx) + runner = Runner(edges, executors, state, ctx) await executor_a.execute( MockMessage(data=0), ["START"], # source_executor_ids - shared_state, # shared_state + state, # state ctx, # runner_context ) @@ -178,7 +178,7 @@ async def _run(): async def test_runner_emits_runner_completion_for_agent_response_without_targets(): ctx = InProcRunnerContext() - runner = Runner([], {}, SharedState(), ctx) + runner = Runner([], {}, State(), ctx) await ctx.send_message( Message( diff --git a/python/packages/core/tests/workflow/test_serialization.py b/python/packages/core/tests/workflow/test_serialization.py index 2bb8f305e9..b22de85cc0 100644 --- a/python/packages/core/tests/workflow/test_serialization.py +++ b/python/packages/core/tests/workflow/test_serialization.py @@ -623,7 +623,7 @@ def test_workflow_serialization_excludes_non_serializable_fields(self) -> None: # These private runtime fields should not be in the serialized data assert "_runner_context" not in data - assert "_shared_state" not in data + assert "_state" not in data assert "_runner" not in data def test_workflow_name_description_serialization(self) -> None: @@ -760,7 +760,7 @@ def test_comprehensive_edge_groups_workflow_serialization() -> None: # Verify that serialization excludes non-serializable fields assert "_runner_context" not in data - assert "_shared_state" not in data + assert "_state" not in data assert "_runner" not in data # Test that we can identify each edge group type by examining their structure diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 6b08b7b22a..f837e3183e 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -90,7 +90,7 @@ class MockExecutorRequestApproval(Executor): @handler async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext) -> None: """A mock handler that requests approval.""" - await ctx.set_shared_state(self.id, message.data) + ctx.set_state(self.id, message.data) await ctx.request_info(MockRequest(prompt="Mock approval request"), ApprovalMessage) @response_handler @@ -101,7 +101,7 @@ async def mock_handler_b( ctx: WorkflowContext[NumberMessage, int], ) -> None: """A mock handler that processes the approval response.""" - data = await ctx.get_shared_state(self.id) + data = ctx.get_state(self.id) assert isinstance(data, int) if response.approved: await ctx.yield_output(data) @@ -371,7 +371,7 @@ async def test_workflow_run_stream_from_checkpoint_with_external_storage( test_checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={}, - shared_state={}, + state={}, iteration_count=0, ) checkpoint_id = await storage.save_checkpoint(test_checkpoint) @@ -406,7 +406,7 @@ async def test_workflow_run_from_checkpoint_non_streaming(simple_executor: Execu test_checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={}, - shared_state={}, + state={}, iteration_count=0, ) checkpoint_id = await storage.save_checkpoint(test_checkpoint) @@ -439,7 +439,7 @@ async def test_workflow_run_stream_from_checkpoint_with_responses( test_checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={}, - shared_state={}, + state={}, pending_request_info_events={ "request_123": RequestInfoEvent( request_id="request_123", @@ -483,7 +483,7 @@ class StateTrackingMessage: class StateTrackingExecutor(Executor): - """An executor that tracks state in shared state to test context reset behavior.""" + """An executor that tracks state in workflow state to test context reset behavior.""" @handler async def handle_message( @@ -491,19 +491,16 @@ async def handle_message( message: StateTrackingMessage, ctx: WorkflowContext[StateTrackingMessage, list[str]], ) -> None: - """Handle the message and track it in shared state.""" - # Get existing messages from shared state - try: - existing_messages = await ctx.get_shared_state("processed_messages") - except KeyError: - existing_messages = [] + """Handle the message and track it in workflow state.""" + # Get existing messages from workflow state + existing_messages = ctx.get_state("processed_messages") or [] # Record this message message_record = f"{message.run_id}:{message.data}" existing_messages.append(message_record) # type: ignore - # Update shared state - await ctx.set_shared_state("processed_messages", existing_messages) + # Update workflow state + ctx.set_state("processed_messages", existing_messages) # Yield output await ctx.yield_output(existing_messages.copy()) # type: ignore @@ -514,7 +511,7 @@ async def test_workflow_multiple_runs_no_state_collision(): with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) - # Create executor that tracks state in shared state + # Create executor that tracks state in workflow state state_executor = StateTrackingExecutor(id="state_executor") # Build workflow with checkpointing diff --git a/python/packages/core/tests/workflow/test_workflow_context.py b/python/packages/core/tests/workflow/test_workflow_context.py index b63742d16f..e3fafc4144 100644 --- a/python/packages/core/tests/workflow/test_workflow_context.py +++ b/python/packages/core/tests/workflow/test_workflow_context.py @@ -41,15 +41,15 @@ async def make_context( executor_id: str = "exec", ) -> AsyncIterator[tuple[WorkflowContext[object], "InProcRunnerContext"]]: from agent_framework._workflows._runner_context import InProcRunnerContext - from agent_framework._workflows._shared_state import SharedState + from agent_framework._workflows._state import State mock_executor = MockExecutor(executor_id) runner_ctx = InProcRunnerContext() - shared_state = SharedState() + state = State() workflow_ctx: WorkflowContext[object] = WorkflowContext( mock_executor, ["source"], - shared_state, + state, runner_ctx, ) try: diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 79aa009f57..ab12ecd815 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -209,48 +209,48 @@ def simple_selector(state: GroupChatState) -> str: # endregion -# region SharedState Verification Tests +# region State Verification Tests -async def test_kwargs_stored_in_shared_state() -> None: - """Test that kwargs are stored in SharedState with the correct key.""" +async def test_kwargs_stored_in_state() -> None: + """Test that kwargs are stored in State with the correct key.""" from agent_framework import Executor, WorkflowContext, handler stored_kwargs: dict[str, Any] | None = None - class _SharedStateInspector(Executor): + class _StateInspector(Executor): @handler async def inspect(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: nonlocal stored_kwargs - stored_kwargs = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + stored_kwargs = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) await ctx.send_message(msgs) - inspector = _SharedStateInspector(id="inspector") + inspector = _StateInspector(id="inspector") workflow = SequentialBuilder().participants([inspector]).build() async for event in workflow.run_stream("test", my_kwarg="my_value", another=123): if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: break - assert stored_kwargs is not None, "kwargs should be stored in SharedState" + assert stored_kwargs is not None, "kwargs should be stored in State" assert stored_kwargs.get("my_kwarg") == "my_value" assert stored_kwargs.get("another") == 123 async def test_empty_kwargs_stored_as_empty_dict() -> None: - """Test that empty kwargs are stored as empty dict in SharedState.""" + """Test that empty kwargs are stored as empty dict in State.""" from agent_framework import Executor, WorkflowContext, handler stored_kwargs: Any = "NOT_CHECKED" - class _SharedStateChecker(Executor): + class _StateChecker(Executor): @handler async def check(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: nonlocal stored_kwargs - stored_kwargs = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + stored_kwargs = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) await ctx.send_message(msgs) - checker = _SharedStateChecker(id="checker") + checker = _StateChecker(id="checker") workflow = SequentialBuilder().participants([checker]).build() # Run without any kwargs @@ -258,7 +258,7 @@ async def check(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMes if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: break - # SharedState should have empty dict when no kwargs provided + # State should have empty dict when no kwargs provided assert stored_kwargs == {}, f"Expected empty dict, got: {stored_kwargs}" @@ -421,8 +421,8 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatM # A more comprehensive integration test would require the manager to select an agent. -async def test_magentic_kwargs_stored_in_shared_state() -> None: - """Test that kwargs are stored in SharedState when using MagenticWorkflow.run_stream().""" +async def test_magentic_kwargs_stored_in_state() -> None: + """Test that kwargs are stored in State when using MagenticWorkflow.run_stream().""" from agent_framework import MagenticBuilder from agent_framework._workflows._magentic import ( MagenticContext, @@ -640,10 +640,10 @@ async def test_subworkflow_kwargs_propagation() -> None: ) -async def test_subworkflow_kwargs_accessible_via_shared_state() -> None: - """Test that kwargs are accessible via SharedState within subworkflow. +async def test_subworkflow_kwargs_accessible_via_state() -> None: + """Test that kwargs are accessible via State within subworkflow. - Verifies that WORKFLOW_RUN_KWARGS_KEY is populated in the subworkflow's SharedState + Verifies that WORKFLOW_RUN_KWARGS_KEY is populated in the subworkflow's State with kwargs from the parent workflow. """ from agent_framework import Executor, WorkflowContext, handler @@ -651,17 +651,17 @@ async def test_subworkflow_kwargs_accessible_via_shared_state() -> None: captured_kwargs_from_state: list[dict[str, Any]] = [] - class _SharedStateReader(Executor): - """Executor that reads kwargs from SharedState for verification.""" + class _StateReader(Executor): + """Executor that reads kwargs from State for verification.""" @handler async def read_kwargs(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: - kwargs_from_state = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + kwargs_from_state = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) captured_kwargs_from_state.append(kwargs_from_state or {}) await ctx.send_message(msgs) - # Build inner workflow with SharedState reader - state_reader = _SharedStateReader(id="state_reader") + # Build inner workflow with State reader + state_reader = _StateReader(id="state_reader") inner_workflow = SequentialBuilder().participants([state_reader]).build() # Wrap as subworkflow @@ -680,15 +680,15 @@ async def read_kwargs(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[C break # Verify the state reader was invoked - assert len(captured_kwargs_from_state) >= 1, "SharedState reader should have been invoked" + assert len(captured_kwargs_from_state) >= 1, "State reader should have been invoked" kwargs_in_subworkflow = captured_kwargs_from_state[0] assert kwargs_in_subworkflow.get("my_custom_kwarg") == "should_be_propagated", ( - f"Expected 'my_custom_kwarg' in subworkflow SharedState, got: {kwargs_in_subworkflow}" + f"Expected 'my_custom_kwarg' in subworkflow got: {kwargs_in_subworkflow}" ) assert kwargs_in_subworkflow.get("another_kwarg") == 42, ( - f"Expected 'another_kwarg'=42 in subworkflow SharedState, got: {kwargs_in_subworkflow}" + f"Expected 'another_kwarg'=42 in subworkflow got: {kwargs_in_subworkflow}" ) diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index 4c97b850b8..123c0ddf04 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -9,7 +9,7 @@ from agent_framework import InMemoryCheckpointStorage, WorkflowBuilder from agent_framework._workflows._executor import Executor, handler from agent_framework._workflows._runner_context import InProcRunnerContext, Message, MessageType -from agent_framework._workflows._shared_state import SharedState +from agent_framework._workflows._state import State from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_context import WorkflowContext from agent_framework.observability import ( @@ -170,7 +170,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter) async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> None: """Test trace context propagation and handling in messages and executors.""" - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() executor = MockExecutor("test-executor") @@ -180,7 +180,7 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No workflow_ctx: WorkflowContext[str] = WorkflowContext( executor, ["source"], - shared_state, + state, ctx, trace_contexts=[{"traceparent": "00-12345678901234567890123456789012-1234567890123456-01"}], source_span_ids=["1234567890123456"], @@ -202,7 +202,7 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No await executor.execute( "test message", ["source"], # source_executor_ids - shared_state, # shared_state + state, # state ctx, # runner_context trace_contexts=[{"traceparent": "00-12345678901234567890123456789012-1234567890123456-01"}], source_span_ids=["1234567890123456"], @@ -236,13 +236,13 @@ async def test_trace_context_disabled_when_tracing_disabled( """Test that no trace context is added when tracing is disabled.""" # Tracing should be disabled by default executor = MockExecutor("test-executor") - shared_state = SharedState() + state = State() ctx = InProcRunnerContext() workflow_ctx: WorkflowContext[str] = WorkflowContext( executor, ["source"], - shared_state, + state, ctx, ) @@ -452,7 +452,7 @@ async def test_message_trace_context_serialization(span_exporter: InMemorySpanEx await ctx.send_message(message) # Create a checkpoint that includes the message - checkpoint_id = await ctx.create_checkpoint(SharedState(), 0) + checkpoint_id = await ctx.create_checkpoint(State(), 0) checkpoint = await ctx.load_checkpoint(checkpoint_id) assert checkpoint is not None diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index 53baf86383..1c354c0d7d 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -8,7 +8,6 @@ ExecutorFailedEvent, InProcRunnerContext, RequestInfoEvent, - SharedState, Workflow, WorkflowBuilder, WorkflowContext, @@ -20,6 +19,7 @@ WorkflowStatusEvent, handler, ) +from agent_framework._workflows._state import State class FailingExecutor(Executor): @@ -62,12 +62,12 @@ async def test_executor_failed_and_workflow_failed_events_streaming(): async def test_executor_failed_event_emitted_on_direct_execute(): failing = FailingExecutor(id="f") ctx = InProcRunnerContext() - shared_state = SharedState() + state = State() with pytest.raises(RuntimeError, match="boom"): await failing.execute( 0, ["START"], - shared_state, + state, ctx, ) drained = await ctx.drain_events() diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index 309a71a4b7..1b1ca6ae04 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -3,7 +3,7 @@ """Base classes for graph-based declarative workflow executors. This module provides: -- DeclarativeWorkflowState: Manages workflow variables via SharedState +- DeclarativeWorkflowState: Manages workflow variables via State - DeclarativeActionExecutor: Base class for action executors - Message types for inter-executor communication @@ -32,9 +32,9 @@ from agent_framework._workflows import ( Executor, - SharedState, WorkflowContext, ) +from agent_framework._workflows._state import State from powerfx import Engine if sys.version_info >= (3, 11): @@ -61,10 +61,10 @@ class ConversationData(TypedDict): class DeclarativeStateData(TypedDict, total=False): - """Structure for the declarative workflow state stored in SharedState. + """Structure for the declarative workflow state stored in State. This TypedDict defines the schema for workflow variables stored - under the DECLARATIVE_STATE_KEY in SharedState. + under the DECLARATIVE_STATE_KEY in State. Variable Scopes (matching .NET naming conventions): Inputs: Initial workflow inputs (read-only after initialization). @@ -87,7 +87,7 @@ class DeclarativeStateData(TypedDict, total=False): _declarative_loop_state: dict[str, Any] -# Key used in SharedState to store declarative workflow variables +# Key used in State to store declarative workflow variables DECLARATIVE_STATE_KEY = "_declarative_workflow_state" @@ -126,10 +126,10 @@ def _make_powerfx_safe(value: Any) -> Any: class DeclarativeWorkflowState: - """Manages workflow variables stored in SharedState. + """Manages workflow variables stored in State. This class provides the same interface as the interpreter-based WorkflowState - but stores all data in SharedState for checkpointing support. + but stores all data in State for checkpointing support. The state is organized into namespaces (matching .NET naming conventions): - Workflow.Inputs: Initial inputs (read-only) @@ -140,15 +140,15 @@ class DeclarativeWorkflowState: - Conversation: Conversation history """ - def __init__(self, shared_state: SharedState): - """Initialize with a SharedState instance. + def __init__(self, state: State): + """Initialize with a State instance. Args: - shared_state: The workflow's shared state for persistence + state: The workflow's state for persistence """ - self._shared_state = shared_state + self._state = state - async def initialize(self, inputs: "Mapping[str, Any] | None" = None) -> None: + def initialize(self, inputs: "Mapping[str, Any] | None" = None) -> None: """Initialize the declarative state with inputs. Args: @@ -168,23 +168,22 @@ async def initialize(self, inputs: "Mapping[str, Any] | None" = None) -> None: "Conversation": {"messages": [], "history": []}, "Custom": {}, } - await self._shared_state.set(DECLARATIVE_STATE_KEY, state_data) + self._state.set(DECLARATIVE_STATE_KEY, state_data) - async def get_state_data(self) -> DeclarativeStateData: - """Get the full state data dict from shared state.""" - try: - result: DeclarativeStateData = await self._shared_state.get(DECLARATIVE_STATE_KEY) - return result - except KeyError: + def get_state_data(self) -> DeclarativeStateData: + """Get the full state data dict from state.""" + result = self._state.get(DECLARATIVE_STATE_KEY) + if result is None: # Initialize if not present - await self.initialize() - return cast(DeclarativeStateData, await self._shared_state.get(DECLARATIVE_STATE_KEY)) + self.initialize() + result = self._state.get(DECLARATIVE_STATE_KEY) + return cast(DeclarativeStateData, result) - async def set_state_data(self, data: DeclarativeStateData) -> None: - """Set the full state data dict in shared state.""" - await self._shared_state.set(DECLARATIVE_STATE_KEY, data) + def set_state_data(self, data: DeclarativeStateData) -> None: + """Set the full state data dict in state.""" + self._state.set(DECLARATIVE_STATE_KEY, data) - async def get(self, path: str, default: Any = None) -> Any: + def get(self, path: str, default: Any = None) -> Any: """Get a value from the state using a dot-notated path. Args: @@ -194,7 +193,7 @@ async def get(self, path: str, default: Any = None) -> Any: Returns: The value at the path, or default if not found """ - state_data = await self.get_state_data() + state_data = self.get_state_data() parts = path.split(".") if not parts: return default @@ -240,7 +239,7 @@ async def get(self, path: str, default: Any = None) -> Any: return obj # type: ignore[return-value] - async def set(self, path: str, value: Any) -> None: + def set(self, path: str, value: Any) -> None: """Set a value in the state using a dot-notated path. Args: @@ -250,7 +249,7 @@ async def set(self, path: str, value: Any) -> None: Raises: ValueError: If attempting to set Workflow.Inputs (which is read-only) """ - state_data = await self.get_state_data() + state_data = self.get_state_data() parts = path.split(".") if not parts: return @@ -296,9 +295,9 @@ async def set(self, path: str, value: Any) -> None: # Set the final value target[remaining[-1]] = value - await self.set_state_data(state_data) + self.set_state_data(state_data) - async def append(self, path: str, value: Any) -> None: + def append(self, path: str, value: Any) -> None: """Append a value to a list at the specified path. If the path doesn't exist, creates a new list with the value. @@ -310,17 +309,17 @@ async def append(self, path: str, value: Any) -> None: path: Dot-notated path to a list value: The value to append """ - existing = await self.get(path) + existing = self.get(path) if existing is None: - await self.set(path, [value]) + self.set(path, [value]) elif isinstance(existing, list): existing_list: list[Any] = list(existing) # type: ignore[arg-type] existing_list.append(value) - await self.set(path, existing_list) + self.set(path, existing_list) else: raise ValueError(f"Cannot append to non-list at path '{path}'") - async def eval(self, expression: str) -> Any: + def eval(self, expression: str) -> Any: """Evaluate a PowerFx expression with the current state. Expressions starting with '=' are evaluated as PowerFx. @@ -354,16 +353,16 @@ async def eval(self, expression: str) -> Any: # Handle custom functions not supported by PowerFx # First check if the entire formula is a custom function - result = await self._eval_custom_function(formula) + result = self._eval_custom_function(formula) if result is not None: return result # Pre-process nested custom functions (e.g., Upper(MessageText(...))) # Replace them with their evaluated results before sending to PowerFx - formula = await self._preprocess_custom_functions(formula) + formula = self._preprocess_custom_functions(formula) engine = Engine() - symbols = await self._to_powerfx_symbols() + symbols = self._to_powerfx_symbols() try: return engine.eval(formula, symbols=symbols) except ValueError as e: @@ -375,7 +374,7 @@ async def eval(self, expression: str) -> Any: return None raise - async def _eval_custom_function(self, formula: str) -> Any | None: + def _eval_custom_function(self, formula: str) -> Any | None: """Handle custom functions not supported by the Python PowerFx library. The standard PowerFx library supports these functions but the Python wrapper @@ -404,7 +403,7 @@ async def _eval_custom_function(self, formula: str) -> Any | None: evaluated_args.append(arg[1:-1]) else: # Variable reference - evaluate it - result = await self.eval(f"={arg}") + result = self.eval(f"={arg}") evaluated_args.append(str(result) if result is not None else "") return "".join(evaluated_args) @@ -413,14 +412,14 @@ async def _eval_custom_function(self, formula: str) -> Any | None: if match: inner_expr = match.group(1).strip() # Evaluate the inner expression - text = await self.eval(f"={inner_expr}") + text = self.eval(f"={inner_expr}") return {"role": "user", "text": str(text) if text else ""} # AgentMessage(expr) - creates an assistant message dict match = re.match(r"AgentMessage\((.+)\)$", formula.strip()) if match: inner_expr = match.group(1).strip() - text = await self.eval(f"={inner_expr}") + text = self.eval(f"={inner_expr}") return {"role": "assistant", "text": str(text) if text else ""} # MessageText(expr) - extracts text from the last message @@ -428,11 +427,11 @@ async def _eval_custom_function(self, formula: str) -> Any | None: if match: inner_expr = match.group(1).strip() # Reuse the helper method for consistent text extraction - return await self._eval_and_replace_message_text(inner_expr) + return self._eval_and_replace_message_text(inner_expr) return None - async def _preprocess_custom_functions(self, formula: str) -> str: + def _preprocess_custom_functions(self, formula: str) -> str: """Pre-process custom functions nested inside other PowerFx functions. Custom functions like MessageText() are not supported by the PowerFx engine. @@ -509,7 +508,7 @@ async def _preprocess_custom_functions(self, formula: str) -> str: inner_expr = formula[paren_start + 1 : end - 1] # Evaluate and get replacement - replacement = await handler(inner_expr) + replacement = handler(inner_expr) # Replace in formula if isinstance(replacement, str): @@ -517,7 +516,7 @@ async def _preprocess_custom_functions(self, formula: str) -> str: # Store long strings in a temp variable to avoid PowerFx expression limit temp_var_name = f"_TempMessageText{temp_var_counter}" temp_var_counter += 1 - await self.set(f"Local.{temp_var_name}", replacement) + self.set(f"Local.{temp_var_name}", replacement) replacement_str = f"Local.{temp_var_name}" logger.debug( f"Stored long MessageText result ({len(replacement)} chars) " @@ -534,7 +533,7 @@ async def _preprocess_custom_functions(self, formula: str) -> str: return formula - async def _eval_and_replace_message_text(self, inner_expr: str) -> str: + def _eval_and_replace_message_text(self, inner_expr: str) -> str: """Evaluate MessageText() and return the text result. Args: @@ -543,7 +542,7 @@ async def _eval_and_replace_message_text(self, inner_expr: str) -> str: Returns: The extracted text from the messages """ - messages: Any = await self.eval(f"={inner_expr}") + messages: Any = self.eval(f"={inner_expr}") if isinstance(messages, list) and messages: last_msg: Any = messages[-1] if isinstance(last_msg, dict): @@ -603,13 +602,13 @@ def _parse_function_args(self, args_str: str) -> list[str]: return args - async def _to_powerfx_symbols(self) -> dict[str, Any]: + def _to_powerfx_symbols(self) -> dict[str, Any]: """Convert the current state to a PowerFx symbols dictionary. Uses .NET-style PascalCase names (System, Local, Workflow) matching the .NET declarative workflow implementation. """ - state_data = await self.get_state_data() + state_data = self.get_state_data() local_data = state_data.get("Local", {}) agent_data = state_data.get("Agent", {}) conversation_data = state_data.get("Conversation", {}) @@ -642,19 +641,19 @@ async def _to_powerfx_symbols(self) -> dict[str, Any]: result = _make_powerfx_safe(symbols) return cast(dict[str, Any], result) - async def eval_if_expression(self, value: Any) -> Any: + def eval_if_expression(self, value: Any) -> Any: """Evaluate a value if it's a PowerFx expression, otherwise return as-is.""" if isinstance(value, str): - return await self.eval(value) + return self.eval(value) if isinstance(value, dict): value_dict: dict[str, Any] = dict(value) # type: ignore[arg-type] - return {k: await self.eval_if_expression(v) for k, v in value_dict.items()} + return {k: self.eval_if_expression(v) for k, v in value_dict.items()} if isinstance(value, list): value_list: list[Any] = list(value) # type: ignore[arg-type] - return [await self.eval_if_expression(item) for item in value_list] + return [self.eval_if_expression(item) for item in value_list] return value - async def interpolate_string(self, text: str) -> str: + def interpolate_string(self, text: str) -> str: """Interpolate {Variable.Path} references in a string. This handles template-style variable substitution like: @@ -669,18 +668,18 @@ async def interpolate_string(self, text: str) -> str: """ import re - async def replace_var(match: re.Match[str]) -> str: + def replace_var(match: re.Match[str]) -> str: var_path: str = match.group(1) - value = await self.get(var_path) + value = self.get(var_path) return str(value) if value is not None else "" # Match {Variable.Path} patterns pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}" - # re.sub doesn't support async, so we need to do it manually + # Replace all matches result = text for match in re.finditer(pattern, text): - replacement = await replace_var(match) + replacement = replace_var(match) result = result.replace(match.group(0), replacement, 1) return result @@ -802,9 +801,9 @@ def display_name(self) -> str | None: """Get the display name for logging.""" return self._action_def.get("displayName") - def _get_state(self, shared_state: SharedState) -> DeclarativeWorkflowState: + def _get_state(self, state: State) -> DeclarativeWorkflowState: """Get the declarative workflow state wrapper.""" - return DeclarativeWorkflowState(shared_state) + return DeclarativeWorkflowState(state) async def _ensure_state_initialized( self, @@ -826,18 +825,18 @@ async def _ensure_state_initialized( Returns: The initialized DeclarativeWorkflowState """ - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) if isinstance(trigger, dict): # Structured inputs - use directly - await state.initialize(trigger) # type: ignore + state.initialize(trigger) # type: ignore elif isinstance(trigger, str): # String input - wrap in dict - await state.initialize({"input": trigger}) + state.initialize({"input": trigger}) elif not isinstance( trigger, (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl) ): # Any other type - convert to string like .NET's DefaultTransform - await state.initialize({"input": str(trigger)}) + state.initialize({"input": str(trigger)}) return state diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py index 18685ef401..a5b692c5a1 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py @@ -348,7 +348,7 @@ class AgentExternalInputResponse: class ExternalLoopState: """State saved for external loop resumption. - Stored in shared_state to allow the response_handler to + Stored in workflow state to allow the response_handler to continue the loop with the same configuration. """ @@ -534,7 +534,7 @@ async def _get_conversation_messages_path( return "Conversation.messages" # Evaluate the conversation ID expression - evaluated_id = await state.eval_if_expression(conversation_id_expr) + evaluated_id = state.eval_if_expression(conversation_id_expr) if not evaluated_id: return "Conversation.messages" @@ -555,11 +555,11 @@ async def _build_input_text(self, state: Any, arguments: dict[str, Any], message # Evaluate arguments evaluated_args: dict[str, Any] = {} for key, value in arguments.items(): - evaluated_args[key] = await state.eval_if_expression(value) + evaluated_args[key] = state.eval_if_expression(value) # Evaluate messages/input if messages_expr: - evaluated_input: Any = await state.eval_if_expression(messages_expr) + evaluated_input: Any = state.eval_if_expression(messages_expr) if isinstance(evaluated_input, str): return evaluated_input if isinstance(evaluated_input, list) and evaluated_input: @@ -581,17 +581,17 @@ async def _build_input_text(self, state: Any, arguments: dict[str, Any], message # 1. Local.input / Local.userInput (explicit turn state) # 2. System.LastMessage.Text (previous agent's response) # 3. Workflow.Inputs (first agent gets workflow inputs) - input_text: str = str(await state.get("Local.input") or await state.get("Local.userInput") or "") + input_text: str = str(state.get("Local.input") or state.get("Local.userInput") or "") if not input_text: # Try System.LastMessage.Text (used by external loop and agent chaining) - last_message: Any = await state.get("System.LastMessage") + last_message: Any = state.get("System.LastMessage") if isinstance(last_message, dict): last_msg_dict = cast(dict[str, Any], last_message) text_val: Any = last_msg_dict.get("Text", "") input_text = str(text_val) if text_val else "" if not input_text: # Fall back to workflow inputs (for first agent in chain) - inputs: Any = await state.get("Workflow.Inputs") + inputs: Any = state.get("Workflow.Inputs") if isinstance(inputs, dict): inputs_dict = cast(dict[str, Any], inputs) # If single input, use its value directly @@ -643,11 +643,11 @@ async def _invoke_agent_and_store_results( # Add user input to conversation history first (via state.append only) if input_text: user_message = ChatMessage(role="user", text=input_text) - await state.append(messages_path, user_message) + state.append(messages_path, user_message) # Get conversation history from state AFTER adding user message # Note: We get a fresh copy to avoid mutation issues - conversation_history: list[ChatMessage] = await state.get(messages_path) or [] + conversation_history: list[ChatMessage] = state.get(messages_path) or [] # Build messages list for agent (use history if available, otherwise just input) messages_for_agent: list[ChatMessage] | str = conversation_history if conversation_history else input_text @@ -704,7 +704,7 @@ async def _invoke_agent_and_store_results( role, content_types, ) - await state.append(messages_path, msg) + state.append(messages_path, msg) elif accumulated_response: # No messages returned, create a simple assistant message logger.debug( @@ -712,24 +712,24 @@ async def _invoke_agent_and_store_results( agent_name, ) assistant_message = ChatMessage(role="assistant", text=accumulated_response) - await state.append(messages_path, assistant_message) + state.append(messages_path, assistant_message) # Store results in state - support both schema formats: # - Graph mode: Agent.response, Agent.name # - Interpreter mode: Agent.text, Agent.messages, Agent.toolCalls - await state.set("Agent.response", accumulated_response) - await state.set("Agent.name", agent_name) - await state.set("Agent.text", accumulated_response) - await state.set("Agent.messages", all_messages if all_messages else []) - await state.set("Agent.toolCalls", tool_calls if tool_calls else []) + state.set("Agent.response", accumulated_response) + state.set("Agent.name", agent_name) + state.set("Agent.text", accumulated_response) + state.set("Agent.messages", all_messages if all_messages else []) + state.set("Agent.toolCalls", tool_calls if tool_calls else []) # Store System.LastMessage for externalLoop.when condition evaluation - await state.set("System.LastMessage", {"Text": accumulated_response}) + state.set("System.LastMessage", {"Text": accumulated_response}) # Store in output variables (.NET style) if messages_var: output_path = _normalize_variable_path(messages_var) - await state.set(output_path, all_messages if all_messages else accumulated_response) + state.set(output_path, all_messages if all_messages else accumulated_response) if response_obj_var: output_path = _normalize_variable_path(response_obj_var) @@ -737,14 +737,14 @@ async def _invoke_agent_and_store_results( try: parsed = _extract_json_from_response(accumulated_response) if accumulated_response else None logger.debug(f"InvokeAzureAgent: parsed responseObject for '{output_path}': type={type(parsed)}") - await state.set(output_path, parsed) + state.set(output_path, parsed) except (json.JSONDecodeError, TypeError) as e: logger.warning(f"InvokeAzureAgent: failed to parse JSON for '{output_path}': {e}, storing as string") - await state.set(output_path, accumulated_response) + state.set(output_path, accumulated_response) # Store in result property (Python style) if result_property: - await state.set(result_property, accumulated_response) + state.set(result_property, accumulated_response) return accumulated_response, all_messages, tool_calls @@ -788,7 +788,7 @@ async def handle_action( agent: Any = self._agents.get(agent_name) if self._agents else None if agent is None: try: - agent_registry: dict[str, Any] | None = await ctx.shared_state.get(AGENT_REGISTRY_KEY) + agent_registry: dict[str, Any] | None = ctx.state.get(AGENT_REGISTRY_KEY) except KeyError: agent_registry = {} agent = agent_registry.get(agent_name) if agent_registry else None @@ -796,9 +796,9 @@ async def handle_action( if agent is None: error_msg = f"Agent '{agent_name}' not found in registry" logger.error(f"InvokeAzureAgent: {error_msg}") - await state.set("Agent.error", error_msg) + state.set("Agent.error", error_msg) if result_property: - await state.set(result_property, {"error": error_msg}) + state.set(result_property, {"error": error_msg}) raise AgentInvocationError(agent_name, "not found in registry") iteration = 0 @@ -820,14 +820,14 @@ async def handle_action( raise # Re-raise our own errors except Exception as e: logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}': {e}") - await state.set("Agent.error", str(e)) + state.set("Agent.error", str(e)) if result_property: - await state.set(result_property, {"error": str(e)}) + state.set(result_property, {"error": str(e)}) raise AgentInvocationError(agent_name, str(e)) from e # Check external loop condition if external_loop_when: - should_continue = await state.eval(external_loop_when) + should_continue = state.eval(external_loop_when) should_continue = bool(should_continue) if should_continue is not None else False logger.debug( @@ -848,7 +848,7 @@ async def handle_action( messages_path=messages_path, max_iterations=max_iterations, ) - await ctx.shared_state.set(EXTERNAL_LOOP_STATE_KEY, loop_state) + ctx.state.set(EXTERNAL_LOOP_STATE_KEY, loop_state) # Emit request for external input - workflow will yield here request = AgentExternalInputRequest( @@ -883,12 +883,11 @@ async def handle_external_input_response( "handle_external_input_response: resuming with user_input='%s'", response.user_input[:100] if response.user_input else None, ) - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) # Retrieve saved loop state - try: - loop_state: ExternalLoopState = await ctx.shared_state.get(EXTERNAL_LOOP_STATE_KEY) - except KeyError: + loop_state: ExternalLoopState | None = ctx.state.get(EXTERNAL_LOOP_STATE_KEY) + if loop_state is None: logger.error("InvokeAzureAgent: external loop state not found, cannot resume") await ctx.send_message(ActionComplete()) return @@ -910,12 +909,12 @@ async def handle_external_input_response( input_text = response.user_input # Store the user input in state for condition evaluation - await state.set("Local.userInput", input_text) - await state.set("System.LastMessage", {"Text": input_text}) + state.set("Local.userInput", input_text) + state.set("System.LastMessage", {"Text": input_text}) # Check if we should continue BEFORE invoking the agent # This matches .NET behavior where the condition checks the user's input - should_continue = await state.eval(external_loop_when) + should_continue = state.eval(external_loop_when) should_continue = bool(should_continue) if should_continue is not None else False logger.debug( @@ -926,7 +925,7 @@ async def handle_external_input_response( if not should_continue: # User input caused loop to exit - clean up and complete with contextlib.suppress(KeyError): - await ctx.shared_state.delete(EXTERNAL_LOOP_STATE_KEY) + ctx.state.delete(EXTERNAL_LOOP_STATE_KEY) await ctx.send_message(ActionComplete()) return @@ -934,7 +933,7 @@ async def handle_external_input_response( agent: Any = self._agents.get(agent_name) if self._agents else None if agent is None: try: - agent_registry: dict[str, Any] | None = await ctx.shared_state.get(AGENT_REGISTRY_KEY) + agent_registry: dict[str, Any] | None = ctx.state.get(AGENT_REGISTRY_KEY) except KeyError: agent_registry = {} agent = agent_registry.get(agent_name) if agent_registry else None @@ -960,12 +959,12 @@ async def handle_external_input_response( raise # Re-raise our own errors except Exception as e: logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}' during loop: {e}") - await state.set("Agent.error", str(e)) + state.set("Agent.error", str(e)) raise AgentInvocationError(agent_name, str(e)) from e # Re-evaluate the condition AFTER the agent responds # This is critical: the agent's response may have set NeedsTicket=true or IsResolved=true - should_continue = await state.eval(external_loop_when) + should_continue = state.eval(external_loop_when) should_continue = bool(should_continue) if should_continue is not None else False logger.debug( @@ -980,7 +979,7 @@ async def handle_external_input_response( "(sending ActionComplete to continue workflow)" ) with contextlib.suppress(KeyError): - await ctx.shared_state.delete(EXTERNAL_LOOP_STATE_KEY) + ctx.state.delete(EXTERNAL_LOOP_STATE_KEY) await ctx.send_message(ActionComplete()) return @@ -988,7 +987,7 @@ async def handle_external_input_response( if iteration < max_iterations: # Update loop state for next iteration loop_state.iteration = iteration + 1 - await ctx.shared_state.set(EXTERNAL_LOOP_STATE_KEY, loop_state) + ctx.state.set(EXTERNAL_LOOP_STATE_KEY, loop_state) # Emit another request for external input request = AgentExternalInputRequest( @@ -1007,7 +1006,7 @@ async def handle_external_input_response( # Loop complete - clean up and send completion with contextlib.suppress(KeyError): - await ctx.shared_state.delete(EXTERNAL_LOOP_STATE_KEY) + ctx.state.delete(EXTERNAL_LOOP_STATE_KEY) await ctx.send_message(ActionComplete()) @@ -1035,7 +1034,7 @@ async def handle_action( # Get tools registry try: - tool_registry: dict[str, Any] | None = await ctx.shared_state.get(TOOL_REGISTRY_KEY) + tool_registry: dict[str, Any] | None = ctx.state.get(TOOL_REGISTRY_KEY) except KeyError: tool_registry = {} @@ -1044,18 +1043,18 @@ async def handle_action( if tool is None: error_msg = f"Tool '{tool_name}' not found in registry" if output_property: - await state.set(output_property, {"error": error_msg}) + state.set(output_property, {"error": error_msg}) await ctx.send_message(ActionComplete()) return # Build parameters params: dict[str, Any] = {} for param_name, param_expression in parameters.items(): - params[param_name] = await state.eval_if_expression(param_expression) + params[param_name] = state.eval_if_expression(param_expression) # Add main input if specified if input_expr: - params["input"] = await state.eval_if_expression(input_expr) + params["input"] = state.eval_if_expression(input_expr) try: # Invoke the tool @@ -1068,11 +1067,11 @@ async def handle_action( # Store result if output_property: - await state.set(output_property, result) + state.set(output_property, result) except Exception as e: if output_property: - await state.set(output_property, {"error": str(e)}) + state.set(output_property, {"error": str(e)}) await ctx.send_message(ActionComplete()) return diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py index 6603357478..f4fed64791 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py @@ -52,8 +52,8 @@ async def handle_action( if path: # Evaluate value if it's an expression - evaluated_value = await state.eval_if_expression(value) - await state.set(path, evaluated_value) + evaluated_value = state.eval_if_expression(value) + state.set(path, evaluated_value) await ctx.send_message(ActionComplete()) @@ -74,8 +74,8 @@ async def handle_action( value = self._action_def.get("value") if path: - evaluated_value = await state.eval_if_expression(value) - await state.set(path, evaluated_value) + evaluated_value = state.eval_if_expression(value) + state.set(path, evaluated_value) await ctx.send_message(ActionComplete()) @@ -96,8 +96,8 @@ async def handle_action( text = self._action_def.get("text", "") if path: - evaluated_text = await state.eval_if_expression(text) - await state.set(path, str(evaluated_text) if evaluated_text is not None else "") + evaluated_text = state.eval_if_expression(text) + state.set(path, str(evaluated_text) if evaluated_text is not None else "") await ctx.send_message(ActionComplete()) @@ -126,8 +126,8 @@ async def handle_action( path = assignment.get("path") value = assignment.get("value") if path: - evaluated_value = await state.eval_if_expression(value) - await state.set(path, evaluated_value) + evaluated_value = state.eval_if_expression(value) + state.set(path, evaluated_value) await ctx.send_message(ActionComplete()) @@ -148,8 +148,8 @@ async def handle_action( value = self._action_def.get("value") if path: - evaluated_value = await state.eval_if_expression(value) - await state.append(path, evaluated_value) + evaluated_value = state.eval_if_expression(value) + state.append(path, evaluated_value) await ctx.send_message(ActionComplete()) @@ -170,7 +170,7 @@ async def handle_action( if path: # Reset to None/empty - await state.set(path, None) + state.set(path, None) await ctx.send_message(ActionComplete()) @@ -188,9 +188,9 @@ async def handle_action( state = await self._ensure_state_initialized(ctx, trigger) # Get state data and clear Local variables - state_data = await state.get_state_data() + state_data = state.get_state_data() state_data["Local"] = {} - await state.set_state_data(state_data) + state.set_state_data(state_data) await ctx.send_message(ActionComplete()) @@ -217,10 +217,10 @@ async def handle_action( if isinstance(text, str): # First evaluate any =expression syntax - text = await state.eval_if_expression(text) + text = state.eval_if_expression(text) # Then interpolate any {Variable.Path} template syntax if isinstance(text, str): - text = await state.interpolate_string(text) + text = state.interpolate_string(text) # Yield the text as workflow output if text: @@ -258,8 +258,8 @@ async def handle_action( event_value = event_def.get("data") if event_name: - evaluated_name = await state.eval_if_expression(event_name) - evaluated_value = await state.eval_if_expression(event_value) + evaluated_name = state.eval_if_expression(event_name) + evaluated_value = state.eval_if_expression(event_value) event_data = { "eventName": evaluated_name, @@ -300,16 +300,16 @@ async def handle_action( if table_path: # Get current table value - current_table = await state.get(table_path) + current_table = state.get(table_path) if current_table is None: current_table = [] elif not isinstance(current_table, list): current_table = [current_table] if operation == "add" or operation == "insert": - evaluated_value = await state.eval_if_expression(value) + evaluated_value = state.eval_if_expression(value) if index is not None: - evaluated_index = await state.eval_if_expression(index) + evaluated_index = state.eval_if_expression(index) idx = int(evaluated_index) if evaluated_index is not None else len(current_table) current_table.insert(idx, evaluated_value) else: @@ -318,12 +318,12 @@ async def handle_action( elif operation == "remove": if value is not None: # Remove by value - evaluated_value = await state.eval_if_expression(value) + evaluated_value = state.eval_if_expression(value) if evaluated_value in current_table: current_table.remove(evaluated_value) elif index is not None: # Remove by index - evaluated_index = await state.eval_if_expression(index) + evaluated_index = state.eval_if_expression(index) idx = int(evaluated_index) if evaluated_index is not None else -1 if 0 <= idx < len(current_table): current_table.pop(idx) @@ -334,13 +334,13 @@ async def handle_action( elif operation == "set" or operation == "update": # Update item at index if index is not None: - evaluated_value = await state.eval_if_expression(value) - evaluated_index = await state.eval_if_expression(index) + evaluated_value = state.eval_if_expression(value) + evaluated_index = state.eval_if_expression(index) idx = int(evaluated_index) if evaluated_index is not None else 0 if 0 <= idx < len(current_table): current_table[idx] = evaluated_value - await state.set(table_path, current_table) + state.set(table_path, current_table) await ctx.send_message(ActionComplete()) @@ -377,16 +377,16 @@ async def handle_action( if table_path: # Get current table value - current_table = await state.get(table_path) + current_table = state.get(table_path) if current_table is None: current_table = [] elif not isinstance(current_table, list): current_table = [current_table] if operation == "add": - evaluated_item = await state.eval_if_expression(item) + evaluated_item = state.eval_if_expression(item) if index is not None: - evaluated_index = await state.eval_if_expression(index) + evaluated_index = state.eval_if_expression(index) idx = int(evaluated_index) if evaluated_index is not None else len(current_table) current_table.insert(idx, evaluated_item) else: @@ -394,7 +394,7 @@ async def handle_action( elif operation == "remove": if item is not None: - evaluated_item = await state.eval_if_expression(item) + evaluated_item = state.eval_if_expression(item) if key_field and isinstance(evaluated_item, dict): # Remove by key match key_value = evaluated_item.get(key_field) @@ -404,7 +404,7 @@ async def handle_action( elif evaluated_item in current_table: current_table.remove(evaluated_item) elif index is not None: - evaluated_index = await state.eval_if_expression(index) + evaluated_index = state.eval_if_expression(index) idx = int(evaluated_index) if evaluated_index is not None else -1 if 0 <= idx < len(current_table): current_table.pop(idx) @@ -413,7 +413,7 @@ async def handle_action( current_table = [] elif operation == "addorupdate": - evaluated_item = await state.eval_if_expression(item) + evaluated_item = state.eval_if_expression(item) if key_field and isinstance(evaluated_item, dict): key_value = evaluated_item.get(key_field) # Find existing item with same key @@ -433,9 +433,9 @@ async def handle_action( current_table.append(evaluated_item) elif operation == "update": - evaluated_item = await state.eval_if_expression(item) + evaluated_item = state.eval_if_expression(item) if index is not None: - evaluated_index = await state.eval_if_expression(index) + evaluated_index = state.eval_if_expression(index) idx = int(evaluated_index) if evaluated_index is not None else 0 if 0 <= idx < len(current_table): current_table[idx] = evaluated_item @@ -446,7 +446,7 @@ async def handle_action( current_table[i] = evaluated_item break - await state.set(table_path, current_table) + state.set(table_path, current_table) await ctx.send_message(ActionComplete()) @@ -479,13 +479,13 @@ async def handle_action( if path and value is not None: # Evaluate the value expression - evaluated_value = await state.eval_if_expression(value) + evaluated_value = state.eval_if_expression(value) # Convert to target type if specified if value_type: evaluated_value = self._convert_to_type(evaluated_value, value_type) - await state.set(path, evaluated_value) + state.set(path, evaluated_value) await ctx.send_message(ActionComplete()) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py index 48aeabb58b..f63e3ada50 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py @@ -7,7 +7,7 @@ returns a ConditionResult with the first-matching branch index. Edge conditions then check the branch_index to route to the correct branch. This ensures only one branch executes (first-match semantics), matching the interpreter behavior. -- Foreach: Loop iteration state managed in SharedState + loop edges +- Foreach: Loop iteration state managed in State + loop edges - Goto: Edge to target action (handled by builder) - Break/Continue: Special signals for loop control @@ -30,7 +30,7 @@ LoopIterationResult, ) -# Keys for loop state in SharedState +# Keys for loop state in State LOOP_STATE_KEY = "_declarative_loop_state" # Index value indicating the else/default branch @@ -88,7 +88,7 @@ async def handle_action( elif isinstance(condition_expr, str) and not condition_expr.startswith("="): condition_expr = f"={condition_expr}" - result = await state.eval(condition_expr) + result = state.eval(condition_expr) if bool(result): # First matching condition found await ctx.send_message(ConditionResult(matched=True, branch_index=index, value=result)) @@ -143,7 +143,7 @@ async def handle_action( return # Evaluate the switch value once - switch_value = await state.eval_if_expression(value_expr) + switch_value = state.eval_if_expression(value_expr) # Compare against each case's match value for index, case_item in enumerate(self._cases): @@ -152,7 +152,7 @@ async def handle_action( continue # Evaluate the match value - match_value = await state.eval_if_expression(match_expr) + match_value = state.eval_if_expression(match_expr) if switch_value == match_value: # Found matching case @@ -196,7 +196,7 @@ async def handle_action( """Evaluate the condition and output the result.""" state = await self._ensure_state_initialized(ctx, trigger) - result = await state.eval(self._condition_expr) + result = state.eval(self._condition_expr) is_truthy = bool(result) if is_truthy: @@ -208,7 +208,7 @@ async def handle_action( class ForeachInitExecutor(DeclarativeActionExecutor): """Initializes a foreach loop. - Sets up the loop state in SharedState and determines if there are items. + Sets up the loop state in State and determines if there are items. """ @handler @@ -226,7 +226,7 @@ async def handle_action( items_expr = ( self._action_def.get("itemsSource") or self._action_def.get("items") or self._action_def.get("source") ) - items_raw: Any = await state.eval_if_expression(items_expr) or [] + items_raw: Any = state.eval_if_expression(items_expr) or [] items: list[Any] items = (list(items_raw) if items_raw else []) if not isinstance(items_raw, (list, tuple)) else list(items_raw) # type: ignore @@ -234,14 +234,14 @@ async def handle_action( loop_id = self.id # Store loop state - state_data = await state.get_state_data() + state_data = state.get_state_data() loop_states: dict[str, Any] = cast(dict[str, Any], state_data).setdefault(LOOP_STATE_KEY, {}) loop_states[loop_id] = { "items": items, "index": 0, "length": len(items), } - await state.set_state_data(state_data) + state.set_state_data(state_data) # Check if we have items if items: @@ -263,9 +263,9 @@ async def handle_action( index_name = self._action_def.get("indexName", "index") index_var = f"Local.{index_name}" - await state.set(item_var, items[0]) + state.set(item_var, items[0]) if index_var: - await state.set(index_var, 0) + state.set(index_var, 0) await ctx.send_message(LoopIterationResult(has_next=True, current_item=items[0], current_index=0)) else: @@ -307,7 +307,7 @@ async def handle_action( loop_id = self._init_executor_id # Get loop state - state_data = await state.get_state_data() + state_data = state.get_state_data() loop_states: dict[str, Any] = cast(dict[str, Any], state_data).get(LOOP_STATE_KEY, {}) loop_state = loop_states.get(loop_id) @@ -322,7 +322,7 @@ async def handle_action( if current_index < len(items): # Update loop state loop_state["index"] = current_index - await state.set_state_data(state_data) + state.set_state_data(state_data) # Set the iteration variable # Support multiple schema formats: @@ -342,9 +342,9 @@ async def handle_action( index_name = self._action_def.get("indexName", "index") index_var = f"Local.{index_name}" - await state.set(item_var, items[current_index]) + state.set(item_var, items[current_index]) if index_var: - await state.set(index_var, current_index) + state.set(index_var, current_index) await ctx.send_message( LoopIterationResult(has_next=True, current_item=items[current_index], current_index=current_index) @@ -354,7 +354,7 @@ async def handle_action( loop_states_dict = cast(dict[str, Any], state_data).get(LOOP_STATE_KEY, {}) if loop_id in loop_states_dict: del loop_states_dict[loop_id] - await state.set_state_data(state_data) + state.set_state_data(state_data) await ctx.send_message(LoopIterationResult(has_next=False)) @@ -365,15 +365,15 @@ async def handle_loop_control( ctx: WorkflowContext[LoopIterationResult], ) -> None: """Handle break/continue signals.""" - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) if control.action == "break": # Clean up loop state and signal done - state_data = await state.get_state_data() + state_data = state.get_state_data() loop_states: dict[str, Any] = cast(dict[str, Any], state_data).get(LOOP_STATE_KEY, {}) if self._init_executor_id in loop_states: del loop_states[self._init_executor_id] - await state.set_state_data(state_data) + state.set_state_data(state_data) await ctx.send_message(LoopIterationResult(has_next=False)) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py index c499f133ea..2c3f5c0e91 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py @@ -84,7 +84,7 @@ async def handle_action( allow_free_text = self._action_def.get("allowFreeText", True) # Evaluate the question text if it's an expression - evaluated_question = await state.eval_if_expression(question_text) + evaluated_question = state.eval_if_expression(question_text) # Build choices metadata choices_data: list[dict[str, str]] | None = None @@ -101,8 +101,8 @@ async def handle_action( choices_data.append({"value": str(c), "label": str(c)}) # Store output property in shared state for response handler - await ctx.shared_state.set("_question_output_property", output_property) - await ctx.shared_state.set("_question_default_value", default_value) + ctx.state.set("_question_output_property", output_property) + ctx.state.set("_question_default_value", default_value) # Request external input - workflow pauses here await ctx.request_info( @@ -128,13 +128,13 @@ async def handle_response( ctx: WorkflowContext[ActionComplete], ) -> None: """Handle the user's response to the question.""" - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) output_property = original_request.metadata.get("output_property", "Local.answer") answer = response.value if response.value is not None else response.user_input if output_property: - await state.set(output_property, answer) + state.set(output_property, answer) await ctx.send_message(ActionComplete()) @@ -163,7 +163,7 @@ async def handle_action( default_value = self._action_def.get("defaultValue", False) # Evaluate the message if it's an expression - evaluated_message = await state.eval_if_expression(message) + evaluated_message = state.eval_if_expression(message) # Request confirmation - workflow pauses here await ctx.request_info( @@ -189,7 +189,7 @@ async def handle_response( ctx: WorkflowContext[ActionComplete], ) -> None: """Handle the user's confirmation response.""" - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) output_property = original_request.metadata.get("output_property", "Local.confirmed") @@ -202,7 +202,7 @@ async def handle_response( confirmed = user_input_lower in ("yes", "y", "true", "1", "confirm", "ok") if output_property: - await state.set(output_property, confirmed) + state.set(output_property, confirmed) await ctx.send_message(ActionComplete()) @@ -231,7 +231,7 @@ async def handle_action( # Emit prompt if specified if prompt: - evaluated_prompt = await state.eval_if_expression(prompt) + evaluated_prompt = state.eval_if_expression(prompt) await ctx.yield_output(str(evaluated_prompt)) # Request user input - workflow pauses here @@ -256,12 +256,12 @@ async def handle_response( ctx: WorkflowContext[ActionComplete, str], ) -> None: """Handle the user's input.""" - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) output_property = original_request.metadata.get("output_property", "Local.input") if output_property: - await state.set(output_property, response.user_input) + state.set(output_property, response.user_input) await ctx.send_message(ActionComplete()) @@ -292,7 +292,7 @@ async def handle_action( metadata = self._action_def.get("metadata", {}) # Evaluate the message if it's an expression - evaluated_message = await state.eval_if_expression(message) + evaluated_message = state.eval_if_expression(message) # Build request metadata request_metadata: dict[str, Any] = { @@ -323,14 +323,14 @@ async def handle_response( ctx: WorkflowContext[ActionComplete], ) -> None: """Handle the external input response.""" - state = self._get_state(ctx.shared_state) + state = self._get_state(ctx.state) output_property = original_request.metadata.get("output_property", "Local.externalInput") # Store the response value or user_input result = response.value if response.value is not None else response.user_input if output_property: - await state.set(output_property, result) + state.set(output_property, result) await ctx.send_message(ActionComplete()) diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index 8f9211e850..ad03fc9b97 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -27,35 +27,37 @@ @pytest.fixture -def mock_shared_state() -> MagicMock: - """Create a mock shared state with async get/set/delete methods.""" - shared_state = MagicMock() - shared_state._data = {} +def mock_state() -> MagicMock: + """Create a mock state with sync get/set/delete methods.""" + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key: str) -> Any: - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key: str, default: Any = None) -> Any: + return mock_state._data.get(key, default) - async def mock_set(key: str, value: Any) -> None: - shared_state._data[key] = value + def mock_set(key: str, value: Any) -> None: + mock_state._data[key] = value - async def mock_delete(key: str) -> None: - if key in shared_state._data: - del shared_state._data[key] + def mock_has(key: str) -> bool: + return key in mock_state._data - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - shared_state.delete = AsyncMock(side_effect=mock_delete) + def mock_delete(key: str) -> None: + if key in mock_state._data: + del mock_state._data[key] - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + mock_state.has = MagicMock(side_effect=mock_has) + mock_state.delete = MagicMock(side_effect=mock_delete) + + return mock_state @pytest.fixture -def mock_context(mock_shared_state: MagicMock) -> MagicMock: +def mock_context(mock_state: MagicMock) -> MagicMock: """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() ctx.request_info = AsyncMock() @@ -70,73 +72,73 @@ def mock_context(mock_shared_state: MagicMock) -> MagicMock: class TestDeclarativeWorkflowStateExtended: """Extended tests for DeclarativeWorkflowState covering uncovered code paths.""" - async def test_get_with_local_namespace(self, mock_shared_state): + async def test_get_with_local_namespace(self, mock_state): """Test Local. namespace mapping.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.myVar", "value123") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.myVar", "value123") # Access via Local. namespace - result = await state.get("Local.myVar") + result = state.get("Local.myVar") assert result == "value123" - async def test_get_with_system_namespace(self, mock_shared_state): + async def test_get_with_system_namespace(self, mock_state): """Test System. namespace mapping.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("System.ConversationId", "conv-123") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("System.ConversationId", "conv-123") - result = await state.get("System.ConversationId") + result = state.get("System.ConversationId") assert result == "conv-123" - async def test_get_with_workflow_namespace(self, mock_shared_state): + async def test_get_with_workflow_namespace(self, mock_state): """Test Workflow. namespace mapping.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"query": "test"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"query": "test"}) - result = await state.get("Workflow.Inputs.query") + result = state.get("Workflow.Inputs.query") assert result == "test" - async def test_get_with_inputs_shorthand(self, mock_shared_state): + async def test_get_with_inputs_shorthand(self, mock_state): """Test inputs. shorthand namespace mapping.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"query": "test"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"query": "test"}) - result = await state.get("Workflow.Inputs.query") + result = state.get("Workflow.Inputs.query") assert result == "test" - async def test_get_agent_namespace(self, mock_shared_state): + async def test_get_agent_namespace(self, mock_state): """Test agent namespace access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Agent.response", "Hello!") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Agent.response", "Hello!") - result = await state.get("Agent.response") + result = state.get("Agent.response") assert result == "Hello!" - async def test_get_conversation_namespace(self, mock_shared_state): + async def test_get_conversation_namespace(self, mock_state): """Test conversation namespace access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Conversation.messages", [{"role": "user", "text": "hi"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Conversation.messages", [{"role": "user", "text": "hi"}]) - result = await state.get("Conversation.messages") + result = state.get("Conversation.messages") assert result == [{"role": "user", "text": "hi"}] - async def test_get_custom_namespace(self, mock_shared_state): + async def test_get_custom_namespace(self, mock_state): """Test custom namespace access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Set via direct state data manipulation to create custom namespace - state_data = await state.get_state_data() + state_data = state.get_state_data() state_data["Custom"] = {"myns": {"value": 42}} - await state.set_state_data(state_data) + state.set_state_data(state_data) - result = await state.get("myns.value") + result = state.get("myns.value") assert result == 42 - async def test_get_object_attribute_access(self, mock_shared_state): + async def test_get_object_attribute_access(self, mock_state): """Test accessing object attributes via hasattr/getattr path.""" @dataclass @@ -144,258 +146,258 @@ class MockObj: name: str value: int - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.obj", MockObj(name="test", value=99)) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.obj", MockObj(name="test", value=99)) - result = await state.get("Local.obj.name") + result = state.get("Local.obj.name") assert result == "test" - async def test_set_with_local_namespace(self, mock_shared_state): + async def test_set_with_local_namespace(self, mock_state): """Test Local. namespace mapping for set.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set("Local.myVar", "value123") - result = await state.get("Local.myVar") + state.set("Local.myVar", "value123") + result = state.get("Local.myVar") assert result == "value123" - async def test_set_with_system_namespace(self, mock_shared_state): + async def test_set_with_system_namespace(self, mock_state): """Test System. namespace mapping for set.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set("System.ConversationId", "conv-456") - result = await state.get("System.ConversationId") + state.set("System.ConversationId", "conv-456") + result = state.get("System.ConversationId") assert result == "conv-456" - async def test_set_workflow_outputs(self, mock_shared_state): + async def test_set_workflow_outputs(self, mock_state): """Test setting workflow outputs.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set("Workflow.Outputs.result", "done") - outputs = await state.get("Workflow.Outputs") + state.set("Workflow.Outputs.result", "done") + outputs = state.get("Workflow.Outputs") assert outputs.get("result") == "done" - async def test_set_workflow_inputs_raises_error(self, mock_shared_state): + async def test_set_workflow_inputs_raises_error(self, mock_state): """Test that setting Workflow.Inputs raises an error (read-only).""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"query": "test"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"query": "test"}) with pytest.raises(ValueError, match="Cannot modify Workflow.Inputs"): - await state.set("Workflow.Inputs.query", "modified") + state.set("Workflow.Inputs.query", "modified") - async def test_set_workflow_directly_raises_error(self, mock_shared_state): + async def test_set_workflow_directly_raises_error(self, mock_state): """Test that setting 'Workflow' directly raises an error.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() with pytest.raises(ValueError, match="Cannot set 'Workflow' directly"): - await state.set("Workflow", {}) + state.set("Workflow", {}) - async def test_set_unknown_workflow_subnamespace_raises_error(self, mock_shared_state): + async def test_set_unknown_workflow_subnamespace_raises_error(self, mock_state): """Test unknown workflow sub-namespace raises error.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() with pytest.raises(ValueError, match="Unknown Workflow namespace"): - await state.set("Workflow.unknown.field", "value") + state.set("Workflow.unknown.field", "value") - async def test_set_creates_custom_namespace(self, mock_shared_state): + async def test_set_creates_custom_namespace(self, mock_state): """Test setting value in custom namespace creates it.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set("myns.field.nested", "value") - result = await state.get("myns.field.nested") + state.set("myns.field.nested", "value") + result = state.get("myns.field.nested") assert result == "value" - async def test_set_cannot_replace_entire_namespace(self, mock_shared_state): + async def test_set_cannot_replace_entire_namespace(self, mock_state): """Test that replacing entire namespace raises error.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() with pytest.raises(ValueError, match="Cannot replace entire namespace"): - await state.set("turn", {}) + state.set("turn", {}) - async def test_append_to_nonlist_raises_error(self, mock_shared_state): + async def test_append_to_nonlist_raises_error(self, mock_state): """Test appending to non-list raises error.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.scalar", "string value") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.scalar", "string value") with pytest.raises(ValueError, match="Cannot append to non-list"): - await state.append("Local.scalar", "new item") + state.append("Local.scalar", "new item") - async def test_eval_empty_string(self, mock_shared_state): + async def test_eval_empty_string(self, mock_state): """Test evaluating empty string returns as-is.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - result = await state.eval("") + result = state.eval("") assert result == "" - async def test_eval_non_string_returns_as_is(self, mock_shared_state): + async def test_eval_non_string_returns_as_is(self, mock_state): """Test evaluating non-string returns as-is.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Cast to Any to test the runtime behavior with non-string inputs - result = await state.eval(42) # type: ignore[arg-type] + result = state.eval(42) # type: ignore[arg-type] assert result == 42 - result = await state.eval([1, 2, 3]) # type: ignore[arg-type] + result = state.eval([1, 2, 3]) # type: ignore[arg-type] assert result == [1, 2, 3] - async def test_eval_simple_and_operator(self, mock_shared_state): + async def test_eval_simple_and_operator(self, mock_state): """Test simple And operator evaluation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.a", True) - await state.set("Local.b", False) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.a", True) + state.set("Local.b", False) - result = await state.eval("=Local.a And Local.b") + result = state.eval("=Local.a And Local.b") assert result is False - await state.set("Local.b", True) - result = await state.eval("=Local.a And Local.b") + state.set("Local.b", True) + result = state.eval("=Local.a And Local.b") assert result is True - async def test_eval_simple_or_operator(self, mock_shared_state): + async def test_eval_simple_or_operator(self, mock_state): """Test simple Or operator evaluation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.a", True) - await state.set("Local.b", False) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.a", True) + state.set("Local.b", False) - result = await state.eval("=Local.a Or Local.b") + result = state.eval("=Local.a Or Local.b") assert result is True - await state.set("Local.a", False) - result = await state.eval("=Local.a Or Local.b") + state.set("Local.a", False) + result = state.eval("=Local.a Or Local.b") assert result is False - async def test_eval_negation(self, mock_shared_state): + async def test_eval_negation(self, mock_state): """Test negation (!) evaluation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.flag", True) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.flag", True) - result = await state.eval("=!Local.flag") + result = state.eval("=!Local.flag") assert result is False - async def test_eval_not_function(self, mock_shared_state): + async def test_eval_not_function(self, mock_state): """Test Not() function evaluation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.flag", True) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.flag", True) - result = await state.eval("=Not(Local.flag)") + result = state.eval("=Not(Local.flag)") assert result is False - async def test_eval_comparison_operators(self, mock_shared_state): + async def test_eval_comparison_operators(self, mock_state): """Test comparison operators.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 5) - await state.set("Local.y", 10) - - assert await state.eval("=Local.x < Local.y") is True - assert await state.eval("=Local.x > Local.y") is False - assert await state.eval("=Local.x <= 5") is True - assert await state.eval("=Local.x >= 5") is True - assert await state.eval("=Local.x <> Local.y") is True - assert await state.eval("=Local.x = 5") is True - - async def test_eval_arithmetic_operators(self, mock_shared_state): + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 5) + state.set("Local.y", 10) + + assert state.eval("=Local.x < Local.y") is True + assert state.eval("=Local.x > Local.y") is False + assert state.eval("=Local.x <= 5") is True + assert state.eval("=Local.x >= 5") is True + assert state.eval("=Local.x <> Local.y") is True + assert state.eval("=Local.x = 5") is True + + async def test_eval_arithmetic_operators(self, mock_state): """Test arithmetic operators.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 10) - await state.set("Local.y", 3) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 10) + state.set("Local.y", 3) - assert await state.eval("=Local.x + Local.y") == 13 - assert await state.eval("=Local.x - Local.y") == 7 - assert await state.eval("=Local.x * Local.y") == 30 - assert await state.eval("=Local.x / Local.y") == pytest.approx(3.333, rel=0.01) + assert state.eval("=Local.x + Local.y") == 13 + assert state.eval("=Local.x - Local.y") == 7 + assert state.eval("=Local.x * Local.y") == 30 + assert state.eval("=Local.x / Local.y") == pytest.approx(3.333, rel=0.01) - async def test_eval_string_literal(self, mock_shared_state): + async def test_eval_string_literal(self, mock_state): """Test string literal evaluation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - result = await state.eval('="hello world"') + result = state.eval('="hello world"') assert result == "hello world" - async def test_eval_float_literal(self, mock_shared_state): + async def test_eval_float_literal(self, mock_state): """Test float literal evaluation.""" from decimal import Decimal - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - result = await state.eval("=3.14") + result = state.eval("=3.14") # Accepts both float (Python fallback) and Decimal (pythonnet/PowerFx) assert result == 3.14 or result == Decimal("3.14") - async def test_eval_variable_reference_with_namespace_mappings(self, mock_shared_state): + async def test_eval_variable_reference_with_namespace_mappings(self, mock_state): """Test variable reference with PowerFx symbols.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"query": "test"}) - await state.set("Local.myVar", "localValue") + state = DeclarativeWorkflowState(mock_state) + state.initialize({"query": "test"}) + state.set("Local.myVar", "localValue") # Test Local namespace (PowerFx symbol) - result = await state.eval("=Local.myVar") + result = state.eval("=Local.myVar") assert result == "localValue" # Test Workflow.Inputs (PowerFx symbol) - result = await state.eval("=Workflow.Inputs.query") + result = state.eval("=Workflow.Inputs.query") assert result == "test" - async def test_eval_if_expression_with_dict(self, mock_shared_state): + async def test_eval_if_expression_with_dict(self, mock_state): """Test eval_if_expression recursively evaluates dicts.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.name", "Alice") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.name", "Alice") - result = await state.eval_if_expression({"greeting": "=Local.name", "static": "hello"}) + result = state.eval_if_expression({"greeting": "=Local.name", "static": "hello"}) assert result == {"greeting": "Alice", "static": "hello"} - async def test_eval_if_expression_with_list(self, mock_shared_state): + async def test_eval_if_expression_with_list(self, mock_state): """Test eval_if_expression recursively evaluates lists.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 10) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 10) - result = await state.eval_if_expression(["=Local.x", "static", "=5"]) + result = state.eval_if_expression(["=Local.x", "static", "=5"]) assert result == [10, "static", 5] - async def test_interpolate_string_with_local_vars(self, mock_shared_state): + async def test_interpolate_string_with_local_vars(self, mock_state): """Test string interpolation with Local. variables.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.TicketId", "TKT-001") - await state.set("Local.TeamName", "Support") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.TicketId", "TKT-001") + state.set("Local.TeamName", "Support") - result = await state.interpolate_string("Created ticket #{Local.TicketId} for team {Local.TeamName}") + result = state.interpolate_string("Created ticket #{Local.TicketId} for team {Local.TeamName}") assert result == "Created ticket #TKT-001 for team Support" - async def test_interpolate_string_with_system_vars(self, mock_shared_state): + async def test_interpolate_string_with_system_vars(self, mock_state): """Test string interpolation with System. variables.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("System.ConversationId", "conv-789") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("System.ConversationId", "conv-789") - result = await state.interpolate_string("Conversation: {System.ConversationId}") + result = state.interpolate_string("Conversation: {System.ConversationId}") assert result == "Conversation: conv-789" - async def test_interpolate_string_with_none_value(self, mock_shared_state): + async def test_interpolate_string_with_none_value(self, mock_state): """Test string interpolation with None value returns empty string.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - result = await state.interpolate_string("Value: {Local.Missing}") + result = state.interpolate_string("Value: {Local.Missing}") assert result == "Value: " @@ -407,14 +409,14 @@ async def test_interpolate_string_with_none_value(self, mock_shared_state): class TestBasicExecutorsCoverage: """Tests for basic executors covering uncovered code paths.""" - async def test_set_variable_executor(self, mock_context, mock_shared_state): + async def test_set_variable_executor(self, mock_context, mock_state): """Test SetVariableExecutor (distinct from SetValueExecutor).""" from agent_framework_declarative._workflows._executors_basic import ( SetVariableExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "SetVariable", @@ -424,17 +426,17 @@ async def test_set_variable_executor(self, mock_context, mock_shared_state): executor = SetVariableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.result") + result = state.get("Local.result") assert result == "test value" - async def test_set_variable_executor_with_nested_variable(self, mock_context, mock_shared_state): + async def test_set_variable_executor_with_nested_variable(self, mock_context, mock_state): """Test SetVariableExecutor with nested variable object.""" from agent_framework_declarative._workflows._executors_basic import ( SetVariableExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "SetVariable", @@ -444,18 +446,18 @@ async def test_set_variable_executor_with_nested_variable(self, mock_context, mo executor = SetVariableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.nested") + result = state.get("Local.nested") assert result == 42 - async def test_set_text_variable_executor(self, mock_context, mock_shared_state): + async def test_set_text_variable_executor(self, mock_context, mock_state): """Test SetTextVariableExecutor.""" from agent_framework_declarative._workflows._executors_basic import ( SetTextVariableExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.name", "World") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.name", "World") action_def = { "kind": "SetTextVariable", @@ -465,17 +467,17 @@ async def test_set_text_variable_executor(self, mock_context, mock_shared_state) executor = SetTextVariableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.greeting") + result = state.get("Local.greeting") assert result == "World" - async def test_set_multiple_variables_executor(self, mock_context, mock_shared_state): + async def test_set_multiple_variables_executor(self, mock_context, mock_state): """Test SetMultipleVariablesExecutor.""" from agent_framework_declarative._workflows._executors_basic import ( SetMultipleVariablesExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "SetMultipleVariables", @@ -488,19 +490,19 @@ async def test_set_multiple_variables_executor(self, mock_context, mock_shared_s executor = SetMultipleVariablesExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - assert await state.get("Local.a") == 1 - assert await state.get("Local.b") == 2 - assert await state.get("Local.c") == 3 + assert state.get("Local.a") == 1 + assert state.get("Local.b") == 2 + assert state.get("Local.c") == 3 - async def test_append_value_executor(self, mock_context, mock_shared_state): + async def test_append_value_executor(self, mock_context, mock_state): """Test AppendValueExecutor.""" from agent_framework_declarative._workflows._executors_basic import ( AppendValueExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a"]) action_def = { "kind": "AppendValue", @@ -510,18 +512,18 @@ async def test_append_value_executor(self, mock_context, mock_shared_state): executor = AppendValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == ["a", "b"] - async def test_reset_variable_executor(self, mock_context, mock_shared_state): + async def test_reset_variable_executor(self, mock_context, mock_state): """Test ResetVariableExecutor.""" from agent_framework_declarative._workflows._executors_basic import ( ResetVariableExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.myVar", "some value") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.myVar", "some value") action_def = { "kind": "ResetVariable", @@ -530,37 +532,37 @@ async def test_reset_variable_executor(self, mock_context, mock_shared_state): executor = ResetVariableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.myVar") + result = state.get("Local.myVar") assert result is None - async def test_clear_all_variables_executor(self, mock_context, mock_shared_state): + async def test_clear_all_variables_executor(self, mock_context, mock_state): """Test ClearAllVariablesExecutor.""" from agent_framework_declarative._workflows._executors_basic import ( ClearAllVariablesExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.a", 1) - await state.set("Local.b", 2) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.a", 1) + state.set("Local.b", 2) action_def = {"kind": "ClearAllVariables"} executor = ClearAllVariablesExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) # Turn namespace should be cleared - assert await state.get("Local.a") is None - assert await state.get("Local.b") is None + assert state.get("Local.a") is None + assert state.get("Local.b") is None - async def test_send_activity_with_dict_activity(self, mock_context, mock_shared_state): + async def test_send_activity_with_dict_activity(self, mock_context, mock_state): """Test SendActivityExecutor with dict activity containing text field.""" from agent_framework_declarative._workflows._executors_basic import ( SendActivityExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.name", "Alice") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.name", "Alice") action_def = { "kind": "SendActivity", @@ -571,14 +573,14 @@ async def test_send_activity_with_dict_activity(self, mock_context, mock_shared_ mock_context.yield_output.assert_called_once_with("Hello, Alice!") - async def test_send_activity_with_string_activity(self, mock_context, mock_shared_state): + async def test_send_activity_with_string_activity(self, mock_context, mock_state): """Test SendActivityExecutor with string activity.""" from agent_framework_declarative._workflows._executors_basic import ( SendActivityExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "SendActivity", @@ -589,15 +591,15 @@ async def test_send_activity_with_string_activity(self, mock_context, mock_share mock_context.yield_output.assert_called_once_with("Plain text message") - async def test_send_activity_with_expression(self, mock_context, mock_shared_state): + async def test_send_activity_with_expression(self, mock_context, mock_state): """Test SendActivityExecutor evaluates expressions.""" from agent_framework_declarative._workflows._executors_basic import ( SendActivityExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.msg", "Dynamic message") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.msg", "Dynamic message") action_def = { "kind": "SendActivity", @@ -608,14 +610,14 @@ async def test_send_activity_with_expression(self, mock_context, mock_shared_sta mock_context.yield_output.assert_called_once_with("Dynamic message") - async def test_emit_event_executor_graph_mode(self, mock_context, mock_shared_state): + async def test_emit_event_executor_graph_mode(self, mock_context, mock_state): """Test EmitEventExecutor with graph-mode schema (eventName/eventValue).""" from agent_framework_declarative._workflows._executors_basic import ( EmitEventExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "EmitEvent", @@ -630,14 +632,14 @@ async def test_emit_event_executor_graph_mode(self, mock_context, mock_shared_st assert event_data["eventName"] == "myEvent" assert event_data["eventValue"] == {"key": "value"} - async def test_emit_event_executor_interpreter_mode(self, mock_context, mock_shared_state): + async def test_emit_event_executor_interpreter_mode(self, mock_context, mock_state): """Test EmitEventExecutor with interpreter-mode schema (event.name/event.data).""" from agent_framework_declarative._workflows._executors_basic import ( EmitEventExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "EmitEvent", @@ -684,7 +686,7 @@ async def test_normalize_variable_path_all_cases(self): # No namespace - default to Local. assert _normalize_variable_path("simpleVar") == "Local.simpleVar" - async def test_agent_executor_get_agent_name_string(self, mock_context, mock_shared_state): + async def test_agent_executor_get_agent_name_string(self, mock_context, mock_state): """Test agent name extraction from simple string config.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -696,13 +698,13 @@ async def test_agent_executor_get_agent_name_string(self, mock_context, mock_sha } executor = InvokeAzureAgentExecutor(action_def) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() name = executor._get_agent_name(state) assert name == "MyAgent" - async def test_agent_executor_get_agent_name_dict(self, mock_context, mock_shared_state): + async def test_agent_executor_get_agent_name_dict(self, mock_context, mock_state): """Test agent name extraction from nested dict config.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -714,13 +716,13 @@ async def test_agent_executor_get_agent_name_dict(self, mock_context, mock_share } executor = InvokeAzureAgentExecutor(action_def) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() name = executor._get_agent_name(state) assert name == "NestedAgent" - async def test_agent_executor_get_agent_name_legacy(self, mock_context, mock_shared_state): + async def test_agent_executor_get_agent_name_legacy(self, mock_context, mock_state): """Test agent name extraction from agentName (legacy).""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -732,13 +734,13 @@ async def test_agent_executor_get_agent_name_legacy(self, mock_context, mock_sha } executor = InvokeAzureAgentExecutor(action_def) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() name = executor._get_agent_name(state) assert name == "LegacyAgent" - async def test_agent_executor_get_input_config_simple(self, mock_context, mock_shared_state): + async def test_agent_executor_get_input_config_simple(self, mock_context, mock_state): """Test input config parsing with simple non-dict input.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -757,7 +759,7 @@ async def test_agent_executor_get_input_config_simple(self, mock_context, mock_s assert external_loop is None assert max_iterations == 100 # Default - async def test_agent_executor_get_input_config_full(self, mock_context, mock_shared_state): + async def test_agent_executor_get_input_config_full(self, mock_context, mock_state): """Test input config parsing with full structured input.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -780,7 +782,7 @@ async def test_agent_executor_get_input_config_full(self, mock_context, mock_sha assert external_loop == "=Local.needsMore" assert max_iterations == 50 - async def test_agent_executor_get_output_config_simple(self, mock_context, mock_shared_state): + async def test_agent_executor_get_output_config_simple(self, mock_context, mock_state): """Test output config parsing with simple resultProperty.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -799,7 +801,7 @@ async def test_agent_executor_get_output_config_simple(self, mock_context, mock_ assert result_prop == "Local.result" assert auto_send is True - async def test_agent_executor_get_output_config_full(self, mock_context, mock_shared_state): + async def test_agent_executor_get_output_config_full(self, mock_context, mock_state): """Test output config parsing with full structured output.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -823,15 +825,15 @@ async def test_agent_executor_get_output_config_full(self, mock_context, mock_sh assert result_prop == "Local.result" assert auto_send is False - async def test_agent_executor_build_input_text_from_string_messages(self, mock_context, mock_shared_state): + async def test_agent_executor_build_input_text_from_string_messages(self, mock_context, mock_state): """Test _build_input_text with string messages expression.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.userInput", "Hello agent!") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.userInput", "Hello agent!") action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} executor = InvokeAzureAgentExecutor(action_def) @@ -839,15 +841,15 @@ async def test_agent_executor_build_input_text_from_string_messages(self, mock_c input_text = await executor._build_input_text(state, {}, "=Local.userInput") assert input_text == "Hello agent!" - async def test_agent_executor_build_input_text_from_message_list(self, mock_context, mock_shared_state): + async def test_agent_executor_build_input_text_from_message_list(self, mock_context, mock_state): """Test _build_input_text extracts text from message list.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set( + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set( "Conversation.messages", [ {"role": "user", "content": "First"}, @@ -862,15 +864,15 @@ async def test_agent_executor_build_input_text_from_message_list(self, mock_cont input_text = await executor._build_input_text(state, {}, "=Conversation.messages") assert input_text == "Last message" - async def test_agent_executor_build_input_text_from_message_with_text_attr(self, mock_context, mock_shared_state): + async def test_agent_executor_build_input_text_from_message_with_text_attr(self, mock_context, mock_state): """Test _build_input_text extracts text from message with text attribute.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.messages", [{"text": "From attribute"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.messages", [{"text": "From attribute"}]) action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} executor = InvokeAzureAgentExecutor(action_def) @@ -878,14 +880,14 @@ async def test_agent_executor_build_input_text_from_message_with_text_attr(self, input_text = await executor._build_input_text(state, {}, "=Local.messages") assert input_text == "From attribute" - async def test_agent_executor_build_input_text_fallback_chain(self, mock_context, mock_shared_state): + async def test_agent_executor_build_input_text_fallback_chain(self, mock_context, mock_state): """Test _build_input_text fallback chain when no messages expression.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"query": "workflow input"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"query": "workflow input"}) action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} executor = InvokeAzureAgentExecutor(action_def) @@ -894,15 +896,15 @@ async def test_agent_executor_build_input_text_fallback_chain(self, mock_context input_text = await executor._build_input_text(state, {}, None) assert input_text == "workflow input" - async def test_agent_executor_build_input_text_from_system_last_message(self, mock_context, mock_shared_state): + async def test_agent_executor_build_input_text_from_system_last_message(self, mock_context, mock_state): """Test _build_input_text falls back to system.LastMessage.Text.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("System.LastMessage", {"Text": "From last message"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("System.LastMessage", {"Text": "From last message"}) action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} executor = InvokeAzureAgentExecutor(action_def) @@ -910,14 +912,14 @@ async def test_agent_executor_build_input_text_from_system_last_message(self, mo input_text = await executor._build_input_text(state, {}, None) assert input_text == "From last message" - async def test_agent_executor_missing_agent_name(self, mock_context, mock_shared_state): + async def test_agent_executor_missing_agent_name(self, mock_context, mock_state): """Test agent executor with missing agent name logs warning.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "InvokeAzureAgent"} # No agent specified executor = InvokeAzureAgentExecutor(action_def) @@ -929,7 +931,7 @@ async def test_agent_executor_missing_agent_name(self, mock_context, mock_shared msg = mock_context.send_message.call_args[0][0] assert isinstance(msg, ActionComplete) - async def test_agent_executor_with_working_agent(self, mock_context, mock_shared_state): + async def test_agent_executor_with_working_agent(self, mock_context, mock_state): """Test agent executor with a working mock agent.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -944,9 +946,9 @@ class MockResult: mock_agent = MagicMock() mock_agent.run = AsyncMock(return_value=MockResult(text="Agent response", messages=[])) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.input", "User query") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.input", "User query") action_def = { "kind": "InvokeAzureAgent", @@ -961,15 +963,15 @@ class MockResult: mock_agent.run.assert_called_once() # Verify result was stored - result = await state.get("Local.result") + result = state.get("Local.result") assert result == "Agent response" # Verify agent state was set - assert await state.get("Agent.response") == "Agent response" - assert await state.get("Agent.name") == "TestAgent" - assert await state.get("Agent.text") == "Agent response" + assert state.get("Agent.response") == "Agent response" + assert state.get("Agent.name") == "TestAgent" + assert state.get("Agent.text") == "Agent response" - async def test_agent_executor_with_agent_from_registry(self, mock_context, mock_shared_state): + async def test_agent_executor_with_agent_from_registry(self, mock_context, mock_state): """Test agent executor retrieves agent from shared state registry.""" from agent_framework_declarative._workflows._executors_agents import ( AGENT_REGISTRY_KEY, @@ -986,11 +988,11 @@ class MockResult: mock_agent.run = AsyncMock(return_value=MockResult(text="Registry agent", messages=[])) # Store in registry - mock_shared_state._data[AGENT_REGISTRY_KEY] = {"RegistryAgent": mock_agent} + mock_state._data[AGENT_REGISTRY_KEY] = {"RegistryAgent": mock_agent} - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.input", "Query") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.input", "Query") action_def = { "kind": "InvokeAzureAgent", @@ -1002,7 +1004,7 @@ class MockResult: mock_agent.run.assert_called_once() - async def test_agent_executor_parses_json_response(self, mock_context, mock_shared_state): + async def test_agent_executor_parses_json_response(self, mock_context, mock_state): """Test agent executor parses JSON response into responseObject.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -1016,9 +1018,9 @@ class MockResult: mock_agent = MagicMock() mock_agent.run = AsyncMock(return_value=MockResult(text='{"status": "ok", "count": 42}', messages=[])) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.input", "Query") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.input", "Query") action_def = { "kind": "InvokeAzureAgent", @@ -1031,17 +1033,17 @@ class MockResult: await executor.handle_action(ActionTrigger(), mock_context) - parsed = await state.get("Local.Parsed") + parsed = state.get("Local.Parsed") assert parsed == {"status": "ok", "count": 42} - async def test_invoke_tool_executor_not_found(self, mock_context, mock_shared_state): + async def test_invoke_tool_executor_not_found(self, mock_context, mock_state): """Test InvokeToolExecutor when tool not found.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeToolExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "InvokeTool", @@ -1052,10 +1054,10 @@ async def test_invoke_tool_executor_not_found(self, mock_context, mock_shared_st await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.result") + result = state.get("Local.result") assert result == {"error": "Tool 'MissingTool' not found in registry"} - async def test_invoke_tool_executor_sync_tool(self, mock_context, mock_shared_state): + async def test_invoke_tool_executor_sync_tool(self, mock_context, mock_state): """Test InvokeToolExecutor with synchronous tool.""" from agent_framework_declarative._workflows._executors_agents import ( TOOL_REGISTRY_KEY, @@ -1065,10 +1067,10 @@ async def test_invoke_tool_executor_sync_tool(self, mock_context, mock_shared_st def my_tool(x: int, y: int) -> int: return x + y - mock_shared_state._data[TOOL_REGISTRY_KEY] = {"add": my_tool} + mock_state._data[TOOL_REGISTRY_KEY] = {"add": my_tool} - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "InvokeTool", @@ -1080,10 +1082,10 @@ def my_tool(x: int, y: int) -> int: await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.result") + result = state.get("Local.result") assert result == 8 - async def test_invoke_tool_executor_async_tool(self, mock_context, mock_shared_state): + async def test_invoke_tool_executor_async_tool(self, mock_context, mock_state): """Test InvokeToolExecutor with asynchronous tool.""" from agent_framework_declarative._workflows._executors_agents import ( TOOL_REGISTRY_KEY, @@ -1093,10 +1095,10 @@ async def test_invoke_tool_executor_async_tool(self, mock_context, mock_shared_s async def my_async_tool(input: str) -> str: return f"Processed: {input}" - mock_shared_state._data[TOOL_REGISTRY_KEY] = {"process": my_async_tool} + mock_state._data[TOOL_REGISTRY_KEY] = {"process": my_async_tool} - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "InvokeTool", @@ -1108,7 +1110,7 @@ async def my_async_tool(input: str) -> str: await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.result") + result = state.get("Local.result") assert result == "Processed: test data" @@ -1120,15 +1122,15 @@ async def my_async_tool(input: str) -> str: class TestControlFlowCoverage: """Tests for control flow executors covering uncovered code paths.""" - async def test_foreach_with_source_alias(self, mock_context, mock_shared_state): + async def test_foreach_with_source_alias(self, mock_context, mock_state): """Test ForeachInitExecutor with 'source' alias (interpreter mode).""" from agent_framework_declarative._workflows._executors_control_flow import ( ForeachInitExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.data", [10, 20, 30]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.data", [10, 20, 30]) action_def = { "kind": "Foreach", @@ -1146,19 +1148,19 @@ async def test_foreach_with_source_alias(self, mock_context, mock_shared_state): assert msg.current_item == 10 assert msg.current_index == 0 - async def test_foreach_next_continues_iteration(self, mock_context, mock_shared_state): + async def test_foreach_next_continues_iteration(self, mock_context, mock_state): """Test ForeachNextExecutor continues to next item.""" from agent_framework_declarative._workflows._executors_control_flow import ( LOOP_STATE_KEY, ForeachNextExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.data", ["a", "b", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.data", ["a", "b", "c"]) # Set up loop state as ForeachInitExecutor would - state_data = await state.get_state_data() + state_data = state.get_state_data() state_data[LOOP_STATE_KEY] = { "foreach_init": { "items": ["a", "b", "c"], @@ -1166,7 +1168,7 @@ async def test_foreach_next_continues_iteration(self, mock_context, mock_shared_ "length": 3, } } - await state.set_state_data(state_data) + state.set_state_data(state_data) action_def = { "kind": "Foreach", @@ -1182,15 +1184,15 @@ async def test_foreach_next_continues_iteration(self, mock_context, mock_shared_ assert msg.current_index == 1 assert msg.current_item == "b" - async def test_switch_evaluator_with_value_cases(self, mock_context, mock_shared_state): + async def test_switch_evaluator_with_value_cases(self, mock_context, mock_state): """Test SwitchEvaluatorExecutor with value/cases schema.""" from agent_framework_declarative._workflows._executors_control_flow import ( SwitchEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.status", "pending") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.status", "pending") action_def = { "kind": "Switch", @@ -1209,15 +1211,15 @@ async def test_switch_evaluator_with_value_cases(self, mock_context, mock_shared assert msg.matched is True assert msg.branch_index == 1 # Second case matched - async def test_switch_evaluator_default_case(self, mock_context, mock_shared_state): + async def test_switch_evaluator_default_case(self, mock_context, mock_state): """Test SwitchEvaluatorExecutor falls through to default.""" from agent_framework_declarative._workflows._executors_control_flow import ( SwitchEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.status", "unknown") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.status", "unknown") action_def = { "kind": "Switch", @@ -1236,14 +1238,14 @@ async def test_switch_evaluator_default_case(self, mock_context, mock_shared_sta assert msg.matched is False assert msg.branch_index == -1 # Default case - async def test_switch_evaluator_no_value(self, mock_context, mock_shared_state): + async def test_switch_evaluator_no_value(self, mock_context, mock_state): """Test SwitchEvaluatorExecutor with no value defaults to else.""" from agent_framework_declarative._workflows._executors_control_flow import ( SwitchEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "Switch"} # No value cases = [{"match": "x"}] @@ -1255,14 +1257,14 @@ async def test_switch_evaluator_no_value(self, mock_context, mock_shared_state): assert isinstance(msg, ConditionResult) assert msg.branch_index == -1 - async def test_join_executor_accepts_condition_result(self, mock_context, mock_shared_state): + async def test_join_executor_accepts_condition_result(self, mock_context, mock_state): """Test JoinExecutor accepts ConditionResult as trigger.""" from agent_framework_declarative._workflows._executors_control_flow import ( JoinExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "_Join"} executor = JoinExecutor(action_def) @@ -1273,14 +1275,14 @@ async def test_join_executor_accepts_condition_result(self, mock_context, mock_s msg = mock_context.send_message.call_args[0][0] assert isinstance(msg, ActionComplete) - async def test_break_loop_executor(self, mock_context, mock_shared_state): + async def test_break_loop_executor(self, mock_context, mock_state): """Test BreakLoopExecutor emits LoopControl.""" from agent_framework_declarative._workflows._executors_control_flow import ( BreakLoopExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "BreakLoop"} executor = BreakLoopExecutor(action_def, loop_next_executor_id="loop_next") @@ -1291,14 +1293,14 @@ async def test_break_loop_executor(self, mock_context, mock_shared_state): assert isinstance(msg, LoopControl) assert msg.action == "break" - async def test_continue_loop_executor(self, mock_context, mock_shared_state): + async def test_continue_loop_executor(self, mock_context, mock_state): """Test ContinueLoopExecutor emits LoopControl.""" from agent_framework_declarative._workflows._executors_control_flow import ( ContinueLoopExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "ContinueLoop"} executor = ContinueLoopExecutor(action_def, loop_next_executor_id="loop_next") @@ -1309,14 +1311,14 @@ async def test_continue_loop_executor(self, mock_context, mock_shared_state): assert isinstance(msg, LoopControl) assert msg.action == "continue" - async def test_foreach_next_no_loop_state(self, mock_context, mock_shared_state): + async def test_foreach_next_no_loop_state(self, mock_context, mock_state): """Test ForeachNextExecutor with missing loop state.""" from agent_framework_declarative._workflows._executors_control_flow import ( ForeachNextExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "Foreach", @@ -1331,18 +1333,18 @@ async def test_foreach_next_no_loop_state(self, mock_context, mock_shared_state) assert isinstance(msg, LoopIterationResult) assert msg.has_next is False - async def test_foreach_next_loop_complete(self, mock_context, mock_shared_state): + async def test_foreach_next_loop_complete(self, mock_context, mock_state): """Test ForeachNextExecutor when loop is complete.""" from agent_framework_declarative._workflows._executors_control_flow import ( LOOP_STATE_KEY, ForeachNextExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Set up loop state at last item - state_data = await state.get_state_data() + state_data = state.get_state_data() state_data[LOOP_STATE_KEY] = { "loop_id": { "items": ["a", "b"], @@ -1350,7 +1352,7 @@ async def test_foreach_next_loop_complete(self, mock_context, mock_shared_state) "length": 2, } } - await state.set_state_data(state_data) + state.set_state_data(state_data) action_def = { "kind": "Foreach", @@ -1365,18 +1367,18 @@ async def test_foreach_next_loop_complete(self, mock_context, mock_shared_state) assert isinstance(msg, LoopIterationResult) assert msg.has_next is False - async def test_foreach_next_handle_break_control(self, mock_context, mock_shared_state): + async def test_foreach_next_handle_break_control(self, mock_context, mock_state): """Test ForeachNextExecutor handles break LoopControl.""" from agent_framework_declarative._workflows._executors_control_flow import ( LOOP_STATE_KEY, ForeachNextExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Set up loop state - state_data = await state.get_state_data() + state_data = state.get_state_data() state_data[LOOP_STATE_KEY] = { "loop_id": { "items": ["a", "b", "c"], @@ -1384,7 +1386,7 @@ async def test_foreach_next_handle_break_control(self, mock_context, mock_shared "length": 3, } } - await state.set_state_data(state_data) + state.set_state_data(state_data) action_def = { "kind": "Foreach", @@ -1399,18 +1401,18 @@ async def test_foreach_next_handle_break_control(self, mock_context, mock_shared assert isinstance(msg, LoopIterationResult) assert msg.has_next is False - async def test_foreach_next_handle_continue_control(self, mock_context, mock_shared_state): + async def test_foreach_next_handle_continue_control(self, mock_context, mock_state): """Test ForeachNextExecutor handles continue LoopControl.""" from agent_framework_declarative._workflows._executors_control_flow import ( LOOP_STATE_KEY, ForeachNextExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Set up loop state - state_data = await state.get_state_data() + state_data = state.get_state_data() state_data[LOOP_STATE_KEY] = { "loop_id": { "items": ["a", "b", "c"], @@ -1418,7 +1420,7 @@ async def test_foreach_next_handle_continue_control(self, mock_context, mock_sha "length": 3, } } - await state.set_state_data(state_data) + state.set_state_data(state_data) action_def = { "kind": "Foreach", @@ -1434,14 +1436,14 @@ async def test_foreach_next_handle_continue_control(self, mock_context, mock_sha assert msg.has_next is True assert msg.current_index == 1 - async def test_end_workflow_executor(self, mock_context, mock_shared_state): + async def test_end_workflow_executor(self, mock_context, mock_state): """Test EndWorkflowExecutor does not send continuation.""" from agent_framework_declarative._workflows._executors_control_flow import ( EndWorkflowExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "EndWorkflow"} executor = EndWorkflowExecutor(action_def) @@ -1451,14 +1453,14 @@ async def test_end_workflow_executor(self, mock_context, mock_shared_state): # Should NOT send any message mock_context.send_message.assert_not_called() - async def test_end_conversation_executor(self, mock_context, mock_shared_state): + async def test_end_conversation_executor(self, mock_context, mock_state): """Test EndConversationExecutor does not send continuation.""" from agent_framework_declarative._workflows._executors_control_flow import ( EndConversationExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "EndConversation"} executor = EndConversationExecutor(action_def) @@ -1468,15 +1470,15 @@ async def test_end_conversation_executor(self, mock_context, mock_shared_state): # Should NOT send any message mock_context.send_message.assert_not_called() - async def test_condition_group_evaluator_first_match(self, mock_context, mock_shared_state): + async def test_condition_group_evaluator_first_match(self, mock_context, mock_state): """Test ConditionGroupEvaluatorExecutor returns first match.""" from agent_framework_declarative._workflows._executors_control_flow import ( ConditionGroupEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 10) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 10) action_def = {"kind": "ConditionGroup"} conditions = [ @@ -1493,15 +1495,15 @@ async def test_condition_group_evaluator_first_match(self, mock_context, mock_sh assert msg.matched is True assert msg.branch_index == 1 # Second condition (x > 5) is first match - async def test_condition_group_evaluator_no_match(self, mock_context, mock_shared_state): + async def test_condition_group_evaluator_no_match(self, mock_context, mock_state): """Test ConditionGroupEvaluatorExecutor with no matches.""" from agent_framework_declarative._workflows._executors_control_flow import ( ConditionGroupEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 0) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 0) action_def = {"kind": "ConditionGroup"} conditions = [ @@ -1517,14 +1519,14 @@ async def test_condition_group_evaluator_no_match(self, mock_context, mock_share assert msg.matched is False assert msg.branch_index == -1 - async def test_condition_group_evaluator_boolean_true_condition(self, mock_context, mock_shared_state): + async def test_condition_group_evaluator_boolean_true_condition(self, mock_context, mock_state): """Test ConditionGroupEvaluatorExecutor with boolean True condition.""" from agent_framework_declarative._workflows._executors_control_flow import ( ConditionGroupEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = {"kind": "ConditionGroup"} conditions = [ @@ -1540,15 +1542,15 @@ async def test_condition_group_evaluator_boolean_true_condition(self, mock_conte assert msg.matched is True assert msg.branch_index == 1 - async def test_if_condition_evaluator_true(self, mock_context, mock_shared_state): + async def test_if_condition_evaluator_true(self, mock_context, mock_state): """Test IfConditionEvaluatorExecutor with true condition.""" from agent_framework_declarative._workflows._executors_control_flow import ( IfConditionEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.flag", True) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.flag", True) action_def = {"kind": "If"} executor = IfConditionEvaluatorExecutor(action_def, condition_expr="=Local.flag") @@ -1560,15 +1562,15 @@ async def test_if_condition_evaluator_true(self, mock_context, mock_shared_state assert msg.matched is True assert msg.branch_index == 0 # Then branch - async def test_if_condition_evaluator_false(self, mock_context, mock_shared_state): + async def test_if_condition_evaluator_false(self, mock_context, mock_state): """Test IfConditionEvaluatorExecutor with false condition.""" from agent_framework_declarative._workflows._executors_control_flow import ( IfConditionEvaluatorExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.flag", False) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.flag", False) action_def = {"kind": "If"} executor = IfConditionEvaluatorExecutor(action_def, condition_expr="=Local.flag") @@ -1589,7 +1591,7 @@ async def test_if_condition_evaluator_false(self, mock_context, mock_shared_stat class TestDeclarativeActionExecutorBase: """Tests for DeclarativeActionExecutor base class.""" - async def test_ensure_state_initialized_with_dict_input(self, mock_context, mock_shared_state): + async def test_ensure_state_initialized_with_dict_input(self, mock_context, mock_state): """Test _ensure_state_initialized with dict input.""" from agent_framework_declarative._workflows._executors_basic import ( SetValueExecutor, @@ -1602,11 +1604,11 @@ async def test_ensure_state_initialized_with_dict_input(self, mock_context, mock await executor.handle_action({"custom": "input"}, mock_context) # State should have been initialized with the dict - state = DeclarativeWorkflowState(mock_shared_state) - inputs = await state.get("Workflow.Inputs") + state = DeclarativeWorkflowState(mock_state) + inputs = state.get("Workflow.Inputs") assert inputs == {"custom": "input"} - async def test_ensure_state_initialized_with_string_input(self, mock_context, mock_shared_state): + async def test_ensure_state_initialized_with_string_input(self, mock_context, mock_state): """Test _ensure_state_initialized with string input.""" from agent_framework_declarative._workflows._executors_basic import ( SetValueExecutor, @@ -1618,11 +1620,11 @@ async def test_ensure_state_initialized_with_string_input(self, mock_context, mo # Trigger with string - should wrap in {"input": ...} await executor.handle_action("string trigger", mock_context) - state = DeclarativeWorkflowState(mock_shared_state) - inputs = await state.get("Workflow.Inputs") + state = DeclarativeWorkflowState(mock_state) + inputs = state.get("Workflow.Inputs") assert inputs == {"input": "string trigger"} - async def test_ensure_state_initialized_with_custom_object(self, mock_context, mock_shared_state): + async def test_ensure_state_initialized_with_custom_object(self, mock_context, mock_state): """Test _ensure_state_initialized with custom object converts to string.""" from agent_framework_declarative._workflows._executors_basic import ( SetValueExecutor, @@ -1637,11 +1639,11 @@ def __str__(self): await executor.handle_action(CustomObj(), mock_context) - state = DeclarativeWorkflowState(mock_shared_state) - inputs = await state.get("Workflow.Inputs") + state = DeclarativeWorkflowState(mock_state) + inputs = state.get("Workflow.Inputs") assert inputs == {"input": "custom string"} - async def test_executor_display_name_property(self, mock_context, mock_shared_state): + async def test_executor_display_name_property(self, mock_context, mock_state): """Test executor display_name property.""" from agent_framework_declarative._workflows._executors_basic import ( SetValueExecutor, @@ -1657,7 +1659,7 @@ async def test_executor_display_name_property(self, mock_context, mock_shared_st assert executor.display_name == "My Custom Action" - async def test_executor_action_def_property(self, mock_context, mock_shared_state): + async def test_executor_action_def_property(self, mock_context, mock_state): """Test executor action_def property.""" from agent_framework_declarative._workflows._executors_basic import ( SetValueExecutor, @@ -1677,15 +1679,15 @@ async def test_executor_action_def_property(self, mock_context, mock_shared_stat class TestHumanInputExecutorsCoverage: """Tests for human input executors covering uncovered code paths.""" - async def test_wait_for_input_executor_with_prompt(self, mock_context, mock_shared_state): + async def test_wait_for_input_executor_with_prompt(self, mock_context, mock_state): """Test WaitForInputExecutor with prompt.""" from agent_framework_declarative._workflows._executors_external_input import ( ExternalInputRequest, WaitForInputExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "WaitForInput", @@ -1706,15 +1708,15 @@ async def test_wait_for_input_executor_with_prompt(self, mock_context, mock_shar assert isinstance(request, ExternalInputRequest) assert request.request_type == "user_input" - async def test_wait_for_input_executor_no_prompt(self, mock_context, mock_shared_state): + async def test_wait_for_input_executor_no_prompt(self, mock_context, mock_state): """Test WaitForInputExecutor without prompt.""" from agent_framework_declarative._workflows._executors_external_input import ( ExternalInputRequest, WaitForInputExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "WaitForInput", @@ -1731,15 +1733,15 @@ async def test_wait_for_input_executor_no_prompt(self, mock_context, mock_shared assert isinstance(request, ExternalInputRequest) assert request.request_type == "user_input" - async def test_request_external_input_executor(self, mock_context, mock_shared_state): + async def test_request_external_input_executor(self, mock_context, mock_state): """Test RequestExternalInputExecutor.""" from agent_framework_declarative._workflows._executors_external_input import ( ExternalInputRequest, RequestExternalInputExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "RequestExternalInput", @@ -1763,15 +1765,15 @@ async def test_request_external_input_executor(self, mock_context, mock_shared_s assert request.metadata["required_fields"] == ["approver", "notes"] assert request.metadata["timeout_seconds"] == 3600 - async def test_question_executor_with_choices(self, mock_context, mock_shared_state): + async def test_question_executor_with_choices(self, mock_context, mock_state): """Test QuestionExecutor with choices as dicts and strings.""" from agent_framework_declarative._workflows._executors_external_input import ( ExternalInputRequest, QuestionExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "Question", @@ -1808,7 +1810,7 @@ async def test_question_executor_with_choices(self, mock_context, mock_shared_st class TestAgentExternalLoopCoverage: """Tests for agent executor external loop handling.""" - async def test_agent_executor_with_external_loop(self, mock_context, mock_shared_state): + async def test_agent_executor_with_external_loop(self, mock_context, mock_state): """Test agent executor with external loop that triggers.""" from unittest.mock import patch @@ -1819,10 +1821,10 @@ async def test_agent_executor_with_external_loop(self, mock_context, mock_shared mock_agent = MagicMock() - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.input", "User query") - await state.set("Local.needsMore", True) # Loop condition will be true + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.input", "User query") + state.set("Local.needsMore", True) # Loop condition will be true action_def = { "kind": "InvokeAzureAgent", @@ -1848,7 +1850,7 @@ async def test_agent_executor_with_external_loop(self, mock_context, mock_shared assert isinstance(request, AgentExternalInputRequest) assert request.agent_name == "TestAgent" - async def test_agent_executor_agent_error_handling(self, mock_context, mock_shared_state): + async def test_agent_executor_agent_error_handling(self, mock_context, mock_state): """Test agent executor raises AgentInvocationError on failure.""" from agent_framework_declarative._workflows._executors_agents import ( AgentInvocationError, @@ -1858,9 +1860,9 @@ async def test_agent_executor_agent_error_handling(self, mock_context, mock_shar mock_agent = MagicMock() mock_agent.run = AsyncMock(side_effect=RuntimeError("Agent failed")) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.input", "Query") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.input", "Query") action_def = { "kind": "InvokeAzureAgent", @@ -1876,12 +1878,12 @@ async def test_agent_executor_agent_error_handling(self, mock_context, mock_shar assert "Agent failed" in str(exc_info.value) # Should still store error in state before raising - error = await state.get("Agent.error") + error = state.get("Agent.error") assert "Agent failed" in error - result = await state.get("Local.result") + result = state.get("Local.result") assert result == {"error": "Agent failed"} - async def test_agent_executor_string_result(self, mock_context, mock_shared_state): + async def test_agent_executor_string_result(self, mock_context, mock_state): """Test agent executor with agent that returns string directly.""" from agent_framework_declarative._workflows._executors_agents import ( InvokeAzureAgentExecutor, @@ -1890,9 +1892,9 @@ async def test_agent_executor_string_result(self, mock_context, mock_shared_stat mock_agent = MagicMock() mock_agent.run = AsyncMock(return_value="Direct string response") - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.input", "Query") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.input", "Query") action_def = { "kind": "InvokeAzureAgent", @@ -1906,10 +1908,10 @@ async def test_agent_executor_string_result(self, mock_context, mock_shared_stat # Should auto-send output mock_context.yield_output.assert_called_with("Direct string response") - result = await state.get("Local.result") + result = state.get("Local.result") assert result == "Direct string response" - async def test_invoke_tool_with_error(self, mock_context, mock_shared_state): + async def test_invoke_tool_with_error(self, mock_context, mock_state): """Test InvokeToolExecutor handles tool errors.""" from agent_framework_declarative._workflows._executors_agents import ( TOOL_REGISTRY_KEY, @@ -1919,10 +1921,10 @@ async def test_invoke_tool_with_error(self, mock_context, mock_shared_state): def failing_tool(**kwargs): raise ValueError("Tool error") - mock_shared_state._data[TOOL_REGISTRY_KEY] = {"bad_tool": failing_tool} + mock_state._data[TOOL_REGISTRY_KEY] = {"bad_tool": failing_tool} - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "InvokeTool", @@ -1933,7 +1935,7 @@ def failing_tool(**kwargs): await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.result") + result = state.get("Local.result") assert result == {"error": "Tool error"} @@ -1945,51 +1947,51 @@ def failing_tool(**kwargs): class TestPowerFxFunctionsCoverage: """Tests for PowerFx function evaluation coverage.""" - async def test_eval_lower_upper_functions(self, mock_shared_state): + async def test_eval_lower_upper_functions(self, mock_state): """Test Lower and Upper functions.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.text", "Hello World") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.text", "Hello World") - result = await state.eval("=Lower(Local.text)") + result = state.eval("=Lower(Local.text)") assert result == "hello world" - result = await state.eval("=Upper(Local.text)") + result = state.eval("=Upper(Local.text)") assert result == "HELLO WORLD" - async def test_eval_if_function(self, mock_shared_state): + async def test_eval_if_function(self, mock_state): """Test If function.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.flag", True) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.flag", True) - result = await state.eval('=If(Local.flag, "yes", "no")') + result = state.eval('=If(Local.flag, "yes", "no")') assert result == "yes" - await state.set("Local.flag", False) - result = await state.eval('=If(Local.flag, "yes", "no")') + state.set("Local.flag", False) + result = state.eval('=If(Local.flag, "yes", "no")') assert result == "no" - async def test_eval_not_function(self, mock_shared_state): + async def test_eval_not_function(self, mock_state): """Test Not function.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.flag", True) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.flag", True) - result = await state.eval("=Not(Local.flag)") + result = state.eval("=Not(Local.flag)") assert result is False - async def test_eval_and_or_functions(self, mock_shared_state): + async def test_eval_and_or_functions(self, mock_state): """Test And and Or functions.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.a", True) - await state.set("Local.b", False) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.a", True) + state.set("Local.b", False) - result = await state.eval("=And(Local.a, Local.b)") + result = state.eval("=And(Local.a, Local.b)") assert result is False - result = await state.eval("=Or(Local.a, Local.b)") + result = state.eval("=Or(Local.a, Local.b)") assert result is True @@ -2325,7 +2327,7 @@ def test_get_branch_exit_none(self): class TestAgentExecutorExternalLoop: """Tests for InvokeAzureAgentExecutor external loop response handling.""" - async def test_handle_external_input_response_no_state(self, mock_context, mock_shared_state): + async def test_handle_external_input_response_no_state(self, mock_context, mock_state): """Test handling external input response when loop state not found.""" from agent_framework_declarative._workflows._executors_agents import ( AgentExternalInputRequest, @@ -2335,7 +2337,7 @@ async def test_handle_external_input_response_no_state(self, mock_context, mock_ executor = InvokeAzureAgentExecutor({"kind": "InvokeAzureAgent", "agent": "TestAgent"}) - # No external loop state in shared_state + # No external loop state in mock_state original_request = AgentExternalInputRequest( request_id="req-1", agent_name="TestAgent", @@ -2353,7 +2355,7 @@ async def test_handle_external_input_response_no_state(self, mock_context, mock_ assert isinstance(call_args, ActionComplete) - async def test_handle_external_input_response_agent_not_found(self, mock_context, mock_shared_state): + async def test_handle_external_input_response_agent_not_found(self, mock_context, mock_state): """Test handling external input raises error when agent not found during resumption.""" from agent_framework_declarative._workflows._executors_agents import ( EXTERNAL_LOOP_STATE_KEY, @@ -2375,11 +2377,11 @@ async def test_handle_external_input_response_agent_not_found(self, mock_context auto_send=True, messages_path="Conversation.messages", ) - mock_shared_state._data[EXTERNAL_LOOP_STATE_KEY] = loop_state + mock_state._data[EXTERNAL_LOOP_STATE_KEY] = loop_state # Initialize declarative state with simple value - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() executor = InvokeAzureAgentExecutor({"kind": "InvokeAzureAgent", "agent": "NonExistentAgent"}) @@ -2598,85 +2600,85 @@ def test_validation_in_foreach_body(self): class TestExpressionEdgeCases: """Tests for expression evaluation edge cases.""" - async def test_division_with_valid_values(self, mock_shared_state): + async def test_division_with_valid_values(self, mock_state): """Test normal division works correctly.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 10) - await state.set("Local.y", 4) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 10) + state.set("Local.y", 4) - result = await state.eval("=Local.x / Local.y") + result = state.eval("=Local.x / Local.y") assert result == 2.5 - async def test_multiplication_normal(self, mock_shared_state): + async def test_multiplication_normal(self, mock_state): """Test normal multiplication.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.x", 6) - await state.set("Local.y", 7) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.x", 6) + state.set("Local.y", 7) - result = await state.eval("=Local.x * Local.y") + result = state.eval("=Local.x * Local.y") assert result == 42 class TestLongMessageTextHandling: """Tests for handling long MessageText results that exceed PowerFx limits.""" - async def test_short_message_text_embedded_inline(self, mock_shared_state): + async def test_short_message_text_embedded_inline(self, mock_state): """Test that short MessageText results are embedded inline.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Store a short message short_text = "Hello world" - await state.set("Local.Messages", [{"text": short_text, "contents": [{"type": "text", "text": short_text}]}]) + state.set("Local.Messages", [{"text": short_text, "contents": [{"type": "text", "text": short_text}]}]) # Evaluate a formula with MessageText - should embed inline - result = await state.eval("=Upper(MessageText(Local.Messages))") + result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "HELLO WORLD" # No temp variable should be created for short strings - temp_var = await state.get("Local._TempMessageText0") + temp_var = state.get("Local._TempMessageText0") assert temp_var is None - async def test_long_message_text_stored_in_temp_variable(self, mock_shared_state): + async def test_long_message_text_stored_in_temp_variable(self, mock_state): """Test that long MessageText results are stored in temp variables.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Create a message longer than 500 characters long_text = "A" * 600 # 600 characters exceeds the 500 char threshold - await state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) + state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) # Evaluate a formula with MessageText - result = await state.eval("=Upper(MessageText(Local.Messages))") + result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 # Upper on 'A' is still 'A' # A temp variable should have been created - temp_var = await state.get("Local._TempMessageText0") + temp_var = state.get("Local._TempMessageText0") assert temp_var == long_text - async def test_find_with_long_message_text(self, mock_shared_state): + async def test_find_with_long_message_text(self, mock_state): """Test Find function works with long MessageText stored in temp variable.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Create a long message with a keyword to find long_text = "X" * 550 + "CONGRATULATIONS" + "Y" * 50 - await state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) + state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) # Test the pattern used in student_teacher workflow - result = await state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))') + result = state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))') assert result is True - async def test_find_without_keyword_in_long_text(self, mock_shared_state): + async def test_find_without_keyword_in_long_text(self, mock_state): """Test Find returns blank when keyword not found in long text.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Long text without the keyword long_text = "X" * 600 - await state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) + state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) - result = await state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))') + result = state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))') assert result is False diff --git a/python/packages/declarative/tests/test_graph_executors.py b/python/packages/declarative/tests/test_graph_executors.py index e03895b4ac..0a4433b095 100644 --- a/python/packages/declarative/tests/test_graph_executors.py +++ b/python/packages/declarative/tests/test_graph_executors.py @@ -24,33 +24,31 @@ class TestDeclarativeWorkflowState: """Tests for DeclarativeWorkflowState.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state with async get/set methods.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_initialize_state(self, mock_shared_state): + async def test_initialize_state(self, mock_state): """Test initializing the workflow state.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"query": "test"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"query": "test"}) # Verify state was set - mock_shared_state.set.assert_called_once() - call_args = mock_shared_state.set.call_args + mock_state.set.assert_called_once() + call_args = mock_state.set.call_args assert call_args[0][0] == DECLARATIVE_STATE_KEY state_data = call_args[0][1] assert state_data["Inputs"] == {"query": "test"} @@ -58,71 +56,71 @@ async def test_initialize_state(self, mock_shared_state): assert state_data["Local"] == {} @pytest.mark.asyncio - async def test_get_and_set_values(self, mock_shared_state): + async def test_get_and_set_values(self, mock_state): """Test getting and setting values.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Set a turn value - await state.set("Local.counter", 5) + state.set("Local.counter", 5) # Get the value - result = await state.get("Local.counter") + result = state.get("Local.counter") assert result == 5 @pytest.mark.asyncio - async def test_get_inputs(self, mock_shared_state): + async def test_get_inputs(self, mock_state): """Test getting workflow inputs.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"name": "Alice", "age": 30}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"name": "Alice", "age": 30}) # Get via path - name = await state.get("Workflow.Inputs.name") + name = state.get("Workflow.Inputs.name") assert name == "Alice" # Get all inputs - inputs = await state.get("Workflow.Inputs") + inputs = state.get("Workflow.Inputs") assert inputs == {"name": "Alice", "age": 30} @pytest.mark.asyncio - async def test_append_value(self, mock_shared_state): + async def test_append_value(self, mock_state): """Test appending values to a list.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Append to non-existent list creates it - await state.append("Local.items", "first") - result = await state.get("Local.items") + state.append("Local.items", "first") + result = state.get("Local.items") assert result == ["first"] # Append to existing list - await state.append("Local.items", "second") - result = await state.get("Local.items") + state.append("Local.items", "second") + result = state.get("Local.items") assert result == ["first", "second"] @pytest.mark.asyncio - async def test_eval_expression(self, mock_shared_state): + async def test_eval_expression(self, mock_state): """Test evaluating expressions.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Non-expression returns as-is - result = await state.eval("plain text") + result = state.eval("plain text") assert result == "plain text" # Boolean literals - result = await state.eval("=true") + result = state.eval("=true") assert result is True - result = await state.eval("=false") + result = state.eval("=false") assert result is False # String literals - result = await state.eval('="hello"') + result = state.eval('="hello"') assert result == "hello" # Numeric literals - result = await state.eval("=42") + result = state.eval("=42") assert result == 42 @@ -130,39 +128,37 @@ class TestDeclarativeActionExecutor: """Tests for DeclarativeActionExecutor subclasses.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_set_value_executor(self, mock_context, mock_shared_state): + async def test_set_value_executor(self, mock_context, mock_state): """Test SetValueExecutor.""" # Initialize state - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "SetValue", @@ -180,10 +176,10 @@ async def test_set_value_executor(self, mock_context, mock_shared_state): assert isinstance(message, ActionComplete) @pytest.mark.asyncio - async def test_send_activity_executor(self, mock_context, mock_shared_state): + async def test_send_activity_executor(self, mock_context, mock_state): """Test SendActivityExecutor.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "SendActivity", @@ -199,11 +195,11 @@ async def test_send_activity_executor(self, mock_context, mock_shared_state): # Note: ConditionEvaluatorExecutor tests removed - conditions are now evaluated on edges - async def test_foreach_init_with_items(self, mock_context, mock_shared_state): + async def test_foreach_init_with_items(self, mock_context, mock_state): """Test ForeachInitExecutor with items.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "b", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "b", "c"]) action_def = { "kind": "Foreach", @@ -224,10 +220,10 @@ async def test_foreach_init_with_items(self, mock_context, mock_shared_state): assert message.current_item == "a" @pytest.mark.asyncio - async def test_foreach_init_empty(self, mock_context, mock_shared_state): + async def test_foreach_init_empty(self, mock_context, mock_state): """Test ForeachInitExecutor with empty items list.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Use a literal empty list - no expression evaluation needed action_def = { @@ -390,43 +386,41 @@ class TestAgentExecutors: """Tests for agent-related executors.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_invoke_agent_not_found(self, mock_context, mock_shared_state): + async def test_invoke_agent_not_found(self, mock_context, mock_state): """Test InvokeAzureAgentExecutor raises error when agent not found.""" from agent_framework_declarative._workflows import ( AgentInvocationError, InvokeAzureAgentExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "InvokeAzureAgent", @@ -447,44 +441,42 @@ class TestHumanInputExecutors: """Tests for human input executors.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() ctx.request_info = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_question_executor(self, mock_context, mock_shared_state): + async def test_question_executor(self, mock_context, mock_state): """Test QuestionExecutor.""" from agent_framework_declarative._workflows import ( ExternalInputRequest, QuestionExecutor, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "Question", @@ -505,15 +497,15 @@ async def test_question_executor(self, mock_context, mock_shared_state): assert "What is your name?" in request.message @pytest.mark.asyncio - async def test_confirmation_executor(self, mock_context, mock_shared_state): + async def test_confirmation_executor(self, mock_context, mock_state): """Test ConfirmationExecutor.""" from agent_framework_declarative._workflows import ( ConfirmationExecutor, ExternalInputRequest, ) - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "Confirmation", @@ -539,41 +531,39 @@ class TestParseValueExecutor: """Tests for the ParseValue action executor.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_parse_value_string(self, mock_context, mock_shared_state): + async def test_parse_value_string(self, mock_context, mock_state): """Test ParseValue with string type.""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", "hello world") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", "hello world") action_def = { "kind": "ParseValue", @@ -584,17 +574,17 @@ async def test_parse_value_string(self, mock_context, mock_shared_state): executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result == "hello world" @pytest.mark.asyncio - async def test_parse_value_number(self, mock_context, mock_shared_state): + async def test_parse_value_number(self, mock_context, mock_state): """Test ParseValue with number type.""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", "123") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", "123") action_def = { "kind": "ParseValue", @@ -605,17 +595,17 @@ async def test_parse_value_number(self, mock_context, mock_shared_state): executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result == 123 @pytest.mark.asyncio - async def test_parse_value_float(self, mock_context, mock_shared_state): + async def test_parse_value_float(self, mock_context, mock_state): """Test ParseValue with float number.""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", "3.14") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", "3.14") action_def = { "kind": "ParseValue", @@ -626,17 +616,17 @@ async def test_parse_value_float(self, mock_context, mock_shared_state): executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result == 3.14 @pytest.mark.asyncio - async def test_parse_value_boolean_true(self, mock_context, mock_shared_state): + async def test_parse_value_boolean_true(self, mock_context, mock_state): """Test ParseValue with boolean type (true).""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", "true") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", "true") action_def = { "kind": "ParseValue", @@ -647,17 +637,17 @@ async def test_parse_value_boolean_true(self, mock_context, mock_shared_state): executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result is True @pytest.mark.asyncio - async def test_parse_value_boolean_false(self, mock_context, mock_shared_state): + async def test_parse_value_boolean_false(self, mock_context, mock_state): """Test ParseValue with boolean type (false).""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", "no") + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", "no") action_def = { "kind": "ParseValue", @@ -668,17 +658,17 @@ async def test_parse_value_boolean_false(self, mock_context, mock_shared_state): executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result is False @pytest.mark.asyncio - async def test_parse_value_object_from_json(self, mock_context, mock_shared_state): + async def test_parse_value_object_from_json(self, mock_context, mock_state): """Test ParseValue with object type from JSON string.""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", '{"name": "Alice", "age": 30}') + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", '{"name": "Alice", "age": 30}') action_def = { "kind": "ParseValue", @@ -689,17 +679,17 @@ async def test_parse_value_object_from_json(self, mock_context, mock_shared_stat executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result == {"name": "Alice", "age": 30} @pytest.mark.asyncio - async def test_parse_value_array_from_json(self, mock_context, mock_shared_state): + async def test_parse_value_array_from_json(self, mock_context, mock_state): """Test ParseValue with array type from JSON string.""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", '["a", "b", "c"]') + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", '["a", "b", "c"]') action_def = { "kind": "ParseValue", @@ -710,17 +700,17 @@ async def test_parse_value_array_from_json(self, mock_context, mock_shared_state executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result == ["a", "b", "c"] @pytest.mark.asyncio - async def test_parse_value_no_type_conversion(self, mock_context, mock_shared_state): + async def test_parse_value_no_type_conversion(self, mock_context, mock_state): """Test ParseValue without type conversion.""" from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.rawValue", {"status": "active"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.rawValue", {"status": "active"}) action_def = { "kind": "ParseValue", @@ -730,7 +720,7 @@ async def test_parse_value_no_type_conversion(self, mock_context, mock_shared_st executor = ParseValueExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.parsedValue") + result = state.get("Local.parsedValue") assert result == {"status": "active"} @@ -738,41 +728,39 @@ class TestEditTableExecutor: """Tests for the EditTable action executor.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_edit_table_add(self, mock_context, mock_shared_state): + async def test_edit_table_add(self, mock_context, mock_state): """Test EditTable with add operation.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "b"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "b"]) action_def = { "kind": "EditTable", @@ -783,17 +771,17 @@ async def test_edit_table_add(self, mock_context, mock_shared_state): executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == ["a", "b", "c"] @pytest.mark.asyncio - async def test_edit_table_insert_at_index(self, mock_context, mock_shared_state): + async def test_edit_table_insert_at_index(self, mock_context, mock_state): """Test EditTable with insert at specific index.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "c"]) action_def = { "kind": "EditTable", @@ -805,17 +793,17 @@ async def test_edit_table_insert_at_index(self, mock_context, mock_shared_state) executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == ["a", "b", "c"] @pytest.mark.asyncio - async def test_edit_table_remove_by_value(self, mock_context, mock_shared_state): + async def test_edit_table_remove_by_value(self, mock_context, mock_state): """Test EditTable with remove by value.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "b", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "b", "c"]) action_def = { "kind": "EditTable", @@ -826,17 +814,17 @@ async def test_edit_table_remove_by_value(self, mock_context, mock_shared_state) executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == ["a", "c"] @pytest.mark.asyncio - async def test_edit_table_remove_by_index(self, mock_context, mock_shared_state): + async def test_edit_table_remove_by_index(self, mock_context, mock_state): """Test EditTable with remove by index.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "b", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "b", "c"]) action_def = { "kind": "EditTable", @@ -847,17 +835,17 @@ async def test_edit_table_remove_by_index(self, mock_context, mock_shared_state) executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == ["a", "c"] @pytest.mark.asyncio - async def test_edit_table_clear(self, mock_context, mock_shared_state): + async def test_edit_table_clear(self, mock_context, mock_state): """Test EditTable with clear operation.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "b", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "b", "c"]) action_def = { "kind": "EditTable", @@ -867,17 +855,17 @@ async def test_edit_table_clear(self, mock_context, mock_shared_state): executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == [] @pytest.mark.asyncio - async def test_edit_table_update_at_index(self, mock_context, mock_shared_state): + async def test_edit_table_update_at_index(self, mock_context, mock_state): """Test EditTable with update at index.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.items", ["a", "b", "c"]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.items", ["a", "b", "c"]) action_def = { "kind": "EditTable", @@ -889,16 +877,16 @@ async def test_edit_table_update_at_index(self, mock_context, mock_shared_state) executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.items") + result = state.get("Local.items") assert result == ["a", "B", "c"] @pytest.mark.asyncio - async def test_edit_table_creates_new_list(self, mock_context, mock_shared_state): + async def test_edit_table_creates_new_list(self, mock_context, mock_state): """Test EditTable creates new list if not exists.""" from agent_framework_declarative._workflows._executors_basic import EditTableExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "EditTable", @@ -909,7 +897,7 @@ async def test_edit_table_creates_new_list(self, mock_context, mock_shared_state executor = EditTableExecutor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.newItems") + result = state.get("Local.newItems") assert result == ["first"] @@ -917,41 +905,39 @@ class TestEditTableV2Executor: """Tests for the EditTableV2 action executor.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_edit_table_v2_add(self, mock_context, mock_shared_state): + async def test_edit_table_v2_add(self, mock_context, mock_state): """Test EditTableV2 with add operation.""" from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.records", [{"id": 1, "name": "Alice"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.records", [{"id": 1, "name": "Alice"}]) action_def = { "kind": "EditTableV2", @@ -962,17 +948,17 @@ async def test_edit_table_v2_add(self, mock_context, mock_shared_state): executor = EditTableV2Executor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.records") + result = state.get("Local.records") assert result == [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] @pytest.mark.asyncio - async def test_edit_table_v2_add_or_update_new(self, mock_context, mock_shared_state): + async def test_edit_table_v2_add_or_update_new(self, mock_context, mock_state): """Test EditTableV2 with addOrUpdate - adding new record.""" from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.records", [{"id": 1, "name": "Alice"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.records", [{"id": 1, "name": "Alice"}]) action_def = { "kind": "EditTableV2", @@ -984,17 +970,17 @@ async def test_edit_table_v2_add_or_update_new(self, mock_context, mock_shared_s executor = EditTableV2Executor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.records") + result = state.get("Local.records") assert result == [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] @pytest.mark.asyncio - async def test_edit_table_v2_add_or_update_existing(self, mock_context, mock_shared_state): + async def test_edit_table_v2_add_or_update_existing(self, mock_context, mock_state): """Test EditTableV2 with addOrUpdate - updating existing record.""" from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.records", [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.records", [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]) action_def = { "kind": "EditTableV2", @@ -1006,17 +992,17 @@ async def test_edit_table_v2_add_or_update_existing(self, mock_context, mock_sha executor = EditTableV2Executor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.records") + result = state.get("Local.records") assert result == [{"id": 1, "name": "Alice Updated"}, {"id": 2, "name": "Bob"}] @pytest.mark.asyncio - async def test_edit_table_v2_remove_by_key(self, mock_context, mock_shared_state): + async def test_edit_table_v2_remove_by_key(self, mock_context, mock_state): """Test EditTableV2 with remove by key.""" from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.records", [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.records", [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]) action_def = { "kind": "EditTableV2", @@ -1028,17 +1014,17 @@ async def test_edit_table_v2_remove_by_key(self, mock_context, mock_shared_state executor = EditTableV2Executor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.records") + result = state.get("Local.records") assert result == [{"id": 2, "name": "Bob"}] @pytest.mark.asyncio - async def test_edit_table_v2_clear(self, mock_context, mock_shared_state): + async def test_edit_table_v2_clear(self, mock_context, mock_state): """Test EditTableV2 with clear operation.""" from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.records", [{"id": 1}, {"id": 2}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.records", [{"id": 1}, {"id": 2}]) action_def = { "kind": "EditTableV2", @@ -1048,17 +1034,17 @@ async def test_edit_table_v2_clear(self, mock_context, mock_shared_state): executor = EditTableV2Executor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.records") + result = state.get("Local.records") assert result == [] @pytest.mark.asyncio - async def test_edit_table_v2_update_by_key(self, mock_context, mock_shared_state): + async def test_edit_table_v2_update_by_key(self, mock_context, mock_state): """Test EditTableV2 with update by key.""" from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() - await state.set("Local.records", [{"id": 1, "status": "pending"}, {"id": 2, "status": "pending"}]) + state = DeclarativeWorkflowState(mock_state) + state.initialize() + state.set("Local.records", [{"id": 1, "status": "pending"}, {"id": 2, "status": "pending"}]) action_def = { "kind": "EditTableV2", @@ -1070,7 +1056,7 @@ async def test_edit_table_v2_update_by_key(self, mock_context, mock_shared_state executor = EditTableV2Executor(action_def) await executor.handle_action(ActionTrigger(), mock_context) - result = await state.get("Local.records") + result = state.get("Local.records") assert result == [{"id": 1, "status": "complete"}, {"id": 2, "status": "pending"}] @@ -1078,40 +1064,38 @@ class TestCancelDialogExecutors: """Tests for CancelDialog and CancelAllDialogs executors.""" @pytest.fixture - def mock_context(self, mock_shared_state): + def mock_context(self, mock_state): """Create a mock workflow context.""" ctx = MagicMock() - ctx.shared_state = mock_shared_state + ctx.state = mock_state ctx.send_message = AsyncMock() ctx.yield_output = AsyncMock() return ctx @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) - return shared_state + return mock_state @pytest.mark.asyncio - async def test_cancel_dialog_executor(self, mock_context, mock_shared_state): + async def test_cancel_dialog_executor(self, mock_context, mock_state): """Test CancelDialogExecutor completes without error.""" from agent_framework_declarative._workflows._executors_control_flow import CancelDialogExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "CancelDialog", @@ -1123,12 +1107,12 @@ async def test_cancel_dialog_executor(self, mock_context, mock_shared_state): # No assertions needed - just verify it doesn't raise @pytest.mark.asyncio - async def test_cancel_all_dialogs_executor(self, mock_context, mock_shared_state): + async def test_cancel_all_dialogs_executor(self, mock_context, mock_state): """Test CancelAllDialogsExecutor completes without error.""" from agent_framework_declarative._workflows._executors_control_flow import CancelAllDialogsExecutor - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() action_def = { "kind": "CancelAllDialogs", diff --git a/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py b/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py index 91cf378578..8f0cd39d31 100644 --- a/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py +++ b/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py @@ -16,7 +16,7 @@ - String interpolation: {Variable.Path} """ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -29,123 +29,125 @@ class TestPowerFxBuiltinFunctions: """Test PowerFx built-in functions used in YAML workflows.""" @pytest.fixture - def mock_shared_state(self): - """Create a mock shared state with async get/set methods.""" - shared_state = MagicMock() - shared_state._data = {} + def mock_state(self): + """Create a mock state with sync get/set methods.""" + state = MagicMock() + state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + def mock_has(key): + return key in state._data - async def test_concat_simple(self, mock_shared_state): + state.get = MagicMock(side_effect=mock_get) + state.set = MagicMock(side_effect=mock_set) + state.has = MagicMock(side_effect=mock_has) + return state + + async def test_concat_simple(self, mock_state): """Test Concat function with simple strings.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Concat("Nice to meet you, ", Local.userName, "!") - await state.set("Local.userName", "Alice") - result = await state.eval('=Concat("Nice to meet you, ", Local.userName, "!")') + state.set("Local.userName", "Alice") + result = state.eval('=Concat("Nice to meet you, ", Local.userName, "!")') assert result == "Nice to meet you, Alice!" - async def test_concat_multiple_args(self, mock_shared_state): + async def test_concat_multiple_args(self, mock_state): """Test Concat with multiple arguments.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Concat(Local.greeting, ", ", Local.name, "!") - await state.set("Local.greeting", "Hello") - await state.set("Local.name", "World") - result = await state.eval('=Concat(Local.greeting, ", ", Local.name, "!")') + state.set("Local.greeting", "Hello") + state.set("Local.name", "World") + result = state.eval('=Concat(Local.greeting, ", ", Local.name, "!")') assert result == "Hello, World!" - async def test_concat_with_local_namespace(self, mock_shared_state): + async def test_concat_with_local_namespace(self, mock_state): """Test Concat using Local.* namespace (maps to Local.*).""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Concat("Starting math coaching session for: ", Local.Problem) - await state.set("Local.Problem", "2 + 2") - result = await state.eval('=Concat("Starting math coaching session for: ", Local.Problem)') + state.set("Local.Problem", "2 + 2") + result = state.eval('=Concat("Starting math coaching session for: ", Local.Problem)') assert result == "Starting math coaching session for: 2 + 2" - async def test_if_with_isblank(self, mock_shared_state): + async def test_if_with_isblank(self, mock_state): """Test If function with IsBlank.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"name": ""}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"name": ""}) # From YAML: =If(IsBlank(inputs.name), "World", inputs.name) # When input is blank - result = await state.eval('=If(IsBlank(Workflow.Inputs.name), "World", Workflow.Inputs.name)') + result = state.eval('=If(IsBlank(Workflow.Inputs.name), "World", Workflow.Inputs.name)') assert result == "World" # When input is provided - await state.initialize({"name": "Alice"}) - result = await state.eval('=If(IsBlank(Workflow.Inputs.name), "World", Workflow.Inputs.name)') + state.initialize({"name": "Alice"}) + result = state.eval('=If(IsBlank(Workflow.Inputs.name), "World", Workflow.Inputs.name)') assert result == "Alice" - async def test_not_function(self, mock_shared_state): + async def test_not_function(self, mock_state): """Test Not function.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Not(Local.EscalationParameters.IsComplete) - await state.set("Local.EscalationParameters", {"IsComplete": False}) - result = await state.eval("=Not(Local.EscalationParameters.IsComplete)") + state.set("Local.EscalationParameters", {"IsComplete": False}) + result = state.eval("=Not(Local.EscalationParameters.IsComplete)") assert result is True - await state.set("Local.EscalationParameters", {"IsComplete": True}) - result = await state.eval("=Not(Local.EscalationParameters.IsComplete)") + state.set("Local.EscalationParameters", {"IsComplete": True}) + result = state.eval("=Not(Local.EscalationParameters.IsComplete)") assert result is False - async def test_or_function(self, mock_shared_state): + async def test_or_function(self, mock_state): """Test Or function.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Or(Local.feeling = "great", Local.feeling = "good") - await state.set("Local.feeling", "great") - result = await state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') + state.set("Local.feeling", "great") + result = state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') assert result is True - await state.set("Local.feeling", "good") - result = await state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') + state.set("Local.feeling", "good") + result = state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') assert result is True - await state.set("Local.feeling", "bad") - result = await state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') + state.set("Local.feeling", "bad") + result = state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') assert result is False - async def test_upper_function(self, mock_shared_state): + async def test_upper_function(self, mock_state): """Test Upper function.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Upper(System.LastMessage.Text) - await state.set("System.LastMessage", {"Text": "hello world"}) - result = await state.eval("=Upper(System.LastMessage.Text)") + state.set("System.LastMessage", {"Text": "hello world"}) + result = state.eval("=Upper(System.LastMessage.Text)") assert result == "HELLO WORLD" - async def test_find_function(self, mock_shared_state): + async def test_find_function(self, mock_state): """Test Find function.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =!IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))) - await state.set("Local.TeacherResponse", "CONGRATULATIONS! You solved it!") - result = await state.eval('=Not(IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))))') + state.set("Local.TeacherResponse", "CONGRATULATIONS! You solved it!") + result = state.eval('=Not(IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))))') assert result is True - await state.set("Local.TeacherResponse", "Try again") - result = await state.eval('=Not(IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))))') + state.set("Local.TeacherResponse", "Try again") + result = state.eval('=Not(IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))))') assert result is False @@ -153,55 +155,53 @@ class TestPowerFxSystemVariables: """Test System.* variable access.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_system_conversation_id(self, mock_shared_state): + async def test_system_conversation_id(self, mock_state): """Test System.ConversationId access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: conversationId: =System.ConversationId - await state.set("System.ConversationId", "conv-12345") - result = await state.eval("=System.ConversationId") + state.set("System.ConversationId", "conv-12345") + result = state.eval("=System.ConversationId") assert result == "conv-12345" - async def test_system_last_message_text(self, mock_shared_state): + async def test_system_last_message_text(self, mock_state): """Test System.LastMessage.Text access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Upper(System.LastMessage.Text) <> "EXIT" - await state.set("System.LastMessage", {"Text": "Hello"}) - result = await state.eval("=System.LastMessage.Text") + state.set("System.LastMessage", {"Text": "Hello"}) + result = state.eval("=System.LastMessage.Text") assert result == "Hello" - async def test_system_last_message_exit_check(self, mock_shared_state): + async def test_system_last_message_exit_check(self, mock_state): """Test the exit check pattern from YAML.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: when: =Upper(System.LastMessage.Text) <> "EXIT" - await state.set("System.LastMessage", {"Text": "hello"}) - result = await state.eval('=Upper(System.LastMessage.Text) <> "EXIT"') + state.set("System.LastMessage", {"Text": "hello"}) + result = state.eval('=Upper(System.LastMessage.Text) <> "EXIT"') assert result is True - await state.set("System.LastMessage", {"Text": "exit"}) - result = await state.eval('=Upper(System.LastMessage.Text) <> "EXIT"') + state.set("System.LastMessage", {"Text": "exit"}) + result = state.eval('=Upper(System.LastMessage.Text) <> "EXIT"') assert result is False @@ -209,99 +209,95 @@ class TestPowerFxComparisonOperators: """Test comparison operators used in YAML workflows.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_less_than(self, mock_shared_state): + async def test_less_than(self, mock_state): """Test < operator.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: condition: =Local.age < 65 - await state.set("Local.age", 30) - assert await state.eval("=Local.age < 65") is True + state.set("Local.age", 30) + assert state.eval("=Local.age < 65") is True - await state.set("Local.age", 70) - assert await state.eval("=Local.age < 65") is False + state.set("Local.age", 70) + assert state.eval("=Local.age < 65") is False - async def test_less_than_with_local(self, mock_shared_state): + async def test_less_than_with_local(self, mock_state): """Test < with Local namespace.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: condition: =Local.TurnCount < 4 - await state.set("Local.TurnCount", 2) - assert await state.eval("=Local.TurnCount < 4") is True + state.set("Local.TurnCount", 2) + assert state.eval("=Local.TurnCount < 4") is True - await state.set("Local.TurnCount", 5) - assert await state.eval("=Local.TurnCount < 4") is False + state.set("Local.TurnCount", 5) + assert state.eval("=Local.TurnCount < 4") is False - async def test_equality(self, mock_shared_state): + async def test_equality(self, mock_state): """Test = equality operator.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Local.feeling = "great" - await state.set("Local.feeling", "great") - assert await state.eval('=Local.feeling = "great"') is True + state.set("Local.feeling", "great") + assert state.eval('=Local.feeling = "great"') is True - await state.set("Local.feeling", "bad") - assert await state.eval('=Local.feeling = "great"') is False + state.set("Local.feeling", "bad") + assert state.eval('=Local.feeling = "great"') is False - async def test_inequality(self, mock_shared_state): + async def test_inequality(self, mock_state): """Test <> inequality operator.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Upper(System.LastMessage.Text) <> "EXIT" - await state.set("Local.status", "active") - assert await state.eval('=Local.status <> "done"') is True - assert await state.eval('=Local.status <> "active"') is False + state.set("Local.status", "active") + assert state.eval('=Local.status <> "done"') is True + assert state.eval('=Local.status <> "active"') is False class TestPowerFxArithmetic: """Test arithmetic operations.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_addition(self, mock_shared_state): + async def test_addition(self, mock_state): """Test + operator.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: value: =Local.TurnCount + 1 - await state.set("Local.TurnCount", 3) - result = await state.eval("=Local.TurnCount + 1") + state.set("Local.TurnCount", 3) + result = state.eval("=Local.TurnCount + 1") assert result == 4 @@ -309,97 +305,95 @@ class TestPowerFxCustomFunctions: """Test custom functions (UserMessage, MessageText, AgentMessage).""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state @pytest.mark.asyncio - async def test_agent_message_function(self, mock_shared_state): + async def test_agent_message_function(self, mock_state): """Test AgentMessage function (.NET compatibility alias for AssistantMessage).""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From .NET YAML: messages: =AgentMessage(Local.Response) - await state.set("Local.Response", "Here is the analysis result") - result = await state.eval("=AgentMessage(Local.Response)") + state.set("Local.Response", "Here is the analysis result") + result = state.eval("=AgentMessage(Local.Response)") assert isinstance(result, dict) assert result["role"] == "assistant" assert result["text"] == "Here is the analysis result" @pytest.mark.asyncio - async def test_agent_message_with_empty_string(self, mock_shared_state): + async def test_agent_message_with_empty_string(self, mock_state): """Test AgentMessage with empty string.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set("Local.Response", "") - result = await state.eval("=AgentMessage(Local.Response)") + state.set("Local.Response", "") + result = state.eval("=AgentMessage(Local.Response)") assert result["role"] == "assistant" assert result["text"] == "" @pytest.mark.asyncio - async def test_user_message_with_variable(self, mock_shared_state): + async def test_user_message_with_variable(self, mock_state): """Test UserMessage function with variable reference.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: messages: =UserMessage(Local.ServiceParameters.IssueDescription) - await state.set("Local.ServiceParameters", {"IssueDescription": "My computer won't boot"}) - result = await state.eval("=UserMessage(Local.ServiceParameters.IssueDescription)") + state.set("Local.ServiceParameters", {"IssueDescription": "My computer won't boot"}) + result = state.eval("=UserMessage(Local.ServiceParameters.IssueDescription)") assert isinstance(result, dict) assert result["role"] == "user" assert result["text"] == "My computer won't boot" - async def test_user_message_with_simple_variable(self, mock_shared_state): + async def test_user_message_with_simple_variable(self, mock_state): """Test UserMessage with simple variable.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: messages: =Local.Problem - await state.set("Local.Problem", "What is 2+2?") - result = await state.eval("=UserMessage(Local.Problem)") + state.set("Local.Problem", "What is 2+2?") + result = state.eval("=UserMessage(Local.Problem)") assert result["role"] == "user" assert result["text"] == "What is 2+2?" - async def test_message_text_with_list(self, mock_shared_state): + async def test_message_text_with_list(self, mock_state): """Test MessageText extracts text from message list.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set( + state.set( "Local.messages", [ {"role": "user", "text": "Hello"}, {"role": "assistant", "text": "Hi there!"}, ], ) - result = await state.eval("=MessageText(Local.messages)") + result = state.eval("=MessageText(Local.messages)") assert result == "Hi there!" - async def test_message_text_empty_list(self, mock_shared_state): + async def test_message_text_empty_list(self, mock_state): """Test MessageText with empty list returns empty string.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() - await state.set("Local.messages", []) - result = await state.eval("=MessageText(Local.messages)") + state.set("Local.messages", []) + result = state.eval("=MessageText(Local.messages)") assert result == "" @@ -407,51 +401,49 @@ class TestPowerFxNestedVariables: """Test nested variable access patterns from YAML.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_nested_local_variable(self, mock_shared_state): + async def test_nested_local_variable(self, mock_state): """Test nested Local.* variable access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Local.ServiceParameters.IssueDescription - await state.set("Local.ServiceParameters", {"IssueDescription": "Screen is black"}) - result = await state.eval("=Local.ServiceParameters.IssueDescription") + state.set("Local.ServiceParameters", {"IssueDescription": "Screen is black"}) + result = state.eval("=Local.ServiceParameters.IssueDescription") assert result == "Screen is black" - async def test_nested_routing_parameters(self, mock_shared_state): + async def test_nested_routing_parameters(self, mock_state): """Test RoutingParameters access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Local.RoutingParameters.TeamName - await state.set("Local.RoutingParameters", {"TeamName": "Windows Support"}) - result = await state.eval("=Local.RoutingParameters.TeamName") + state.set("Local.RoutingParameters", {"TeamName": "Windows Support"}) + result = state.eval("=Local.RoutingParameters.TeamName") assert result == "Windows Support" - async def test_nested_ticket_parameters(self, mock_shared_state): + async def test_nested_ticket_parameters(self, mock_state): """Test TicketParameters access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: =Local.TicketParameters.TicketId - await state.set("Local.TicketParameters", {"TicketId": "TKT-12345"}) - result = await state.eval("=Local.TicketParameters.TicketId") + state.set("Local.TicketParameters", {"TicketId": "TKT-12345"}) + result = state.eval("=Local.TicketParameters.TicketId") assert result == "TKT-12345" @@ -459,39 +451,37 @@ class TestPowerFxUndefinedVariables: """Test graceful handling of undefined variables.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_undefined_local_variable_returns_none(self, mock_shared_state): + async def test_undefined_local_variable_returns_none(self, mock_state): """Test that undefined Local.* variables return None.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Variable not set - should return None (not raise) - result = await state.eval("=Local.UndefinedVariable") + result = state.eval("=Local.UndefinedVariable") assert result is None - async def test_undefined_nested_variable_returns_none(self, mock_shared_state): + async def test_undefined_nested_variable_returns_none(self, mock_state): """Test that undefined nested variables return None.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # Nested undefined variable - result = await state.eval("=Local.Something.Nested.Deep") + result = state.eval("=Local.Something.Nested.Deep") assert result is None @@ -499,41 +489,39 @@ class TestStringInterpolation: """Test string interpolation patterns.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_interpolate_local_variable(self, mock_shared_state): + async def test_interpolate_local_variable(self, mock_state): """Test {Local.Variable} interpolation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: activity: "Created ticket #{Local.TicketParameters.TicketId}" - await state.set("Local.TicketParameters", {"TicketId": "TKT-999"}) - result = await state.interpolate_string("Created ticket #{Local.TicketParameters.TicketId}") + state.set("Local.TicketParameters", {"TicketId": "TKT-999"}) + result = state.interpolate_string("Created ticket #{Local.TicketParameters.TicketId}") assert result == "Created ticket #TKT-999" - async def test_interpolate_routing_team(self, mock_shared_state): + async def test_interpolate_routing_team(self, mock_state): """Test routing team interpolation.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize() + state = DeclarativeWorkflowState(mock_state) + state.initialize() # From YAML: activity: Routing to {Local.RoutingParameters.TeamName} - await state.set("Local.RoutingParameters", {"TeamName": "Linux Support"}) - result = await state.interpolate_string("Routing to {Local.RoutingParameters.TeamName}") + state.set("Local.RoutingParameters", {"TeamName": "Linux Support"}) + result = state.interpolate_string("Routing to {Local.RoutingParameters.TeamName}") assert result == "Routing to Linux Support" @@ -541,41 +529,39 @@ class TestWorkflowInputsAccess: """Test Workflow.Inputs access patterns.""" @pytest.fixture - def mock_shared_state(self): + def mock_state(self): """Create a mock shared state.""" - shared_state = MagicMock() - shared_state._data = {} + mock_state = MagicMock() + mock_state._data = {} - async def mock_get(key): - if key not in shared_state._data: - raise KeyError(key) - return shared_state._data[key] + def mock_get(key, default=None): + return mock_state._data.get(key, default) - async def mock_set(key, value): - shared_state._data[key] = value + def mock_set(key, value): + mock_state._data[key] = value - shared_state.get = AsyncMock(side_effect=mock_get) - shared_state.set = AsyncMock(side_effect=mock_set) - return shared_state + mock_state.get = MagicMock(side_effect=mock_get) + mock_state.set = MagicMock(side_effect=mock_set) + return mock_state - async def test_inputs_name(self, mock_shared_state): + async def test_inputs_name(self, mock_state): """Test inputs.name access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"name": "Alice", "age": 25}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"name": "Alice", "age": 25}) # .NET style (standard) - result = await state.eval("=Workflow.Inputs.name") + result = state.eval("=Workflow.Inputs.name") assert result == "Alice" # Also test inputs.name shorthand - result = await state.eval("=inputs.name") + result = state.eval("=inputs.name") assert result == "Alice" - async def test_inputs_problem(self, mock_shared_state): + async def test_inputs_problem(self, mock_state): """Test inputs.problem access.""" - state = DeclarativeWorkflowState(mock_shared_state) - await state.initialize({"problem": "What is 5 * 6?"}) + state = DeclarativeWorkflowState(mock_state) + state.initialize({"problem": "What is 5 * 6?"}) # .NET style (standard) - result = await state.eval("=Workflow.Inputs.problem") + result = state.eval("=Workflow.Inputs.problem") assert result == "What is 5 * 6?" diff --git a/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx b/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx index e3b52055bc..ca589f4e42 100644 --- a/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx +++ b/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx @@ -86,8 +86,8 @@ export function CheckpointInfoModal({ (cp) => cp.checkpoint_id === selectedCheckpointId ); - const executorIds = fullCheckpoint?.shared_state?._executor_state - ? Object.keys(fullCheckpoint.shared_state._executor_state) + const executorIds = fullCheckpoint?.state?._executor_state + ? Object.keys(fullCheckpoint.state._executor_state) : []; const messageExecutors = fullCheckpoint?.messages ? Object.keys(fullCheckpoint.messages) @@ -348,11 +348,11 @@ export function CheckpointInfoModal({ {/* Shared State */}
Shared State
- {fullCheckpoint?.shared_state && Object.keys(fullCheckpoint.shared_state).filter( + {fullCheckpoint?.state && Object.keys(fullCheckpoint.state).filter( (k) => k !== "_executor_state" ).length > 0 ? (
- {Object.keys(fullCheckpoint.shared_state) + {Object.keys(fullCheckpoint.state) .filter((k) => k !== "_executor_state") .map((key) => ( diff --git a/python/packages/devui/frontend/src/types/index.ts b/python/packages/devui/frontend/src/types/index.ts index 3cbc471403..7d6e9a8f73 100644 --- a/python/packages/devui/frontend/src/types/index.ts +++ b/python/packages/devui/frontend/src/types/index.ts @@ -290,7 +290,7 @@ export interface FullCheckpoint { workflow_id: string; timestamp: string; messages: Record; - shared_state: Record; + state: Record; pending_request_info_events: Record; iteration_count: number; metadata: Record; diff --git a/python/packages/devui/tests/test_checkpoints.py b/python/packages/devui/tests/test_checkpoints.py index fbaf8734cd..3e1e0c96c7 100644 --- a/python/packages/devui/tests/test_checkpoints.py +++ b/python/packages/devui/tests/test_checkpoints.py @@ -106,7 +106,7 @@ async def test_conversation_scoped_checkpoint_save(self, checkpoint_manager, tes from agent_framework._workflows._checkpoint import WorkflowCheckpoint checkpoint = WorkflowCheckpoint( - checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, shared_state={"test": "data"} + checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"} ) # Get checkpoint storage for this conversation and save @@ -144,7 +144,7 @@ async def test_conversation_isolation(self, checkpoint_manager, test_workflow): checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, - shared_state={"conversation": "A"}, + state={"conversation": "A"}, ) storage_a = checkpoint_manager.get_checkpoint_storage(conv_a) await storage_a.save_checkpoint(checkpoint_a) @@ -181,7 +181,7 @@ async def test_list_checkpoints_in_session(self, checkpoint_manager, test_workfl checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, - shared_state={"iteration": i}, + state={"iteration": i}, ) saved_id = await storage.save_checkpoint(checkpoint) checkpoint_ids.append(saved_id) @@ -217,7 +217,7 @@ async def test_checkpoints_appear_as_conversation_items(self, checkpoint_manager checkpoint_id=f"checkpoint_{i}", workflow_id=test_workflow.id, messages={}, - shared_state={"iteration": i}, + state={"iteration": i}, ) saved_id = await storage.save_checkpoint(checkpoint) checkpoint_ids.append(saved_id) @@ -259,7 +259,7 @@ async def test_load_checkpoint_from_session(self, checkpoint_manager, test_workf checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, - shared_state={"test_key": "test_value"}, + state={"test_key": "test_value"}, ) # Save to this session @@ -272,7 +272,7 @@ async def test_load_checkpoint_from_session(self, checkpoint_manager, test_workf assert loaded_checkpoint is not None assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id assert loaded_checkpoint.workflow_id == original_checkpoint.workflow_id - assert loaded_checkpoint.shared_state == {"test_key": "test_value"} + assert loaded_checkpoint.state == {"test_key": "test_value"} class TestCheckpointStorage: @@ -298,7 +298,7 @@ async def test_checkpoint_storage_protocol(self, checkpoint_manager, test_workfl from agent_framework._workflows._checkpoint import WorkflowCheckpoint checkpoint = WorkflowCheckpoint( - checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, shared_state={"test": "data"} + checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"} ) # Test save_checkpoint @@ -348,7 +348,7 @@ async def test_manual_checkpoint_save_via_injected_storage(self, checkpoint_mana from agent_framework._workflows._checkpoint import WorkflowCheckpoint checkpoint = WorkflowCheckpoint( - checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, shared_state={"injected": True} + checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"injected": True} ) await checkpoint_storage.save_checkpoint(checkpoint) @@ -381,7 +381,7 @@ async def test_checkpoint_roundtrip_via_storage(self, checkpoint_manager, test_w checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, - shared_state={"ready_to_resume": True}, + state={"ready_to_resume": True}, ) checkpoint_id = await checkpoint_storage.save_checkpoint(checkpoint) @@ -389,7 +389,7 @@ async def test_checkpoint_roundtrip_via_storage(self, checkpoint_manager, test_w loaded = await checkpoint_storage.load_checkpoint(checkpoint_id) assert loaded is not None assert loaded.checkpoint_id == checkpoint_id - assert loaded.shared_state == {"ready_to_resume": True} + assert loaded.state == {"ready_to_resume": True} # Verify checkpoint is accessible via storage (for UI to list checkpoints) checkpoints = await checkpoint_storage.list_checkpoints() diff --git a/python/packages/devui/tests/test_server.py b/python/packages/devui/tests/test_server.py index ac835bdfb5..16766bc14f 100644 --- a/python/packages/devui/tests/test_server.py +++ b/python/packages/devui/tests/test_server.py @@ -384,7 +384,7 @@ async def test_checkpoint_api_endpoints(test_entities_dir): checkpoint = WorkflowCheckpoint( checkpoint_id="test_checkpoint_1", workflow_id="test_workflow", - shared_state={"key": "value"}, + state={"key": "value"}, iteration_count=1, ) await storage.save_checkpoint(checkpoint) diff --git a/python/samples/getting_started/devui/fanout_workflow/workflow.py b/python/samples/getting_started/devui/fanout_workflow/workflow.py index bb84c28db7..52b893fa1f 100644 --- a/python/samples/getting_started/devui/fanout_workflow/workflow.py +++ b/python/samples/getting_started/devui/fanout_workflow/workflow.py @@ -190,9 +190,9 @@ async def ingest_data(self, request: ProcessingRequest, ctx: WorkflowContext[Dat timestamp=asyncio.get_event_loop().time(), ) - # Store both batch data and original request in shared state - await ctx.set_shared_state(f"batch_{batch.batch_id}", batch) - await ctx.set_shared_state(f"request_{batch.batch_id}", request) + # Store both batch data and original request in workflow state + ctx.set_state(f"batch_{batch.batch_id}", batch) + ctx.set_state(f"request_{batch.batch_id}", request) await ctx.send_message(batch) @@ -205,7 +205,7 @@ class SchemaValidator(Executor): async def validate_schema(self, batch: DataBatch, ctx: WorkflowContext[ValidationReport]) -> None: """Perform schema validation with processing delay.""" # Check if schema validation is enabled - request = await ctx.get_shared_state(f"request_{batch.batch_id}") + request = ctx.get_state(f"request_{batch.batch_id}") if not request or not request.enable_schema_validation: return @@ -241,7 +241,7 @@ class DataQualityValidator(Executor): async def validate_quality(self, batch: DataBatch, ctx: WorkflowContext[ValidationReport]) -> None: """Perform data quality validation.""" # Check if quality validation is enabled - request = await ctx.get_shared_state(f"request_{batch.batch_id}") + request = ctx.get_state(f"request_{batch.batch_id}") if not request or not request.enable_quality_validation: return @@ -283,7 +283,7 @@ class SecurityValidator(Executor): async def validate_security(self, batch: DataBatch, ctx: WorkflowContext[ValidationReport]) -> None: """Perform security validation.""" # Check if security validation is enabled - request = await ctx.get_shared_state(f"request_{batch.batch_id}") + request = ctx.get_state(f"request_{batch.batch_id}") if not request or not request.enable_security_validation: return @@ -324,7 +324,7 @@ async def aggregate_validations( return batch_id = reports[0].batch_id - request = await ctx.get_shared_state(f"request_{batch_id}") + request = ctx.get_state(f"request_{batch_id}") await asyncio.sleep(1) # Aggregation processing time @@ -354,8 +354,8 @@ async def aggregate_validations( ) return - # Retrieve original batch from shared state - batch_data = await ctx.get_shared_state(f"batch_{batch_id}") + # Retrieve original batch from workflow state + batch_data = ctx.get_state(f"batch_{batch_id}") if batch_data: await ctx.send_message(batch_data) else: @@ -376,7 +376,7 @@ class DataNormalizer(Executor): @handler async def normalize_data(self, batch: DataBatch, ctx: WorkflowContext[TransformationResult]) -> None: """Perform data normalization.""" - request = await ctx.get_shared_state(f"request_{batch.batch_id}") + request = ctx.get_state(f"request_{batch.batch_id}") # Check if normalization is enabled if not request or "normalize" not in request.transformations: @@ -421,7 +421,7 @@ class DataEnrichment(Executor): @handler async def enrich_data(self, batch: DataBatch, ctx: WorkflowContext[TransformationResult]) -> None: """Perform data enrichment.""" - request = await ctx.get_shared_state(f"request_{batch.batch_id}") + request = ctx.get_state(f"request_{batch.batch_id}") # Check if enrichment is enabled if not request or "enrich" not in request.transformations: @@ -465,7 +465,7 @@ class DataAggregator(Executor): @handler async def aggregate_data(self, batch: DataBatch, ctx: WorkflowContext[TransformationResult]) -> None: """Perform data aggregation.""" - request = await ctx.get_shared_state(f"request_{batch.batch_id}") + request = ctx.get_state(f"request_{batch.batch_id}") # Check if aggregation is enabled if not request or "aggregate" not in request.transformations: @@ -626,12 +626,12 @@ async def process_final_results( # Workflow Builder Helper class WorkflowSetupHelper: - """Helper class to set up the complex workflow with shared state management.""" + """Helper class to set up the complex workflow with state management.""" @staticmethod async def store_batch_data(batch: DataBatch, ctx: WorkflowContext) -> None: - """Store batch data in shared state for later retrieval.""" - await ctx.set_shared_state(f"batch_{batch.batch_id}", batch) + """Store batch data in workflow state for later retrieval.""" + ctx.set_state(f"batch_{batch.batch_id}", batch) # Create the workflow instance diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index 524f93fd61..deb1597021 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -146,16 +146,10 @@ to configure which agents can route to which others with a fluent, type-safe API ### state-management -| Sample | File | Concepts | -| -------------------------------- | ------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------- | -| Shared States | [state-management/shared_states_with_agents.py](./state-management/shared_states_with_agents.py) | Store in shared state once and later reuse across agents | -| Workflow Kwargs (Custom Context) | [state-management/workflow_kwargs.py](./state-management/workflow_kwargs.py) | Pass custom context (data, user tokens) via kwargs to `@ai_function` tools | - -======= -| Sample | File | Concepts | -|---|---|---| -| Shared States | [state-management/shared_states_with_agents.py](./state-management/shared_states_with_agents.py) | Store in shared state once and later reuse across agents | -| Workflow Kwargs (Custom Context) | [state-management/workflow_kwargs.py](./state-management/workflow_kwargs.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools | +| Sample | File | Concepts | +| -------------------------------- | ------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------- | +| State with Agents | [state-management/state_with_agents.py](./state-management/state_with_agents.py) | Store in state once and later reuse across agents | +| Workflow Kwargs (Custom Context) | [state-management/workflow_kwargs.py](./state-management/workflow_kwargs.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools | ### visualization diff --git a/python/samples/getting_started/workflows/agents/workflow_as_agent_kwargs.py b/python/samples/getting_started/workflows/agents/workflow_as_agent_kwargs.py index 56b8c6de77..305f6ae07b 100644 --- a/python/samples/getting_started/workflows/agents/workflow_as_agent_kwargs.py +++ b/python/samples/getting_started/workflows/agents/workflow_as_agent_kwargs.py @@ -18,7 +18,7 @@ - Build a workflow using SequentialBuilder (or any builder pattern) - Expose the workflow as a reusable agent via workflow.as_agent() - Pass custom context as kwargs when invoking workflow_agent.run() or run_stream() -- kwargs are stored in SharedState and propagated to all agent invocations +- kwargs are stored in State and propagated to all agent invocations - @tool functions receive kwargs via **kwargs parameter When to use workflow.as_agent(): diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py index a2628592ea..072faff168 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py @@ -83,9 +83,9 @@ async def prepare(self, brief: str, ctx: WorkflowContext[AgentExecutorRequest, s normalized = " ".join(brief.split()).strip() if not normalized.endswith("."): normalized += "." - # Persist the cleaned brief in shared state so downstream executors and + # Persist the cleaned brief in workflow state so downstream executors and # future checkpoints can recover the original intent. - await ctx.set_shared_state("brief", normalized) + ctx.set_state("brief", normalized) prompt = ( "You are drafting product release notes. Summarise the brief below in two sentences. " "Keep it positive and end with a call to action.\n\n" diff --git a/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py b/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py index 44385bffca..275901b8f6 100644 --- a/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py +++ b/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py @@ -38,7 +38,7 @@ Demonstrate how to use a multi-selection edge group to fan out from one executor to multiple possible targets. Show how to: - Implement a selection function that chooses one or more downstream branches based on analysis. -- Share state across branches so different executors can read the same email content. +- Share workflow state across branches so different executors can read the same email content. - Validate agent outputs with Pydantic models for robust structured data exchange. - Merge results from multiple branches (e.g., a summary) back into a typed state. - Apply conditional persistence logic (short vs long emails). @@ -46,7 +46,7 @@ Prerequisites: - Familiarity with WorkflowBuilder, executors, edges, and events. - Understanding of multi-selection edge groups and how their selection function maps to target ids. -- Experience with shared state in workflows for persisting and reusing objects. +- Experience with workflow state for persisting and reusing objects. """ @@ -89,8 +89,8 @@ class DatabaseEvent(WorkflowEvent): ... @executor(id="store_email") async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: new_email = Email(email_id=str(uuid4()), email_content=email_text) - await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) - await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) + ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) + ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=new_email.email_content)], should_respond=True) @@ -100,8 +100,8 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest @executor(id="to_analysis_result") async def to_analysis_result(response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisResult]) -> None: parsed = AnalysisResultAgent.model_validate_json(response.agent_response.text) - email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{email_id}") + email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY) + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{email_id}") await ctx.send_message( AnalysisResult( spam_decision=parsed.spam_decision, @@ -118,7 +118,7 @@ async def submit_to_email_assistant(analysis: AnalysisResult, ctx: WorkflowConte if analysis.spam_decision != "NotSpam": raise RuntimeError("This executor should only handle NotSpam messages.") - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}") + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}") await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=email.email_content)], should_respond=True) ) @@ -133,7 +133,7 @@ async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContex @executor(id="summarize_email") async def summarize_email(analysis: AnalysisResult, ctx: WorkflowContext[AgentExecutorRequest]) -> None: # Only called for long NotSpam emails by selection_func - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}") + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}") await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=email.email_content)], should_respond=True) ) @@ -142,8 +142,8 @@ async def summarize_email(analysis: AnalysisResult, ctx: WorkflowContext[AgentEx @executor(id="merge_summary") async def merge_summary(response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisResult]) -> None: summary = EmailSummaryModel.model_validate_json(response.agent_response.text) - email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{email_id}") + email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY) + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{email_id}") # Build an AnalysisResult mirroring to_analysis_result but with summary await ctx.send_message( AnalysisResult( @@ -167,7 +167,7 @@ async def handle_spam(analysis: AnalysisResult, ctx: WorkflowContext[Never, str] @executor(id="handle_uncertain") async def handle_uncertain(analysis: AnalysisResult, ctx: WorkflowContext[Never, str]) -> None: if analysis.spam_decision == "Uncertain": - email: Email | None = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}") + email: Email | None = ctx.get_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}") await ctx.yield_output( f"Email marked as uncertain: {analysis.reason}. Email content: {getattr(email, 'email_content', '')}" ) diff --git a/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py b/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py index f2090e4acc..8ebbe9af2b 100644 --- a/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py +++ b/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py @@ -27,13 +27,13 @@ """ Sample: Switch-Case Edge Group with an explicit Uncertain branch. -The workflow stores a single email in shared state, asks a spam detection agent for a three way decision, +The workflow stores a single email in workflow state, asks a spam detection agent for a three way decision, then routes with a switch-case group: NotSpam to the drafting assistant, Spam to a spam handler, and Default to an Uncertain handler. Purpose: Demonstrate deterministic one of N routing with switch-case edges. Show how to: -- Persist input once in shared state, then pass around a small typed pointer that carries the email id. +- Persist input once in workflow state, then pass around a small typed pointer that carries the email id. - Validate agent JSON with Pydantic models for robust parsing. - Keep executor responsibilities narrow. Transform model output to a typed DetectionResult, then route based on that type. @@ -76,7 +76,7 @@ class DetectionResult: @dataclass class Email: - # In memory record of the email content stored in shared state. + # In memory record of the email content stored in workflow state. email_id: str email_content: str @@ -95,8 +95,8 @@ def condition(message: Any) -> bool: async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: # Persist the raw email once. Store under a unique key and set the current pointer for convenience. new_email = Email(email_id=str(uuid4()), email_content=email_text) - await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) - await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) + ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) + ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) # Kick off the detector by forwarding the email as a user message to the spam_detection_agent. await ctx.send_message( @@ -108,7 +108,7 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowContext[DetectionResult]) -> None: # Parse the detector JSON into a typed model. Attach the current email id for downstream lookups. parsed = DetectionResultAgent.model_validate_json(response.agent_response.text) - email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) + email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY) await ctx.send_message(DetectionResult(spam_decision=parsed.spam_decision, reason=parsed.reason, email_id=email_id)) @@ -118,8 +118,8 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon if detection.spam_decision != "NotSpam": raise RuntimeError("This executor should only handle NotSpam messages.") - # Load the original content from shared state using the id carried in DetectionResult. - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") + # Load the original content from workflow state using the id carried in DetectionResult. + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=email.email_content)], should_respond=True) ) @@ -145,7 +145,7 @@ async def handle_spam(detection: DetectionResult, ctx: WorkflowContext[Never, st async def handle_uncertain(detection: DetectionResult, ctx: WorkflowContext[Never, str]) -> None: # Uncertain path terminal. Surface the original content to aid human review. if detection.spam_decision == "Uncertain": - email: Email | None = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") + email: Email | None = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") await ctx.yield_output( f"Email marked as uncertain: {detection.reason}. Email content: {getattr(email, 'email_content', '')}" ) diff --git a/python/samples/getting_started/workflows/parallelism/map_reduce_and_visualization.py b/python/samples/getting_started/workflows/parallelism/map_reduce_and_visualization.py index d98c6cb78b..af2a6ad53d 100644 --- a/python/samples/getting_started/workflows/parallelism/map_reduce_and_visualization.py +++ b/python/samples/getting_started/workflows/parallelism/map_reduce_and_visualization.py @@ -10,11 +10,10 @@ from agent_framework import ( Executor, # Base class for custom workflow steps WorkflowBuilder, # Fluent builder for executors and edges - WorkflowContext, # Per run context with shared state and messaging + WorkflowContext, # Per run context with workflow state and messaging WorkflowOutputEvent, # Event emitted when workflow yields output WorkflowViz, # Utility to visualize a workflow graph handler, # Decorator to expose an Executor method as a step - tool, ) from typing_extensions import Never @@ -27,7 +26,7 @@ Purpose: Show how to: -- Partition input once and coordinate parallel mappers with shared state. +- Partition input once and coordinate parallel mappers with workflow state. - Implement map, shuffle, and reduce executors that pass file paths instead of large payloads. - Use fan out and fan in edges to express parallelism and joins. - Persist intermediate results to disk to bound memory usage for large inputs. @@ -50,8 +49,8 @@ # Ensure the temporary directory exists os.makedirs(TEMP_DIR, exist_ok=True) -# Define a key for the shared state to store the data to be processed -SHARED_STATE_DATA_KEY = "data_to_be_processed" +# Define a key for the workflow state to store the data to be processed +STATE_DATA_KEY = "data_to_be_processed" class SplitCompleted: @@ -70,17 +69,17 @@ def __init__(self, map_executor_ids: list[str], id: str | None = None): @handler async def split(self, data: str, ctx: WorkflowContext[SplitCompleted]) -> None: - """Tokenize input and assign contiguous index ranges to each mapper via shared state. + """Tokenize input and assign contiguous index ranges to each mapper via workflow state. Args: data: The raw text to process. - ctx: Workflow context to persist shared state and send messages. + ctx: Workflow context to persist state and send messages. """ # Process data into a list of words and remove empty lines or words. word_list = self._preprocess(data) # Store tokenized words once so all mappers can read by index. - await ctx.set_shared_state(SHARED_STATE_DATA_KEY, word_list) + ctx.set_state(STATE_DATA_KEY, word_list) # Divide indices into contiguous slices for each mapper. map_executor_count = len(self._map_executor_ids) @@ -91,8 +90,8 @@ async def _process_chunk(i: int) -> None: start_index = i * chunk_size end_index = start_index + chunk_size if i < map_executor_count - 1 else len(word_list) - # The mapper reads its slice from shared state keyed by its own executor id. - await ctx.set_shared_state(self._map_executor_ids[i], (start_index, end_index)) + # The mapper reads its slice from workflow state keyed by its own executor id. + ctx.set_state(self._map_executor_ids[i], (start_index, end_index)) await ctx.send_message(SplitCompleted(), self._map_executor_ids[i]) tasks = [asyncio.create_task(_process_chunk(i)) for i in range(map_executor_count)] @@ -120,11 +119,11 @@ async def map(self, _: SplitCompleted, ctx: WorkflowContext[MapCompleted]) -> No Args: _: SplitCompleted marker indicating maps can begin. - ctx: Workflow context for shared state access and messaging. + ctx: Workflow context for workflow state access and messaging. """ # Retrieve tokens and our assigned slice. - data_to_be_processed: list[str] = await ctx.get_shared_state(SHARED_STATE_DATA_KEY) - chunk_start, chunk_end = await ctx.get_shared_state(self.id) + data_to_be_processed: list[str] = ctx.get_state(STATE_DATA_KEY) + chunk_start, chunk_end = ctx.get_state(self.id) results = [(item, 1) for item in data_to_be_processed[chunk_start:chunk_end]] @@ -287,7 +286,8 @@ async def main(): # Step 2: Build the workflow graph using fan out and fan in edges. workflow = ( - workflow_builder.set_start_executor("split_data_executor") + workflow_builder + .set_start_executor("split_data_executor") .add_fan_out_edges( "split_data_executor", ["map_executor_0", "map_executor_1", "map_executor_2"], diff --git a/python/samples/getting_started/workflows/state-management/shared_states_with_agents.py b/python/samples/getting_started/workflows/state-management/state_with_agents.py similarity index 89% rename from python/samples/getting_started/workflows/state-management/shared_states_with_agents.py rename to python/samples/getting_started/workflows/state-management/state_with_agents.py index 700dcb1b95..239a8f2950 100644 --- a/python/samples/getting_started/workflows/state-management/shared_states_with_agents.py +++ b/python/samples/getting_started/workflows/state-management/state_with_agents.py @@ -15,7 +15,6 @@ WorkflowBuilder, WorkflowContext, executor, - tool, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -23,14 +22,14 @@ from typing_extensions import Never """ -Sample: Shared state with agents and conditional routing. +Sample: Workflow state with agents and conditional routing. Store an email once by id, classify it with a detector agent, then either draft a reply with an assistant agent or finish with a spam notice. Stream events as the workflow runs. Purpose: Show how to: -- Use shared state to decouple large payloads from messages and pass around lightweight references. +- Use workflow state to decouple large payloads from messages and pass around lightweight references. - Enforce structured agent outputs with Pydantic models via response_format for robust parsing. - Route using conditional edges based on a typed intermediate DetectionResult. - Compose agent backed executors with function style executors and yield the final output when the workflow completes. @@ -60,7 +59,7 @@ class EmailResponse(BaseModel): @dataclass class DetectionResult: - """Internal detection result enriched with the shared state email_id for later lookups.""" + """Internal detection result enriched with the state email_id for later lookups.""" is_spam: bool reason: str @@ -69,7 +68,7 @@ class DetectionResult: @dataclass class Email: - """In memory record stored in shared state to avoid re-sending large bodies on edges.""" + """In memory record stored in state to avoid re-sending large bodies on edges.""" email_id: str email_content: str @@ -93,7 +92,7 @@ def condition(message: Any) -> bool: @executor(id="store_email") async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: - """Persist the raw email content in shared state and trigger spam detection. + """Persist the raw email content in state and trigger spam detection. Responsibilities: - Generate a unique email_id (UUID) for downstream retrieval. @@ -101,8 +100,8 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest - Emit an AgentExecutorRequest asking the detector to respond. """ new_email = Email(email_id=str(uuid4()), email_content=email_text) - await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) - await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) + ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) + ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=new_email.email_content)], should_respond=True) @@ -115,11 +114,11 @@ async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowCont Steps: 1) Validate the agent's JSON output into DetectionResultAgent. - 2) Retrieve the current email_id from shared state. + 2) Retrieve the current email_id from workflow state. 3) Send a typed DetectionResult for conditional routing. """ parsed = DetectionResultAgent.model_validate_json(response.agent_response.text) - email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) + email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY) await ctx.send_message(DetectionResult(is_spam=parsed.is_spam, reason=parsed.reason, email_id=email_id)) @@ -133,8 +132,8 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon if detection.is_spam: raise RuntimeError("This executor should only handle non-spam messages.") - # Load the original content by id from shared state and forward it to the assistant. - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") + # Load the original content by id from workflow state and forward it to the assistant. + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=email.email_content)], should_respond=True) ) @@ -183,7 +182,7 @@ def create_email_assistant_agent() -> ChatAgent: async def main() -> None: - """Build and run the shared state with agents and conditional routing workflow.""" + """Build and run the workflow state with agents and conditional routing workflow.""" # Build the workflow graph with conditional edges. # Flow: diff --git a/python/samples/getting_started/workflows/state-management/workflow_kwargs.py b/python/samples/getting_started/workflows/state-management/workflow_kwargs.py index bf7320f834..796164efce 100644 --- a/python/samples/getting_started/workflows/state-management/workflow_kwargs.py +++ b/python/samples/getting_started/workflows/state-management/workflow_kwargs.py @@ -16,7 +16,7 @@ Key Concepts: - Pass custom context as kwargs when invoking workflow.run_stream() or workflow.run() -- kwargs are stored in SharedState and passed to all agent invocations +- kwargs are stored in State and passed to all agent invocations - @tool functions receive kwargs via **kwargs parameter - Works with Sequential, Concurrent, GroupChat, Handoff, and Magentic patterns From f4bf6a4650ea72caa119e89a2271349b9eae8d35 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 4 Feb 2026 15:47:39 +0900 Subject: [PATCH 2/5] Fixes --- .../agent_framework/_workflows/__init__.py | 2 - .../core/agent_framework/_workflows/_state.py | 28 +- .../core/tests/workflow/test_state.py | 303 ++++++++++++++++++ .../workflow/checkpoint-info-modal.tsx | 4 +- 4 files changed, 323 insertions(+), 14 deletions(-) create mode 100644 python/packages/core/tests/workflow/test_state.py diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 8dc5a71054..7c0a2e4ad4 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -102,7 +102,6 @@ RunnerContext, ) from ._sequential import SequentialBuilder -from ._state import State from ._validation import ( EdgeDuplicationError, GraphConnectivityError, @@ -181,7 +180,6 @@ "SequentialBuilder", "SingleEdgeGroup", "StandardMagenticManager", - "State", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SuperStepCompletedEvent", diff --git a/python/packages/core/agent_framework/_workflows/_state.py b/python/packages/core/agent_framework/_workflows/_state.py index 8e8a793389..4dca14a499 100644 --- a/python/packages/core/agent_framework/_workflows/_state.py +++ b/python/packages/core/agent_framework/_workflows/_state.py @@ -46,27 +46,35 @@ def get(self, key: str, default: Any = None) -> Any: The value if found, otherwise the default value. """ if key in self._pending: - return self._pending[key] + value = self._pending[key] + if value is _DeleteSentinel: + return default + return value return self._committed.get(key, default) def has(self, key: str) -> bool: """Check if a key exists in pending or committed state.""" - return key in self._pending or key in self._committed + if key in self._pending: + return self._pending[key] is not _DeleteSentinel + return key in self._committed def delete(self, key: str) -> None: """Mark a key for deletion. - If the key exists in pending, it is removed from pending. - If the key exists only in committed, a sentinel is stored in pending - to indicate deletion at commit time. + If the key exists in committed state, a sentinel is stored in pending + to indicate deletion at commit time. If it only exists in pending, + it is removed from pending. """ - if key in self._pending: - del self._pending[key] - elif key in self._committed: - self._pending[key] = _DeleteSentinel - else: + if key not in self._pending and key not in self._committed: raise KeyError(f"Key '{key}' not found in state.") + if key in self._committed: + # Mark for deletion from committed state at commit time + self._pending[key] = _DeleteSentinel + elif key in self._pending: + # Only exists in pending, safe to just remove + del self._pending[key] + def clear(self) -> None: """Clear both committed and pending state.""" self._committed.clear() diff --git a/python/packages/core/tests/workflow/test_state.py b/python/packages/core/tests/workflow/test_state.py new file mode 100644 index 0000000000..486fc9fa25 --- /dev/null +++ b/python/packages/core/tests/workflow/test_state.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for the State class superstep caching behavior.""" + +import pytest + +from agent_framework._workflows._state import State + + +class TestStateBasicOperations: + """Tests for basic State get/set/has/delete operations.""" + + def test_set_and_get(self) -> None: + state = State() + state.set("key", "value") + assert state.get("key") == "value" + + def test_get_with_default(self) -> None: + state = State() + assert state.get("missing") is None + assert state.get("missing", "default") == "default" + + def test_has_returns_true_for_existing_key(self) -> None: + state = State() + state.set("key", "value") + assert state.has("key") is True + + def test_has_returns_false_for_missing_key(self) -> None: + state = State() + assert state.has("missing") is False + + def test_delete_existing_key(self) -> None: + state = State() + state.set("key", "value") + state.commit() + state.delete("key") + state.commit() + assert state.has("key") is False + assert state.get("key") is None + + def test_delete_missing_key_raises(self) -> None: + state = State() + with pytest.raises(KeyError, match="Key 'missing' not found"): + state.delete("missing") + + def test_clear(self) -> None: + state = State() + state.set("key1", "value1") + state.commit() + state.set("key2", "value2") + state.clear() + assert state.get("key1") is None + assert state.get("key2") is None + + +class TestSuperstepCaching: + """Tests for superstep caching semantics - pending vs committed state.""" + + def test_set_writes_to_pending_not_committed(self) -> None: + state = State() + state.set("key", "value") + + # Value is in pending + assert "key" in state._pending + # Value is NOT in committed + assert "key" not in state._committed + # But get() still returns it + assert state.get("key") == "value" + + def test_commit_moves_pending_to_committed(self) -> None: + state = State() + state.set("key", "value") + + # Before commit: in pending, not committed + assert "key" in state._pending + assert "key" not in state._committed + + state.commit() + + # After commit: in committed, pending cleared + assert "key" not in state._pending + assert "key" in state._committed + assert state.get("key") == "value" + + def test_discard_clears_pending_without_committing(self) -> None: + state = State() + state.set("existing", "original") + state.commit() + + # Make a pending change + state.set("existing", "modified") + state.set("new_key", "new_value") + + # Discard pending changes + state.discard() + + # Original value is preserved, new key never committed + assert state.get("existing") == "original" + assert state.get("new_key") is None + + def test_pending_overrides_committed_on_get(self) -> None: + state = State() + state.set("key", "committed_value") + state.commit() + + state.set("key", "pending_value") + + # get() returns pending value, not committed + assert state.get("key") == "pending_value" + # But committed still has old value + assert state._committed["key"] == "committed_value" + + def test_multiple_sets_before_commit(self) -> None: + state = State() + state.set("key", "value1") + state.set("key", "value2") + state.set("key", "value3") + + # Only final value is in pending + assert state.get("key") == "value3" + + state.commit() + assert state.get("key") == "value3" + + +class TestDeleteWithSuperstepCaching: + """Tests for delete behavior with superstep caching.""" + + def test_delete_pending_only_key(self) -> None: + state = State() + state.set("key", "value") + # Key only in pending, not committed + assert "key" in state._pending + assert "key" not in state._committed + + state.delete("key") + + # Should be removed from pending + assert "key" not in state._pending + assert state.get("key") is None + assert state.has("key") is False + + def test_delete_committed_key_marks_for_deletion(self) -> None: + state = State() + state.set("key", "value") + state.commit() + + state.delete("key") + + # Key should be marked for deletion in pending (sentinel) + assert "key" in state._pending + # get() should return default (not the sentinel!) + assert state.get("key") is None + assert state.get("key", "default") == "default" + # has() should return False + assert state.has("key") is False + # But committed still has it until commit() + assert "key" in state._committed + + def test_delete_committed_key_removed_on_commit(self) -> None: + state = State() + state.set("key", "value") + state.commit() + + state.delete("key") + state.commit() + + # Now it should be gone from committed too + assert "key" not in state._committed + assert "key" not in state._pending + + def test_delete_key_in_both_pending_and_committed(self) -> None: + """Test delete when key exists in both pending (modified) and committed.""" + state = State() + state.set("key", "original") + state.commit() + + # Modify the key (now in both pending and committed) + state.set("key", "modified") + assert state._pending["key"] == "modified" + assert state._committed["key"] == "original" + + # Delete should mark for deletion from committed + state.delete("key") + + # Should be marked for deletion + assert state.get("key") is None + assert state.has("key") is False + + # After commit, key should be fully removed + state.commit() + assert "key" not in state._committed + assert "key" not in state._pending + + def test_discard_after_delete_restores_committed_value(self) -> None: + state = State() + state.set("key", "value") + state.commit() + + state.delete("key") + # Key appears deleted + assert state.has("key") is False + + state.discard() + # After discard, committed value is restored + assert state.has("key") is True + assert state.get("key") == "value" + + +class TestFailureScenarios: + """Tests simulating failure scenarios - pending changes should not leak to committed.""" + + def test_failure_before_commit_preserves_committed_state(self) -> None: + """Simulate executor failure - pending changes should not affect committed state.""" + state = State() + state.set("key1", "original1") + state.set("key2", "original2") + state.commit() + + # Superstep starts - make some changes + state.set("key1", "modified1") + state.set("key3", "new_value") + state.delete("key2") + + # Simulate failure - we call discard() instead of commit() + state.discard() + + # All original values should be intact + assert state.get("key1") == "original1" + assert state.get("key2") == "original2" + assert state.get("key3") is None + + def test_no_partial_commits(self) -> None: + """Ensure commit is atomic - either all changes apply or none.""" + state = State() + state.set("key1", "value1") + state.set("key2", "value2") + state.set("key3", "value3") + + # Before commit - nothing in committed + assert len(state._committed) == 0 + + state.commit() + + # After commit - all three values committed together + assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"} + + def test_repeated_supersteps_are_isolated(self) -> None: + """Test that each superstep's changes are isolated until committed.""" + state = State() + + # Superstep 1 + state.set("counter", 1) + state.commit() + assert state.get("counter") == 1 + + # Superstep 2 + state.set("counter", 2) + state.set("temp", "should_be_discarded") + state.discard() # Simulate failure + assert state.get("counter") == 1 # Reverted to superstep 1 value + assert state.get("temp") is None + + # Superstep 3 + state.set("counter", 3) + state.commit() + assert state.get("counter") == 3 + + +class TestExportImport: + """Tests for state serialization (export/import).""" + + def test_export_returns_committed_only(self) -> None: + state = State() + state.set("committed_key", "committed_value") + state.commit() + state.set("pending_key", "pending_value") + + exported = state.export_state() + + # Only committed state is exported + assert exported == {"committed_key": "committed_value"} + assert "pending_key" not in exported + + def test_import_merges_into_committed(self) -> None: + state = State() + state.set("existing", "original") + state.commit() + + state.import_state({"imported": "value", "existing": "overwritten"}) + + assert state.get("imported") == "value" + assert state.get("existing") == "overwritten" + + def test_import_does_not_affect_pending(self) -> None: + state = State() + state.set("pending_key", "pending_value") + + state.import_state({"imported": "value"}) + + # Pending is still there + assert state.get("pending_key") == "pending_value" + assert "pending_key" in state._pending diff --git a/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx b/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx index ca589f4e42..10c9a0d38c 100644 --- a/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx +++ b/python/packages/devui/frontend/src/components/features/workflow/checkpoint-info-modal.tsx @@ -345,9 +345,9 @@ export function CheckpointInfoModal({
)} - {/* Shared State */} + {/* Workflow State */}
-
Shared State
+
Workflow State
{fullCheckpoint?.state && Object.keys(fullCheckpoint.state).filter( (k) => k !== "_executor_state" ).length > 0 ? ( From bd04dd65e2b45c74660c0784f70f94c3f1fc7a05 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 5 Feb 2026 09:47:13 +0900 Subject: [PATCH 3/5] Address PR feedback --- .../_workflows/_agent_executor.py | 2 +- .../agent_framework/_workflows/_checkpoint.py | 6 ++- .../core/agent_framework/_workflows/_state.py | 7 ++++ .../_workflows/_workflow_context.py | 39 +------------------ .../core/tests/workflow/test_runner.py | 2 +- 5 files changed, 14 insertions(+), 42 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 8385e0a4d7..9facbd07d8 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -333,7 +333,7 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentResponse | None: Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {} + run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) response = await self._agent.run( self._cache, diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint.py b/python/packages/core/agent_framework/_workflows/_checkpoint.py index c249c02bdc..874ded5568 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint.py @@ -26,8 +26,10 @@ class WorkflowCheckpoint: workflow_id: Identifier of the workflow this checkpoint belongs to timestamp: ISO 8601 timestamp when checkpoint was created messages: Messages exchanged between executors - state: Complete workflow state including user data and executor states. - Executor states are stored under the reserved key '_executor_state'. + state: Committed workflow state including user data and executor states. + This contains only committed state; pending state changes are not + included in checkpoints. Executor states are stored under the + reserved key '_executor_state'. iteration_count: Current iteration number when checkpoint was created metadata: Additional metadata (e.g., superstep info, graph signature) version: Checkpoint format version diff --git a/python/packages/core/agent_framework/_workflows/_state.py b/python/packages/core/agent_framework/_workflows/_state.py index 4dca14a499..093cfea8b6 100644 --- a/python/packages/core/agent_framework/_workflows/_state.py +++ b/python/packages/core/agent_framework/_workflows/_state.py @@ -32,6 +32,13 @@ def set(self, key: str, value: Any) -> None: The value will be visible to subsequent `get()` calls but won't be committed to the actual state until `commit()` is called. + + Note: + When multiple executors run concurrently within the same superstep, + each executor's writes go to the same pending buffer. The last write + for a given key wins when commit() is called. This is consistent with + the .NET behavior and the superstep execution model where all executors + in a superstep see the same committed state at the start. """ self._pending[key] = value diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index cfd48b436f..481d8db615 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -9,10 +9,9 @@ from opentelemetry.propagate import inject from opentelemetry.trace import SpanKind -from typing_extensions import Never, TypeVar, deprecated +from typing_extensions import Never, TypeVar from ..observability import OtelAttr, create_workflow_span -from ._const import EXECUTOR_STATE_KEY from ._events import ( RequestInfoEvent, WorkflowEvent, @@ -457,42 +456,6 @@ def get_yielded_outputs(self) -> list[Any]: """ return self._yielded_outputs.copy() - @deprecated( - "Override `on_checkpoint_save()` methods instead. " - "For cross-executor state sharing, use set_state() instead. " - "This API will be removed after 12/01/2025." - ) - async def set_executor_state(self, state: dict[str, Any]) -> None: - """Store executor state in state under a reserved key. - - Executors call this with a JSON-serializable dict capturing the minimal - state needed to resume. It replaces any previously stored state. - """ - existing_states = self._state.get(EXECUTOR_STATE_KEY, {}) - - if not isinstance(existing_states, dict): - raise ValueError("Existing executor states in state is not a dictionary.") - - existing_states[self._executor_id] = state - self._state.set(EXECUTOR_STATE_KEY, existing_states) - - @deprecated( - "Override `on_checkpoint_restore()` methods instead. " - "For cross-executor state sharing, use get_state() instead. " - "This API will be removed after 12/01/2025." - ) - async def get_executor_state(self) -> dict[str, Any] | None: - """Retrieve previously persisted state for this executor, if any.""" - has_existing_states = self._state.has(EXECUTOR_STATE_KEY) - if not has_existing_states: - return None - - existing_states = self._state.get(EXECUTOR_STATE_KEY) - if not isinstance(existing_states, dict): - raise ValueError("Existing executor states in state is not a dictionary.") - - return existing_states.get(self._executor_id) # type: ignore - def is_streaming(self) -> bool: """Check if the workflow is running in streaming mode. diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 7098d305d9..b3c97126c8 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -227,7 +227,7 @@ async def test_runner_cancellation_stops_active_executor(): executor_a.id: executor_a, executor_b.id: executor_b, } - shared_state = SharedState() + shared_state = State() ctx = InProcRunnerContext() runner = Runner(edges, executors, shared_state, ctx) From 6ca9874c8006bbe36aed5cf658742071b01c5504 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 5 Feb 2026 10:26:06 +0900 Subject: [PATCH 4/5] Remove dead links --- python/samples/getting_started/workflows/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index deb1597021..dc21829053 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -37,8 +37,6 @@ Once comfortable with these, explore the rest of the samples below. | Azure Chat Agents (Streaming) | [agents/azure_chat_agents_streaming.py](./agents/azure_chat_agents_streaming.py) | Add Azure Chat agents as edges and handle streaming events | | Azure AI Agents (Streaming) | [agents/azure_ai_agents_streaming.py](./agents/azure_ai_agents_streaming.py) | Add Azure AI agents as edges and handle streaming events | | Azure AI Agents (Shared Thread) | [agents/azure_ai_agents_with_shared_thread.py](./agents/azure_ai_agents_with_shared_thread.py) | Share a common message thread between multiple Azure AI agents in a workflow | -| Azure Chat Agents (Function Bridge) | [agents/azure_chat_agents_function_bridge.py](./agents/azure_chat_agents_function_bridge.py) | Chain two agents with a function executor that injects external context | -| Azure Chat Agents (Tools + HITL) | [agents/azure_chat_agents_tool_calls_with_feedback.py](./agents/azure_chat_agents_tool_calls_with_feedback.py) | Tool-enabled writer/editor pipeline with human feedback gating | | Custom Agent Executors | [agents/custom_agent_executors.py](./agents/custom_agent_executors.py) | Create executors to handle agent run methods | | Sequential Workflow as Agent | [agents/sequential_workflow_as_agent.py](./agents/sequential_workflow_as_agent.py) | Build a sequential workflow orchestrating agents, then expose it as a reusable agent | | Concurrent Workflow as Agent | [agents/concurrent_workflow_as_agent.py](./agents/concurrent_workflow_as_agent.py) | Build a concurrent fan-out/fan-in workflow, then expose it as a reusable agent | From c7394667b71d4ab6190e22f37bd937562226c1ee Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 5 Feb 2026 10:32:44 +0900 Subject: [PATCH 5/5] Fix lab test import --- python/packages/lab/lightning/tests/test_lightning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/lab/lightning/tests/test_lightning.py b/python/packages/lab/lightning/tests/test_lightning.py index c56adf2b20..ae64ec772f 100644 --- a/python/packages/lab/lightning/tests/test_lightning.py +++ b/python/packages/lab/lightning/tests/test_lightning.py @@ -9,7 +9,7 @@ agentlightning = pytest.importorskip("agentlightning") -from agent_framework import AgentExecutor, AgentRunEvent, ChatAgent, WorkflowBuilder, Workflow +from agent_framework import AgentExecutor, ChatAgent, WorkflowBuilder, Workflow, WorkflowOutputEvent from agent_framework_lab_lightning import AgentFrameworkTracer from agent_framework.openai import OpenAIChatClient from agentlightning import TracerTraceToTriplet @@ -109,8 +109,8 @@ def workflow_two_agents(): async def test_openai_workflow_two_agents(workflow_two_agents: Workflow): events = await workflow_two_agents.run("Please analyze the quarterly sales data") - # Get all AgentRunEvent data - agent_outputs = [event.data for event in events if isinstance(event, AgentRunEvent)] + # Get all WorkflowOutputEvent data + agent_outputs = [event.data for event in events if isinstance(event, WorkflowOutputEvent)] # Check that we have outputs from both agents assert len(agent_outputs) == 2