|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import enum |
4 | 5 | import inspect |
5 | 6 | import json |
6 | 7 | import weakref |
|
48 | 49 | if TYPE_CHECKING: |
49 | 50 | from .agent import Agent, AgentBase |
50 | 51 | from .items import RunItem, ToolApprovalItem |
| 52 | + from .mcp.server import MCPServer |
51 | 53 |
|
52 | 54 |
|
53 | 55 | ToolParams = ParamSpec("ToolParams") |
@@ -182,6 +184,59 @@ class ComputerProvider(Generic[ComputerT]): |
182 | 184 | ] |
183 | 185 |
|
184 | 186 |
|
| 187 | +class ToolOriginType(str, enum.Enum): |
| 188 | + """The type of tool origin.""" |
| 189 | + |
| 190 | + FUNCTION = "function" |
| 191 | + """Regular Python function tool created via @function_tool decorator.""" |
| 192 | + |
| 193 | + MCP = "mcp" |
| 194 | + """MCP server tool converted via MCPUtil.to_function_tool().""" |
| 195 | + |
| 196 | + AGENT_AS_TOOL = "agent_as_tool" |
| 197 | + """Agent converted to tool via agent.as_tool().""" |
| 198 | + |
| 199 | + |
| 200 | +@dataclass |
| 201 | +class ToolOrigin: |
| 202 | + """Information about the origin/source of a function tool.""" |
| 203 | + |
| 204 | + type: ToolOriginType |
| 205 | + """The type of tool origin.""" |
| 206 | + |
| 207 | + mcp_server: MCPServer | None = None |
| 208 | + """The MCP server object. Only set when type is MCP.""" |
| 209 | + |
| 210 | + agent_as_tool: Agent[Any] | None = None |
| 211 | + """The agent object. Only set when type is AGENT_AS_TOOL.""" |
| 212 | + |
| 213 | + def __repr__(self) -> str: |
| 214 | + """Custom repr that only includes relevant fields.""" |
| 215 | + parts = [f"type={self.type.value!r}"] |
| 216 | + if self.mcp_server is not None: |
| 217 | + parts.append(f"mcp_server_name={self.mcp_server.name!r}") |
| 218 | + if self.agent_as_tool is not None: |
| 219 | + parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") |
| 220 | + return f"ToolOrigin({', '.join(parts)})" |
| 221 | + |
| 222 | + |
| 223 | +def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None: |
| 224 | + """Extract origin information from a FunctionTool. |
| 225 | +
|
| 226 | + Args: |
| 227 | + function_tool: The function tool to extract origin info from. |
| 228 | +
|
| 229 | + Returns: |
| 230 | + ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type). |
| 231 | + """ |
| 232 | + origin = function_tool._tool_origin |
| 233 | + if origin is None: |
| 234 | + # Default to FUNCTION if not explicitly set |
| 235 | + return ToolOrigin(type=ToolOriginType.FUNCTION) |
| 236 | + |
| 237 | + return origin |
| 238 | + |
| 239 | + |
185 | 240 | @dataclass |
186 | 241 | class FunctionToolResult: |
187 | 242 | tool: FunctionTool |
@@ -264,6 +319,9 @@ class FunctionTool: |
264 | 319 | _agent_instance: Any = field(default=None, init=False, repr=False) |
265 | 320 | """Internal reference to the agent instance if this is an agent-as-tool.""" |
266 | 321 |
|
| 322 | + _tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False) |
| 323 | + """Internal field tracking the origin of this tool (FUNCTION, MCP, or AGENT_AS_TOOL).""" |
| 324 | + |
267 | 325 | def __post_init__(self): |
268 | 326 | if self.strict_json_schema: |
269 | 327 | self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) |
|
0 commit comments