Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ scratch/
.vscode/
*.swp
.DS_Store

# Node modules
node_modules/

# Browser env generated files
environments/browser_env/cua-server/node_modules/
11 changes: 11 additions & 0 deletions environments/mcp_env/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# OpenAI API key (required for judge-based evaluation)
OPENAI_API_KEY=your_openai_api_key

# Smithery credentials for Exa MCP server
SMITHERY_KEY=your_smithery_key
SMITHERY_PROFILE=your_smithery_profile

# Browserbase MCP server credentials
BROWSERBASE_API_KEY=your_browserbase_api_key
BROWSERBASE_PROJECT_ID=your_browserbase_project_id
GEMINI_API_KEY=your_gemini_api_key
65 changes: 63 additions & 2 deletions environments/mcp_env/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,73 @@

### Task

- **Type**: <multi-turn | tool use>
- **Type**: multi-turn | tool use
- **Parser**: N/A
- **Rubric overview**: N/A
- **Rubric overview**: Judge-based evaluation using gpt-4.1-mini

### MCP Tools

This environment integrates with MCP (Model Context Protocol) servers to provide tool-calling capabilities. By default, it includes:

#### Exa & Fetch Tools

- **Exa MCP Server**: Search and discovery tool for finding relevant web content (via Smithery)
- Command: `npx -y @smithery/cli@latest run exa --key <KEY> --profile <PROFILE>`
- Note: Authentication is handled via Smithery CLI key/profile

- **Fetch MCP Server**: Fetches and retrieves web content from URLs
- Command: `uvx mcp-server-fetch`
- No API key required

#### Browserbase Tools

- **Browserbase MCP Server**: Browser automation for interacting with web pages using AI-powered navigation
- Command: `npx @browserbasehq/mcp-server-browserbase`
- Required environment variables:
- `BROWSERBASE_API_KEY`
- `BROWSERBASE_PROJECT_ID`
- `GEMINI_API_KEY`

**Customizing Tools:**

You can pass custom MCP server configurations via the `mcp_servers` argument to `load_environment()`:

```python
custom_servers = [
{
"name": "my-server",
"transport": "stdio",
"command": "npx",
"args": ["my-mcp-server"],
"env": {"API_KEY": "your_key"},
"description": "Custom MCP server"
}
]
env = load_environment(mcp_servers=custom_servers)
```

### Quickstart

**Prerequisites:**

Export the required API keys for the judge LLM and MCP tools:

```bash
# Required for judge-based evaluation
export OPENAI_API_KEY=your_openai_key

# Required for Exa MCP server (via Smithery)
export SMITHERY_KEY=your_smithery_key
export SMITHERY_PROFILE=your_smithery_profile

# Required for Browserbase MCP server (browser automation)
export BROWSERBASE_API_KEY=your_browserbase_key
export BROWSERBASE_PROJECT_ID=your_project_id
export GEMINI_API_KEY=your_gemini_key
```

**Note:** Not all API keys are required for every task. The Fetch MCP server works without any API key. Only export the keys for the tools you intend to use.

Run an evaluation with default settings:

