Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 261 additions & 4 deletions verifiers/envs/cli_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
import time
import uuid
import json
from typing import Any

from aiohttp import web
Expand All @@ -25,6 +26,37 @@
logger = logging.getLogger(__name__)


def _truncate(text: str, max_len: int = 500) -> str:
"""Truncate text for logging."""
if len(text) <= max_len:
return text
return text[:max_len] + f"... ({len(text) - max_len} more chars)"


def _format_message(msg: dict) -> str:
"""Format a message for logging."""
role = msg.get("role", "unknown")
content = msg.get("content", "")

# Handle tool calls
tool_calls = msg.get("tool_calls", [])
if tool_calls:
tools_summary = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "unknown")
args = func.get("arguments", "")
tools_summary.append(f"{name}({_truncate(args, 200)})")
return f"[{role}] tool_calls: {', '.join(tools_summary)}"

# Handle tool results
if role == "tool":
tool_call_id = msg.get("tool_call_id", "")
return f"[{role}:{tool_call_id}] {_truncate(str(content), 300)}"

return f"[{role}] {_truncate(str(content), 500)}"


class CliAgentEnv(vf.MultiTurnEnv):
"""
Environment for running full agent code inside sandboxes.
Expand Down Expand Up @@ -55,6 +87,7 @@ def __init__(
environment_vars: dict[str, str] | None = None,
team_id: str | None = None,
advanced_configs: AdvancedConfigs | None = None,
log_requests: bool = True,
**kwargs,
):
super().__init__(max_turns=max_turns, message_type="chat", **kwargs)
Expand All @@ -77,12 +110,14 @@ def __init__(
self.environment_vars = environment_vars
self.team_id = team_id
self.advanced_configs = advanced_configs
self.log_requests = log_requests
self.active_rollouts: dict[str, dict[str, Any]] = {}
self.intercepts: dict[str, dict[str, Any]] = {} # request_id -> intercept data
self.interception_server: Any = None
self._server_lock = asyncio.Lock()
self._server_runner: Any = None
self._server_site: Any = None
self._request_counts: dict[str, int] = {} # rollout_id -> request count

def _ensure_cloudflared_installed(self) -> str:
"""Install cloudflared if not already installed. Returns path to cloudflared binary."""
Expand Down Expand Up @@ -318,12 +353,45 @@ async def get_prompt_messages(self, state: State) -> Messages:
process request immediately with injected sampling_args, store response in intercept,
return messages.
"""
rollout_id = state.get("rollout_id", "unknown")
request_id_queue = state["request_id_queue"]

request_id = await asyncio.wait_for(
request_id_queue.get(),
timeout=self.request_timeout,
)
# Track request count for logging
if rollout_id not in self._request_counts:
self._request_counts[rollout_id] = 0
self._request_counts[rollout_id] += 1
req_num = self._request_counts[rollout_id]

# Poll for requests while checking completion periodically
# This avoids blocking for the full request_timeout when agent finishes
poll_interval = 5.0 # Check completion every 5 seconds
elapsed = 0.0
request_id = None
while elapsed < self.request_timeout:
try:
request_id = await asyncio.wait_for(
request_id_queue.get(),
timeout=poll_interval,
)
break # Got a request, continue processing
except asyncio.TimeoutError:
elapsed += poll_interval
# Check if agent signaled completion
if await self.agent_signaled_completion(state):
logger.debug("Agent signaled completion while waiting for request")
# Set flag for early exit - stop conditions will handle termination
state["_cli_agent_completed"] = True
return [] # Return empty messages, stop condition will trigger
# Check timeout
if await self.timeout_reached(state):
logger.debug("Timeout reached while waiting for request")
state["_cli_agent_completed"] = True
return []

if request_id is None:
raise asyncio.TimeoutError(
f"No request received within {self.request_timeout}s"
)

intercept = self.intercepts[request_id]
messages = intercept["messages"]
Expand All @@ -333,6 +401,16 @@ async def get_prompt_messages(self, state: State) -> Messages:
request_tools = intercept.get("tools")
effective_sampling_args = state.get("sampling_args") or {}

# Log the intercepted request
if self.log_requests:
logger.info(
f"[Request #{req_num}] model={request_model}, "
f"messages={len(messages)}, tools={len(request_tools or [])}"
)
# Log the last few messages (most relevant context)
for msg in messages[-3:]:
logger.info(f" {_format_message(msg)}")

client = state.get("client")
if client is None:
raise RuntimeError("Client not set in state")
Expand All @@ -347,6 +425,18 @@ async def get_prompt_messages(self, state: State) -> Messages:
message_type=None,
)

# Log the response
if self.log_requests and response.choices:
choice = response.choices[0]
msg = choice.message
if msg.tool_calls:
tools = [f"{tc.function.name}(...)" for tc in msg.tool_calls]
logger.info(f"[Response #{req_num}] tool_calls: {', '.join(tools)}")
elif msg.content:
logger.info(f"[Response #{req_num}] {_truncate(msg.content, 200)}")
else:
logger.info(f"[Response #{req_num}] (empty)")

