Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/google/adk/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ class RunConfig(BaseModel):
speech_config: Optional[types.SpeechConfig] = None
"""Speech configuration for the live agent."""

http_options: Optional[types.HttpOptions] = None
"""HTTP options for the agent execution (e.g. custom headers)."""

response_modalities: Optional[list[str]] = None
"""The output modalities. If not set, it's default to AUDIO."""

Expand Down
8 changes: 8 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ async def _preprocess_async(
f'Expected agent to be an LlmAgent, but got {type(agent)}'
)

# Propagate http_options from RunConfig to LlmRequest as defaults.
if (
invocation_context.run_config
and invocation_context.run_config.http_options
):
llm_request.config.http_options = (
invocation_context.run_config.http_options.model_copy(deep=True)
)
# Runs processors.
for processor in self.request_processors:
async with Aclosing(
Expand Down
32 changes: 31 additions & 1 deletion src/google/adk/flows/llm_flows/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

from typing import AsyncGenerator
from typing import Generator

from google.genai import types
from typing_extensions import override
Expand Down Expand Up @@ -45,11 +44,42 @@ def _build_basic_request(
agent = invocation_context.agent
model = agent.canonical_model
llm_request.model = model if isinstance(model, str) else model.model

# Preserve http_options propagated from RunConfig
run_config_http_options = llm_request.config.http_options

llm_request.config = (
agent.generate_content_config.model_copy(deep=True)
if agent.generate_content_config
else types.GenerateContentConfig()
)

if run_config_http_options:
# Merge RunConfig http_options back, overriding agent config
if not llm_request.config.http_options:
llm_request.config.http_options = run_config_http_options
else:
# Merge headers
if run_config_http_options.headers:
if not llm_request.config.http_options.headers:
llm_request.config.http_options.headers = {}
llm_request.config.http_options.headers.update(
run_config_http_options.headers
)

# Merge other http_options fields if present in RunConfig.
# RunConfig values override agent defaults.
# Note: base_url, api_version, base_url_resource_scope are intentionally
# excluded as they are configuration-time settings, not request-time.
for field in [
'timeout',
'retry_options',
'extra_body',
]:
val = getattr(run_config_http_options, field, None)
if val is not None:
setattr(llm_request.config.http_options, field, val)

# Only set output_schema if no tools are specified. as of now, model don't
# support output_schema and tools together. we have a workaround to support
# both output_schema and tools at the same time. see
Expand Down
23 changes: 23 additions & 0 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,6 +1836,29 @@ async def generate_content_async(
if generation_params:
completion_args.update(generation_params)

if llm_request.config.http_options:
http_opts = llm_request.config.http_options
if http_opts.headers:
extra_headers = completion_args.get("extra_headers", {})
if isinstance(extra_headers, dict):
extra_headers = extra_headers.copy()
else:
extra_headers = {}
extra_headers.update(http_opts.headers)
completion_args["extra_headers"] = extra_headers

if http_opts.timeout is not None:
completion_args["timeout"] = http_opts.timeout

if http_opts.retry_options is not None:
# Map google.genai.types.HttpRetryOptions to litellm's parameters.
# LiteLLM accepts num_retries as a top-level parameter.
if http_opts.retry_options.attempts is not None:
completion_args["num_retries"] = http_opts.retry_options.attempts

if http_opts.extra_body is not None:
completion_args["extra_body"] = http_opts.extra_body

if stream:
text = ""
reasoning_parts: List[types.Part] = []
Expand Down
47 changes: 47 additions & 0 deletions tests/unittests/flows/llm_flows/test_basic_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,50 @@ async def test_sets_model_name(self):

# Should have set the model name
assert llm_request.model == 'gemini-1.5-flash'

@pytest.mark.asyncio
async def test_preserves_merged_http_options(self):
"""Test that processor preserves and merges existing http_options."""
from google.genai import types

agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
generate_content_config=types.GenerateContentConfig(
http_options=types.HttpOptions(
timeout=1000,
headers={'Agent-Header': 'agent-val'},
)
),
)

invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()

# Simulate http_options propagated from RunConfig
llm_request.config.http_options = types.HttpOptions(
timeout=500, # Should override agent
headers={
'RunConfig-Header': 'run-val',
'Agent-Header': 'run-val-override',
},
)

processor = _BasicLlmRequestProcessor()

# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)

# Verify timeout from RunConfig wins
assert llm_request.config.http_options.timeout == 500

# Verify headers merged, RunConfig wins
assert (
llm_request.config.http_options.headers['RunConfig-Header'] == 'run-val'
)
assert (
llm_request.config.http_options.headers['Agent-Header']
== 'run-val-override'
)
183 changes: 182 additions & 1 deletion tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,6 @@ def test_schema_to_dict_filters_none_enum_values():
),
]


