Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .base import AgentBase
from .conversation_manager import (
ConversationManager,
SlidingWindowConversationManager,
Expand All @@ -83,7 +84,7 @@ class _DefaultCallbackHandlerSentinel:
_DEFAULT_AGENT_ID = "default"


class Agent:
class Agent(AgentBase):
"""Core Agent implementation.

An agent orchestrates the following workflow:
Expand Down
38 changes: 21 additions & 17 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .._async import run_async
from ..agent import Agent
from ..agent.base import AgentBase
from ..agent.state import AgentState
from ..hooks.events import (
AfterMultiAgentInvocationEvent,
Expand Down Expand Up @@ -161,7 +162,7 @@ class GraphNode:
"""Represents a node in the graph."""

node_id: str
executor: Agent | MultiAgentBase
executor: AgentBase | MultiAgentBase
dependencies: set["GraphNode"] = field(default_factory=set)
execution_status: Status = Status.PENDING
result: NodeResult | None = None
Expand Down Expand Up @@ -206,7 +207,7 @@ def __eq__(self, other: Any) -> bool:


def _validate_node_executor(
executor: Agent | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None
executor: AgentBase | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None
) -> None:
"""Validate a node executor for graph compatibility.

Expand Down Expand Up @@ -245,8 +246,8 @@ def __init__(self) -> None:
self._session_manager: SessionManager | None = None
self._hooks: list[HookProvider] | None = None

def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode:
"""Add an AgentBase or MultiAgentBase instance as a node to the graph."""
_validate_node_executor(executor, self.nodes)

# Auto-generate node_id if not provided
Expand Down Expand Up @@ -864,9 +865,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
logger.debug("node_id=<%s> | executing node", node.node_id)

# Emit node start event
start_event = MultiAgentNodeStartEvent(
node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent"
)
node_type = "multiagent" if isinstance(node.executor, MultiAgentBase) else "agent"
start_event = MultiAgentNodeStartEvent(node_id=node.node_id, node_type=node_type)
yield start_event

before_event, interrupts = await self.hooks.invoke_callbacks_async(
Expand Down Expand Up @@ -916,8 +916,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
interrupts=multi_agent_result.interrupts,
)

elif isinstance(node.executor, Agent):
# For agents, stream their events and collect result
elif isinstance(node.executor, AgentBase):
# For AgentBase implementations (Agent, A2AAgent, etc.), stream events and collect result
agent_response = None
async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
# Forward agent events with node context
Expand All @@ -938,14 +938,18 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
)
metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=0))

# Handle stop_reason and interrupts (use getattr for AgentBase compatibility)
stop_reason = getattr(agent_response, "stop_reason", "end_turn")
interrupts = getattr(agent_response, "interrupts", None) or []

node_result = NodeResult(
result=agent_response,
execution_time=round((time.time() - start_time) * 1000),
status=Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED,
status=Status.INTERRUPTED if stop_reason == "interrupt" else Status.COMPLETED,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=1,
interrupts=agent_response.interrupts or [],
interrupts=interrupts,
)
else:
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")
Expand Down Expand Up @@ -1056,13 +1060,13 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"]
]

if isinstance(node.executor, MultiAgentBase):
return node_responses
# Restore Agent-specific state for interrupt resumption
# Only Agent (not generic AgentBase) supports interrupt state restoration
if isinstance(node.executor, Agent):
node.executor.messages = node_context["messages"]
node.executor.state = AgentState(node_context["state"])
node.executor._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"])

agent = node.executor
agent.messages = node_context["messages"]
agent.state = AgentState(node_context["state"])
agent._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"])
return node_responses

# Get satisfied dependencies
Expand Down
59 changes: 55 additions & 4 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from strands.agent import Agent, AgentResult
from strands.agent import Agent, AgentBase, AgentResult
from strands.agent.state import AgentState
from strands.hooks import AgentInitializedEvent, BeforeNodeCallEvent
from strands.hooks.registry import HookProvider, HookRegistry
Expand Down Expand Up @@ -1103,9 +1103,6 @@ async def test_state_reset_only_with_cycles_enabled():
# Create GraphNode
node = GraphNode("test_node", agent)

# Simulate agent being in completed_nodes (as if revisited)
from strands.multiagent.graph import GraphState

state = GraphState()
state.completed_nodes.add(node)

Expand Down Expand Up @@ -2354,3 +2351,57 @@ def test_graph_interrupt_on_multiagent(agenerator):
assert len(multiagent_result.results) == 1

multiagent.stream_async.assert_called_once_with(responses, {})


@pytest.mark.asyncio
async def test_graph_with_agentbase_implementation(mock_strands_tracer, mock_use_span):
"""Test that Graph accepts any AgentBase implementation (not just Agent)."""

# Create a minimal AgentBase implementation
class CustomAgentBase:
"""Custom AgentBase implementation for testing."""

def __init__(self, name: str, response_text: str):
self.name = name
self.id = f"{name}_id"
self._response_text = response_text

def __call__(self, prompt=None, **kwargs):
return AgentResult(
message={"role": "assistant", "content": [{"text": self._response_text}]},
stop_reason="end_turn",
state={},
metrics=Mock(
accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
accumulated_metrics={"latencyMs": 100.0},
),
)

async def invoke_async(self, prompt=None, **kwargs):
return self(prompt, **kwargs)

async def stream_async(self, prompt=None, **kwargs):
yield {"start": True}
yield {"result": self(prompt, **kwargs)}

# Verify it satisfies AgentBase protocol
custom_agent = CustomAgentBase("custom", "Custom response")
assert isinstance(custom_agent, AgentBase)

# Create a regular mock agent
regular_agent = create_mock_agent("regular", "Regular response")

# Build graph with both
builder = GraphBuilder()
builder.add_node(custom_agent, "custom_node")
builder.add_node(regular_agent, "regular_node")
builder.add_edge("custom_node", "regular_node")
builder.set_entry_point("custom_node")
graph = builder.build()

result = await graph.invoke_async("Test task")

assert result.status == Status.COMPLETED
assert result.completed_nodes == 2
assert "custom_node" in result.results
assert "regular_node" in result.results
32 changes: 32 additions & 0 deletions tests_integ/a2a/test_multiagent_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import pytest
from a2a.client import ClientConfig, ClientFactory

from strands import Agent
from strands.agent.a2a_agent import A2AAgent
from strands.multiagent.graph import GraphBuilder, Status


@pytest.fixture
Expand Down Expand Up @@ -70,3 +72,33 @@ async def test_a2a_agent_with_non_streaming_client_config(a2a_server):
assert result.stop_reason == "end_turn"
finally:
await httpx_client.aclose()


@pytest.mark.asyncio
async def test_graph_with_a2a_agent_and_regular_agent(a2a_server):
"""Test Graph execution with both A2AAgent and regular Agent nodes."""
# Create A2AAgent pointing to the test server
a2a_agent = A2AAgent(endpoint=a2a_server, name="remote_agent")

# Create a regular Agent
regular_agent = Agent(
model="us.amazon.nova-lite-v1:0",
system_prompt="You are a summarizer. Summarize the input briefly.",
name="summarizer",
)

# Build graph with both agent types
builder = GraphBuilder()
builder.add_node(a2a_agent, "remote")
builder.add_node(regular_agent, "summarizer")
builder.add_edge("remote", "summarizer")
builder.set_entry_point("remote")
graph = builder.build()

# Execute the graph
result = await graph.invoke_async("Say hello in one sentence")

assert result.status == Status.COMPLETED
assert result.completed_nodes == 2
assert "remote" in result.results
assert "summarizer" in result.results
Loading