-
Notifications
You must be signed in to change notification settings - Fork 773
feat(swarm): add AgentBase protocol support #2002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
|
|
||
| from .._async import run_async | ||
| from ..agent import Agent | ||
| from ..agent.base import AgentBase | ||
| from ..agent.state import AgentState | ||
| from ..hooks.events import ( | ||
| AfterMultiAgentInvocationEvent, | ||
|
|
@@ -65,7 +66,7 @@ class SwarmNode: | |
| """Represents a node (e.g. Agent) in the swarm.""" | ||
|
|
||
| node_id: str | ||
| executor: Agent | ||
| executor: AgentBase | ||
| swarm: Optional["Swarm"] = None | ||
| _initial_messages: Messages = field(default_factory=list, init=False) | ||
| _initial_state: AgentState = field(default_factory=AgentState, init=False) | ||
|
|
@@ -74,9 +75,14 @@ class SwarmNode: | |
| def __post_init__(self) -> None: | ||
| """Capture initial executor state after initialization.""" | ||
| # Deep copy the initial messages and state to preserve them | ||
| self._initial_messages = copy.deepcopy(self.executor.messages) | ||
| self._initial_state = AgentState(self.executor.state.get()) | ||
| self._initial_model_state = copy.deepcopy(self.executor._model_state) | ||
| if hasattr(self.executor, "messages"): | ||
| self._initial_messages = copy.deepcopy(self.executor.messages) | ||
|
|
||
| if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"): | ||
| self._initial_state = AgentState(self.executor.state.get()) | ||
|
|
||
| if hasattr(self.executor, "_model_state"): | ||
| self._initial_model_state = copy.deepcopy(self.executor._model_state) | ||
|
|
||
| def __hash__(self) -> int: | ||
| """Return hash for SwarmNode based on node_id.""" | ||
|
|
@@ -101,17 +107,26 @@ def reset_executor_state(self) -> None: | |
|
|
||
| If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context. | ||
| """ | ||
| if self.swarm and self.swarm._interrupt_state.activated: | ||
| # Handle interrupt state restoration (Agent-specific) | ||
| if self.swarm and self.swarm._interrupt_state.activated and isinstance(self.executor, Agent): | ||
| if self.node_id not in self.swarm._interrupt_state.context: | ||
| return | ||
| context = self.swarm._interrupt_state.context[self.node_id] | ||
| self.executor.messages = context["messages"] | ||
| self.executor.state = AgentState(context["state"]) | ||
| self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) | ||
| self.executor._model_state = context.get("model_state", {}) | ||
| return | ||
|
|
||
| self.executor.messages = copy.deepcopy(self._initial_messages) | ||
| self.executor.state = AgentState(self._initial_state.get()) | ||
| self.executor._model_state = copy.deepcopy(self._initial_model_state) | ||
| # Reset to initial state (works with any AgentBase that has these attributes) | ||
| if hasattr(self.executor, "messages"): | ||
| self.executor.messages = copy.deepcopy(self._initial_messages) | ||
|
|
||
| if hasattr(self.executor, "state"): | ||
| self.executor.state = AgentState(self._initial_state.get()) | ||
|
|
||
| if hasattr(self.executor, "_model_state"): | ||
| self.executor._model_state = copy.deepcopy(self._initial_model_state) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -236,9 +251,9 @@ class Swarm(MultiAgentBase): | |
|
|
||
| def __init__( | ||
| self, | ||
| nodes: list[Agent], | ||
| nodes: list[AgentBase], | ||
| *, | ||
| entry_point: Agent | None = None, | ||
| entry_point: AgentBase | None = None, | ||
| max_handoffs: int = 20, | ||
| max_iterations: int = 20, | ||
| execution_timeout: float = 900.0, | ||
|
|
@@ -301,6 +316,7 @@ def __init__( | |
|
|
||
| self._resume_from_session = False | ||
|
|
||
| self._handoff_capable_nodes: set[str] = set() | ||
| self._setup_swarm(nodes) | ||
| self._inject_swarm_tools() | ||
| run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) | ||
|
|
@@ -462,33 +478,35 @@ async def _stream_with_timeout( | |
| except asyncio.TimeoutError as err: | ||
| raise Exception(timeout_message) from err | ||
|
|
||
| def _setup_swarm(self, nodes: list[Agent]) -> None: | ||
| def _setup_swarm(self, nodes: list[AgentBase]) -> None: | ||
| """Initialize swarm configuration.""" | ||
| # Validate nodes before setup | ||
| self._validate_swarm(nodes) | ||
|
|
||
| # Validate agents have names and create SwarmNode objects | ||
| for i, node in enumerate(nodes): | ||
| if not node.name: | ||
| # Only access name if it exists (AgentBase protocol doesn't guarantee it) | ||
| node_name = getattr(node, "name", None) | ||
| if not node_name: | ||
| node_id = f"node_{i}" | ||
| node.name = node_id | ||
| logger.debug("node_id=<%s> | agent has no name, dynamically generating one", node_id) | ||
|
|
||
| node_id = str(node.name) | ||
| logger.debug("node_id=<%s> | agent has no name, using generated id", node_id) | ||
| else: | ||
| node_id = str(node_name) | ||
|
|
||
| # Ensure node IDs are unique | ||
| if node_id in self.nodes: | ||
| raise ValueError(f"Node ID '{node_id}' is not unique. Each agent must have a unique name.") | ||
|
|
||
| self.nodes[node_id] = SwarmNode(node_id, node, swarm=self) | ||
|
|
||
| # Validate entry point if specified | ||
| # Validate entry point if specified (use identity-based lookup to handle nameless AgentBase) | ||
| if self.entry_point is not None: | ||
| entry_point_node_id = str(self.entry_point.name) | ||
| if ( | ||
| entry_point_node_id not in self.nodes | ||
| or self.nodes[entry_point_node_id].executor is not self.entry_point | ||
| ): | ||
| entry_node = None | ||
| for swarm_node in self.nodes.values(): | ||
| if swarm_node.executor is self.entry_point: | ||
| entry_node = swarm_node | ||
| break | ||
| if entry_node is None: | ||
| available_agents = [ | ||
| f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items() | ||
| ] | ||
|
|
@@ -504,7 +522,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None: | |
| first_node = next(iter(self.nodes.keys())) | ||
| logger.debug("entry_point=<%s> | using first node as entry point", first_node) | ||
|
|
||
| def _validate_swarm(self, nodes: list[Agent]) -> None: | ||
| def _validate_swarm(self, nodes: list[AgentBase]) -> None: | ||
| """Validate swarm structure and nodes.""" | ||
| # Check for duplicate object instances | ||
| seen_instances = set() | ||
|
|
@@ -513,18 +531,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None: | |
| raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.") | ||
| seen_instances.add(id(node)) | ||
|
|
||
| # Check for session persistence | ||
| if node._session_manager is not None: | ||
| # Check for session persistence (only Agent has _session_manager attribute) | ||
| if isinstance(node, Agent) and node._session_manager is not None: | ||
| raise ValueError("Session persistence is not supported for Swarm agents yet.") | ||
|
|
||
| def _inject_swarm_tools(self) -> None: | ||
| """Add swarm coordination tools to each agent.""" | ||
| """Add swarm coordination tools to each agent. | ||
|
|
||
| Note: Only Agent instances can receive swarm tools. AgentBase implementations | ||
| without tool_registry will not have handoff capabilities. | ||
| """ | ||
| # Create tool functions with proper closures | ||
| swarm_tools = [ | ||
| self._create_handoff_tool(), | ||
| ] | ||
|
|
||
| injected_count = 0 | ||
| for node in self.nodes.values(): | ||
| # Only Agent (not generic AgentBase) has tool_registry attribute | ||
| if not isinstance(node.executor, Agent): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Non-Agent nodes silently lose key Swarm capabilities with no user feedback. When an AgentBase (non-Agent) node is skipped for tool injection, the user isn't informed at Swarm creation time that these nodes won't be able to initiate handoffs. The debug log on line 549 is only visible at DEBUG level. A user might expect their custom AgentBase to participate in handoffs and be surprised when it doesn't. Suggestion:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed. |
||
| logger.debug( | ||
| "node_id=<%s> | skipping tool injection for non-Agent node", | ||
| node.node_id, | ||
| ) | ||
| continue | ||
|
|
||
| # Check for existing tools with conflicting names | ||
| existing_tools = node.executor.tool_registry.registry | ||
| conflicting_tools = [] | ||
|
|
@@ -540,11 +571,14 @@ def _inject_swarm_tools(self) -> None: | |
|
|
||
| # Use the agent's tool registry to process and register the tools | ||
| node.executor.tool_registry.process_tools(swarm_tools) | ||
| self._handoff_capable_nodes.add(node.node_id) | ||
| injected_count += 1 | ||
|
|
||
| logger.debug( | ||
| "tool_count=<%d>, node_count=<%d> | injected coordination tools into agents", | ||
| "tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools", | ||
| len(swarm_tools), | ||
| len(self.nodes), | ||
| injected_count, | ||
| ) | ||
|
|
||
| def _create_handoff_tool(self) -> Callable[..., Any]: | ||
|
|
@@ -673,10 +707,13 @@ def _build_node_input(self, target_node: SwarmNode) -> str: | |
| context_text += "\n" | ||
| context_text += "\n" | ||
|
|
||
| context_text += ( | ||
| "You have access to swarm coordination tools if you need help from other agents. " | ||
| "If you don't hand off to another agent, the swarm will consider the task complete." | ||
| ) | ||
| if target_node.node_id in self._handoff_capable_nodes: | ||
| context_text += ( | ||
| "You have access to swarm coordination tools if you need help from other agents. " | ||
| "If you don't hand off to another agent, the swarm will consider the task complete." | ||
| ) | ||
| else: | ||
| context_text += "If you complete your task, the swarm will consider the task complete." | ||
|
|
||
| return context_text | ||
|
|
||
|
|
@@ -696,13 +733,19 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M | |
| logger.debug("node=<%s> | node interrupted", node.node_id) | ||
| self.state.completion_status = Status.INTERRUPTED | ||
|
|
||
| # Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes | ||
| self._interrupt_state.context[node.node_id] = { | ||
| "activated": node.executor._interrupt_state.activated, | ||
| "interrupt_state": node.executor._interrupt_state.to_dict(), | ||
| "state": node.executor.state.get(), | ||
| "messages": node.executor.messages, | ||
| "model_state": node.executor._model_state, | ||
| "activated": isinstance(node.executor, Agent) and node.executor._interrupt_state.activated, | ||
| } | ||
| if isinstance(node.executor, Agent): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Silent degradation for interrupt state on non-Agent nodes. When an interrupt occurs on a non-Agent AgentBase node, no executor context is saved (the Suggestion: Either:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed two things:
|
||
| self._interrupt_state.context[node.node_id].update( | ||
| { | ||
| "interrupt_state": node.executor._interrupt_state.to_dict(), | ||
| "state": node.executor.state.get(), | ||
| "messages": node.executor.messages, | ||
| "model_state": node.executor._model_state, | ||
| } | ||
| ) | ||
|
|
||
| self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) | ||
| self._interrupt_state.activate() | ||
|
|
@@ -1042,5 +1085,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None: | |
|
|
||
| def _initial_node(self) -> SwarmNode: | ||
| if self.entry_point: | ||
| return self.nodes[str(self.entry_point.name)] | ||
| for node in self.nodes.values(): | ||
| if node.executor is self.entry_point: | ||
| return node | ||
| return next(iter(self.nodes.values())) # First SwarmNode | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do you determine handoffs for non-Agent AgentBase instances then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Looking at the implementation: non-Agent AgentBase instances currently cannot initiate handoffs because
_inject_swarm_tools(line 534-577) skips tool injection for nodes that aren'tisinstance(node.executor, Agent). They can only be handed to by Agent nodes that do have the handoff tool.This is a significant capability gap that should be clearly communicated. The
_build_node_inputmethod (line 706) also tells all nodes they have "access to swarm coordination tools" regardless of whether tools were injected, which could confuse LLM-backed AgentBase implementations. I've left a separate comment on that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They can't. Only native
Agentinstances get thehandoff_to_agenttool injected sinceAgentBasedoesn't guarantee atool_registry. AgentBase nodes can only be handed to by Agent nodes. This is intentional for the #1720 use case of integrating existing agents from other frameworks as handoff targets.I also updated the prompt text to be conditional now. Agent nodes see "You have access to swarm coordination tools..." while AgentBase nodes just see "If you complete your task, the swarm will consider the task complete." so the LLM isn't told about tools it doesn't have.