diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index f79eb1859..30cedcd0a 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -230,6 +230,7 @@ async def get_model_response( Convenience function for wrapping (chat, completion) API calls. Returns special error messages for context length issues. """ + RAW_RESPONSE = True sampling_args = sampling_args or {} # resolve message type first if message_type is None: @@ -248,10 +249,21 @@ async def get_model_response( ): sampling_args.pop("max_completion_tokens") clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None} + + if RAW_RESPONSE: + if message_type == "chat": + fn = client.chat.completions.with_raw_response.create + else: + fn = client.completions.with_raw_response.create + else: + if message_type == "chat": + fn = client.chat.completions.create + else: + fn = client.completions.create + try: if message_type == "chat": assert isinstance(prompt, list) - # --- detect audio parts and force text-only modality if caller didn't set one --- has_audio = False try: for m in prompt: @@ -274,32 +286,26 @@ async def get_model_response( } if oai_tools: - response = await client.chat.completions.create( + response = await fn( model=model, messages=prompt, # type: ignore tools=oai_tools, **clean_sampling_args, ) else: - response = await client.chat.completions.create( + response = await fn( model=model, messages=prompt, # type: ignore **clean_sampling_args, ) - return response elif message_type == "completion": if oai_tools: raise ValueError( "oai_tools are not supported for completion tasks." ) assert isinstance(prompt, str) - response = await client.completions.create( - model=model, prompt=prompt, **clean_sampling_args - ) - return response + response = await fn(model=model, prompt=prompt, **clean_sampling_args) except Exception as e: - # in case of making a request with an overlong prompt, e.g from a too-long - # environment response, we return a dummy response to with finish_reason "length" if isinstance(e, BadRequestError): error_text = e.response.text.lower() context_length_phrases = [ @@ -315,6 +321,88 @@ async def get_model_response( self.logger.error(f"Error getting model response: {e} \n\nExiting...") raise e + if RAW_RESPONSE: + from typing import cast, Union, Dict, List, Optional, Any + from openai.types import Completion + from openai.types.chat import ChatCompletion + from openai._legacy_response import LegacyAPIResponse + + response = cast( + Union[LegacyAPIResponse[Completion], LegacyAPIResponse[ChatCompletion]], + response, + ) + + def _as_float(x: Any) -> Optional[float]: + try: + return float(x) + except Exception: + return None + + def _coerce_logprobs_inplace(raw: Dict[str, Any]) -> None: + """Mutate `raw` to normalize/strip logprobs without copying.""" + choices = raw.get("choices") + if not isinstance(choices, list): + return + + for ch in choices: + lp = ch.get("logprobs") + + if lp is None: + continue + + if isinstance(lp, dict) and isinstance(lp.get("content"), list): + for t in lp["content"]: + if isinstance(t, dict) and "logprob" in t: + t["logprob"] = _as_float(t["logprob"]) + tops = t.get("top_logprobs") + if isinstance(tops, list): + for tlp in tops: + if isinstance(tlp, dict) and "logprob" in tlp: + tlp["logprob"] = _as_float(tlp["logprob"]) + continue + + if isinstance(lp, dict) and {"tokens", "token_logprobs"} <= set( + lp.keys() + ): + tokens = lp.get("tokens") or [] + token_lps = lp.get("token_logprobs") or [] + top_series = lp.get("top_logprobs") + if not isinstance(top_series, list): + top_series = [] + + content: List[Dict[str, Any]] = [] + for i, (tok, logp) in enumerate(zip(tokens, token_lps)): + item: Dict[str, Any] = { + "token": tok, + "logprob": _as_float(logp), + } + tops_i = top_series[i] if i < len(top_series) else None + if isinstance(tops_i, dict): + item["top_logprobs"] = [ + {"token": k, "logprob": _as_float(v)} + for k, v in tops_i.items() + ] + content.append(item) + + ch["logprobs"] = {"content": content} + continue + + def parse_fast_no_validate(response) -> Union[ChatCompletion, Completion]: + """ + Convert the HTTP JSON to a ChatCompletion *without* running Pydantic validation. + Only coerces/normalizes logprobs; everything else passes through untouched. + """ + raw = response.http_response.json() + _coerce_logprobs_inplace(raw) + + return ChatCompletion.model_construct(**raw) + + # --- usage --- + chat_completion = parse_fast_no_validate(response) + return chat_completion + else: + return response + @abstractmethod async def rollout( self,