Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 98 additions & 10 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ async def get_model_response(
Convenience function for wrapping (chat, completion) API calls.
Returns special error messages for context length issues.
"""
RAW_RESPONSE = True
sampling_args = sampling_args or {}
# resolve message type first
if message_type is None:
Expand All @@ -248,10 +249,21 @@ async def get_model_response(
):
sampling_args.pop("max_completion_tokens")
clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None}

if RAW_RESPONSE:
if message_type == "chat":
fn = client.chat.completions.with_raw_response.create
else:
fn = client.completions.with_raw_response.create
else:
if message_type == "chat":
fn = client.chat.completions.create
else:
fn = client.completions.create

try:
if message_type == "chat":
assert isinstance(prompt, list)
# --- detect audio parts and force text-only modality if caller didn't set one ---
has_audio = False
try:
for m in prompt:
Expand All @@ -274,32 +286,26 @@ async def get_model_response(
}

if oai_tools:
response = await client.chat.completions.create(
response = await fn(
model=model,
messages=prompt, # type: ignore
tools=oai_tools,
**clean_sampling_args,
)
else:
response = await client.chat.completions.create(
response = await fn(
model=model,
messages=prompt, # type: ignore
**clean_sampling_args,
)
return response
elif message_type == "completion":
if oai_tools:
raise ValueError(
"oai_tools are not supported for completion tasks."
)
assert isinstance(prompt, str)
response = await client.completions.create(
model=model, prompt=prompt, **clean_sampling_args
)
return response
response = await fn(model=model, prompt=prompt, **clean_sampling_args)
except Exception as e:
# in case of making a request with an overlong prompt, e.g from a too-long
# environment response, we return a dummy response to with finish_reason "length"
if isinstance(e, BadRequestError):
error_text = e.response.text.lower()
context_length_phrases = [
Expand All @@ -315,6 +321,88 @@ async def get_model_response(
self.logger.error(f"Error getting model response: {e} \n\nExiting...")
raise e

if RAW_RESPONSE:
from typing import cast, Union, Dict, List, Optional, Any
from openai.types import Completion
from openai.types.chat import ChatCompletion
from openai._legacy_response import LegacyAPIResponse

response = cast(
Union[LegacyAPIResponse[Completion], LegacyAPIResponse[ChatCompletion]],
response,
)

def _as_float(x: Any) -> Optional[float]:
try:
return float(x)
except Exception:
return None

def _coerce_logprobs_inplace(raw: Dict[str, Any]) -> None:
"""Mutate `raw` to normalize/strip logprobs without copying."""
choices = raw.get("choices")
if not isinstance(choices, list):
return

for ch in choices:
lp = ch.get("logprobs")

if lp is None:
continue

if isinstance(lp, dict) and isinstance(lp.get("content"), list):
for t in lp["content"]:
if isinstance(t, dict) and "logprob" in t:
t["logprob"] = _as_float(t["logprob"])
tops = t.get("top_logprobs")
if isinstance(tops, list):
for tlp in tops:
if isinstance(tlp, dict) and "logprob" in tlp:
tlp["logprob"] = _as_float(tlp["logprob"])
continue

if isinstance(lp, dict) and {"tokens", "token_logprobs"} <= set(
lp.keys()
):
tokens = lp.get("tokens") or []
token_lps = lp.get("token_logprobs") or []
top_series = lp.get("top_logprobs")
if not isinstance(top_series, list):
top_series = []

content: List[Dict[str, Any]] = []
for i, (tok, logp) in enumerate(zip(tokens, token_lps)):
item: Dict[str, Any] = {
"token": tok,
"logprob": _as_float(logp),
}
tops_i = top_series[i] if i < len(top_series) else None
if isinstance(tops_i, dict):
item["top_logprobs"] = [
{"token": k, "logprob": _as_float(v)}
for k, v in tops_i.items()
]
content.append(item)

ch["logprobs"] = {"content": content}
continue

def parse_fast_no_validate(response) -> Union[ChatCompletion, Completion]:
"""
Convert the HTTP JSON to a ChatCompletion *without* running Pydantic validation.
Only coerces/normalizes logprobs; everything else passes through untouched.
"""
raw = response.http_response.json()
_coerce_logprobs_inplace(raw)

return ChatCompletion.model_construct(**raw)

# --- usage ---
chat_completion = parse_fast_no_validate(response)
return chat_completion
else:
return response

@abstractmethod
async def rollout(
self,
Expand Down
Loading