STREAM_WITH_EMPTY_CHUNK = [
ModelResponse(
choices=[
Expand Down Expand Up @@ -3793,3 +3792,185 @@ def test_handles_litellm_logger_names(logger_name):
finally:
# Clean up
test_logger.removeHandler(handler)


@pytest.mark.asyncio
async def test_generate_content_async_passes_http_options_headers_as_extra_headers(
mock_acompletion, lite_llm_instance
):
"""Test that http_options.headers from LlmRequest are forwarded to litellm."""
llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
http_options=types.HttpOptions(
headers={"X-User-Id": "user-123", "X-Trace-Id": "trace-abc"}
)
),
)

async for _ in lite_llm_instance.generate_content_async(llm_request):
pass

mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert "extra_headers" in kwargs
assert kwargs["extra_headers"]["X-User-Id"] == "user-123"
assert kwargs["extra_headers"]["X-Trace-Id"] == "trace-abc"


@pytest.mark.asyncio
async def test_generate_content_async_merges_http_options_with_existing_extra_headers(
mock_response,
):
"""Test that http_options.headers merge with pre-existing extra_headers."""
mock_acompletion = AsyncMock(return_value=mock_response)
mock_client = MockLLMClient(mock_acompletion, Mock())
# Create instance with pre-existing extra_headers via kwargs
lite_llm_with_extra = LiteLlm(
model="test_model",
llm_client=mock_client,
extra_headers={"X-Api-Key": "secret-key"},
)

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
http_options=types.HttpOptions(headers={"X-User-Id": "user-456"})
),
)

async for _ in lite_llm_with_extra.generate_content_async(llm_request):
pass

mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert "extra_headers" in kwargs
# Both existing and new headers should be present
assert kwargs["extra_headers"]["X-Api-Key"] == "secret-key"
assert kwargs["extra_headers"]["X-User-Id"] == "user-456"


@pytest.mark.asyncio
async def test_generate_content_async_http_options_headers_override_existing(
mock_response,
):
"""Test that http_options.headers override same-key extra_headers from init."""
mock_acompletion = AsyncMock(return_value=mock_response)
mock_client = MockLLMClient(mock_acompletion, Mock())
lite_llm_with_extra = LiteLlm(
model="test_model",
llm_client=mock_client,
extra_headers={"X-Override-Me": "old-value"},
)

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
http_options=types.HttpOptions(headers={"X-Override-Me": "new-value"})
),
)

async for _ in lite_llm_with_extra.generate_content_async(llm_request):
pass

mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
# Request-level headers should override init-level headers
assert kwargs["extra_headers"]["X-Override-Me"] == "new-value"


@pytest.mark.asyncio
async def test_generate_content_async_passes_http_options_timeout(
mock_acompletion, lite_llm_instance
):
"""Test that http_options.timeout is forwarded to litellm."""

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
http_options=types.HttpOptions(timeout=30000)
),
)

async for _ in lite_llm_instance.generate_content_async(llm_request):
pass

mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert "timeout" in kwargs
assert kwargs["timeout"] == 30000


@pytest.mark.asyncio
async def test_generate_content_async_passes_http_options_retry_options(
mock_acompletion, lite_llm_instance
):
"""Test that http_options.retry_options is forwarded to litellm."""

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
http_options=types.HttpOptions(
retry_options=types.HttpRetryOptions(
attempts=3,
)
)
),
)

async for _ in lite_llm_instance.generate_content_async(llm_request):
pass

mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert "num_retries" in kwargs
assert kwargs["num_retries"] == 3


@pytest.mark.asyncio
async def test_generate_content_async_passes_http_options_extra_body(
mock_acompletion, lite_llm_instance
):
"""Test that http_options.extra_body is forwarded to litellm."""

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
http_options=types.HttpOptions(
extra_body={"custom_field": "custom_value", "priority": "high"}
)
),
)

async for _ in lite_llm_instance.generate_content_async(llm_request):
pass

mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert "extra_body" in kwargs
assert kwargs["extra_body"]["custom_field"] == "custom_value"
assert kwargs["extra_body"]["priority"] == "high"