```bash
Expand Down
102 changes: 84 additions & 18 deletions environments/mcp_env/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,49 @@
EXA_FETCH_TOOLS = [
{
"name": "exa",
"transport": "stdio",
"command": "npx",
"args": [
"-y",
"exa-mcp-server",
"@smithery/cli@latest",
"run",
"exa",
"--key",
os.getenv("SMITHERY_KEY", ""),
"--profile",
os.getenv("SMITHERY_PROFILE", ""),
],
"env": {
"EXA_API_KEY": os.getenv("EXA_API_KEY"),
"NPM_CONFIG_LOGLEVEL": "silent",
},
"description": "Exa MCP server",
},
{
"name": "fetch",
"transport": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"description": "Fetch MCP server",
},
]

BROWSERBASE_TOOLS = [
{
"name": "browserbase",
"transport": "stdio",
"command": "npx",
"args": ["@browserbasehq/mcp-server-browserbase"],
"env": {
"BROWSERBASE_API_KEY": os.getenv("BROWSERBASE_API_KEY", ""),
"BROWSERBASE_PROJECT_ID": os.getenv("BROWSERBASE_PROJECT_ID", ""),
"GEMINI_API_KEY": os.getenv("GEMINI_API_KEY", ""),
"NPM_CONFIG_LOGLEVEL": "silent",
},
"description": "Browserbase MCP (via npx)",
},
]



class MCPEnv(ToolEnv):
"""Environment for MCP-based tools using the official MCP SDK."""
Expand Down Expand Up @@ -67,15 +92,21 @@ def __init__(
super().__init__(
tools=[], max_turns=max_turns, error_formatter=error_formatter, **kwargs
)

self.logger.info(f"Initializing MCPEnv with {len(self.mcp_servers)} MCP server(s)")

# Start a persistent background event loop and connect synchronously
self._bg_loop = asyncio.new_event_loop()
self._bg_thread = threading.Thread(
target=self._run_loop, args=(self._bg_loop,), daemon=True
)
self._bg_thread.start()
self.logger.debug("Background event loop started")

fut = asyncio.run_coroutine_threadsafe(self._connect_servers(), self._bg_loop)
fut.result()
self._setup_complete = True
self.logger.info("MCPEnv initialization complete")

# cleanup on exit
atexit.register(
Expand All @@ -93,70 +124,105 @@ def _run_loop(self, loop: asyncio.AbstractEventLoop):

async def _connect_servers(self):
wrapper_tools = []
self.logger.info(f"Starting connection to {len(self.mcp_servers)} MCP server(s)")

for server_config in self.mcp_servers:
connection = MCPServerConnection(server_config, self.logger)
tools = await connection.connect()

self.server_connections[server_config.name] = connection
self.logger.info(f"Connecting to MCP server: '{server_config.name}'")
self.logger.debug(f" Transport: {server_config.transport}")
self.logger.debug(f" Command: {server_config.command}")
self.logger.debug(f" Args: {server_config.args}")
if server_config.env:
env_keys = list(server_config.env.keys())
self.logger.debug(f" Environment variables: {env_keys}")

for tool in tools.values():
wrapper = MCPToolWrapper(server_config.name, tool, connection)
wrapper_tools.append(wrapper)
self.mcp_tools[wrapper.__name__] = wrapper
self.logger.info(
f"Registered MCP tool: {wrapper.__name__} from server '{server_config.name}'"
)
try:
connection = MCPServerConnection(server_config, self.logger)
tools = await connection.connect()

self.server_connections[server_config.name] = connection
self.logger.info(f"✓ Successfully connected to '{server_config.name}', discovered {len(tools)} tool(s)")

for tool in tools.values():
wrapper = MCPToolWrapper(server_config.name, tool, connection)
wrapper_tools.append(wrapper)
self.mcp_tools[wrapper.__name__] = wrapper
self.logger.info(
f" ├─ Registered MCP tool: {wrapper.__name__}"
)
except Exception as e:
self.logger.error(f"✗ Failed to connect to MCP server '{server_config.name}': {e}")
raise

self.tools = wrapper_tools
self.oai_tools = [tool.to_oai_tool() for tool in wrapper_tools]
self.tool_map = {tool.__name__: tool for tool in wrapper_tools}
self.logger.info(f"✓ Total MCP tools registered: {len(self.tool_map)}")

async def call_tool(
self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs
) -> Message:
if tool_name in self.tool_map:
tool_wrapper = self.tool_map[tool_name]
self.logger.info(f"Calling tool: {tool_name}")
self.logger.debug(f" Arguments: {tool_args}")

try:
result = await tool_wrapper(**tool_args)
result_str = str(result)
result_preview = result_str[:200] + "..." if len(result_str) > 200 else result_str
self.logger.info(f"✓ Tool '{tool_name}' completed successfully")
self.logger.debug(f" Result preview: {result_preview}")

return {
"role": "tool",
"content": str(result),
"content": result_str,
"tool_call_id": tool_call_id,
}
except Exception as e:
self.logger.error(f"✗ Tool '{tool_name}' failed: {e}")
return {
"role": "tool",
"content": self.error_formatter(e),
"tool_call_id": tool_call_id,
}
else:
self.logger.error(f"✗ Tool '{tool_name}' not found in tool_map")
return {
"role": "tool",
"content": f"Error: Tool '{tool_name}' not found",
"tool_call_id": tool_call_id,
}

async def cleanup(self):
for connection in self.server_connections.values():
await connection.disconnect()
self.logger.info(f"Cleaning up {len(self.server_connections)} MCP server connection(s)")

for name, connection in self.server_connections.items():
try:
self.logger.debug(f"Disconnecting from MCP server: '{name}'")
await connection.disconnect()
self.logger.info(f"✓ Disconnected from MCP server: '{name}'")
except Exception as e:
self.logger.error(f"✗ Error disconnecting from MCP server '{name}': {e}")

self.server_connections.clear()
self.mcp_tools.clear()
self.logger.info("Cleanup complete")

def _shutdown_loop(self):
self.logger.debug("Shutting down background event loop")
self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
self._bg_thread.join(timeout=5)
self.logger.debug("Background event loop stopped")


def load_environment(
mcp_servers: list = EXA_FETCH_TOOLS, dataset=None, **kwargs
mcp_servers: list = EXA_FETCH_TOOLS + BROWSERBASE_TOOLS, dataset=None, **kwargs
) -> vf.Environment:
"""Load an MCPEnv environment with fetch server for testing."""
dataset = dataset or Dataset.from_dict(
{
"question": [
"Find out what Prime Intellect's newest announcement was from their website, give me the headline in 2 words. Their url is primeintellect.ai",
"Find out what Prime Intellect's newest announcement was from their website, give me the headline in 2 words. Their url is primeintellect.ai. Use the browserbase tools to get the information.",
],
"answer": ["ENVIRONMENTS HUB"],
}
Expand Down
2 changes: 1 addition & 1 deletion environments/mcp_env/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ tags = ["train", "eval"]
version = "0.1.0"
requires-python = ">=3.11"
dependencies = [
"mcp>=1.14.1",
"mcp[cli]>=1.14.1",
"python-dotenv>=1.1.1",
"verifiers>=0.1.4",
]
Expand Down
59 changes: 44 additions & 15 deletions environments/mcp_env/src/mcp_server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent, Tool

from .models import MCPServerConfig
Expand Down Expand Up @@ -35,27 +36,55 @@ async def connect(self):

async def _get_connection(self):
try:
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
)
if self.config.transport == "stdio":
if not self.config.command:
raise ValueError("stdio transport requires 'command'")
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
)

async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
self.session = session
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
self.session = session

await session.initialize()
await session.initialize()

tools_response = await session.list_tools()
tools_response = await session.list_tools()

for tool in tools_response.tools:
self.tools[tool.name] = tool
for tool in tools_response.tools:
self.tools[tool.name] = tool

self._ready.set()
self._ready.set()

while True:
await asyncio.sleep(1)
while True:
await asyncio.sleep(1)

elif self.config.transport == "http":
if not self.config.url:
raise ValueError("http transport requires 'url'")

async with streamablehttp_client(
self.config.url,
headers=self.config.headers or {},
) as (read, write, _get_session_id):
async with ClientSession(read, write) as session:
self.session = session

await session.initialize()

tools_response = await session.list_tools()

for tool in tools_response.tools:
self.tools[tool.name] = tool

self._ready.set()

while True:
await asyncio.sleep(1)
else:
raise ValueError(f"Unknown transport: {self.config.transport}")

except asyncio.CancelledError:
raise
Expand Down
Loading