Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
from verifiers.utils.async_utils import maybe_semaphore
from verifiers.utils.eval_utils import make_dataset, save_results
from verifiers.utils.message_utils import (
adapt_tools_for_responses_api,
cleanup_messages,
extract_system_message,
get_overlong_prompt_dummy_response,
)
from verifiers.utils.responses_api_adapter import ResponsesAPIAdapter
from verifiers.utils.path_utils import get_results_path
from verifiers.utils.processing_utils import (
process_chat_format_vllm,
Expand Down Expand Up @@ -67,6 +70,7 @@ def __init__(
env_id: str | None = None,
env_args: dict | None = None,
map_kwargs: dict = {},
use_responses_api: bool = False,
**kwargs,
):
self.logger = logging.getLogger(f"verifiers.envs.{self.__class__.__name__}")
Expand All @@ -75,6 +79,7 @@ def __init__(
self.system_prompt = system_prompt
self.few_shot = few_shot
self.parser = parser or Parser()
self.use_responses_api = use_responses_api
self.rubric = rubric or Rubric()
if self.parser.__class__ != self.rubric.parser.__class__:
self.logger.warning(
Expand Down Expand Up @@ -281,20 +286,53 @@ async def get_model_response(
"modalities": ["text"],
}

if oai_tools:
response = await client.chat.completions.create(
if self.use_responses_api:
instructions, input_messages = extract_system_message(prompt)
adapted_tools = adapt_tools_for_responses_api(oai_tools)
if len(input_messages) == 1:
api_input = input_messages[0].get("content", "")
else:
api_input = input_messages
unsupported_params = {
"n",
"presence_penalty",
"frequency_penalty",
"logprobs",
"top_logprobs",
"logit_bias",
"stream",
"stream_options",
"user",
"temperature",
}
responses_sampling_args = {
k: v
for k, v in clean_sampling_args.items()
if k not in unsupported_params
}
response = await client.responses.create(
model=model,
messages=prompt, # type: ignore
tools=oai_tools,
**clean_sampling_args,
instructions=instructions,
input=api_input,
tools=adapted_tools,
**responses_sampling_args,
)
return ResponsesAPIAdapter(response)
else:
response = await client.chat.completions.create(
model=model,
messages=prompt, # type: ignore
**clean_sampling_args,
)
return response
if oai_tools:
response = await client.chat.completions.create(
model=model,
messages=prompt, # type: ignore
tools=oai_tools,
**clean_sampling_args,
)
else:
response = await client.chat.completions.create(
model=model,
messages=prompt, # type: ignore
**clean_sampling_args,
)
return response
elif message_type == "completion":
if oai_tools:
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion verifiers/envs/multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ async def rollout(
response_text: str = ""
if self.message_type == "chat":
assert isinstance(context_messages, list)
assert isinstance(response, ChatCompletion)
assert isinstance(response, ChatCompletion) or hasattr(
response, "choices"
)
if response.choices and response.choices[0].message:
response_text = response.choices[0].message.content or ""
response_message: ChatMessage = {
Expand Down
38 changes: 38 additions & 0 deletions verifiers/utils/message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,41 @@ def get_overlong_prompt_dummy_response(message_type: MessageType) -> ModelRespon
)
else:
raise ValueError(f"Invalid message type: {message_type}")


def extract_system_message(messages: list) -> tuple[str | None, list]:
"""Extract system message as instructions for Responses API."""
instructions = None
remaining = []

for msg in messages:
if msg.get("role") == "system":
if instructions is None:
instructions = msg.get("content")
else:
remaining.append(msg)

return instructions, remaining


def adapt_tools_for_responses_api(chat_tools: list | None) -> list | None:
"""Convert Chat Completions tool format to Responses API format."""
if not chat_tools:
return None

adapted = []
for tool in chat_tools:
if tool.get("type") == "function":
func = tool.get("function", {})
adapted.append(
{
"type": "function",
"name": func.get("name"),
"description": func.get("description"),
"parameters": func.get("parameters"),
}
)
else:
adapted.append(tool)

return adapted
61 changes: 61 additions & 0 deletions verifiers/utils/responses_api_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion import Choice


class ResponsesAPIAdapter:
"""Adapter to normalize Responses API responses to ChatCompletion format."""

def __init__(self, responses_response):
self._response = responses_response
self._text = getattr(responses_response, "output_text", "")
self._tool_calls = self._extract_tool_calls()

def _extract_tool_calls(self):
tool_calls = []
output = getattr(self._response, "output", [])

for item in output:
item_type = getattr(item, "type", None)
if item_type == "function_call":
tool_calls.append(
{
"id": getattr(item, "call_id", ""),
"type": "function",
"function": {
"name": getattr(item, "name", ""),
"arguments": str(getattr(item, "arguments", {})),
},
}
)

return tool_calls if tool_calls else None

@property
def choices(self):
return [
Choice(
index=0,
message=ChatCompletionMessage(role="assistant", content=self._text, tool_calls=self._tool_calls),
finish_reason="stop",
)
]

@property
def id(self):
return getattr(self._response, "id", "responses-api-adapter")

@property
def model(self):
return getattr(self._response, "model", "")

@property
def created(self):
return getattr(self._response, "created_at", 0)

@property
def object(self):
return "chat.completion"

@property
def usage(self):
return getattr(self._response, "usage", None)