From 3363d3e2d67a045f97be0ea0ef42227a43b954cb Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 18 Mar 2026 16:22:05 +0900 Subject: [PATCH 1/3] feat: add any-llm model support with responses-compatible routing --- examples/model_providers/README.md | 25 +- examples/model_providers/any_llm_auto.py | 50 + examples/model_providers/any_llm_provider.py | 58 + examples/model_providers/litellm_auto.py | 11 +- examples/model_providers/litellm_provider.py | 12 +- pyproject.toml | 1 + src/agents/extensions/models/any_llm_model.py | 1208 +++++++++++++++++ .../extensions/models/any_llm_provider.py | 35 + src/agents/models/multi_provider.py | 7 +- tests/models/test_any_llm_model.py | 467 +++++++ tests/models/test_map.py | 24 + uv.lock | 35 +- 12 files changed, 1910 insertions(+), 23 deletions(-) create mode 100644 examples/model_providers/any_llm_auto.py create mode 100644 examples/model_providers/any_llm_provider.py create mode 100644 src/agents/extensions/models/any_llm_model.py create mode 100644 src/agents/extensions/models/any_llm_provider.py create mode 100644 tests/models/test_any_llm_model.py diff --git a/examples/model_providers/README.md b/examples/model_providers/README.md index f9330c24ad..a477e00f66 100644 --- a/examples/model_providers/README.md +++ b/examples/model_providers/README.md @@ -1,19 +1,24 @@ -# Custom LLM providers +# Model provider examples -The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model. +The examples in this directory show how to route models through adapter layers such as LiteLLM and +any-llm. The default examples all use OpenRouter so you only need one API key: ```bash -export EXAMPLE_BASE_URL="..." -export EXAMPLE_API_KEY="..." -export EXAMPLE_MODEL_NAME"..." +export OPENROUTER_API_KEY="..." ``` -Then run the examples, e.g.: +Run one of the adapter examples: +```bash +uv run examples/model_providers/any_llm_provider.py +uv run examples/model_providers/any_llm_auto.py +uv run examples/model_providers/litellm_provider.py +uv run examples/model_providers/litellm_auto.py ``` -python examples/model_providers/custom_example_provider.py -Loops within themselves, -Function calls its own being, -Depth without ending. +Direct-model examples let you override the target model: + +```bash +uv run examples/model_providers/any_llm_provider.py --model openrouter/openai/gpt-5.4-mini +uv run examples/model_providers/litellm_provider.py --model openrouter/openai/gpt-5.4-mini ``` diff --git a/examples/model_providers/any_llm_auto.py b/examples/model_providers/any_llm_auto.py new file mode 100644 index 0000000000..3a6bc8ba76 --- /dev/null +++ b/examples/model_providers/any_llm_auto.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import asyncio + +from pydantic import BaseModel + +from agents import Agent, ModelSettings, Runner, function_tool, set_tracing_disabled + +"""This example uses the built-in any-llm routing through OpenRouter. + +Set OPENROUTER_API_KEY before running it. +""" + +set_tracing_disabled(disabled=True) + + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +class Result(BaseModel): + output_text: str + tool_results: list[str] + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model="any-llm/openrouter/openai/gpt-5.4-mini", + tools=[get_weather], + model_settings=ModelSettings(tool_choice="required"), + output_type=Result, + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + import os + + if os.getenv("OPENROUTER_API_KEY") is None: + raise ValueError( + "OPENROUTER_API_KEY is not set. Please set the environment variable and try again." + ) + + asyncio.run(main()) diff --git a/examples/model_providers/any_llm_provider.py b/examples/model_providers/any_llm_provider.py new file mode 100644 index 0000000000..931efb11d6 --- /dev/null +++ b/examples/model_providers/any_llm_provider.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import asyncio +import os + +from agents import Agent, Runner, function_tool, set_tracing_disabled +from agents.extensions.models.any_llm_model import AnyLLMModel + +"""This example uses the AnyLLMModel directly. + +You can run it like this: +uv run examples/model_providers/any_llm_provider.py --model openrouter/openai/gpt-5.4-mini +or +uv run examples/model_providers/any_llm_provider.py --model openrouter/anthropic/claude-4.5-sonnet +""" + +set_tracing_disabled(disabled=True) + + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(model: str, api_key: str): + if api_key == "dummy": + print("Skipping run because no valid OPENROUTER_API_KEY was provided.") + return + + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=AnyLLMModel(model=model, api_key=api_key), + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=False) + parser.add_argument("--api-key", type=str, required=False) + args = parser.parse_args() + + model = args.model or os.environ.get("ANY_LLM_MODEL", "openrouter/openai/gpt-5.4-mini") + api_key = args.api_key or os.environ.get("OPENROUTER_API_KEY", "dummy") + + if not args.model: + print(f"Using default model: {model}") + if not args.api_key: + print("Using OPENROUTER_API_KEY from environment (or dummy placeholder).") + + asyncio.run(main(model, api_key)) diff --git a/examples/model_providers/litellm_auto.py b/examples/model_providers/litellm_auto.py index ca4959a69f..3b30a3ecb9 100644 --- a/examples/model_providers/litellm_auto.py +++ b/examples/model_providers/litellm_auto.py @@ -6,8 +6,9 @@ from agents import Agent, ModelSettings, Runner, function_tool, set_tracing_disabled -"""This example uses the built-in support for LiteLLM. To use this, ensure you have the -ANTHROPIC_API_KEY environment variable set. +"""This example uses the built-in support for LiteLLM through OpenRouter. + +Set OPENROUTER_API_KEY before running it. """ set_tracing_disabled(disabled=True) @@ -32,7 +33,7 @@ async def main(): name="Assistant", instructions="You only respond in haikus.", # We prefix with litellm/ to tell the Runner to use the LitellmModel - model="litellm/anthropic/claude-sonnet-4-5-20250929", + model="litellm/openrouter/openai/gpt-5.4-mini", tools=[get_weather], model_settings=ModelSettings(tool_choice="required"), output_type=Result, @@ -45,9 +46,9 @@ async def main(): if __name__ == "__main__": import os - if os.getenv("ANTHROPIC_API_KEY") is None: + if os.getenv("OPENROUTER_API_KEY") is None: raise ValueError( - "ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again." + "OPENROUTER_API_KEY is not set. Please set the environment variable and try again." ) asyncio.run(main()) diff --git a/examples/model_providers/litellm_provider.py b/examples/model_providers/litellm_provider.py index ea5f09ab32..d9e7db7734 100644 --- a/examples/model_providers/litellm_provider.py +++ b/examples/model_providers/litellm_provider.py @@ -8,9 +8,9 @@ """This example uses the LitellmModel directly, to hit any model provider. You can run it like this: -uv run examples/model_providers/litellm_provider.py --model anthropic/claude-3-5-sonnet-20240620 +uv run examples/model_providers/litellm_provider.py --model openrouter/openai/gpt-5.4-mini or -uv run examples/model_providers/litellm_provider.py --model gemini/gemini-2.0-flash +uv run examples/model_providers/litellm_provider.py --model openrouter/anthropic/claude-4.5-sonnet Find more providers here: https://docs.litellm.ai/docs/providers """ @@ -26,7 +26,7 @@ def get_weather(city: str): async def main(model: str, api_key: str): if api_key == "dummy": - print("Skipping run because no valid LITELLM_API_KEY was provided.") + print("Skipping run because no valid OPENROUTER_API_KEY was provided.") return agent = Agent( name="Assistant", @@ -48,12 +48,12 @@ async def main(model: str, api_key: str): parser.add_argument("--api-key", type=str, required=False) args = parser.parse_args() - model = args.model or os.environ.get("LITELLM_MODEL", "openai/gpt-4o-mini") - api_key = args.api_key or os.environ.get("LITELLM_API_KEY", "dummy") + model = args.model or os.environ.get("LITELLM_MODEL", "openrouter/openai/gpt-5.4-mini") + api_key = args.api_key or os.environ.get("OPENROUTER_API_KEY", "dummy") if not args.model: print(f"Using default model: {model}") if not args.api_key: - print("Using LITELLM_API_KEY from environment (or dummy placeholder).") + print("Using OPENROUTER_API_KEY from environment (or dummy placeholder).") asyncio.run(main(model, api_key)) diff --git a/pyproject.toml b/pyproject.toml index 98808d6f53..e22a8dd4ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python" voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] viz = ["graphviz>=0.17"] litellm = ["litellm>=1.81.0, <2"] +any-llm = ["any-llm-sdk>=1.11.0, <2; python_version >= '3.11'"] realtime = ["websockets>=15.0, <16"] sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"] encrypt = ["cryptography>=45.0, <46"] diff --git a/src/agents/extensions/models/any_llm_model.py b/src/agents/extensions/models/any_llm_model.py new file mode 100644 index 0000000000..bf5e526971 --- /dev/null +++ b/src/agents/extensions/models/any_llm_model.py @@ -0,0 +1,1208 @@ +from __future__ import annotations + +import inspect +import json +import time +from collections.abc import AsyncIterator, Iterable +from copy import copy +from typing import TYPE_CHECKING, Any, Literal, cast, overload + +from openai import NotGiven, omit +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionMessageCustomToolCall, + ChatCompletionMessageFunctionToolCall, + ChatCompletionMessageParam, +) +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_function_tool_call import Function +from openai.types.responses import Response, ResponseCompletedEvent, ResponseStreamEvent +from pydantic import BaseModel + +from ... import _debug +from ...agent_output import AgentOutputSchemaBase +from ...exceptions import ModelBehaviorError, UserError +from ...handoffs import Handoff +from ...items import ItemHelpers, ModelResponse, TResponseInputItem, TResponseStreamEvent +from ...logger import logger +from ...model_settings import ModelSettings +from ...models._openai_retry import get_openai_retry_advice +from ...models._retry_runtime import should_disable_provider_managed_retries +from ...models.chatcmpl_converter import Converter +from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers +from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler +from ...models.fake_id import FAKE_RESPONSES_ID +from ...models.interface import Model, ModelTracing +from ...models.openai_responses import ( + Converter as OpenAIResponsesConverter, + _coerce_response_includables, + _materialize_responses_tool_params, +) +from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest +from ...tool import Tool +from ...tracing import generation_span, response_span +from ...tracing.span_data import GenerationSpanData +from ...tracing.spans import Span +from ...usage import Usage +from ...util._json import _to_dump_compatible + +try: + from any_llm import AnyLLM +except ImportError as _e: + raise ImportError( + "`any-llm-sdk` is required to use the AnyLLMModel. Install it via the optional " + "dependency group: `pip install 'openai-agents[any-llm]'`. " + "`any-llm-sdk` currently requires Python 3.11+." + ) from _e + +if TYPE_CHECKING: + from openai.types.responses.response_prompt_param import ResponsePromptParam + + +class InternalChatCompletionMessage(ChatCompletionMessage): + """Internal wrapper used to carry normalized reasoning content.""" + + reasoning_content: str = "" + + +class _AnyLLMResponsesParamsShim: + """Fallback shim for tests and older any-llm layouts.""" + + def __init__(self, **payload: Any) -> None: + self._payload = payload + for key, value in payload.items(): + setattr(self, key, value) + + def model_dump(self, *, exclude_none: bool = False) -> dict[str, Any]: + if not exclude_none: + return dict(self._payload) + return {key: value for key, value in self._payload.items() if value is not None} + + +_ANY_LLM_RESPONSES_PARAM_FIELDS = { + "background", + "conversation", + "frequency_penalty", + "include", + "input", + "instructions", + "max_output_tokens", + "max_tool_calls", + "metadata", + "model", + "parallel_tool_calls", + "presence_penalty", + "previous_response_id", + "prompt_cache_key", + "prompt_cache_retention", + "reasoning", + "response_format", + "safety_identifier", + "service_tier", + "store", + "stream", + "stream_options", + "temperature", + "text", + "tool_choice", + "tools", + "top_logprobs", + "top_p", + "truncation", + "user", +} + + +def _convert_any_llm_tool_call_to_openai( + tool_call: Any, +) -> ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall: + tool_call_type = getattr(tool_call, "type", None) + if tool_call_type == "custom": + if isinstance(tool_call, BaseModel): + return ChatCompletionMessageCustomToolCall.model_validate(tool_call.model_dump()) + return ChatCompletionMessageCustomToolCall.model_validate(tool_call) + + function = getattr(tool_call, "function", None) + return ChatCompletionMessageFunctionToolCall( + id=str(getattr(tool_call, "id", "")), + type="function", + function=Function( + name=str(getattr(function, "name", "") or ""), + arguments=str(getattr(function, "arguments", "") or ""), + ), + ) + + +def _flatten_any_llm_reasoning_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, dict): + for key in ("content", "text", "thinking"): + flattened = _flatten_any_llm_reasoning_value(value.get(key)) + if flattened: + return flattened + return "" + if isinstance(value, Iterable) and not isinstance(value, (str, bytes)): + parts = [_flatten_any_llm_reasoning_value(item) for item in value] + return "".join(part for part in parts if part) + + for attr in ("content", "text", "thinking"): + flattened = _flatten_any_llm_reasoning_value(getattr(value, attr, None)) + if flattened: + return flattened + return "" + + +def _extract_any_llm_reasoning_text(value: Any) -> str: + direct_reasoning_content = getattr(value, "reasoning_content", None) + if isinstance(direct_reasoning_content, str): + return direct_reasoning_content + + reasoning = getattr(value, "reasoning", None) + if reasoning is None and isinstance(value, dict): + reasoning = value.get("reasoning") + if reasoning is None: + direct_reasoning_content = value.get("reasoning_content") + if isinstance(direct_reasoning_content, str): + return direct_reasoning_content + + if reasoning is None: + thinking = getattr(value, "thinking", None) + if thinking is None and isinstance(value, dict): + thinking = value.get("thinking") + return _flatten_any_llm_reasoning_value(thinking) + + return _flatten_any_llm_reasoning_value(reasoning) + + +def _normalize_any_llm_message(message: ChatCompletionMessage) -> ChatCompletionMessage: + if message.role != "assistant": + raise ModelBehaviorError(f"Unsupported role: {message.role}") + + tool_calls: ( + list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None + ) = None + if message.tool_calls: + tool_calls = [ + _convert_any_llm_tool_call_to_openai(tool_call) for tool_call in message.tool_calls + ] + + return InternalChatCompletionMessage( + content=message.content, + refusal=message.refusal, + role="assistant", + annotations=message.annotations, + audio=message.audio, + tool_calls=tool_calls, + reasoning_content=_extract_any_llm_reasoning_text(message), + ) + + +class AnyLLMModel(Model): + """Use any-llm as an adapter layer for chat completions and native Responses where supported.""" + + def __init__( + self, + model: str, + base_url: str | None = None, + api_key: str | None = None, + api: Literal["responses", "chat_completions"] | None = None, + ): + self.model = model + self.base_url = base_url + self.api_key = api_key + self.api: Literal["responses", "chat_completions"] | None = self._validate_api(api) + self._provider_name, self._provider_model = self._split_model_name(model) + self._provider_cache: dict[bool, Any] = {} + + def get_retry_advice(self, request: ModelRetryAdviceRequest) -> ModelRetryAdvice | None: + return get_openai_retry_advice(request) + + async def close(self) -> None: + seen_clients: set[int] = set() + for provider in self._provider_cache.values(): + client = getattr(provider, "client", None) + if client is None or id(client) in seen_clients: + continue + seen_clients.add(id(client)) + await self._maybe_aclose(client) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: ResponsePromptParam | None = None, + ) -> ModelResponse: + if self._selected_api() == "responses": + return await self._get_response_via_responses( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + return await self._get_response_via_chat( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + prompt=prompt, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: ResponsePromptParam | None = None, + ) -> AsyncIterator[TResponseStreamEvent]: + if self._selected_api() == "responses": + async for chunk in self._stream_response_via_responses( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ): + yield chunk + return + + async for chunk in self._stream_response_via_chat( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + prompt=prompt, + ): + yield chunk + + async def _get_response_via_responses( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> ModelResponse: + with response_span(disabled=tracing.is_disabled()) as span_response: + response = await self._fetch_responses_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=False, + prompt=prompt, + ) + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("LLM responded") + else: + logger.debug( + "LLM resp:\n%s\n", + json.dumps( + [item.model_dump() for item in response.output], + indent=2, + ensure_ascii=False, + ), + ) + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, + ) + if response.usage + else Usage() + ) + + if tracing.include_data(): + span_response.span_data.response = response + span_response.span_data.input = input + + return ModelResponse( + output=response.output, + usage=usage, + response_id=response.id, + request_id=getattr(response, "_request_id", None), + ) + + async def _stream_response_via_responses( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: + with response_span(disabled=tracing.is_disabled()) as span_response: + stream = await self._fetch_responses_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=True, + prompt=prompt, + ) + + final_response: Response | None = None + try: + async for chunk in stream: + if isinstance(chunk, ResponseCompletedEvent): + final_response = chunk.response + elif getattr(chunk, "type", None) in {"response.failed", "response.incomplete"}: + terminal_response = getattr(chunk, "response", None) + if isinstance(terminal_response, Response): + final_response = terminal_response + yield chunk + finally: + await self._maybe_aclose(stream) + + if tracing.include_data() and final_response: + span_response.span_data.response = final_response + span_response.span_data.input = input + + async def _get_response_via_chat( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + prompt: ResponsePromptParam | None, + ) -> ModelResponse: + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() + | { + "base_url": str(self.base_url or ""), + "provider": self._provider_name, + "model_impl": "any-llm", + }, + disabled=tracing.is_disabled(), + ) as span_generation: + response = await self._fetch_chat_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + span=span_generation, + tracing=tracing, + stream=False, + prompt=prompt, + ) + + message: ChatCompletionMessage | None = None + first_choice: Choice | None = None + if response.choices: + first_choice = response.choices[0] + message = first_choice.message + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Received model response") + else: + if message is not None: + logger.debug( + "LLM resp:\n%s\n", + json.dumps(message.model_dump(), indent=2, ensure_ascii=False), + ) + else: + finish_reason = first_choice.finish_reason if first_choice else "-" + logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.prompt_tokens_details, # type: ignore[arg-type] + output_tokens_details=response.usage.completion_tokens_details, # type: ignore[arg-type] + ) + if response.usage + else Usage() + ) + + if tracing.include_data(): + span_generation.span_data.output = ( + [message.model_dump()] if message is not None else [] + ) + span_generation.span_data.usage = { + "requests": usage.requests, + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.total_tokens, + "input_tokens_details": usage.input_tokens_details.model_dump(), + "output_tokens_details": usage.output_tokens_details.model_dump(), + } + + provider_data: dict[str, Any] = {"model": self.model} + if message is not None and hasattr(response, "id"): + provider_data["response_id"] = response.id + + items = ( + Converter.message_to_output_items( + _normalize_any_llm_message(message), + provider_data=provider_data, + ) + if message is not None + else [] + ) + + logprob_models = None + if first_choice and first_choice.logprobs and first_choice.logprobs.content: + logprob_models = ChatCmplHelpers.convert_logprobs_for_output_text( + first_choice.logprobs.content + ) + + if logprob_models: + self._attach_logprobs_to_output(items, logprob_models) + + return ModelResponse(output=items, usage=usage, response_id=None) + + async def _stream_response_via_chat( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[TResponseStreamEvent]: + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() + | { + "base_url": str(self.base_url or ""), + "provider": self._provider_name, + "model_impl": "any-llm", + }, + disabled=tracing.is_disabled(), + ) as span_generation: + response, stream = await self._fetch_chat_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + span=span_generation, + tracing=tracing, + stream=True, + prompt=prompt, + ) + + final_response: Response | None = None + try: + async for chunk in ChatCmplStreamHandler.handle_stream( + response, + cast(Any, self._normalize_chat_stream(stream)), + model=self.model, + ): + yield chunk + if chunk.type == "response.completed": + final_response = chunk.response + finally: + await self._maybe_aclose(stream) + + if tracing.include_data() and final_response: + span_generation.span_data.output = [final_response.model_dump()] + + if final_response and final_response.usage: + span_generation.span_data.usage = { + "requests": 1, + "input_tokens": final_response.usage.input_tokens, + "output_tokens": final_response.usage.output_tokens, + "total_tokens": final_response.usage.total_tokens, + "input_tokens_details": ( + final_response.usage.input_tokens_details.model_dump() + if final_response.usage.input_tokens_details + else {"cached_tokens": 0} + ), + "output_tokens_details": ( + final_response.usage.output_tokens_details.model_dump() + if final_response.usage.output_tokens_details + else {"reasoning_tokens": 0} + ), + } + + @overload + async def _fetch_chat_response( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[True], + prompt: ResponsePromptParam | None, + ) -> tuple[Response, AsyncIterator[ChatCompletionChunk]]: ... + + @overload + async def _fetch_chat_response( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[False], + prompt: ResponsePromptParam | None, + ) -> ChatCompletion: ... + + async def _fetch_chat_response( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: bool, + prompt: ResponsePromptParam | None, + ) -> ChatCompletion | tuple[Response, AsyncIterator[ChatCompletionChunk]]: + if prompt is not None: + raise UserError("AnyLLMModel does not currently support prompt-managed requests.") + + preserve_thinking_blocks = ( + model_settings.reasoning is not None and model_settings.reasoning.effort is not None + ) + converted_messages = Converter.items_to_messages( + input, + preserve_thinking_blocks=preserve_thinking_blocks, + preserve_tool_output_all_content=True, + model=self.model, + ) + if any(name in self.model.lower() for name in ["anthropic", "claude", "gemini"]): + converted_messages = self._fix_tool_message_ordering(converted_messages) + + if system_instructions: + converted_messages.insert(0, {"content": system_instructions, "role": "system"}) + converted_messages = _to_dump_compatible(converted_messages) + + if tracing.include_data(): + span.span_data.input = converted_messages + + parallel_tool_calls = ( + True + if model_settings.parallel_tool_calls and tools + else False + if model_settings.parallel_tool_calls is False + else None + ) + tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) + response_format = Converter.convert_response_format(output_schema) + converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] + for handoff in handoffs: + converted_tools.append(Converter.convert_handoff_tool(handoff)) + converted_tools = _to_dump_compatible(converted_tools) + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Calling LLM") + else: + logger.debug( + "Calling any-llm provider %s with messages:\n%s\nTools:\n%s\nStream: %s\n" + "Tool choice: %s\nResponse format: %s\n", + self._provider_name, + json.dumps(converted_messages, indent=2, ensure_ascii=False), + json.dumps(converted_tools, indent=2, ensure_ascii=False), + stream, + tool_choice, + response_format, + ) + + reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None + if reasoning_effort is None and model_settings.extra_args: + reasoning_effort = cast(Any, model_settings.extra_args.get("reasoning_effort")) + + stream_options = None + if stream and model_settings.include_usage is not None: + stream_options = {"include_usage": model_settings.include_usage} + + extra_kwargs = self._build_chat_extra_kwargs(model_settings) + extra_kwargs.pop("reasoning_effort", None) + + ret = await self._get_provider().acompletion( + model=self._provider_model, + messages=converted_messages, + tools=converted_tools or None, + temperature=model_settings.temperature, + top_p=model_settings.top_p, + frequency_penalty=model_settings.frequency_penalty, + presence_penalty=model_settings.presence_penalty, + max_tokens=model_settings.max_tokens, + tool_choice=self._remove_not_given(tool_choice), + response_format=self._remove_not_given(response_format), + parallel_tool_calls=parallel_tool_calls, + stream=stream, + stream_options=stream_options, + reasoning_effort=reasoning_effort, + top_logprobs=model_settings.top_logprobs, + extra_headers=self._merge_headers(model_settings), + **extra_kwargs, + ) + + if isinstance(ret, ChatCompletion): + return self._normalize_chat_completion_response(ret) + + responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice( + model_settings.tool_choice + ) + if responses_tool_choice is None or responses_tool_choice is omit: + responses_tool_choice = "auto" + + response = Response( + id=FAKE_RESPONSES_ID, + created_at=time.time(), + model=self.model, + object="response", + output=[], + tool_choice=responses_tool_choice, # type: ignore[arg-type] + top_p=model_settings.top_p, + temperature=model_settings.temperature, + tools=[], + parallel_tool_calls=parallel_tool_calls or False, + reasoning=model_settings.reasoning, + ) + return response, cast(AsyncIterator[ChatCompletionChunk], ret) + + @overload + async def _fetch_responses_response( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: Literal[True], + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: ... + + @overload + async def _fetch_responses_response( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: Literal[False], + prompt: ResponsePromptParam | None, + ) -> Response: ... + + async def _fetch_responses_response( + self, + *, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: bool, + prompt: ResponsePromptParam | None, + ) -> Response | AsyncIterator[ResponseStreamEvent]: + if prompt is not None: + raise UserError("AnyLLMModel does not currently support prompt-managed requests.") + + if not self._supports_responses(): + raise UserError(f"Provider '{self._provider_name}' does not support the Responses API.") + + list_input = ItemHelpers.input_to_new_input_list(input) + list_input = _to_dump_compatible(list_input) + list_input = self._remove_openai_responses_api_incompatible_fields(list_input) + + parallel_tool_calls = ( + True + if model_settings.parallel_tool_calls and tools + else False + if model_settings.parallel_tool_calls is False + else None + ) + + tool_choice = OpenAIResponsesConverter.convert_tool_choice( + model_settings.tool_choice, + tools=tools, + handoffs=handoffs, + model=self._provider_model, + ) + converted_tools = OpenAIResponsesConverter.convert_tools( + tools, + handoffs, + model=self._provider_model, + tool_choice=model_settings.tool_choice, + ) + converted_tools_payload = _materialize_responses_tool_params(converted_tools.tools) + + include_set = set(converted_tools.includes) + if model_settings.response_include is not None: + include_set.update(_coerce_response_includables(model_settings.response_include)) + if model_settings.top_logprobs is not None: + include_set.add("message.output_text.logprobs") + include = list(include_set) or None + + text = OpenAIResponsesConverter.get_response_format(output_schema) + if model_settings.verbosity is not None: + if text is not omit: + text["verbosity"] = model_settings.verbosity # type: ignore[index] + else: + text = {"verbosity": model_settings.verbosity} + + request_kwargs: dict[str, Any] = { + "model": self._provider_model, + "input": list_input, + "instructions": system_instructions, + "tools": converted_tools_payload or None, + "tool_choice": self._remove_not_given(tool_choice), + "temperature": model_settings.temperature, + "top_p": model_settings.top_p, + "max_output_tokens": model_settings.max_tokens, + "stream": stream, + "truncation": model_settings.truncation, + "store": model_settings.store, + "previous_response_id": previous_response_id, + "conversation": conversation_id, + "include": include, + "parallel_tool_calls": parallel_tool_calls, + "reasoning": _to_dump_compatible(model_settings.reasoning) + if model_settings.reasoning is not None + else None, + "text": self._remove_not_given(text), + **self._build_responses_extra_kwargs(model_settings), + } + transport_kwargs = self._build_responses_transport_kwargs(model_settings) + + response = await self._call_any_llm_responses( + request_kwargs=request_kwargs, + transport_kwargs=transport_kwargs, + ) + + if stream: + return cast(AsyncIterator[ResponseStreamEvent], response) + + return self._normalize_response(response) + + @staticmethod + def _split_model_name(model: str) -> tuple[str, str]: + if not model: + raise UserError("AnyLLMModel requires a non-empty model name.") + if "/" not in model: + return "openai", model + + provider_name, provider_model = model.split("/", 1) + if not provider_name or not provider_model: + raise UserError( + "AnyLLMModel expects model names in the form 'provider/model', " + "for example 'openrouter/openai/gpt-5.4-mini'." + ) + return provider_name, provider_model + + def _supports_responses(self) -> bool: + return bool(getattr(self._get_provider(), "SUPPORTS_RESPONSES", False)) + + @staticmethod + def _validate_api( + api: Literal["responses", "chat_completions"] | None, + ) -> Literal["responses", "chat_completions"] | None: + if api not in {None, "responses", "chat_completions"}: + raise UserError( + "AnyLLMModel api must be one of: None, 'responses', 'chat_completions'." + ) + return api + + def _selected_api(self) -> Literal["responses", "chat_completions"]: + if self.api is not None: + if self.api == "responses" and not self._supports_responses(): + raise UserError( + f"Provider '{self._provider_name}' does not support the Responses API." + ) + return self.api + + return "responses" if self._supports_responses() else "chat_completions" + + def _get_provider(self) -> Any: + disable_provider_retries = should_disable_provider_managed_retries() + cached = self._provider_cache.get(disable_provider_retries) + if cached is not None: + return cached + + base_provider = self._provider_cache.get(False) + if base_provider is None: + base_provider = AnyLLM.create( + self._provider_name, + api_key=self.api_key, + api_base=self.base_url, + ) + self._provider_cache[False] = base_provider + + if disable_provider_retries: + cloned = self._clone_provider_without_retries(base_provider) + self._provider_cache[True] = cloned + return cloned + + return base_provider + + def _clone_provider_without_retries(self, provider: Any) -> Any: + client = getattr(provider, "client", None) + with_options = getattr(client, "with_options", None) + if not callable(with_options): + return provider + + cloned_provider = copy(provider) + cloned_provider.client = with_options(max_retries=0) + return cloned_provider + + def _normalize_response(self, response: Any) -> Response: + if isinstance(response, Response): + return response + if isinstance(response, BaseModel): + return Response.model_validate(response.model_dump()) + return Response.model_validate(response) + + def _normalize_chat_completion_response(self, response: Any) -> ChatCompletion: + if isinstance(response, ChatCompletion): + return response + if isinstance(response, BaseModel): + return ChatCompletion.model_validate(response.model_dump()) + return ChatCompletion.model_validate(response) + + async def _normalize_chat_stream( + self, stream: AsyncIterator[ChatCompletionChunk] + ) -> AsyncIterator[ChatCompletionChunk]: + async for chunk in stream: + yield self._normalize_chat_chunk(chunk) + + def _normalize_chat_chunk(self, chunk: Any) -> ChatCompletionChunk: + normalized_chunk = chunk + if not isinstance(normalized_chunk, ChatCompletionChunk): + normalized_chunk = ChatCompletionChunk.model_validate(chunk) + if not normalized_chunk.choices: + return normalized_chunk + + delta = normalized_chunk.choices[0].delta + reasoning_text = _extract_any_llm_reasoning_text(delta) + if not reasoning_text: + return normalized_chunk + + payload = normalized_chunk.model_dump() + choices = payload.get("choices") + if not isinstance(choices, list) or not choices: + return normalized_chunk + + delta_payload = choices[0].get("delta") + if not isinstance(delta_payload, dict): + return normalized_chunk + + delta_payload["reasoning"] = reasoning_text + choices[0]["delta"] = delta_payload + payload["choices"] = choices + return ChatCompletionChunk.model_validate(payload) + + @staticmethod + async def _maybe_aclose(value: Any) -> None: + aclose = getattr(value, "aclose", None) + if callable(aclose): + await aclose() + return + + close = getattr(value, "close", None) + if callable(close): + result = close() + if inspect.isawaitable(result): + await result + + def _build_chat_extra_kwargs(self, model_settings: ModelSettings) -> dict[str, Any]: + extra_kwargs: dict[str, Any] = {} + if model_settings.extra_query: + extra_kwargs["extra_query"] = copy(model_settings.extra_query) + if model_settings.metadata: + extra_kwargs["metadata"] = copy(model_settings.metadata) + if isinstance(model_settings.extra_body, dict): + extra_kwargs.update(model_settings.extra_body) + if model_settings.extra_args: + extra_kwargs.update(model_settings.extra_args) + return extra_kwargs + + def _build_responses_extra_kwargs(self, model_settings: ModelSettings) -> dict[str, Any]: + extra_kwargs = dict(model_settings.extra_args or {}) + if model_settings.top_logprobs is not None: + extra_kwargs["top_logprobs"] = model_settings.top_logprobs + if model_settings.metadata is not None: + extra_kwargs["metadata"] = copy(model_settings.metadata) + if model_settings.extra_query is not None: + extra_kwargs["extra_query"] = copy(model_settings.extra_query) + if model_settings.extra_body is not None: + extra_kwargs["extra_body"] = copy(model_settings.extra_body) + return extra_kwargs + + def _build_responses_transport_kwargs(self, model_settings: ModelSettings) -> dict[str, Any]: + transport_kwargs: dict[str, Any] = {} + headers = self._merge_headers(model_settings) + if headers: + transport_kwargs["extra_headers"] = headers + return transport_kwargs + + async def _call_any_llm_responses( + self, + *, + request_kwargs: dict[str, Any], + transport_kwargs: dict[str, Any], + ) -> Response | AsyncIterator[ResponseStreamEvent]: + provider = self._get_provider() + if not transport_kwargs: + response = await provider.aresponses( + model=request_kwargs["model"], + input_data=request_kwargs["input"], + **{ + key: value + for key, value in request_kwargs.items() + if key not in {"model", "input"} + }, + ) + return cast(Response | AsyncIterator[ResponseStreamEvent], response) + + params_payload = { + key: value + for key, value in request_kwargs.items() + if key in _ANY_LLM_RESPONSES_PARAM_FIELDS + } + provider_kwargs = { + key: value + for key, value in request_kwargs.items() + if key not in _ANY_LLM_RESPONSES_PARAM_FIELDS + } + provider_kwargs.update(transport_kwargs) + + # any-llm 1.11.0 validates public `aresponses()` kwargs against ResponsesParams, + # which rejects OpenAI transport kwargs like `extra_headers`. Build the params + # model ourselves so we can still pass transport kwargs through to the provider. + response = await provider._aresponses( + self._make_any_llm_responses_params(params_payload), + **provider_kwargs, + ) + return cast(Response | AsyncIterator[ResponseStreamEvent], response) + + @staticmethod + def _make_any_llm_responses_params(payload: dict[str, Any]) -> Any: + try: + from any_llm.types.responses import ResponsesParams as AnyLLMResponsesParams + except ImportError: + return _AnyLLMResponsesParamsShim(**payload) + + return AnyLLMResponsesParams(**payload) + + def _remove_openai_responses_api_incompatible_fields(self, list_input: list[Any]) -> list[Any]: + has_provider_data = any( + isinstance(item, dict) and item.get("provider_data") for item in list_input + ) + if not has_provider_data: + return list_input + + result: list[Any] = [] + for item in list_input: + cleaned = self._clean_item_for_openai(item) + if cleaned is not None: + result.append(cleaned) + return result + + def _clean_item_for_openai(self, item: Any) -> Any | None: + if not isinstance(item, dict): + return item + + if item.get("type") == "reasoning" and item.get("provider_data"): + return None + if item.get("id") == FAKE_RESPONSES_ID: + del item["id"] + if "provider_data" in item: + del item["provider_data"] + return item + + def _attach_logprobs_to_output(self, output_items: list[Any], logprobs: list[Any]) -> None: + from openai.types.responses import ResponseOutputMessage, ResponseOutputText + + for output_item in output_items: + if not isinstance(output_item, ResponseOutputMessage): + continue + for content in output_item.content: + if isinstance(content, ResponseOutputText): + content.logprobs = logprobs + return + + def _remove_not_given(self, value: Any) -> Any: + if value is omit or isinstance(value, NotGiven): + return None + return value + + def _merge_headers(self, model_settings: ModelSettings) -> dict[str, str]: + headers: dict[str, str] = {**HEADERS} + for source in (model_settings.extra_headers or {}, HEADERS_OVERRIDE.get() or {}): + for key, value in source.items(): + if isinstance(value, str): + headers[key] = value + return headers + + def _fix_tool_message_ordering( + self, messages: list[ChatCompletionMessageParam] + ) -> list[ChatCompletionMessageParam]: + if not messages: + return messages + + tool_call_messages: dict[str, tuple[int, ChatCompletionMessageParam]] = {} + tool_result_messages: dict[str, tuple[int, ChatCompletionMessageParam]] = {} + paired_tool_result_indices: set[int] = set() + fixed_messages: list[ChatCompletionMessageParam] = [] + used_indices: set[int] = set() + + for index, message in enumerate(messages): + if not isinstance(message, dict): + continue + message_dict = cast(dict[str, Any], message) + + if message_dict.get("role") == "assistant" and message_dict.get("tool_calls"): + tool_calls = message_dict.get("tool_calls", []) + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if isinstance(tool_call, dict) and tool_call.get("id"): + single_tool_msg = message_dict.copy() + single_tool_msg["tool_calls"] = [tool_call] + tool_call_messages[str(tool_call["id"])] = ( + index, + cast(ChatCompletionMessageParam, single_tool_msg), + ) + elif message_dict.get("role") == "tool" and message_dict.get("tool_call_id"): + tool_result_messages[str(message_dict["tool_call_id"])] = ( + index, + cast(ChatCompletionMessageParam, message_dict), + ) + + for tool_id in tool_call_messages: + if tool_id in tool_result_messages: + paired_tool_result_indices.add(tool_result_messages[tool_id][0]) + + for index, original_message in enumerate(messages): + if index in used_indices: + continue + + if not isinstance(original_message, dict): + fixed_messages.append(original_message) + used_indices.add(index) + continue + + role = original_message.get("role") + if role == "assistant" and original_message.get("tool_calls"): + tool_calls = original_message.get("tool_calls", []) + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + tool_id_value = tool_call.get("id") + if not isinstance(tool_id_value, str): + continue + tool_id = tool_id_value + if tool_id in tool_call_messages and tool_id in tool_result_messages: + _, tool_call_message = tool_call_messages[tool_id] + tool_result_index, tool_result_message = tool_result_messages[tool_id] + fixed_messages.append(tool_call_message) + fixed_messages.append(tool_result_message) + used_indices.add(tool_call_messages[tool_id][0]) + used_indices.add(tool_result_index) + elif tool_id in tool_call_messages: + _, tool_call_message = tool_call_messages[tool_id] + fixed_messages.append(tool_call_message) + used_indices.add(tool_call_messages[tool_id][0]) + used_indices.add(index) + elif role == "tool": + if index not in paired_tool_result_indices: + fixed_messages.append(original_message) + used_indices.add(index) + else: + fixed_messages.append(original_message) + used_indices.add(index) + + return fixed_messages diff --git a/src/agents/extensions/models/any_llm_provider.py b/src/agents/extensions/models/any_llm_provider.py new file mode 100644 index 0000000000..f327869499 --- /dev/null +++ b/src/agents/extensions/models/any_llm_provider.py @@ -0,0 +1,35 @@ +from typing import Literal + +from ...models.default_models import get_default_model +from ...models.interface import Model, ModelProvider +from .any_llm_model import AnyLLMModel + +DEFAULT_MODEL: str = f"openai/{get_default_model()}" + + +class AnyLLMProvider(ModelProvider): + """A ModelProvider that routes model calls through any-llm. + + API keys are typically sourced from the provider-specific environment variables expected by + any-llm, such as `OPENAI_API_KEY` or `OPENROUTER_API_KEY`. For custom wiring or explicit + credentials, instantiate `AnyLLMModel` directly. + """ + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | None = None, + api: Literal["responses", "chat_completions"] | None = None, + ) -> None: + self.api_key = api_key + self.base_url = base_url + self.api = api + + def get_model(self, model_name: str | None) -> Model: + return AnyLLMModel( + model=model_name or DEFAULT_MODEL, + api_key=self.api_key, + base_url=self.base_url, + api=self.api, + ) diff --git a/src/agents/models/multi_provider.py b/src/agents/models/multi_provider.py index bc7126d5ad..dc9087c430 100644 --- a/src/agents/models/multi_provider.py +++ b/src/agents/models/multi_provider.py @@ -61,6 +61,7 @@ class MultiProvider(ModelProvider): mapping is: - "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1" - "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1" + - "any-llm/" prefix -> AnyLLMProvider. e.g. "any-llm/openrouter/openai/gpt-4.1" You can override or customize this mapping. The ``openai`` prefix is ambiguous for some OpenAI-compatible backends because a string like ``openai/gpt-4.1`` could mean either "route @@ -143,6 +144,10 @@ def _create_fallback_provider(self, prefix: str) -> ModelProvider: from ..extensions.models.litellm_provider import LitellmProvider return LitellmProvider() + elif prefix == "any-llm": + from ..extensions.models.any_llm_provider import AnyLLMProvider + + return AnyLLMProvider() else: raise UserError(f"Unknown prefix: {prefix}") @@ -181,7 +186,7 @@ def _resolve_prefixed_model( if self.provider_map and (provider := self.provider_map.get_provider(prefix)): return provider, stripped_model_name - if prefix == "litellm": + if prefix in {"litellm", "any-llm"}: return self._get_fallback_provider(prefix), stripped_model_name if prefix == "openai": diff --git a/tests/models/test_any_llm_model.py b/tests/models/test_any_llm_model.py new file mode 100644 index 0000000000..49fb635ed8 --- /dev/null +++ b/tests/models/test_any_llm_model.py @@ -0,0 +1,467 @@ +from __future__ import annotations + +import importlib +import sys +import types as pytypes +from collections.abc import AsyncIterator +from typing import Any + +import pytest +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from openai.types.completion_usage import CompletionUsage, PromptTokensDetails +from openai.types.responses import Response, ResponseCompletedEvent, ResponseOutputMessage +from openai.types.responses.response_output_text import ResponseOutputText +from openai.types.responses.response_usage import ( + InputTokensDetails, + OutputTokensDetails, + ResponseUsage, +) + +from agents import ModelSettings, ModelTracing, __version__ +from agents.exceptions import UserError +from agents.models.chatcmpl_helpers import HEADERS_OVERRIDE + + +class FakeAnyLLMProvider: + def __init__( + self, + *, + supports_responses: bool, + chat_response: Any | None = None, + responses_response: Any | None = None, + ) -> None: + self.SUPPORTS_RESPONSES = supports_responses + self.chat_response = chat_response + self.responses_response = responses_response + self.chat_calls: list[dict[str, Any]] = [] + self.responses_calls: list[dict[str, Any]] = [] + self.private_responses_calls: list[dict[str, Any]] = [] + + async def acompletion(self, **kwargs: Any) -> Any: + self.chat_calls.append(kwargs) + return self.chat_response + + async def aresponses(self, **kwargs: Any) -> Any: + self.responses_calls.append(kwargs) + return self.responses_response + + async def _aresponses(self, params: Any, **kwargs: Any) -> Any: + self.private_responses_calls.append({"params": params, "kwargs": kwargs}) + return self.responses_response + + +def _import_any_llm_module( + monkeypatch: pytest.MonkeyPatch, + provider: FakeAnyLLMProvider, +) -> tuple[Any, list[dict[str, Any]]]: + create_calls: list[dict[str, Any]] = [] + + class FakeAnyLLMFactory: + @staticmethod + def create(provider_name: str, api_key: str | None = None, api_base: str | None = None): + create_calls.append( + { + "provider_name": provider_name, + "api_key": api_key, + "api_base": api_base, + } + ) + return provider + + fake_any_llm: Any = pytypes.ModuleType("any_llm") + fake_any_llm.AnyLLM = FakeAnyLLMFactory + + sys.modules.pop("agents.extensions.models.any_llm_model", None) + monkeypatch.setitem(sys.modules, "any_llm", fake_any_llm) + + module = importlib.import_module("agents.extensions.models.any_llm_model") + monkeypatch.setattr(module, "AnyLLM", FakeAnyLLMFactory, raising=True) + return module, create_calls + + +def _chat_completion(text: str) -> ChatCompletion: + return ChatCompletion( + id="chatcmpl_123", + created=0, + model="fake-model", + object="chat.completion", + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage(role="assistant", content=text), + ) + ], + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + ), + ) + + +def _responses_output(text: str) -> list[Any]: + return [ + ResponseOutputMessage( + id="msg_123", + role="assistant", + status="completed", + type="message", + content=[ + ResponseOutputText( + text=text, + type="output_text", + annotations=[], + logprobs=[], + ) + ], + ) + ] + + +def _response(text: str, response_id: str = "resp_123") -> Response: + return Response( + id=response_id, + created_at=123, + model="fake-model", + object="response", + output=_responses_output(text), + tool_choice="none", + tools=[], + parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=11, + output_tokens=13, + total_tokens=24, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), + ) + + +async def _empty_chat_stream() -> AsyncIterator[ChatCompletionChunk]: + if False: + yield ChatCompletionChunk( + id="chunk_123", + created=0, + model="fake-model", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(), finish_reason=None)], + ) + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +@pytest.mark.parametrize("override_ua", [None, "test_user_agent"]) +async def test_user_agent_header_any_llm_chat(override_ua: str | None, monkeypatch) -> None: + provider = FakeAnyLLMProvider(supports_responses=False, chat_response=_chat_completion("Hello")) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openrouter/openai/gpt-5.4-mini") + expected_ua = override_ua or f"Agents/Python {__version__}" + + if override_ua is not None: + token = HEADERS_OVERRIDE.set({"User-Agent": override_ua}) + else: + token = None + try: + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + finally: + if token is not None: + HEADERS_OVERRIDE.reset(token) + + assert provider.chat_calls[0]["extra_headers"]["User-Agent"] == expected_ua + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_chat_path_is_used_when_responses_are_unsupported(monkeypatch) -> None: + provider = FakeAnyLLMProvider(supports_responses=False, chat_response=_chat_completion("Hello")) + module, create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openrouter/openai/gpt-5.4-mini", api_key="router-key") + response = await model.get_response( + system_instructions="You are terse.", + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id="resp_prev", + conversation_id="conv_123", + prompt=None, + ) + + assert create_calls == [ + { + "provider_name": "openrouter", + "api_key": "router-key", + "api_base": None, + } + ] + assert len(provider.chat_calls) == 1 + assert provider.responses_calls == [] + assert provider.chat_calls[0]["model"] == "openai/gpt-5.4-mini" + assert response.response_id is None + assert response.output[0].content[0].text == "Hello" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_responses_path_is_used_when_supported(monkeypatch) -> None: + provider = FakeAnyLLMProvider(supports_responses=True, responses_response=_response("Hello")) + module, create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="gpt-5.4-mini", api_key="openai-key") + response = await model.get_response( + system_instructions="You are terse.", + input="hi", + model_settings=ModelSettings(store=True), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id="resp_prev", + conversation_id="conv_123", + prompt=None, + ) + + assert create_calls == [ + { + "provider_name": "openai", + "api_key": "openai-key", + "api_base": None, + } + ] + assert provider.chat_calls == [] + assert provider.responses_calls == [] + assert len(provider.private_responses_calls) == 1 + params = provider.private_responses_calls[0]["params"] + kwargs = provider.private_responses_calls[0]["kwargs"] + assert params.model == "gpt-5.4-mini" + assert params.previous_response_id == "resp_prev" + assert params.conversation == "conv_123" + assert kwargs["extra_headers"]["User-Agent"] == f"Agents/Python {__version__}" + assert response.response_id == "resp_123" + assert response.output[0].content[0].text == "Hello" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_can_force_chat_completions_when_responses_are_supported(monkeypatch) -> None: + provider = FakeAnyLLMProvider( + supports_responses=True, + chat_response=_chat_completion("Hello from chat"), + responses_response=_response("Hello from responses"), + ) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openai/gpt-4.1-mini", api="chat_completions") + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id="resp_prev", + conversation_id="conv_123", + prompt=None, + ) + + assert len(provider.chat_calls) == 1 + assert provider.responses_calls == [] + assert response.response_id is None + assert response.output[0].content[0].text == "Hello from chat" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_forced_responses_errors_when_provider_does_not_support_it( + monkeypatch, +) -> None: + provider = FakeAnyLLMProvider(supports_responses=False, chat_response=_chat_completion("Hello")) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openrouter/openai/gpt-4.1-mini", api="responses") + with pytest.raises(UserError, match="does not support the Responses API"): + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_stream_uses_chat_handler_when_responses_are_unsupported(monkeypatch) -> None: + provider = FakeAnyLLMProvider(supports_responses=False, chat_response=_empty_chat_stream()) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + completed = ResponseCompletedEvent( + type="response.completed", + response=_response("Hello from stream"), + sequence_number=1, + ) + + async def fake_handle_stream(response, stream, model=None): + assert model == "openrouter/openai/gpt-5.4-mini" + async for _chunk in stream: + pass + yield completed + + monkeypatch.setattr(module.ChatCmplStreamHandler, "handle_stream", fake_handle_stream) + + model = AnyLLMModel(model="openrouter/openai/gpt-5.4-mini") + events = [ + event + async for event in model.stream_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + ] + + assert [event.type for event in events] == ["response.completed"] + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_stream_passthrough_uses_responses_when_supported(monkeypatch) -> None: + async def response_stream() -> AsyncIterator[ResponseCompletedEvent]: + yield ResponseCompletedEvent( + type="response.completed", + response=_response("Hello from responses stream"), + sequence_number=1, + ) + + provider = FakeAnyLLMProvider(supports_responses=True, responses_response=response_stream()) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openai/gpt-5.4-mini") + events = [ + event + async for event in model.stream_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id="resp_prev", + conversation_id="conv_123", + prompt=None, + ) + ] + + assert [event.type for event in events] == ["response.completed"] + assert provider.responses_calls == [] + assert provider.private_responses_calls[0]["params"].previous_response_id == "resp_prev" + assert provider.private_responses_calls[0]["params"].conversation == "conv_123" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_responses_path_passes_transport_kwargs_via_private_provider_api( + monkeypatch, +) -> None: + provider = FakeAnyLLMProvider(supports_responses=True, responses_response=_response("Hello")) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openai/gpt-5.4-mini") + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings( + extra_headers={"X-Test-Header": "test"}, + extra_query={"trace": "1"}, + extra_body={"foo": "bar"}, + ), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + assert provider.responses_calls == [] + assert len(provider.private_responses_calls) == 1 + call = provider.private_responses_calls[0] + assert call["kwargs"]["extra_headers"]["X-Test-Header"] == "test" + assert call["kwargs"]["extra_query"] == {"trace": "1"} + assert call["kwargs"]["extra_body"] == {"foo": "bar"} + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_prompt_requests_fail_fast(monkeypatch) -> None: + provider = FakeAnyLLMProvider(supports_responses=True, responses_response=_response("Hello")) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openai/gpt-5.4-mini") + with pytest.raises(Exception, match="prompt-managed requests"): + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt={"id": "pmpt_123"}, + ) + + +def test_any_llm_provider_passes_api_override() -> None: + from agents.extensions.models.any_llm_model import AnyLLMModel + from agents.extensions.models.any_llm_provider import AnyLLMProvider + + provider = AnyLLMProvider(api="chat_completions") + model = provider.get_model("openai/gpt-4.1-mini") + + assert isinstance(model, AnyLLMModel) + assert model.api == "chat_completions" diff --git a/tests/models/test_map.py b/tests/models/test_map.py index 3e4f913718..15d4e74951 100644 --- a/tests/models/test_map.py +++ b/tests/models/test_map.py @@ -33,6 +33,30 @@ def test_litellm_prefix_is_litellm(): assert isinstance(model, LitellmModel) +def test_any_llm_prefix_uses_any_llm_provider(monkeypatch): + import sys + import types as pytypes + + captured_model: dict[str, Any] = {} + + class FakeAnyLLMModel: + pass + + class FakeAnyLLMProvider: + def get_model(self, model_name): + captured_model["value"] = model_name + return FakeAnyLLMModel() + + fake_module: Any = pytypes.ModuleType("agents.extensions.models.any_llm_provider") + fake_module.AnyLLMProvider = FakeAnyLLMProvider + monkeypatch.setitem(sys.modules, "agents.extensions.models.any_llm_provider", fake_module) + + agent = Agent(model="any-llm/openrouter/openai/gpt-5.4-mini", instructions="", name="test") + model = get_model(agent, RunConfig()) + assert isinstance(model, FakeAnyLLMModel) + assert captured_model["value"] == "openrouter/openai/gpt-5.4-mini" + + def test_no_prefix_can_use_openai_responses_websocket(): agent = Agent(model="gpt-4o", instructions="", name="test") model = get_model( diff --git a/uv.lock b/uv.lock index 7e616e7bba..8bbb2f9304 100644 --- a/uv.lock +++ b/uv.lock @@ -136,6 +136,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "any-llm-sdk" +version = "1.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx", marker = "python_full_version >= '3.11'" }, + { name = "openai", marker = "python_full_version >= '3.11'" }, + { name = "openresponses-types", marker = "python_full_version >= '3.11'" }, + { name = "pydantic", marker = "python_full_version >= '3.11'" }, + { name = "rich", marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/18/161747c16bbe4b15122ac690e7941f3c58f24b3df382189fdbadf0624595/any_llm_sdk-1.11.0.tar.gz", hash = "sha256:cabda4135041127e728d6d6fe6a3c0d77f45c0dd50b38a8f0bc132a2ad948a6a", size = 148392, upload-time = "2026-03-12T13:18:29.74Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/d7/3d89d25e08e7bef70565b8af1872407a636308ba5fa203c667134157344b/any_llm_sdk-1.11.0-py3-none-any.whl", hash = "sha256:1329bfb7c5fea68918ff0a8f47ecde876bb2e2a8cf990500adb6ec119339010f", size = 206124, upload-time = "2026-03-12T13:18:28.116Z" }, +] + [[package]] name = "anyio" version = "4.10.0" @@ -1901,6 +1918,9 @@ dependencies = [ ] [package.optional-dependencies] +any-llm = [ + { name = "any-llm-sdk", marker = "python_full_version >= '3.11'" }, +] dapr = [ { name = "dapr" }, { name = "grpcio" }, @@ -1965,6 +1985,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "any-llm-sdk", marker = "python_full_version >= '3.11' and extra == 'any-llm'", specifier = ">=1.11.0,<2" }, { name = "asyncpg", marker = "extra == 'sqlalchemy'", specifier = ">=0.29.0" }, { name = "cryptography", marker = "extra == 'encrypt'", specifier = ">=45.0,<46" }, { name = "dapr", marker = "extra == 'dapr'", specifier = ">=1.16.0" }, @@ -1984,7 +2005,7 @@ requires-dist = [ { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<16" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] -provides-extras = ["voice", "viz", "litellm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr"] +provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr"] [package.metadata.requires-dev] dev = [ @@ -2020,6 +2041,18 @@ dev = [ { name = "websockets" }, ] +[[package]] +name = "openresponses-types" +version = "2.3.0.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/26/b612c3215f5599714fa94d63eb5ee59b4eb66dbdeeaf86bb4d848359484d/openresponses_types-2.3.0.post1.tar.gz", hash = "sha256:11b8896d3621d2ac2439f6ff106f34ddcb1bbd517c317a6c852a9df2e98a0753", size = 19254, upload-time = "2026-01-22T20:02:03.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/5f/e16dad89ed24f586da5b01b9b206d3adbf21fe1af8e4dc55d5b93158fde6/openresponses_types-2.3.0.post1-py3-none-any.whl", hash = "sha256:88f6abcef9cad839203abff420dd080978bf6eb33cc06ddc5d78da4ccdba7613", size = 13847, upload-time = "2026-01-22T20:02:02.582Z" }, +] + [[package]] name = "packaging" version = "25.0" From c76f8581246787bd0004f64a0314a816ddc8d181 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 18 Mar 2026 16:47:32 +0900 Subject: [PATCH 2/3] fix review comments --- src/agents/extensions/models/any_llm_model.py | 46 ++++-- tests/models/test_any_llm_model.py | 133 +++++++++++++++++- 2 files changed, 164 insertions(+), 15 deletions(-) diff --git a/src/agents/extensions/models/any_llm_model.py b/src/agents/extensions/models/any_llm_model.py index bf5e526971..8eb6090ebe 100644 --- a/src/agents/extensions/models/any_llm_model.py +++ b/src/agents/extensions/models/any_llm_model.py @@ -17,7 +17,6 @@ ChatCompletionMessageParam, ) from openai.types.chat.chat_completion import Choice -from openai.types.chat.chat_completion_message_function_tool_call import Function from openai.types.responses import Response, ResponseCompletedEvent, ResponseStreamEvent from pydantic import BaseModel @@ -49,7 +48,7 @@ from ...util._json import _to_dump_compatible try: - from any_llm import AnyLLM + from any_llm import AnyLLM # type: ignore[import-not-found] except ImportError as _e: raise ImportError( "`any-llm-sdk` is required to use the AnyLLMModel. Install it via the optional " @@ -118,21 +117,38 @@ def model_dump(self, *, exclude_none: bool = False) -> dict[str, Any]: def _convert_any_llm_tool_call_to_openai( tool_call: Any, ) -> ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall: + tool_call_payload: dict[str, Any] | None = None + if isinstance(tool_call, BaseModel): + dumped = tool_call.model_dump() + if isinstance(dumped, dict): + tool_call_payload = dumped + elif isinstance(tool_call, dict): + tool_call_payload = dict(tool_call) + tool_call_type = getattr(tool_call, "type", None) + if tool_call_type is None and tool_call_payload is not None: + tool_call_type = tool_call_payload.get("type") if tool_call_type == "custom": - if isinstance(tool_call, BaseModel): - return ChatCompletionMessageCustomToolCall.model_validate(tool_call.model_dump()) + if tool_call_payload is not None: + return ChatCompletionMessageCustomToolCall.model_validate(tool_call_payload) return ChatCompletionMessageCustomToolCall.model_validate(tool_call) + if tool_call_payload is not None: + return ChatCompletionMessageFunctionToolCall.model_validate(tool_call_payload) + function = getattr(tool_call, "function", None) - return ChatCompletionMessageFunctionToolCall( - id=str(getattr(tool_call, "id", "")), - type="function", - function=Function( - name=str(getattr(function, "name", "") or ""), - arguments=str(getattr(function, "arguments", "") or ""), - ), - ) + payload: dict[str, Any] = { + "id": str(getattr(tool_call, "id", "")), + "type": "function", + "function": { + "name": str(getattr(function, "name", "") or ""), + "arguments": str(getattr(function, "arguments", "") or ""), + }, + } + extra_content = getattr(tool_call, "extra_content", None) + if extra_content is not None: + payload["extra_content"] = extra_content + return ChatCompletionMessageFunctionToolCall.model_validate(payload) def _flatten_any_llm_reasoning_value(value: Any) -> str: @@ -718,7 +734,7 @@ async def _fetch_chat_response( **extra_kwargs, ) - if isinstance(ret, ChatCompletion): + if not stream: return self._normalize_chat_completion_response(ret) responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice( @@ -1071,7 +1087,9 @@ async def _call_any_llm_responses( @staticmethod def _make_any_llm_responses_params(payload: dict[str, Any]) -> Any: try: - from any_llm.types.responses import ResponsesParams as AnyLLMResponsesParams + from any_llm.types.responses import ( # type: ignore[import-not-found] + ResponsesParams as AnyLLMResponsesParams, + ) except ImportError: return _AnyLLMResponsesParamsShim(**payload) diff --git a/tests/models/test_any_llm_model.py b/tests/models/test_any_llm_model.py index 49fb635ed8..b65b206668 100644 --- a/tests/models/test_any_llm_model.py +++ b/tests/models/test_any_llm_model.py @@ -7,7 +7,12 @@ from typing import Any import pytest -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionMessageFunctionToolCall, +) from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage, PromptTokensDetails @@ -18,6 +23,7 @@ OutputTokensDetails, ResponseUsage, ) +from pydantic import BaseModel from agents import ModelSettings, ModelTracing, __version__ from agents.exceptions import UserError @@ -142,6 +148,55 @@ def _response(text: str, response_id: str = "resp_123") -> Response: ) +def _chat_completion_with_tool_call(*, thought_signature: str) -> ChatCompletion: + return ChatCompletion( + id="chatcmpl_tool_123", + created=0, + model="fake-model", + object="chat.completion", + choices=[ + Choice( + index=0, + finish_reason="tool_calls", + message=ChatCompletionMessage( + role="assistant", + content="Calling a tool.", + tool_calls=[ + ChatCompletionMessageFunctionToolCall.model_validate( + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city":"Paris"}', + }, + "extra_content": { + "google": {"thought_signature": thought_signature} + }, + } + ) + ], + ), + ) + ], + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ), + ) + + +class GenericChatCompletionPayload(BaseModel): + id: str + created: int + model: str + object: str + choices: list[Any] + usage: Any + + async def _empty_chat_stream() -> AsyncIterator[ChatCompletionChunk]: if False: yield ChatCompletionChunk( @@ -223,6 +278,78 @@ async def test_any_llm_chat_path_is_used_when_responses_are_unsupported(monkeypa assert response.output[0].content[0].text == "Hello" +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +@pytest.mark.parametrize( + "chat_response", + [ + pytest.param(_chat_completion("Hello").model_dump(), id="dict"), + pytest.param( + GenericChatCompletionPayload.model_validate(_chat_completion("Hello").model_dump()), + id="basemodel", + ), + ], +) +async def test_any_llm_chat_path_normalizes_non_stream_payloads( + monkeypatch, + chat_response: Any, +) -> None: + provider = FakeAnyLLMProvider(supports_responses=False, chat_response=chat_response) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="openrouter/openai/gpt-5.4-mini") + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + assert response.response_id is None + assert response.output[0].content[0].text == "Hello" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_any_llm_chat_path_preserves_gemini_tool_call_metadata(monkeypatch) -> None: + provider = FakeAnyLLMProvider( + supports_responses=False, + chat_response=_chat_completion_with_tool_call(thought_signature="sig_123"), + ) + module, _create_calls = _import_any_llm_module(monkeypatch, provider) + AnyLLMModel = module.AnyLLMModel + + model = AnyLLMModel(model="gemini/gemini-2.0-flash") + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + function_calls = [ + item for item in response.output if getattr(item, "type", None) == "function_call" + ] + assert len(function_calls) == 1 + provider_data = function_calls[0].model_dump()["provider_data"] + assert provider_data["model"] == "gemini/gemini-2.0-flash" + assert provider_data["response_id"] == "chatcmpl_tool_123" + assert provider_data["thought_signature"] == "sig_123" + + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_any_llm_responses_path_is_used_when_supported(monkeypatch) -> None: @@ -457,6 +584,10 @@ async def test_any_llm_prompt_requests_fail_fast(monkeypatch) -> None: def test_any_llm_provider_passes_api_override() -> None: + pytest.importorskip( + "any_llm", + reason="`any-llm-sdk` is only available when the optional dependency is installed.", + ) from agents.extensions.models.any_llm_model import AnyLLMModel from agents.extensions.models.any_llm_provider import AnyLLMProvider From 2b83eee3af2b1b037ed29d718187a059cfd78512 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 19 Mar 2026 11:05:42 +0900 Subject: [PATCH 3/3] fix type error --- src/agents/extensions/models/any_llm_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/agents/extensions/models/any_llm_model.py b/src/agents/extensions/models/any_llm_model.py index 8eb6090ebe..5302e49779 100644 --- a/src/agents/extensions/models/any_llm_model.py +++ b/src/agents/extensions/models/any_llm_model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib import inspect import json import time @@ -48,7 +49,7 @@ from ...util._json import _to_dump_compatible try: - from any_llm import AnyLLM # type: ignore[import-not-found] + AnyLLM = importlib.import_module("any_llm").AnyLLM except ImportError as _e: raise ImportError( "`any-llm-sdk` is required to use the AnyLLMModel. Install it via the optional " @@ -1087,12 +1088,11 @@ async def _call_any_llm_responses( @staticmethod def _make_any_llm_responses_params(payload: dict[str, Any]) -> Any: try: - from any_llm.types.responses import ( # type: ignore[import-not-found] - ResponsesParams as AnyLLMResponsesParams, - ) + any_llm_responses = importlib.import_module("any_llm.types.responses") except ImportError: return _AnyLLMResponsesParamsShim(**payload) + AnyLLMResponsesParams = any_llm_responses.ResponsesParams return AnyLLMResponsesParams(**payload) def _remove_openai_responses_api_incompatible_fields(self, list_input: list[Any]) -> list[Any]: