From df7479baed0aae0af2b26836180d8765472d0e31 Mon Sep 17 00:00:00 2001 From: Faraaz1994 Date: Wed, 28 Jan 2026 14:52:28 +0530 Subject: [PATCH 1/2] feat(http_options): add dynamic HTTP options support via RunConfig Enable per-request HTTP configuration (headers, timeout, retry_options) to be passed via RunConfig and propagated through the request pipeline to models. --- src/google/adk/agents/run_config.py | 4 +- .../adk/flows/llm_flows/base_llm_flow.py | 16 +++ src/google/adk/flows/llm_flows/basic.py | 34 ++++++- src/google/adk/models/lite_llm.py | 59 +++-------- .../flows/llm_flows/test_basic_processor.py | 47 +++++++++ tests/unittests/models/test_litellm.py | 97 +++++++++++++++++++ 6 files changed, 206 insertions(+), 51 deletions(-) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index b0e96dceed..b56900df59 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -178,7 +178,6 @@ class StreamingMode(Enum): For bidirectional streaming, use runner.run_live() instead of run_async(). """ - class RunConfig(BaseModel): """Configs for runtime behavior of agents. @@ -193,6 +192,9 @@ class RunConfig(BaseModel): speech_config: Optional[types.SpeechConfig] = None """Speech configuration for the live agent.""" + http_options: Optional[types.HttpOptions] = None + """HTTP options for the agent execution (e.g. custom headers).""" + response_modalities: Optional[list[str]] = None """The output modalities. If not set, it's default to AUDIO.""" diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f9f80e6cd0..fc6fcb2f06 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -520,6 +520,22 @@ async def _preprocess_async( f'Expected agent to be an LlmAgent, but got {type(agent)}' ) + # Propagate http_options from RunConfig to LlmRequest as defaults. + # Request-level settings (from callbacks/processors) take precedence. + if ( + invocation_context.run_config + and invocation_context.run_config.http_options + ): + run_opts = invocation_context.run_config.http_options + if not llm_request.config.http_options: + # Deep-copy to avoid mutating the user's RunConfig across steps. + llm_request.config.http_options = run_opts.model_copy(deep=True) + elif run_opts.headers: + # Merge headers: request-level headers win (use setdefault). + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + for key, value in run_opts.headers.items(): + llm_request.config.http_options.headers.setdefault(key, value) # Runs processors. for processor in self.request_processors: async with Aclosing( diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 4a296734a5..76a97a881a 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -17,7 +17,6 @@ from __future__ import annotations from typing import AsyncGenerator -from typing import Generator from google.genai import types from typing_extensions import override @@ -28,7 +27,6 @@ from ...utils.output_schema_utils import can_use_output_schema_with_tools from ._base_llm_processor import BaseLlmRequestProcessor - def _build_basic_request( invocation_context: InvocationContext, llm_request: LlmRequest, @@ -45,11 +43,42 @@ def _build_basic_request( agent = invocation_context.agent model = agent.canonical_model llm_request.model = model if isinstance(model, str) else model.model + + # Preserve http_options propagated from RunConfig + run_config_http_options = llm_request.config.http_options + llm_request.config = ( agent.generate_content_config.model_copy(deep=True) if agent.generate_content_config else types.GenerateContentConfig() ) + + if run_config_http_options: + # Merge RunConfig http_options back, overriding agent config + if not llm_request.config.http_options: + llm_request.config.http_options = run_config_http_options + else: + # Merge headers + if run_config_http_options.headers: + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + llm_request.config.http_options.headers.update( + run_config_http_options.headers + ) + + # Merge other http_options fields if present in RunConfig. + # RunConfig values override agent defaults. + # Note: base_url, api_version, base_url_resource_scope are intentionally + # excluded as they are configuration-time settings, not request-time. + for field in [ + 'timeout', + 'retry_options', + 'extra_body', + ]: + val = getattr(run_config_http_options, field, None) + if val is not None: + setattr(llm_request.config.http_options, field, val) + # Only set output_schema if no tools are specified. as of now, model don't # support output_schema and tools together. we have a workaround to support # both output_schema and tools at the same time. see @@ -100,5 +129,4 @@ async def run_async( return yield # Generator requires yield statement in function body. - request_processor = _BasicLlmRequestProcessor() diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 79182d7b0a..bdb2d9f355 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -148,7 +148,6 @@ "completion", ) - def _ensure_litellm_imported() -> None: """Imports LiteLLM with safe defaults. @@ -175,7 +174,6 @@ def _ensure_litellm_imported() -> None: _redirect_litellm_loggers_to_stdout() _LITELLM_IMPORTED = True - def _map_finish_reason( finish_reason: Any, ) -> types.FinishReason | None: @@ -187,7 +185,6 @@ def _map_finish_reason( finish_reason_str = str(finish_reason).lower() return _FINISH_REASON_MAPPING.get(finish_reason_str, types.FinishReason.OTHER) - def _get_provider_from_model(model: str) -> str: """Extracts the provider name from a LiteLLM model string. @@ -213,11 +210,9 @@ def _get_provider_from_model(model: str) -> str: return "openai" return "" - # Default MIME type when none can be inferred _DEFAULT_MIME_TYPE = "application/octet-stream" - def _infer_mime_type_from_uri(uri: str) -> Optional[str]: """Attempts to infer MIME type from a URI's path extension. @@ -258,12 +253,10 @@ def _infer_mime_type_from_uri(uri: str) -> Optional[str]: logger.debug("Could not infer MIME type from URI %s: %s", uri, e) return None - def _looks_like_openai_file_id(file_uri: str) -> bool: """Returns True when file_uri resembles an OpenAI/Azure file id.""" return file_uri.startswith("file-") - def _redact_file_uri_for_log( file_uri: str, *, display_name: str | None = None ) -> str: @@ -284,7 +277,6 @@ def _redact_file_uri_for_log( return f"{parsed.scheme}:///{tail}" return f"{parsed.scheme}://" - def _requires_file_uri_fallback( provider: str, model: str, file_uri: str ) -> bool: @@ -297,7 +289,6 @@ def _requires_file_uri_fallback( return True return False - def _decode_inline_text_data(raw_bytes: bytes) -> str: """Decodes inline file bytes that represent textual content.""" try: @@ -306,7 +297,6 @@ def _decode_inline_text_data(raw_bytes: bytes) -> str: logger.debug("Falling back to latin-1 decoding for inline file bytes.") return raw_bytes.decode("latin-1", errors="replace") - def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: """Yields textual fragments from provider specific reasoning payloads.""" if reasoning_value is None: @@ -344,7 +334,6 @@ def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: elif isinstance(reasoning_value, (int, float, bool)): yield str(reasoning_value) - def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: """Converts provider reasoning payloads into Gemini thought parts.""" return [ @@ -353,7 +342,6 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: if text ] - def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: """Fetches the reasoning payload from a LiteLLM message or dict.""" if message is None: @@ -364,35 +352,29 @@ def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: return message.get("reasoning_content") return None - class ChatCompletionFileUrlObject(TypedDict, total=False): file_data: str file_id: str format: str - class FunctionChunk(BaseModel): id: Optional[str] name: Optional[str] args: Optional[str] index: Optional[int] = 0 - class TextChunk(BaseModel): text: str - class ReasoningChunk(BaseModel): parts: List[types.Part] - class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int cached_prompt_tokens: int = 0 - class LiteLLMClient: """Provides acompletion method (for better testability).""" @@ -444,7 +426,6 @@ def completion( **kwargs, ) - def _safe_json_serialize(obj) -> str: """Convert any Python object to a JSON-serializable type or string. @@ -461,7 +442,6 @@ def _safe_json_serialize(obj) -> str: except (TypeError, OverflowError): return str(obj) - def _part_has_payload(part: types.Part) -> bool: """Checks whether a Part contains usable payload for the model.""" if part.text: @@ -472,7 +452,6 @@ def _part_has_payload(part: types.Part) -> bool: return True return False - def _append_fallback_user_content_if_missing( llm_request: LlmRequest, ) -> None: @@ -508,7 +487,6 @@ def _append_fallback_user_content_if_missing( ) ) - def _extract_cached_prompt_tokens(usage: Any) -> int: """Extracts cached prompt tokens from LiteLLM usage. @@ -561,7 +539,6 @@ def _extract_cached_prompt_tokens(usage: Any) -> int: return 0 - async def _content_to_message_param( content: types.Content, *, @@ -680,7 +657,6 @@ async def _content_to_message_param( reasoning_content=reasoning_content or None, ) - def _ensure_tool_results(messages: List[Message]) -> List[Message]: """Insert placeholder tool messages for missing tool results. @@ -741,7 +717,6 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: return healed_messages - async def _get_content( parts: Iterable[types.Part], *, @@ -891,7 +866,6 @@ async def _get_content( return content_objects - def _is_ollama_chat_provider( model: Optional[str], custom_llm_provider: Optional[str] ) -> bool: @@ -905,7 +879,6 @@ def _is_ollama_chat_provider( return True return False - def _flatten_ollama_content( content: OpenAIMessageContent | str | None, ) -> str | None: @@ -946,7 +919,6 @@ def _flatten_ollama_content( except TypeError: return str(blocks) - def _normalize_ollama_chat_messages( messages: list[Message], *, @@ -992,7 +964,6 @@ def _normalize_ollama_chat_messages( return normalized_messages - def _build_tool_call_from_json_dict( candidate: Any, *, index: int ) -> Optional[ChatCompletionMessageToolCall]: @@ -1040,7 +1011,6 @@ def _build_tool_call_from_json_dict( return tool_call - def _parse_tool_calls_from_text( text_block: str, ) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: @@ -1083,7 +1053,6 @@ def _parse_tool_calls_from_text( return tool_calls, remainder or None - def _split_message_content_and_tool_calls( message: Message, ) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]: @@ -1105,7 +1074,6 @@ def _split_message_content_and_tool_calls( return content, [] - def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: """Converts a types.Content role to a litellm role. @@ -1120,7 +1088,6 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: return "assistant" return "user" - TYPE_LABELS = { "STRING": "string", "NUMBER": "number", @@ -1130,7 +1097,6 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: "INTEGER": "integer", } - def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: """Recursively converts a schema object or dict to a pure-python dict. @@ -1174,7 +1140,6 @@ def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: return schema_dict - def _function_declaration_to_tool_param( function_declaration: types.FunctionDeclaration, ) -> dict: @@ -1227,7 +1192,6 @@ def _function_declaration_to_tool_param( return tool_params - def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ @@ -1317,7 +1281,6 @@ def _model_response_to_chunk( cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]), ), None - def _model_response_to_generate_content_response( response: ModelResponse, ) -> LlmResponse: @@ -1370,7 +1333,6 @@ def _model_response_to_generate_content_response( ) return llm_response - def _message_to_generate_content_response( message: Message, *, @@ -1417,7 +1379,6 @@ def _message_to_generate_content_response( model_version=model_version, ) - def _to_litellm_response_format( response_schema: types.SchemaUnion, model: str, @@ -1496,7 +1457,6 @@ def _to_litellm_response_format( }, } - async def _get_completion_inputs( llm_request: LlmRequest, model: str, @@ -1591,7 +1551,6 @@ async def _get_completion_inputs( return messages, tools, response_format, generation_params - def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: @@ -1615,7 +1574,6 @@ def _build_function_declaration_log( return_str = str(func_decl.response.model_dump(exclude_none=True)) return f"{func_decl.name}: {param_str} -> {return_str}" - def _build_request_log(req: LlmRequest) -> str: """Builds a request log. @@ -1664,7 +1622,6 @@ def _build_request_log(req: LlmRequest) -> str: ----------------------------------------------------------- """ - def _is_litellm_gemini_model(model_string: str) -> bool: """Check if the model is a Gemini model accessed via LiteLLM. @@ -1677,7 +1634,6 @@ def _is_litellm_gemini_model(model_string: str) -> bool: """ return model_string.startswith(("gemini/gemini-", "vertex_ai/gemini-")) - def _extract_gemini_model_from_litellm(litellm_model: str) -> str: """Extract the pure Gemini model name from a LiteLLM model string. @@ -1692,7 +1648,6 @@ def _extract_gemini_model_from_litellm(litellm_model: str) -> str: return litellm_model.split("/", 1)[1] return litellm_model - def _warn_gemini_via_litellm(model_string: str) -> None: """Warn if Gemini is being used via LiteLLM. @@ -1723,7 +1678,6 @@ def _warn_gemini_via_litellm(model_string: str) -> None: stacklevel=3, ) - def _redirect_litellm_loggers_to_stdout() -> None: """Redirects LiteLLM loggers from stderr to stdout. @@ -1742,7 +1696,6 @@ def _redirect_litellm_loggers_to_stdout() -> None: ): handler.stream = sys.stdout - class LiteLlm(BaseLlm): """Wrapper around litellm. @@ -1836,6 +1789,18 @@ async def generate_content_async( if generation_params: completion_args.update(generation_params) + if ( + llm_request.config.http_options + and llm_request.config.http_options.headers + ): + extra_headers = completion_args.get("extra_headers", {}) + if isinstance(extra_headers, dict): + extra_headers = extra_headers.copy() + else: + extra_headers = {} + extra_headers.update(llm_request.config.http_options.headers) + completion_args["extra_headers"] = extra_headers + if stream: text = "" reasoning_parts: List[types.Part] = [] diff --git a/tests/unittests/flows/llm_flows/test_basic_processor.py b/tests/unittests/flows/llm_flows/test_basic_processor.py index af0ccfe0b1..60896b850b 100644 --- a/tests/unittests/flows/llm_flows/test_basic_processor.py +++ b/tests/unittests/flows/llm_flows/test_basic_processor.py @@ -188,3 +188,50 @@ async def test_sets_model_name(self): # Should have set the model name assert llm_request.model == 'gemini-1.5-flash' + + @pytest.mark.asyncio + async def test_preserves_merged_http_options(self): + """Test that processor preserves and merges existing http_options.""" + from google.genai import types + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + generate_content_config=types.GenerateContentConfig( + http_options=types.HttpOptions( + timeout=1000, + headers={'Agent-Header': 'agent-val'}, + ) + ) + ) + + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + + # Simulate http_options propagated from RunConfig + llm_request.config.http_options = types.HttpOptions( + timeout=500, # Should override agent + headers={ + 'RunConfig-Header': 'run-val', + 'Agent-Header': 'run-val-override' + } + ) + + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Verify timeout from RunConfig wins + assert llm_request.config.http_options.timeout == 500 + + # Verify headers merged, RunConfig wins + assert ( + llm_request.config.http_options.headers['RunConfig-Header'] == 'run-val' + ) + assert ( + llm_request.config.http_options.headers['Agent-Header'] + == 'run-val-override' + ) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 2ebbc5dfe8..10404e17e6 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -3793,3 +3793,100 @@ def test_handles_litellm_logger_names(logger_name): finally: # Clean up test_logger.removeHandler(handler) + + +@pytest.mark.asyncio +async def test_generate_content_async_passes_http_options_headers_as_extra_headers( + mock_acompletion, lite_llm_instance +): + """Test that http_options.headers from LlmRequest are forwarded to litellm.""" + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions( + headers={"X-User-Id": "user-123", "X-Trace-Id": "trace-abc"} + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "extra_headers" in kwargs + assert kwargs["extra_headers"]["X-User-Id"] == "user-123" + assert kwargs["extra_headers"]["X-Trace-Id"] == "trace-abc" + + +@pytest.mark.asyncio +async def test_generate_content_async_merges_http_options_with_existing_extra_headers( + mock_response, +): + """Test that http_options.headers merge with pre-existing extra_headers.""" + mock_acompletion = AsyncMock(return_value=mock_response) + mock_client = MockLLMClient(mock_acompletion, Mock()) + # Create instance with pre-existing extra_headers via kwargs + lite_llm_with_extra = LiteLlm( + model="test_model", + llm_client=mock_client, + extra_headers={"X-Api-Key": "secret-key"}, + ) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions(headers={"X-User-Id": "user-456"}) + ), + ) + + async for _ in lite_llm_with_extra.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "extra_headers" in kwargs + # Both existing and new headers should be present + assert kwargs["extra_headers"]["X-Api-Key"] == "secret-key" + assert kwargs["extra_headers"]["X-User-Id"] == "user-456" + + +@pytest.mark.asyncio +async def test_generate_content_async_http_options_headers_override_existing( + mock_response, +): + """Test that http_options.headers override same-key extra_headers from init.""" + mock_acompletion = AsyncMock(return_value=mock_response) + mock_client = MockLLMClient(mock_acompletion, Mock()) + lite_llm_with_extra = LiteLlm( + model="test_model", + llm_client=mock_client, + extra_headers={"X-Override-Me": "old-value"}, + ) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions(headers={"X-Override-Me": "new-value"}) + ), + ) + + async for _ in lite_llm_with_extra.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + # Request-level headers should override init-level headers + assert kwargs["extra_headers"]["X-Override-Me"] == "new-value" From 10a07d31f43462b61cc91745f88e65d317210e7d Mon Sep 17 00:00:00 2001 From: Faraaz1994 Date: Fri, 30 Jan 2026 22:34:03 +0530 Subject: [PATCH 2/2] feat(http_options): Support passing timeout, retry_options, and extra_body to LiteLLM --- src/google/adk/agents/run_config.py | 1 + .../adk/flows/llm_flows/base_llm_flow.py | 14 +-- src/google/adk/flows/llm_flows/basic.py | 2 + src/google/adk/models/lite_llm.py | 80 ++++++++++++++--- .../flows/llm_flows/test_basic_processor.py | 6 +- tests/unittests/models/test_litellm.py | 86 ++++++++++++++++++- 6 files changed, 163 insertions(+), 26 deletions(-) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index b56900df59..9a4978fe3e 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -178,6 +178,7 @@ class StreamingMode(Enum): For bidirectional streaming, use runner.run_live() instead of run_async(). """ + class RunConfig(BaseModel): """Configs for runtime behavior of agents. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index fc6fcb2f06..0022049c1d 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -521,21 +521,13 @@ async def _preprocess_async( ) # Propagate http_options from RunConfig to LlmRequest as defaults. - # Request-level settings (from callbacks/processors) take precedence. if ( invocation_context.run_config and invocation_context.run_config.http_options ): - run_opts = invocation_context.run_config.http_options - if not llm_request.config.http_options: - # Deep-copy to avoid mutating the user's RunConfig across steps. - llm_request.config.http_options = run_opts.model_copy(deep=True) - elif run_opts.headers: - # Merge headers: request-level headers win (use setdefault). - if not llm_request.config.http_options.headers: - llm_request.config.http_options.headers = {} - for key, value in run_opts.headers.items(): - llm_request.config.http_options.headers.setdefault(key, value) + llm_request.config.http_options = ( + invocation_context.run_config.http_options.model_copy(deep=True) + ) # Runs processors. for processor in self.request_processors: async with Aclosing( diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 76a97a881a..303a61e2da 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -27,6 +27,7 @@ from ...utils.output_schema_utils import can_use_output_schema_with_tools from ._base_llm_processor import BaseLlmRequestProcessor + def _build_basic_request( invocation_context: InvocationContext, llm_request: LlmRequest, @@ -129,4 +130,5 @@ async def run_async( return yield # Generator requires yield statement in function body. + request_processor = _BasicLlmRequestProcessor() diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index bdb2d9f355..7ee95ef0d5 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -148,6 +148,7 @@ "completion", ) + def _ensure_litellm_imported() -> None: """Imports LiteLLM with safe defaults. @@ -174,6 +175,7 @@ def _ensure_litellm_imported() -> None: _redirect_litellm_loggers_to_stdout() _LITELLM_IMPORTED = True + def _map_finish_reason( finish_reason: Any, ) -> types.FinishReason | None: @@ -185,6 +187,7 @@ def _map_finish_reason( finish_reason_str = str(finish_reason).lower() return _FINISH_REASON_MAPPING.get(finish_reason_str, types.FinishReason.OTHER) + def _get_provider_from_model(model: str) -> str: """Extracts the provider name from a LiteLLM model string. @@ -210,9 +213,11 @@ def _get_provider_from_model(model: str) -> str: return "openai" return "" + # Default MIME type when none can be inferred _DEFAULT_MIME_TYPE = "application/octet-stream" + def _infer_mime_type_from_uri(uri: str) -> Optional[str]: """Attempts to infer MIME type from a URI's path extension. @@ -253,10 +258,12 @@ def _infer_mime_type_from_uri(uri: str) -> Optional[str]: logger.debug("Could not infer MIME type from URI %s: %s", uri, e) return None + def _looks_like_openai_file_id(file_uri: str) -> bool: """Returns True when file_uri resembles an OpenAI/Azure file id.""" return file_uri.startswith("file-") + def _redact_file_uri_for_log( file_uri: str, *, display_name: str | None = None ) -> str: @@ -277,6 +284,7 @@ def _redact_file_uri_for_log( return f"{parsed.scheme}:///{tail}" return f"{parsed.scheme}://" + def _requires_file_uri_fallback( provider: str, model: str, file_uri: str ) -> bool: @@ -289,6 +297,7 @@ def _requires_file_uri_fallback( return True return False + def _decode_inline_text_data(raw_bytes: bytes) -> str: """Decodes inline file bytes that represent textual content.""" try: @@ -297,6 +306,7 @@ def _decode_inline_text_data(raw_bytes: bytes) -> str: logger.debug("Falling back to latin-1 decoding for inline file bytes.") return raw_bytes.decode("latin-1", errors="replace") + def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: """Yields textual fragments from provider specific reasoning payloads.""" if reasoning_value is None: @@ -334,6 +344,7 @@ def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: elif isinstance(reasoning_value, (int, float, bool)): yield str(reasoning_value) + def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: """Converts provider reasoning payloads into Gemini thought parts.""" return [ @@ -342,6 +353,7 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: if text ] + def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: """Fetches the reasoning payload from a LiteLLM message or dict.""" if message is None: @@ -352,29 +364,35 @@ def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: return message.get("reasoning_content") return None + class ChatCompletionFileUrlObject(TypedDict, total=False): file_data: str file_id: str format: str + class FunctionChunk(BaseModel): id: Optional[str] name: Optional[str] args: Optional[str] index: Optional[int] = 0 + class TextChunk(BaseModel): text: str + class ReasoningChunk(BaseModel): parts: List[types.Part] + class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int cached_prompt_tokens: int = 0 + class LiteLLMClient: """Provides acompletion method (for better testability).""" @@ -426,6 +444,7 @@ def completion( **kwargs, ) + def _safe_json_serialize(obj) -> str: """Convert any Python object to a JSON-serializable type or string. @@ -442,6 +461,7 @@ def _safe_json_serialize(obj) -> str: except (TypeError, OverflowError): return str(obj) + def _part_has_payload(part: types.Part) -> bool: """Checks whether a Part contains usable payload for the model.""" if part.text: @@ -452,6 +472,7 @@ def _part_has_payload(part: types.Part) -> bool: return True return False + def _append_fallback_user_content_if_missing( llm_request: LlmRequest, ) -> None: @@ -487,6 +508,7 @@ def _append_fallback_user_content_if_missing( ) ) + def _extract_cached_prompt_tokens(usage: Any) -> int: """Extracts cached prompt tokens from LiteLLM usage. @@ -539,6 +561,7 @@ def _extract_cached_prompt_tokens(usage: Any) -> int: return 0 + async def _content_to_message_param( content: types.Content, *, @@ -657,6 +680,7 @@ async def _content_to_message_param( reasoning_content=reasoning_content or None, ) + def _ensure_tool_results(messages: List[Message]) -> List[Message]: """Insert placeholder tool messages for missing tool results. @@ -717,6 +741,7 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: return healed_messages + async def _get_content( parts: Iterable[types.Part], *, @@ -866,6 +891,7 @@ async def _get_content( return content_objects + def _is_ollama_chat_provider( model: Optional[str], custom_llm_provider: Optional[str] ) -> bool: @@ -879,6 +905,7 @@ def _is_ollama_chat_provider( return True return False + def _flatten_ollama_content( content: OpenAIMessageContent | str | None, ) -> str | None: @@ -919,6 +946,7 @@ def _flatten_ollama_content( except TypeError: return str(blocks) + def _normalize_ollama_chat_messages( messages: list[Message], *, @@ -964,6 +992,7 @@ def _normalize_ollama_chat_messages( return normalized_messages + def _build_tool_call_from_json_dict( candidate: Any, *, index: int ) -> Optional[ChatCompletionMessageToolCall]: @@ -1011,6 +1040,7 @@ def _build_tool_call_from_json_dict( return tool_call + def _parse_tool_calls_from_text( text_block: str, ) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: @@ -1053,6 +1083,7 @@ def _parse_tool_calls_from_text( return tool_calls, remainder or None + def _split_message_content_and_tool_calls( message: Message, ) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]: @@ -1074,6 +1105,7 @@ def _split_message_content_and_tool_calls( return content, [] + def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: """Converts a types.Content role to a litellm role. @@ -1088,6 +1120,7 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: return "assistant" return "user" + TYPE_LABELS = { "STRING": "string", "NUMBER": "number", @@ -1097,6 +1130,7 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: "INTEGER": "integer", } + def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: """Recursively converts a schema object or dict to a pure-python dict. @@ -1140,6 +1174,7 @@ def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: return schema_dict + def _function_declaration_to_tool_param( function_declaration: types.FunctionDeclaration, ) -> dict: @@ -1192,6 +1227,7 @@ def _function_declaration_to_tool_param( return tool_params + def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ @@ -1281,6 +1317,7 @@ def _model_response_to_chunk( cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]), ), None + def _model_response_to_generate_content_response( response: ModelResponse, ) -> LlmResponse: @@ -1333,6 +1370,7 @@ def _model_response_to_generate_content_response( ) return llm_response + def _message_to_generate_content_response( message: Message, *, @@ -1379,6 +1417,7 @@ def _message_to_generate_content_response( model_version=model_version, ) + def _to_litellm_response_format( response_schema: types.SchemaUnion, model: str, @@ -1457,6 +1496,7 @@ def _to_litellm_response_format( }, } + async def _get_completion_inputs( llm_request: LlmRequest, model: str, @@ -1551,6 +1591,7 @@ async def _get_completion_inputs( return messages, tools, response_format, generation_params + def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: @@ -1574,6 +1615,7 @@ def _build_function_declaration_log( return_str = str(func_decl.response.model_dump(exclude_none=True)) return f"{func_decl.name}: {param_str} -> {return_str}" + def _build_request_log(req: LlmRequest) -> str: """Builds a request log. @@ -1622,6 +1664,7 @@ def _build_request_log(req: LlmRequest) -> str: ----------------------------------------------------------- """ + def _is_litellm_gemini_model(model_string: str) -> bool: """Check if the model is a Gemini model accessed via LiteLLM. @@ -1634,6 +1677,7 @@ def _is_litellm_gemini_model(model_string: str) -> bool: """ return model_string.startswith(("gemini/gemini-", "vertex_ai/gemini-")) + def _extract_gemini_model_from_litellm(litellm_model: str) -> str: """Extract the pure Gemini model name from a LiteLLM model string. @@ -1648,6 +1692,7 @@ def _extract_gemini_model_from_litellm(litellm_model: str) -> str: return litellm_model.split("/", 1)[1] return litellm_model + def _warn_gemini_via_litellm(model_string: str) -> None: """Warn if Gemini is being used via LiteLLM. @@ -1678,6 +1723,7 @@ def _warn_gemini_via_litellm(model_string: str) -> None: stacklevel=3, ) + def _redirect_litellm_loggers_to_stdout() -> None: """Redirects LiteLLM loggers from stderr to stdout. @@ -1696,6 +1742,7 @@ def _redirect_litellm_loggers_to_stdout() -> None: ): handler.stream = sys.stdout + class LiteLlm(BaseLlm): """Wrapper around litellm. @@ -1789,17 +1836,28 @@ async def generate_content_async( if generation_params: completion_args.update(generation_params) - if ( - llm_request.config.http_options - and llm_request.config.http_options.headers - ): - extra_headers = completion_args.get("extra_headers", {}) - if isinstance(extra_headers, dict): - extra_headers = extra_headers.copy() - else: - extra_headers = {} - extra_headers.update(llm_request.config.http_options.headers) - completion_args["extra_headers"] = extra_headers + if llm_request.config.http_options: + http_opts = llm_request.config.http_options + if http_opts.headers: + extra_headers = completion_args.get("extra_headers", {}) + if isinstance(extra_headers, dict): + extra_headers = extra_headers.copy() + else: + extra_headers = {} + extra_headers.update(http_opts.headers) + completion_args["extra_headers"] = extra_headers + + if http_opts.timeout is not None: + completion_args["timeout"] = http_opts.timeout + + if http_opts.retry_options is not None: + # Map google.genai.types.HttpRetryOptions to litellm's parameters. + # LiteLLM accepts num_retries as a top-level parameter. + if http_opts.retry_options.attempts is not None: + completion_args["num_retries"] = http_opts.retry_options.attempts + + if http_opts.extra_body is not None: + completion_args["extra_body"] = http_opts.extra_body if stream: text = "" diff --git a/tests/unittests/flows/llm_flows/test_basic_processor.py b/tests/unittests/flows/llm_flows/test_basic_processor.py index 60896b850b..1ff71af083 100644 --- a/tests/unittests/flows/llm_flows/test_basic_processor.py +++ b/tests/unittests/flows/llm_flows/test_basic_processor.py @@ -202,7 +202,7 @@ async def test_preserves_merged_http_options(self): timeout=1000, headers={'Agent-Header': 'agent-val'}, ) - ) + ), ) invocation_context = await _create_invocation_context(agent) @@ -213,8 +213,8 @@ async def test_preserves_merged_http_options(self): timeout=500, # Should override agent headers={ 'RunConfig-Header': 'run-val', - 'Agent-Header': 'run-val-override' - } + 'Agent-Header': 'run-val-override', + }, ) processor = _BasicLlmRequestProcessor() diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 10404e17e6..a837d373eb 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -624,7 +624,6 @@ def test_schema_to_dict_filters_none_enum_values(): ), ] - STREAM_WITH_EMPTY_CHUNK = [ ModelResponse( choices=[ @@ -3890,3 +3889,88 @@ async def test_generate_content_async_http_options_headers_override_existing( _, kwargs = mock_acompletion.call_args # Request-level headers should override init-level headers assert kwargs["extra_headers"]["X-Override-Me"] == "new-value" + + +@pytest.mark.asyncio +async def test_generate_content_async_passes_http_options_timeout( + mock_acompletion, lite_llm_instance +): + """Test that http_options.timeout is forwarded to litellm.""" + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions(timeout=30000) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "timeout" in kwargs + assert kwargs["timeout"] == 30000 + + +@pytest.mark.asyncio +async def test_generate_content_async_passes_http_options_retry_options( + mock_acompletion, lite_llm_instance +): + """Test that http_options.retry_options is forwarded to litellm.""" + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions( + retry_options=types.HttpRetryOptions( + attempts=3, + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "num_retries" in kwargs + assert kwargs["num_retries"] == 3 + + +@pytest.mark.asyncio +async def test_generate_content_async_passes_http_options_extra_body( + mock_acompletion, lite_llm_instance +): + """Test that http_options.extra_body is forwarded to litellm.""" + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + config=types.GenerateContentConfig( + http_options=types.HttpOptions( + extra_body={"custom_field": "custom_value", "priority": "high"} + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "extra_body" in kwargs + assert kwargs["extra_body"]["custom_field"] == "custom_value" + assert kwargs["extra_body"]["priority"] == "high"