Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,17 +603,20 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent:
def _activate_interrupt(
self, node: GraphNode, interrupts: list[Interrupt], from_hook: bool = False
) -> MultiAgentNodeInterruptEvent:
"""Activate the interrupt state.

Args:
node: The interrupted node.
interrupts: The interrupts raised by the user.
from_hook: Whether the interrupt originated from a hook (e.g., BeforeNodeCallEvent).

Returns:
MultiAgentNodeInterruptEvent
"""
logger.debug("node=<%s> | node interrupted", node.node_id)
logger.debug("node=<%s>, from_hook=<%s> | node interrupted", node.node_id, from_hook)

node.execution_status = Status.INTERRUPTED

Expand All @@ -622,13 +625,20 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()

self._interrupt_state.context[node.node_id] = {
"from_hook": from_hook,
"interrupt_ids": [interrupt.id for interrupt in interrupts],
}

if isinstance(node.executor, Agent):
self._interrupt_state.context[node.node_id] = {
"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}
self._interrupt_state.context[node.node_id].update(
{
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}
)

return MultiAgentNodeInterruptEvent(node.node_id, interrupts)

Expand Down Expand Up @@ -866,7 +876,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
start_time = time.time()
try:
if interrupts:
yield self._activate_interrupt(node, interrupts)
yield self._activate_interrupt(node, interrupts, from_hook=True)
return

if before_event.cancel_node:
Expand Down Expand Up @@ -896,20 +906,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if multi_agent_result is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

if multi_agent_result.status == Status.INTERRUPTED:
raise NotImplementedError(
f"node_id=<{node.node_id}>, "
"issue=<https://github.com/strands-agents/sdk-python/issues/204> "
"| user raised interrupt from a multi agent node"
)

node_result = NodeResult(
result=multi_agent_result,
execution_time=multi_agent_result.execution_time,
status=Status.COMPLETED,
status=multi_agent_result.status,
accumulated_usage=multi_agent_result.accumulated_usage,
accumulated_metrics=multi_agent_result.accumulated_metrics,
execution_count=multi_agent_result.execution_count,
interrupts=multi_agent_result.interrupts,
)

elif isinstance(node.executor, Agent):
Expand Down Expand Up @@ -1040,18 +1044,26 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
"""
if self._interrupt_state.activated:
context = self._interrupt_state.context
if node.node_id in context and context[node.node_id]["activated"]:
agent_context = context[node.node_id]
agent = cast(Agent, node.executor)
agent.messages = agent_context["messages"]
agent.state = AgentState(agent_context["state"])
agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"])

responses = context["responses"]
interrupts = agent._interrupt_state.interrupts
return [
response for response in responses if response["interruptResponse"]["interruptId"] in interrupts
]
if node.node_id in context:
node_context = context[node.node_id]

# Only route responses if the interrupt originated from the node's execution
if not node_context["from_hook"]:
# Filter responses to only those for this node's interrupts
node_responses = [
response
for response in context["responses"]
if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"]
]

if isinstance(node.executor, MultiAgentBase):
return node_responses

agent = node.executor
agent.messages = node_context["messages"]
agent.state = AgentState(node_context["state"])
agent._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"])
return node_responses

# Get satisfied dependencies
dependency_results = {}
Expand Down
97 changes: 96 additions & 1 deletion tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,7 +2228,8 @@ def test_graph_interrupt_on_agent(agenerator):
],
)
graph._interrupt_state.context["test_agent"] = {
"activated": True,
"from_hook": False,
"interrupt_ids": [interrupt.id],
"interrupt_state": {
"activated": True,
"context": {},
Expand Down Expand Up @@ -2259,3 +2260,97 @@ def test_graph_interrupt_on_agent(agenerator):
assert len(multiagent_result.results) == 1

agent.stream_async.assert_called_once_with(responses, invocation_state={})


def test_graph_interrupt_on_multiagent(agenerator):
exp_interrupts = [
Interrupt(
id="test_id",
name="test_name",
reason="test_reason",
)
]

multiagent = create_mock_multi_agent("test_multiagent", "Multi-agent completed")
multiagent.stream_async = Mock()
multiagent.stream_async.return_value = agenerator(
[
{
"result": MultiAgentResult(
results={},
status=Status.INTERRUPTED,
interrupts=exp_interrupts,
),
},
],
)

builder = GraphBuilder()
builder.add_node(multiagent, "test_multiagent")
graph = builder.build()

multiagent_result = graph("Test task")

tru_result_status = multiagent_result.status
exp_result_status = Status.INTERRUPTED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.INTERRUPTED
assert tru_state_status == exp_state_status

tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes]
exp_node_ids = ["test_multiagent"]
assert tru_node_ids == exp_node_ids

tru_interrupts = multiagent_result.interrupts
assert tru_interrupts == exp_interrupts

interrupt = multiagent_result.interrupts[0]

multiagent.stream_async = Mock()
multiagent.stream_async.return_value = agenerator(
[
{
"result": MultiAgentResult(
results={
"inner_node": NodeResult(
result=AgentResult(
message={"role": "assistant", "content": [{"text": "Inner completed"}]},
stop_reason="end_turn",
state={},
metrics={},
)
)
},
status=Status.COMPLETED,
),
},
],
)
graph._interrupt_state.context["test_multiagent"] = {
"from_hook": False,
"interrupt_ids": [interrupt.id],
}

responses = [
{
"interruptResponse": {
"interruptId": interrupt.id,
"response": "test_response",
},
},
]
multiagent_result = graph(responses)

tru_result_status = multiagent_result.status
exp_result_status = Status.COMPLETED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.COMPLETED
assert tru_state_status == exp_state_status

assert len(multiagent_result.results) == 1

multiagent.stream_async.assert_called_once_with(responses, {})
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def swarm(weather_agent):


@pytest.fixture
def graph(info_agent, day_agent, time_agent, weather_agent):
def graph(info_agent, day_agent, time_agent, swarm):
builder = GraphBuilder()

builder.add_node(info_agent, "info")
builder.add_node(day_agent, "day")
builder.add_node(time_agent, "time")
builder.add_node(weather_agent, "weather")
builder.add_node(swarm, "weather")

builder.add_edge("info", "day")
builder.add_edge("info", "time")
Expand All @@ -82,7 +82,7 @@ def graph(info_agent, day_agent, time_agent, weather_agent):
return builder.build()


def test_swarm_interrupt_agent(swarm):
def test_swarm_interrupt_node(swarm):
multiagent_result = swarm("What is the weather?")

tru_status = multiagent_result.status
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_swarm_interrupt_agent(swarm):
assert "sunny" in weather_message


def test_graph_interrupt_agent(graph):
def test_graph_interrupt_node(graph):
multiagent_result = graph("What is the day, time, and weather?")

tru_result_status = multiagent_result.status
Expand Down Expand Up @@ -180,7 +180,9 @@ def test_graph_interrupt_agent(graph):

day_message = json.dumps(multiagent_result.results["day"].result.message).lower()
time_message = json.dumps(multiagent_result.results["time"].result.message).lower()
weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower()
assert "monday" in day_message
assert "12:01" in time_message

nested_multiagent_result = multiagent_result.results["weather"].result
weather_message = json.dumps(nested_multiagent_result.results["weather"].result.message).lower()
assert "sunny" in weather_message
28 changes: 22 additions & 6 deletions tests_integ/interrupts/multiagent/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,23 @@ def test_swarm_interrupt_session(weather_tool, tmpdir):


def test_graph_interrupt_session(weather_tool, tmpdir):
parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent")
child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child")

weather_agent = Agent(name="weather", tools=[weather_tool])
summarizer_agent = Agent(name="summarizer")
session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir)

weather_builder = GraphBuilder()
weather_builder.add_node(weather_agent, "weather")
weather_builder.set_entry_point("weather")
weather_builder.set_session_manager(child_sm)
weather_graph = weather_builder.build()

builder = GraphBuilder()
builder.add_node(weather_agent, "weather")
builder.add_node(weather_graph, "weather")
builder.add_node(summarizer_agent, "summarizer")
builder.add_edge("weather", "summarizer")
builder.set_session_manager(session_manager)
builder.set_session_manager(parent_sm)
graph = builder.build()

multiagent_result = graph("Can you check the weather and then summarize the results?")
Expand All @@ -105,15 +113,23 @@ def test_graph_interrupt_session(weather_tool, tmpdir):

interrupt = multiagent_result.interrupts[0]

parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent")
child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child")

weather_agent = Agent(name="weather", tools=[weather_tool])
summarizer_agent = Agent(name="summarizer")
session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir)

weather_builder = GraphBuilder()
weather_builder.add_node(weather_agent, "weather")
weather_builder.set_entry_point("weather")
weather_builder.set_session_manager(child_sm)
weather_graph = weather_builder.build()

builder = GraphBuilder()
builder.add_node(weather_agent, "weather")
builder.add_node(weather_graph, "weather")
builder.add_node(summarizer_agent, "summarizer")
builder.add_edge("weather", "summarizer")
builder.set_session_manager(session_manager)
builder.set_session_manager(parent_sm)
graph = builder.build()

responses = [
Expand Down
Loading