intercept["response_future"].set_result(response)
intercept["response"] = response
state["current_request_id"] = request_id
Expand All @@ -367,6 +457,29 @@ async def get_model_response(
Return cached response if available (set by get_prompt_messages).
Otherwise fall back to parent implementation.
"""
# If prompt is empty, we're in early-exit mode - return a dummy response
# The stop condition will terminate the rollout on the next iteration
if not prompt:
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice

return ChatCompletion(
id="cli_agent_early_exit",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
role="assistant",
content="",
),
)
],
created=int(time.time()),
model=model,
object="chat.completion",
)

for request_id, intercept in list(self.intercepts.items()):
rollout_id = intercept.get("rollout_id")
if rollout_id and rollout_id in self.active_rollouts:
Expand Down Expand Up @@ -427,13 +540,26 @@ async def _handle_intercepted_request(self, request: Any) -> Any:
{"error": f"Invalid JSON: {e}"}, status=400
)

# Log request details including stream parameter
stream_requested = request_body.get("stream", False)
logger.info(
f"Intercepted request: stream={stream_requested}, "
f"model={request_body.get('model')}, "
f"messages={len(request_body.get('messages', []))}"
)

# Force non-streaming - we don't support SSE streaming yet
# The response will be converted to SSE format if stream was requested
request_body["stream"] = False

request_id = f"req_{uuid.uuid4().hex[:8]}"
intercept = {
"request_id": request_id,
"rollout_id": rollout_id,
"messages": request_body["messages"],
"model": request_body.get("model"),
"tools": request_body.get("tools"),
"stream_requested": stream_requested, # Remember if client wanted streaming
"response_future": asyncio.Future(),
}

Expand All @@ -451,8 +577,130 @@ async def _handle_intercepted_request(self, request: Any) -> Any:
response_dict = (
response.model_dump() if hasattr(response, "model_dump") else dict(response)
)

# logger.info(
# f"Response to agent: {json.dumps(response_dict, indent=2, default=str)[:2000]}"
# )

# If client requested streaming, convert to SSE format
if intercept.get("stream_requested", False):
return self._create_sse_response(response_dict)

return web.json_response(response_dict) # type: ignore

def _create_sse_response(self, response_dict: dict) -> web.Response:
"""Convert a chat completion response to SSE streaming format."""
response_id = response_dict.get("id", "chatcmpl-unknown")
created = response_dict.get("created", 0)
model = response_dict.get("model", "unknown")

chunks = []

for choice in response_dict.get("choices", []):
message = choice.get("message", {})
finish_reason = choice.get("finish_reason")
index = choice.get("index", 0)

# First chunk: role
first_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": {"role": message.get("role", "assistant")},
"finish_reason": None,
}
],
}
chunks.append(f"data: {json.dumps(first_chunk)}\n\n")

# Content chunk (if any)
content = message.get("content")
if content:
content_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": {"content": content},
"finish_reason": None,
}
],
}
chunks.append(f"data: {json.dumps(content_chunk)}\n\n")

# Tool calls chunks (if any)
tool_calls = message.get("tool_calls", [])
if tool_calls:
# Send tool call with full info in one chunk
tc_delta = []
for i, tc in enumerate(tool_calls):
tc_delta.append(
{
"index": i,
"id": tc.get("id"),
"type": tc.get("type", "function"),
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get(
"arguments", ""
),
},
}
)

tool_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": {"tool_calls": tc_delta},
"finish_reason": None,
}
],
}
chunks.append(f"data: {json.dumps(tool_chunk)}\n\n")

# Final chunk with finish_reason
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": {},
"finish_reason": finish_reason,
}
],
}
chunks.append(f"data: {json.dumps(final_chunk)}\n\n")

# End of stream
chunks.append("data: [DONE]\n\n")

body = "".join(chunks)
logger.debug(f"SSE response body:\n{body[:1500]}")
return web.Response(
body=body,
status=200,
content_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

@vf.teardown
async def teardown_tunnel(self):
"""Stop all cloudflared tunnel processes"""
Expand Down Expand Up @@ -482,6 +730,10 @@ async def cleanup_interception_context(self, state: State):
del self.intercepts[request_id]
del self.active_rollouts[rollout_id]

# Clean up request count
if rollout_id and rollout_id in self._request_counts:
del self._request_counts[rollout_id]

# Decrement active rollouts for the tunnel used by this rollout
tunnel_url = state.get("tunnel_url")
if tunnel_url:
Expand All @@ -493,6 +745,11 @@ async def cleanup_interception_context(self, state: State):
)
break

@vf.stop
async def early_exit_flag_set(self, state: State) -> bool:
"""Check if early exit flag was set (by completion detection in get_prompt_messages)"""
return state.get("_cli_agent_completed", False)

@vf.stop
async def agent_signaled_completion(self, state: State) -> bool:
"""Check for /tmp/vf_complete marker file"""
Expand Down
Loading