diff --git a/.gitignore b/.gitignore index f920a7c4e..7f95c3c6d 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,9 @@ scratch/ .vscode/ *.swp .DS_Store + +# Node modules +node_modules/ + +# Browser env generated files +environments/browser_env/cua-server/node_modules/ diff --git a/environments/mcp_env/.env.example b/environments/mcp_env/.env.example new file mode 100644 index 000000000..52b78ee64 --- /dev/null +++ b/environments/mcp_env/.env.example @@ -0,0 +1,11 @@ +# OpenAI API key (required for judge-based evaluation) +OPENAI_API_KEY=your_openai_api_key + +# Smithery credentials for Exa MCP server +SMITHERY_KEY=your_smithery_key +SMITHERY_PROFILE=your_smithery_profile + +# Browserbase MCP server credentials +BROWSERBASE_API_KEY=your_browserbase_api_key +BROWSERBASE_PROJECT_ID=your_browserbase_project_id +GEMINI_API_KEY=your_gemini_api_key diff --git a/environments/mcp_env/README.md b/environments/mcp_env/README.md index 700fee9b3..55300b0b7 100644 --- a/environments/mcp_env/README.md +++ b/environments/mcp_env/README.md @@ -14,12 +14,73 @@ ### Task -- **Type**: +- **Type**: multi-turn | tool use - **Parser**: N/A -- **Rubric overview**: N/A +- **Rubric overview**: Judge-based evaluation using gpt-4.1-mini + +### MCP Tools + +This environment integrates with MCP (Model Context Protocol) servers to provide tool-calling capabilities. By default, it includes: + +#### Exa & Fetch Tools + +- **Exa MCP Server**: Search and discovery tool for finding relevant web content (via Smithery) + - Command: `npx -y @smithery/cli@latest run exa --key --profile ` + - Note: Authentication is handled via Smithery CLI key/profile + +- **Fetch MCP Server**: Fetches and retrieves web content from URLs + - Command: `uvx mcp-server-fetch` + - No API key required + +#### Browserbase Tools + +- **Browserbase MCP Server**: Browser automation for interacting with web pages using AI-powered navigation + - Command: `npx @browserbasehq/mcp-server-browserbase` + - Required environment variables: + - `BROWSERBASE_API_KEY` + - `BROWSERBASE_PROJECT_ID` + - `GEMINI_API_KEY` + +**Customizing Tools:** + +You can pass custom MCP server configurations via the `mcp_servers` argument to `load_environment()`: + +```python +custom_servers = [ + { + "name": "my-server", + "transport": "stdio", + "command": "npx", + "args": ["my-mcp-server"], + "env": {"API_KEY": "your_key"}, + "description": "Custom MCP server" + } +] +env = load_environment(mcp_servers=custom_servers) +``` ### Quickstart +**Prerequisites:** + +Export the required API keys for the judge LLM and MCP tools: + +```bash +# Required for judge-based evaluation +export OPENAI_API_KEY=your_openai_key + +# Required for Exa MCP server (via Smithery) +export SMITHERY_KEY=your_smithery_key +export SMITHERY_PROFILE=your_smithery_profile + +# Required for Browserbase MCP server (browser automation) +export BROWSERBASE_API_KEY=your_browserbase_key +export BROWSERBASE_PROJECT_ID=your_project_id +export GEMINI_API_KEY=your_gemini_key +``` + +**Note:** Not all API keys are required for every task. The Fetch MCP server works without any API key. Only export the keys for the tools you intend to use. + Run an evaluation with default settings: ```bash diff --git a/environments/mcp_env/mcp_env.py b/environments/mcp_env/mcp_env.py index a7d3f6244..e31e9f561 100644 --- a/environments/mcp_env/mcp_env.py +++ b/environments/mcp_env/mcp_env.py @@ -19,24 +19,49 @@ EXA_FETCH_TOOLS = [ { "name": "exa", + "transport": "stdio", "command": "npx", "args": [ "-y", - "exa-mcp-server", + "@smithery/cli@latest", + "run", + "exa", + "--key", + os.getenv("SMITHERY_KEY", ""), + "--profile", + os.getenv("SMITHERY_PROFILE", ""), ], "env": { - "EXA_API_KEY": os.getenv("EXA_API_KEY"), + "NPM_CONFIG_LOGLEVEL": "silent", }, "description": "Exa MCP server", }, { "name": "fetch", + "transport": "stdio", "command": "uvx", "args": ["mcp-server-fetch"], "description": "Fetch MCP server", }, ] +BROWSERBASE_TOOLS = [ + { + "name": "browserbase", + "transport": "stdio", + "command": "npx", + "args": ["@browserbasehq/mcp-server-browserbase"], + "env": { + "BROWSERBASE_API_KEY": os.getenv("BROWSERBASE_API_KEY", ""), + "BROWSERBASE_PROJECT_ID": os.getenv("BROWSERBASE_PROJECT_ID", ""), + "GEMINI_API_KEY": os.getenv("GEMINI_API_KEY", ""), + "NPM_CONFIG_LOGLEVEL": "silent", + }, + "description": "Browserbase MCP (via npx)", + }, +] + + class MCPEnv(ToolEnv): """Environment for MCP-based tools using the official MCP SDK.""" @@ -67,15 +92,21 @@ def __init__( super().__init__( tools=[], max_turns=max_turns, error_formatter=error_formatter, **kwargs ) + + self.logger.info(f"Initializing MCPEnv with {len(self.mcp_servers)} MCP server(s)") + # Start a persistent background event loop and connect synchronously self._bg_loop = asyncio.new_event_loop() self._bg_thread = threading.Thread( target=self._run_loop, args=(self._bg_loop,), daemon=True ) self._bg_thread.start() + self.logger.debug("Background event loop started") + fut = asyncio.run_coroutine_threadsafe(self._connect_servers(), self._bg_loop) fut.result() self._setup_complete = True + self.logger.info("MCPEnv initialization complete") # cleanup on exit atexit.register( @@ -93,44 +124,69 @@ def _run_loop(self, loop: asyncio.AbstractEventLoop): async def _connect_servers(self): wrapper_tools = [] + self.logger.info(f"Starting connection to {len(self.mcp_servers)} MCP server(s)") for server_config in self.mcp_servers: - connection = MCPServerConnection(server_config, self.logger) - tools = await connection.connect() - - self.server_connections[server_config.name] = connection + self.logger.info(f"Connecting to MCP server: '{server_config.name}'") + self.logger.debug(f" Transport: {server_config.transport}") + self.logger.debug(f" Command: {server_config.command}") + self.logger.debug(f" Args: {server_config.args}") + if server_config.env: + env_keys = list(server_config.env.keys()) + self.logger.debug(f" Environment variables: {env_keys}") - for tool in tools.values(): - wrapper = MCPToolWrapper(server_config.name, tool, connection) - wrapper_tools.append(wrapper) - self.mcp_tools[wrapper.__name__] = wrapper - self.logger.info( - f"Registered MCP tool: {wrapper.__name__} from server '{server_config.name}'" - ) + try: + connection = MCPServerConnection(server_config, self.logger) + tools = await connection.connect() + + self.server_connections[server_config.name] = connection + self.logger.info(f"✓ Successfully connected to '{server_config.name}', discovered {len(tools)} tool(s)") + + for tool in tools.values(): + wrapper = MCPToolWrapper(server_config.name, tool, connection) + wrapper_tools.append(wrapper) + self.mcp_tools[wrapper.__name__] = wrapper + self.logger.info( + f" ├─ Registered MCP tool: {wrapper.__name__}" + ) + except Exception as e: + self.logger.error(f"✗ Failed to connect to MCP server '{server_config.name}': {e}") + raise self.tools = wrapper_tools self.oai_tools = [tool.to_oai_tool() for tool in wrapper_tools] self.tool_map = {tool.__name__: tool for tool in wrapper_tools} + self.logger.info(f"✓ Total MCP tools registered: {len(self.tool_map)}") async def call_tool( self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs ) -> Message: if tool_name in self.tool_map: tool_wrapper = self.tool_map[tool_name] + self.logger.info(f"Calling tool: {tool_name}") + self.logger.debug(f" Arguments: {tool_args}") + try: result = await tool_wrapper(**tool_args) + result_str = str(result) + result_preview = result_str[:200] + "..." if len(result_str) > 200 else result_str + self.logger.info(f"✓ Tool '{tool_name}' completed successfully") + self.logger.debug(f" Result preview: {result_preview}") + return { "role": "tool", - "content": str(result), + "content": result_str, "tool_call_id": tool_call_id, } except Exception as e: + self.logger.error(f"✗ Tool '{tool_name}' failed: {e}") return { "role": "tool", "content": self.error_formatter(e), "tool_call_id": tool_call_id, } else: + self.logger.error(f"✗ Tool '{tool_name}' not found in tool_map") return { "role": "tool", "content": f"Error: Tool '{tool_name}' not found", @@ -138,25 +194,35 @@ async def call_tool( } async def cleanup(self): - for connection in self.server_connections.values(): - await connection.disconnect() + self.logger.info(f"Cleaning up {len(self.server_connections)} MCP server connection(s)") + + for name, connection in self.server_connections.items(): + try: + self.logger.debug(f"Disconnecting from MCP server: '{name}'") + await connection.disconnect() + self.logger.info(f"✓ Disconnected from MCP server: '{name}'") + except Exception as e: + self.logger.error(f"✗ Error disconnecting from MCP server '{name}': {e}") self.server_connections.clear() self.mcp_tools.clear() + self.logger.info("Cleanup complete") def _shutdown_loop(self): + self.logger.debug("Shutting down background event loop") self._bg_loop.call_soon_threadsafe(self._bg_loop.stop) self._bg_thread.join(timeout=5) + self.logger.debug("Background event loop stopped") def load_environment( - mcp_servers: list = EXA_FETCH_TOOLS, dataset=None, **kwargs + mcp_servers: list = EXA_FETCH_TOOLS + BROWSERBASE_TOOLS, dataset=None, **kwargs ) -> vf.Environment: """Load an MCPEnv environment with fetch server for testing.""" dataset = dataset or Dataset.from_dict( { "question": [ - "Find out what Prime Intellect's newest announcement was from their website, give me the headline in 2 words. Their url is primeintellect.ai", + "Find out what Prime Intellect's newest announcement was from their website, give me the headline in 2 words. Their url is primeintellect.ai. Use the browserbase tools to get the information.", ], "answer": ["ENVIRONMENTS HUB"], } diff --git a/environments/mcp_env/pyproject.toml b/environments/mcp_env/pyproject.toml index 43529eadf..b59cc1c21 100644 --- a/environments/mcp_env/pyproject.toml +++ b/environments/mcp_env/pyproject.toml @@ -5,7 +5,7 @@ tags = ["train", "eval"] version = "0.1.0" requires-python = ">=3.11" dependencies = [ - "mcp>=1.14.1", + "mcp[cli]>=1.14.1", "python-dotenv>=1.1.1", "verifiers>=0.1.4", ] diff --git a/environments/mcp_env/src/mcp_server_connection.py b/environments/mcp_env/src/mcp_server_connection.py index 675947651..899ce9382 100644 --- a/environments/mcp_env/src/mcp_server_connection.py +++ b/environments/mcp_env/src/mcp_server_connection.py @@ -4,6 +4,7 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client from mcp.types import TextContent, Tool from .models import MCPServerConfig @@ -35,27 +36,55 @@ async def connect(self): async def _get_connection(self): try: - server_params = StdioServerParameters( - command=self.config.command, - args=self.config.args or [], - env=self.config.env, - ) + if self.config.transport == "stdio": + if not self.config.command: + raise ValueError("stdio transport requires 'command'") + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + ) - async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: - self.session = session + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + self.session = session - await session.initialize() + await session.initialize() - tools_response = await session.list_tools() + tools_response = await session.list_tools() - for tool in tools_response.tools: - self.tools[tool.name] = tool + for tool in tools_response.tools: + self.tools[tool.name] = tool - self._ready.set() + self._ready.set() - while True: - await asyncio.sleep(1) + while True: + await asyncio.sleep(1) + + elif self.config.transport == "http": + if not self.config.url: + raise ValueError("http transport requires 'url'") + + async with streamablehttp_client( + self.config.url, + headers=self.config.headers or {}, + ) as (read, write, _get_session_id): + async with ClientSession(read, write) as session: + self.session = session + + await session.initialize() + + tools_response = await session.list_tools() + + for tool in tools_response.tools: + self.tools[tool.name] = tool + + self._ready.set() + + while True: + await asyncio.sleep(1) + else: + raise ValueError(f"Unknown transport: {self.config.transport}") except asyncio.CancelledError: raise diff --git a/environments/mcp_env/src/models.py b/environments/mcp_env/src/models.py index 7a20dd38e..1b5ffc7c8 100644 --- a/environments/mcp_env/src/models.py +++ b/environments/mcp_env/src/models.py @@ -1,11 +1,16 @@ from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Literal, Optional @dataclass class MCPServerConfig: name: str - command: str - args: List[str] | None = None - env: Dict[str, str] | None = None + transport: Literal["stdio", "http"] = "stdio" description: str = "" + # stdio params + command: Optional[str] = None + args: Optional[List[str]] = None + env: Optional[Dict[str, str]] = None + # http params + url: Optional[str] = None + headers: Optional[Dict[str, str]] = None