From ef8a93cf82d0c23104ad1eaf8c2245f026c346cd Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 30 Jan 2026 18:07:55 -0500 Subject: [PATCH] interrupts - graph - multiagent nodes --- src/strands/multiagent/graph.py | 70 +++++++------ tests/strands/multiagent/test_graph.py | 97 ++++++++++++++++++- .../{test_agent.py => test_node.py} | 12 ++- .../interrupts/multiagent/test_session.py | 28 ++++-- 4 files changed, 166 insertions(+), 41 deletions(-) rename tests_integ/interrupts/multiagent/{test_agent.py => test_node.py} (93%) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index d296753c0..6b135d1a7 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -603,17 +603,20 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + def _activate_interrupt( + self, node: GraphNode, interrupts: list[Interrupt], from_hook: bool = False + ) -> MultiAgentNodeInterruptEvent: """Activate the interrupt state. Args: node: The interrupted node. interrupts: The interrupts raised by the user. + from_hook: Whether the interrupt originated from a hook (e.g., BeforeNodeCallEvent). Returns: MultiAgentNodeInterruptEvent """ - logger.debug("node=<%s> | node interrupted", node.node_id) + logger.debug("node=<%s>, from_hook=<%s> | node interrupted", node.node_id, from_hook) node.execution_status = Status.INTERRUPTED @@ -622,13 +625,20 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) self._interrupt_state.activate() + + self._interrupt_state.context[node.node_id] = { + "from_hook": from_hook, + "interrupt_ids": [interrupt.id for interrupt in interrupts], + } + if isinstance(node.executor, Agent): - self._interrupt_state.context[node.node_id] = { - "activated": node.executor._interrupt_state.activated, - "interrupt_state": node.executor._interrupt_state.to_dict(), - "state": node.executor.state.get(), - "messages": node.executor.messages, - } + self._interrupt_state.context[node.node_id].update( + { + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } + ) return MultiAgentNodeInterruptEvent(node.node_id, interrupts) @@ -866,7 +876,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) start_time = time.time() try: if interrupts: - yield self._activate_interrupt(node, interrupts) + yield self._activate_interrupt(node, interrupts, from_hook=True) return if before_event.cancel_node: @@ -896,20 +906,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - if multi_agent_result.status == Status.INTERRUPTED: - raise NotImplementedError( - f"node_id=<{node.node_id}>, " - "issue= " - "| user raised interrupt from a multi agent node" - ) - node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, + status=multi_agent_result.status, accumulated_usage=multi_agent_result.accumulated_usage, accumulated_metrics=multi_agent_result.accumulated_metrics, execution_count=multi_agent_result.execution_count, + interrupts=multi_agent_result.interrupts, ) elif isinstance(node.executor, Agent): @@ -1040,18 +1044,26 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: """ if self._interrupt_state.activated: context = self._interrupt_state.context - if node.node_id in context and context[node.node_id]["activated"]: - agent_context = context[node.node_id] - agent = cast(Agent, node.executor) - agent.messages = agent_context["messages"] - agent.state = AgentState(agent_context["state"]) - agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"]) - - responses = context["responses"] - interrupts = agent._interrupt_state.interrupts - return [ - response for response in responses if response["interruptResponse"]["interruptId"] in interrupts - ] + if node.node_id in context: + node_context = context[node.node_id] + + # Only route responses if the interrupt originated from the node's execution + if not node_context["from_hook"]: + # Filter responses to only those for this node's interrupts + node_responses = [ + response + for response in context["responses"] + if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"] + ] + + if isinstance(node.executor, MultiAgentBase): + return node_responses + + agent = node.executor + agent.messages = node_context["messages"] + agent.state = AgentState(node_context["state"]) + agent._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) + return node_responses # Get satisfied dependencies dependency_results = {} diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c511328d4..0fbb102a4 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2228,7 +2228,8 @@ def test_graph_interrupt_on_agent(agenerator): ], ) graph._interrupt_state.context["test_agent"] = { - "activated": True, + "from_hook": False, + "interrupt_ids": [interrupt.id], "interrupt_state": { "activated": True, "context": {}, @@ -2259,3 +2260,97 @@ def test_graph_interrupt_on_agent(agenerator): assert len(multiagent_result.results) == 1 agent.stream_async.assert_called_once_with(responses, invocation_state={}) + + +def test_graph_interrupt_on_multiagent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + multiagent = create_mock_multi_agent("test_multiagent", "Multi-agent completed") + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={}, + status=Status.INTERRUPTED, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(multiagent, "test_multiagent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_multiagent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={ + "inner_node": NodeResult( + result=AgentResult( + message={"role": "assistant", "content": [{"text": "Inner completed"}]}, + stop_reason="end_turn", + state={}, + metrics={}, + ) + ) + }, + status=Status.COMPLETED, + ), + }, + ], + ) + graph._interrupt_state.context["test_multiagent"] = { + "from_hook": False, + "interrupt_ids": [interrupt.id], + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + multiagent.stream_async.assert_called_once_with(responses, {}) diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_node.py similarity index 93% rename from tests_integ/interrupts/multiagent/test_agent.py rename to tests_integ/interrupts/multiagent/test_node.py index 1a6ad87c6..23e7a62bc 100644 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ b/tests_integ/interrupts/multiagent/test_node.py @@ -65,13 +65,13 @@ def swarm(weather_agent): @pytest.fixture -def graph(info_agent, day_agent, time_agent, weather_agent): +def graph(info_agent, day_agent, time_agent, swarm): builder = GraphBuilder() builder.add_node(info_agent, "info") builder.add_node(day_agent, "day") builder.add_node(time_agent, "time") - builder.add_node(weather_agent, "weather") + builder.add_node(swarm, "weather") builder.add_edge("info", "day") builder.add_edge("info", "time") @@ -82,7 +82,7 @@ def graph(info_agent, day_agent, time_agent, weather_agent): return builder.build() -def test_swarm_interrupt_agent(swarm): +def test_swarm_interrupt_node(swarm): multiagent_result = swarm("What is the weather?") tru_status = multiagent_result.status @@ -122,7 +122,7 @@ def test_swarm_interrupt_agent(swarm): assert "sunny" in weather_message -def test_graph_interrupt_agent(graph): +def test_graph_interrupt_node(graph): multiagent_result = graph("What is the day, time, and weather?") tru_result_status = multiagent_result.status @@ -180,7 +180,9 @@ def test_graph_interrupt_agent(graph): day_message = json.dumps(multiagent_result.results["day"].result.message).lower() time_message = json.dumps(multiagent_result.results["time"].result.message).lower() - weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() assert "monday" in day_message assert "12:01" in time_message + + nested_multiagent_result = multiagent_result.results["weather"].result + weather_message = json.dumps(nested_multiagent_result.results["weather"].result.message).lower() assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index 96b9844bf..8a5979d63 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -72,15 +72,23 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): def test_graph_interrupt_session(weather_tool, tmpdir): + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() builder = GraphBuilder() - builder.add_node(weather_agent, "weather") + builder.add_node(weather_graph, "weather") builder.add_node(summarizer_agent, "summarizer") builder.add_edge("weather", "summarizer") - builder.set_session_manager(session_manager) + builder.set_session_manager(parent_sm) graph = builder.build() multiagent_result = graph("Can you check the weather and then summarize the results?") @@ -105,15 +113,23 @@ def test_graph_interrupt_session(weather_tool, tmpdir): interrupt = multiagent_result.interrupts[0] + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() builder = GraphBuilder() - builder.add_node(weather_agent, "weather") + builder.add_node(weather_graph, "weather") builder.add_node(summarizer_agent, "summarizer") builder.add_edge("weather", "summarizer") - builder.set_session_manager(session_manager) + builder.set_session_manager(parent_sm) graph = builder.build() responses = [