Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def execute(

# Invoke the handler with the message and context
with _framework_event_origin():
invoke_event = ExecutorInvokedEvent(self.id)
invoke_event = ExecutorInvokedEvent(self.id, message)
await context.add_event(invoke_event)
try:
await handler(message, context)
Expand All @@ -275,7 +275,9 @@ async def execute(
await context.add_event(failure_event)
raise
with _framework_event_origin():
completed_event = ExecutorCompletedEvent(self.id)
# Include sent messages as the completion data
sent_messages = context.get_sent_messages()
completed_event = ExecutorCompletedEvent(self.id, sent_messages if sent_messages else None)
await context.add_event(completed_event)

def _create_context_for_handler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def __init__(
self._runner_context = runner_context
self._shared_state = shared_state

# Track messages sent via send_message() for ExecutorCompletedEvent
self._sent_messages: list[Any] = []

# Store trace contexts and source span IDs for linking (supporting multiple sources)
self._trace_contexts = trace_contexts or []
self._source_span_ids = source_span_ids or []
Expand All @@ -313,6 +316,9 @@ async def send_message(self, message: T_Out, target_id: str | None = None) -> No
# Create Message wrapper
msg = Message(data=message, source_id=self._executor_id, target_id=target_id)

# Track sent message for ExecutorCompletedEvent
self._sent_messages.append(message)

# Inject current trace context if tracing enabled
if OBSERVABILITY_SETTINGS.ENABLED and span and span.is_recording(): # type: ignore[name-defined]
trace_context: dict[str, str] = {}
Expand Down Expand Up @@ -410,6 +416,14 @@ def shared_state(self) -> SharedState:
"""Get the shared state."""
return self._shared_state

def get_sent_messages(self) -> list[Any]:
"""Get all messages sent via send_message() during this handler execution.

Returns:
A list of messages that were sent to downstream executors.
"""
return self._sent_messages.copy()

@deprecated(
"Override `on_checkpoint_save()` methods instead. "
"For cross-executor state sharing, use set_shared_state() instead. "
Expand Down Expand Up @@ -448,7 +462,7 @@ async def get_executor_state(self) -> dict[str, Any] | None:
if not isinstance(existing_states, dict):
raise ValueError("Existing executor states in shared state is not a dictionary.")

return existing_states.get(self._executor_id)
return existing_states.get(self._executor_id) # type: ignore

def is_streaming(self) -> bool:
"""Check if the workflow is running in streaming mode.
Expand Down
162 changes: 161 additions & 1 deletion python/packages/core/tests/workflow/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import pytest

from agent_framework import Executor, Message, WorkflowContext, handler
from agent_framework import (
Executor,
ExecutorCompletedEvent,
ExecutorInvokedEvent,
Message,
WorkflowBuilder,
WorkflowContext,
handler,
)


def test_executor_without_id():
Expand Down Expand Up @@ -101,3 +109,155 @@ async def handle_integer(self, number: int, ctx: WorkflowContext[int]) -> None:
assert int_handler._handler_spec["name"] == "handle_integer" # type: ignore
assert int_handler._handler_spec["message_type"] is int # type: ignore
assert int_handler._handler_spec["output_types"] == [int] # type: ignore


async def test_executor_invoked_event_contains_input_data():
"""Test that ExecutorInvokedEvent contains the input message data."""

class UpperCaseExecutor(Executor):
@handler
async def handle(self, text: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(text.upper())

class CollectorExecutor(Executor):
@handler
async def handle(self, text: str, ctx: WorkflowContext) -> None:
pass

upper = UpperCaseExecutor(id="upper")
collector = CollectorExecutor(id="collector")

workflow = WorkflowBuilder().add_edge(upper, collector).set_start_executor(upper).build()

events = await workflow.run("hello world")
invoked_events = [e for e in events if isinstance(e, ExecutorInvokedEvent)]

assert len(invoked_events) == 2

# First invoked event should be for 'upper' executor with input "hello world"
upper_invoked = next(e for e in invoked_events if e.executor_id == "upper")
assert upper_invoked.data == "hello world"

# Second invoked event should be for 'collector' executor with input "HELLO WORLD"
collector_invoked = next(e for e in invoked_events if e.executor_id == "collector")
assert collector_invoked.data == "HELLO WORLD"


async def test_executor_completed_event_contains_sent_messages():
"""Test that ExecutorCompletedEvent contains the messages sent via ctx.send_message()."""

class MultiSenderExecutor(Executor):
@handler
async def handle(self, text: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(f"{text}-first")
await ctx.send_message(f"{text}-second")

class CollectorExecutor(Executor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.received: list[str] = []

@handler
async def handle(self, text: str, ctx: WorkflowContext) -> None:
self.received.append(text)

sender = MultiSenderExecutor(id="sender")
collector = CollectorExecutor(id="collector")

workflow = WorkflowBuilder().add_edge(sender, collector).set_start_executor(sender).build()

events = await workflow.run("hello")
completed_events = [e for e in events if isinstance(e, ExecutorCompletedEvent)]

# Sender should have completed with the sent messages
sender_completed = next(e for e in completed_events if e.executor_id == "sender")
assert sender_completed.data is not None
assert sender_completed.data == ["hello-first", "hello-second"]

# Collector should have completed with no sent messages (None)
collector_completed_events = [e for e in completed_events if e.executor_id == "collector"]
# Collector is called twice (once per message from sender)
assert len(collector_completed_events) == 2
for collector_completed in collector_completed_events:
assert collector_completed.data is None


async def test_executor_completed_event_none_when_no_messages_sent():
"""Test that ExecutorCompletedEvent.data is None when no messages are sent."""
from typing_extensions import Never

from agent_framework import WorkflowOutputEvent

class YieldOnlyExecutor(Executor):
@handler
async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(text.upper())

executor = YieldOnlyExecutor(id="yielder")
workflow = WorkflowBuilder().set_start_executor(executor).build()

events = await workflow.run("test")
completed_events = [e for e in events if isinstance(e, ExecutorCompletedEvent)]

assert len(completed_events) == 1
assert completed_events[0].executor_id == "yielder"
assert completed_events[0].data is None

# Verify the output was still yielded correctly
output_events = [e for e in events if isinstance(e, WorkflowOutputEvent)]
assert len(output_events) == 1
assert output_events[0].data == "TEST"


async def test_executor_events_with_complex_message_types():
"""Test that executor events correctly capture complex message types."""
from dataclasses import dataclass

@dataclass
class Request:
query: str
limit: int

@dataclass
class Response:
results: list[str]

class ProcessorExecutor(Executor):
@handler
async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None:
response = Response(results=[request.query.upper()] * request.limit)
await ctx.send_message(response)

class CollectorExecutor(Executor):
@handler
async def handle(self, response: Response, ctx: WorkflowContext) -> None:
pass

processor = ProcessorExecutor(id="processor")
collector = CollectorExecutor(id="collector")

workflow = WorkflowBuilder().add_edge(processor, collector).set_start_executor(processor).build()

input_request = Request(query="hello", limit=3)
events = await workflow.run(input_request)

invoked_events = [e for e in events if isinstance(e, ExecutorInvokedEvent)]
completed_events = [e for e in events if isinstance(e, ExecutorCompletedEvent)]

# Check processor invoked event has the Request object
processor_invoked = next(e for e in invoked_events if e.executor_id == "processor")
assert isinstance(processor_invoked.data, Request)
assert processor_invoked.data.query == "hello"
assert processor_invoked.data.limit == 3

# Check processor completed event has the Response object
processor_completed = next(e for e in completed_events if e.executor_id == "processor")
assert processor_completed.data is not None
assert len(processor_completed.data) == 1
assert isinstance(processor_completed.data[0], Response)
assert processor_completed.data[0].results == ["HELLO", "HELLO", "HELLO"]

# Check collector invoked event has the Response object
collector_invoked = next(e for e in invoked_events if e.executor_id == "collector")
assert isinstance(collector_invoked.data, Response)
assert collector_invoked.data.results == ["HELLO", "HELLO", "HELLO"]
7 changes: 4 additions & 3 deletions python/packages/core/tests/workflow/test_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
WorkflowOutputEvent,
)
from agent_framework._mcp import MCPTool
from agent_framework._workflows import AgentRunEvent
from agent_framework._workflows import _handoff as handoff_module # type: ignore
from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage]
from agent_framework._workflows._workflow_builder import WorkflowBuilder
Expand Down Expand Up @@ -224,12 +225,12 @@ async def test_handoff_preserves_complex_additional_properties(complex_metadata:

# Initial run should preserve complex metadata in the triage response
events = await _drain(workflow.run_stream("Need help with a return"))
agent_events = [ev for ev in events if hasattr(ev, "data") and hasattr(ev.data, "messages")]
agent_events = [ev for ev in events if isinstance(ev, AgentRunEvent)]
if agent_events:
first_agent_event = agent_events[0]
first_agent_event_data = first_agent_event.data
if first_agent_event_data and hasattr(first_agent_event_data, "messages"):
first_agent_message = first_agent_event_data.messages[0] # type: ignore[attr-defined]
if first_agent_event_data and first_agent_event_data.messages:
first_agent_message = first_agent_event_data.messages[0]
assert "complex" in first_agent_message.additional_properties, "Agent event lost complex metadata"
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests, "Workflow should request additional user input"
Expand Down
6 changes: 5 additions & 1 deletion python/samples/getting_started/workflows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ Once comfortable with these, explore the rest of the samples below.

### observability

For observability samples in Agent Framework, see the [observability getting started samples](../observability/README.md). The [sample](../observability/workflow_observability.py) demonstrates integrating observability into workflows.
| Sample | File | Concepts |
|---|---|---|
| Executor I/O Observation | [observability/executor_io_observation.py](./observability/executor_io_observation.py) | Observe executor input/output data via ExecutorInvokedEvent and ExecutorCompletedEvent without modifying executor code |

For additional observability samples in Agent Framework, see the [observability getting started samples](../observability/README.md). The [sample](../observability/workflow_observability.py) demonstrates integrating observability into workflows.

### orchestration

Expand Down
Loading
Loading