From 0ed898af3e1440ec3b83c9e745c2fa3e5212aab6 Mon Sep 17 00:00:00 2001 From: brucearctor <5032356+brucearctor@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:31:18 -0700 Subject: [PATCH] fix: ensure LLM callbacks share the same OTel span context (#4851) Move before_model_callback inside the call_llm span and wrap after_model_callback with trace.use_span(span) to re-activate the call_llm span context. This ensures before_model_callback, after_model_callback, and on_model_error_callback all see the same span_id, fixing the mismatch that broke the BigQuery Analytics Plugin. The root cause was twofold: 1. before_model_callback ran outside the call_llm span 2. after_model_callback ran inside a child generate_content span (created by _run_and_handle_error via use_inference_span) Fixes #4851 --- .../adk/flows/llm_flows/base_llm_flow.py | 65 +++-- .../test_llm_callback_span_consistency.py | 226 ++++++++++++++++++ 2 files changed, 265 insertions(+), 26 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index bd0037bdcb..47ec9bfc61 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -46,6 +46,7 @@ from ...telemetry.tracing import trace_call_llm from ...telemetry.tracing import trace_send_data from ...telemetry.tracing import tracer +from opentelemetry import trace from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -1102,28 +1103,34 @@ async def _call_llm_async( llm_request: LlmRequest, model_response_event: Event, ) -> AsyncGenerator[LlmResponse, None]: - # Runs before_model_callback if it exists. - if response := await self._handle_before_model_callback( - invocation_context, llm_request, model_response_event - ): - yield response - return - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.labels = llm_request.config.labels or {} + async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: + with tracer.start_as_current_span('call_llm') as span: + # Runs before_model_callback if it exists. + # This must be inside the call_llm span so that before_model_callback + # and after_model_callback/on_model_error_callback all share the same + # span context (fixes issue #4851). + if response := await self._handle_before_model_callback( + invocation_context, llm_request, model_response_event + ): + yield response + return - # Add agent name as a label to the llm_request. This will help with slicing - # the billing reports on a per-agent basis. - if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: - llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( - invocation_context.agent.name - ) + llm_request.config = ( + llm_request.config or types.GenerateContentConfig() + ) + llm_request.config.labels = llm_request.config.labels or {} - # Calls the LLM. - llm = self.__get_llm(invocation_context) + # Add agent name as a label to the llm_request. This will help with + # slicing the billing reports on a per-agent basis. + if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: + llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( + invocation_context.agent.name + ) + + # Calls the LLM. + llm = self.__get_llm(invocation_context) - async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: - with tracer.start_as_current_span('call_llm') as span: if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() responses_generator = self.run_live(invocation_context) @@ -1137,10 +1144,13 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: ) as agen: async for llm_response in agen: # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + # Re-activate the call_llm span so after_model_callback sees + # the same span_id as before_model_callback (issue #4851). + with trace.use_span(span): + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response # only yield partial response in SSE streaming mode if ( invocation_context.run_config.streaming_mode @@ -1177,10 +1187,13 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: span, ) # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + # Re-activate the call_llm span so after_model_callback sees + # the same span_id as before_model_callback (issue #4851). + with trace.use_span(span): + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response yield llm_response diff --git a/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py new file mode 100644 index 0000000000..d4109dfed0 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py @@ -0,0 +1,226 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that LLM callbacks share the same OTel span context (issue #4851). + +When OpenTelemetry tracing is enabled, before_model_callback, +after_model_callback, and on_model_error_callback must all execute within +the same call_llm span so that plugins (e.g. BigQueryAgentAnalyticsPlugin) +see a consistent span_id for LLM_REQUEST and LLM_RESPONSE events. +""" + +from typing import Optional +from unittest import mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.llm_agent import Agent +from google.adk.flows.llm_flows import base_llm_flow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.telemetry import tracing as adk_tracing +from google.genai import types +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +import pytest + +from ... import testing_utils + + +def _make_real_tracer(): + """Create a real tracer that produces valid span IDs.""" + provider = TracerProvider() + return provider.get_tracer('test_tracer') + + +class SpanCapturingPlugin(BasePlugin): + """Plugin that captures the current span ID in each model callback.""" + + def __init__(self): + super().__init__(name='span_capturing_plugin') + self.before_model_span_id: Optional[int] = None + self.after_model_span_id: Optional[int] = None + self.on_model_error_span_id: Optional[int] = None + + async def before_model_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.before_model_span_id = ctx.span_id + return None + + async def after_model_callback( + self, + *, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.after_model_span_id = ctx.span_id + return None + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.on_model_error_span_id = ctx.span_id + return LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text='error handled')] + ) + ) + + +@pytest.mark.asyncio +async def test_before_and_after_model_callbacks_share_span_id(): + """Verify before_model_callback and after_model_callback share the same span. + + This is the core regression test for issue #4851. Before the fix, + before_model_callback ran outside the call_llm span, causing a span_id + mismatch between LLM_REQUEST and LLM_RESPONSE events. + """ + plugin = SpanCapturingPlugin() + real_tracer = _make_real_tracer() + + mock_model = testing_utils.MockModel.create(responses=['model_response']) + agent = Agent( + name='test_agent', + model=mock_model, + ) + + with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ + mock.patch.object(adk_tracing, 'tracer', real_tracer): + runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) + events = await runner.run_async_with_new_session('test') + + # Both callbacks should have captured a span ID + assert plugin.before_model_span_id is not None, ( + 'before_model_callback did not capture a span ID' + ) + assert plugin.after_model_span_id is not None, ( + 'after_model_callback did not capture a span ID' + ) + + # The span IDs must match — this is the core assertion for issue #4851 + assert plugin.before_model_span_id == plugin.after_model_span_id, ( + f'Span ID mismatch: before_model_callback span_id=' + f'{plugin.before_model_span_id:#018x}, ' + f'after_model_callback span_id={plugin.after_model_span_id:#018x}. ' + f'Both callbacks must run inside the same call_llm span.' + ) + + +@pytest.mark.asyncio +async def test_before_and_on_error_model_callbacks_share_span_id(): + """Verify before_model_callback and on_model_error_callback share span. + + When the model raises an error, on_model_error_callback should see the + same span as before_model_callback. + """ + plugin = SpanCapturingPlugin() + real_tracer = _make_real_tracer() + + mock_model = testing_utils.MockModel.create( + responses=[], error=SystemError('model error') + ) + agent = Agent( + name='test_agent', + model=mock_model, + ) + + with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ + mock.patch.object(adk_tracing, 'tracer', real_tracer): + runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) + events = await runner.run_async_with_new_session('test') + + # Both callbacks should have captured a span ID + assert plugin.before_model_span_id is not None, ( + 'before_model_callback did not capture a span ID' + ) + assert plugin.on_model_error_span_id is not None, ( + 'on_model_error_callback did not capture a span ID' + ) + + # The span IDs must match + assert plugin.before_model_span_id == plugin.on_model_error_span_id, ( + f'Span ID mismatch: before_model_callback span_id=' + f'{plugin.before_model_span_id:#018x}, ' + f'on_model_error_callback span_id=' + f'{plugin.on_model_error_span_id:#018x}. ' + f'Both callbacks must run inside the same call_llm span.' + ) + + +@pytest.mark.asyncio +async def test_before_model_callback_short_circuit_has_span(): + """Verify before_model_callback has a valid span when short-circuiting.""" + + class ShortCircuitPlugin(BasePlugin): + + def __init__(self): + super().__init__(name='short_circuit_plugin') + self.span_id: Optional[int] = None + + async def before_model_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.span_id = ctx.span_id + return LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text='short-circuited')] + ) + ) + + plugin = ShortCircuitPlugin() + real_tracer = _make_real_tracer() + + mock_model = testing_utils.MockModel.create(responses=['model_response']) + agent = Agent( + name='test_agent', + model=mock_model, + ) + + with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ + mock.patch.object(adk_tracing, 'tracer', real_tracer): + runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) + events = await runner.run_async_with_new_session('test') + + # The callback should have a valid (non-zero) span ID from the call_llm span + assert plugin.span_id is not None and plugin.span_id != 0, ( + 'before_model_callback should have a valid span ID even when ' + 'short-circuiting the LLM call' + ) + + # Verify the short-circuit response was received + simplified = testing_utils.simplify_events(events) + assert any('short-circuited' in str(e) for e in simplified)