Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ...telemetry.tracing import trace_call_llm
from ...telemetry.tracing import trace_send_data
from ...telemetry.tracing import tracer
from opentelemetry import trace
from ...tools.base_toolset import BaseToolset
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
Expand Down Expand Up @@ -1102,28 +1103,34 @@ async def _call_llm_async(
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
# Runs before_model_callback if it exists.
if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return

llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.labels = llm_request.config.labels or {}
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
# Runs before_model_callback if it exists.
# This must be inside the call_llm span so that before_model_callback
# and after_model_callback/on_model_error_callback all share the same
# span context (fixes issue #4851).
if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return

# Add agent name as a label to the llm_request. This will help with slicing
# the billing reports on a per-agent basis.
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
invocation_context.agent.name
)
llm_request.config = (
llm_request.config or types.GenerateContentConfig()
)
llm_request.config.labels = llm_request.config.labels or {}

# Calls the LLM.
llm = self.__get_llm(invocation_context)
# Add agent name as a label to the llm_request. This will help with
# slicing the billing reports on a per-agent basis.
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
invocation_context.agent.name
)

# Calls the LLM.
llm = self.__get_llm(invocation_context)

async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
responses_generator = self.run_live(invocation_context)
Expand All @@ -1137,10 +1144,13 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
) as agen:
async for llm_response in agen:
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
# Re-activate the call_llm span so after_model_callback sees
# the same span_id as before_model_callback (issue #4851).
with trace.use_span(span):
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
Comment on lines +1149 to +1153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for handling the after_model_callback is duplicated in the else branch on lines 1192-1196. To improve maintainability and avoid repeating code (DRY principle), consider extracting this logic into a local helper coroutine within the _call_llm_with_tracing function.

For example:

async def _apply_after_model_callback(response: LlmResponse) -> LlmResponse:
    """Applies after_model_callback within the correct span context."""
    with trace.use_span(span):
        if altered_response := await self._handle_after_model_callback(
            invocation_context, response, model_response_event
        ):
            return altered_response
    return response

You could then replace both duplicated blocks with a single call:
llm_response = await _apply_after_model_callback(llm_response)

# only yield partial response in SSE streaming mode
if (
invocation_context.run_config.streaming_mode
Expand Down Expand Up @@ -1177,10 +1187,13 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
span,
)
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
# Re-activate the call_llm span so after_model_callback sees
# the same span_id as before_model_callback (issue #4851).
with trace.use_span(span):
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response

yield llm_response

Expand Down
226 changes: 226 additions & 0 deletions tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests that LLM callbacks share the same OTel span context (issue #4851).

When OpenTelemetry tracing is enabled, before_model_callback,
after_model_callback, and on_model_error_callback must all execute within
the same call_llm span so that plugins (e.g. BigQueryAgentAnalyticsPlugin)
see a consistent span_id for LLM_REQUEST and LLM_RESPONSE events.
"""

from typing import Optional
from unittest import mock

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.flows.llm_flows import base_llm_flow
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.telemetry import tracing as adk_tracing
from google.genai import types
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
import pytest

from ... import testing_utils


def _make_real_tracer():
"""Create a real tracer that produces valid span IDs."""
provider = TracerProvider()
return provider.get_tracer('test_tracer')


class SpanCapturingPlugin(BasePlugin):
"""Plugin that captures the current span ID in each model callback."""

def __init__(self):
super().__init__(name='span_capturing_plugin')
self.before_model_span_id: Optional[int] = None
self.after_model_span_id: Optional[int] = None
self.on_model_error_span_id: Optional[int] = None

async def before_model_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.before_model_span_id = ctx.span_id
return None

async def after_model_callback(
self,
*,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.after_model_span_id = ctx.span_id
return None

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.on_model_error_span_id = ctx.span_id
return LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text='error handled')]
)
)


@pytest.mark.asyncio
async def test_before_and_after_model_callbacks_share_span_id():
"""Verify before_model_callback and after_model_callback share the same span.

This is the core regression test for issue #4851. Before the fix,
before_model_callback ran outside the call_llm span, causing a span_id
mismatch between LLM_REQUEST and LLM_RESPONSE events.
"""
plugin = SpanCapturingPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(responses=['model_response'])
agent = Agent(
name='test_agent',
model=mock_model,
)

with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \
mock.patch.object(adk_tracing, 'tracer', real_tracer):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# Both callbacks should have captured a span ID
assert plugin.before_model_span_id is not None, (
'before_model_callback did not capture a span ID'
)
assert plugin.after_model_span_id is not None, (
'after_model_callback did not capture a span ID'
)

# The span IDs must match — this is the core assertion for issue #4851
assert plugin.before_model_span_id == plugin.after_model_span_id, (
f'Span ID mismatch: before_model_callback span_id='
f'{plugin.before_model_span_id:#018x}, '
f'after_model_callback span_id={plugin.after_model_span_id:#018x}. '
f'Both callbacks must run inside the same call_llm span.'
)


@pytest.mark.asyncio
async def test_before_and_on_error_model_callbacks_share_span_id():
"""Verify before_model_callback and on_model_error_callback share span.

When the model raises an error, on_model_error_callback should see the
same span as before_model_callback.
"""
plugin = SpanCapturingPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(
responses=[], error=SystemError('model error')
)
agent = Agent(
name='test_agent',
model=mock_model,
)

with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \
mock.patch.object(adk_tracing, 'tracer', real_tracer):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# Both callbacks should have captured a span ID
assert plugin.before_model_span_id is not None, (
'before_model_callback did not capture a span ID'
)
assert plugin.on_model_error_span_id is not None, (
'on_model_error_callback did not capture a span ID'
)

# The span IDs must match
assert plugin.before_model_span_id == plugin.on_model_error_span_id, (
f'Span ID mismatch: before_model_callback span_id='
f'{plugin.before_model_span_id:#018x}, '
f'on_model_error_callback span_id='
f'{plugin.on_model_error_span_id:#018x}. '
f'Both callbacks must run inside the same call_llm span.'
)


@pytest.mark.asyncio
async def test_before_model_callback_short_circuit_has_span():
"""Verify before_model_callback has a valid span when short-circuiting."""

class ShortCircuitPlugin(BasePlugin):

def __init__(self):
super().__init__(name='short_circuit_plugin')
self.span_id: Optional[int] = None

async def before_model_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.span_id = ctx.span_id
return LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text='short-circuited')]
)
)

plugin = ShortCircuitPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(responses=['model_response'])
agent = Agent(
name='test_agent',
model=mock_model,
)

with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \
mock.patch.object(adk_tracing, 'tracer', real_tracer):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# The callback should have a valid (non-zero) span ID from the call_llm span
assert plugin.span_id is not None and plugin.span_id != 0, (
'before_model_callback should have a valid span ID even when '
'short-circuiting the LLM call'
)

# Verify the short-circuit response was received
simplified = testing_utils.simplify_events(events)
assert any('short-circuited' in str(e) for e in simplified)