Skip to content

Commit 7b18b09

Browse files
authored
Merge branch 'main' into fix/oauth-token-exchange-client-id
2 parents d2f5b60 + 4c6096b commit 7b18b09

File tree

8 files changed

+675
-25
lines changed

8 files changed

+675
-25
lines changed

src/google/adk/agents/run_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from pydantic import field_validator
2929
from pydantic import model_validator
3030

31+
from ..sessions.base_session_service import GetSessionConfig
32+
3133
logger = logging.getLogger('google_adk.' + __name__)
3234

3335

@@ -319,6 +321,26 @@ class RunConfig(BaseModel):
319321
custom_metadata: Optional[dict[str, Any]] = None
320322
"""Custom metadata for the current invocation."""
321323

324+
get_session_config: Optional[GetSessionConfig] = None
325+
"""Configuration for controlling which events are fetched when loading
326+
a session.
327+
328+
When set, the Runner will pass this configuration to the session service's
329+
``get_session`` method, allowing the caller to limit the events returned
330+
(e.g. via ``num_recent_events`` or ``after_timestamp``). This is especially
331+
useful in combination with ``EventsCompactionConfig`` to avoid loading the
332+
full event history on every invocation.
333+
334+
Example::
335+
336+
from google.adk.agents.run_config import RunConfig
337+
from google.adk.sessions.base_session_service import GetSessionConfig
338+
339+
run_config = RunConfig(
340+
get_session_config=GetSessionConfig(num_recent_events=50),
341+
)
342+
"""
343+
322344
@model_validator(mode='before')
323345
@classmethod
324346
def check_for_deprecated_save_live_audio(cls, data: Any) -> Any:

src/google/adk/models/gemini_llm_connection.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,31 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
179179
)
180180
if message.server_content:
181181
content = message.server_content.model_turn
182+
183+
# Standalone grounding_metadata event (when content is empty)
184+
if (
185+
not (content and content.parts)
186+
and message.server_content.grounding_metadata
187+
and not message.server_content.turn_complete
188+
):
189+
yield LlmResponse(
190+
grounding_metadata=message.server_content.grounding_metadata,
191+
interrupted=message.server_content.interrupted,
192+
model_version=self._model_version,
193+
)
194+
182195
if content and content.parts:
183196
llm_response = LlmResponse(
184-
content=content, interrupted=message.server_content.interrupted
197+
content=content,
198+
interrupted=message.server_content.interrupted,
199+
model_version=self._model_version,
185200
)
201+
# grounding_metadata is yielded again at turn_complete,
202+
# so avoid duplicating it here if turn_complete is true.
203+
if not message.server_content.turn_complete:
204+
llm_response.grounding_metadata = (
205+
message.server_content.grounding_metadata
206+
)
186207
if content.parts[0].text:
187208
text += content.parts[0].text
188209
llm_response.partial = True
@@ -205,6 +226,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
205226
finished=False,
206227
),
207228
partial=True,
229+
model_version=self._model_version,
208230
)
209231
# finished=True and partial transcription may happen in the same
210232
# message.
@@ -215,6 +237,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
215237
finished=True,
216238
),
217239
partial=False,
240+
model_version=self._model_version,
218241
)
219242
self._input_transcription_text = ''
220243
if message.server_content.output_transcription:
@@ -228,6 +251,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
228251
finished=False,
229252
),
230253
partial=True,
254+
model_version=self._model_version,
231255
)
232256
if message.server_content.output_transcription.finished:
233257
yield LlmResponse(
@@ -236,6 +260,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
236260
finished=True,
237261
),
238262
partial=False,
263+
model_version=self._model_version,
239264
)
240265
self._output_transcription_text = ''
241266
# The Gemini API might not send a transcription finished signal.
@@ -253,6 +278,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
253278
finished=True,
254279
),
255280
partial=False,
281+
model_version=self._model_version,
256282
)
257283
self._input_transcription_text = ''
258284
if self._output_transcription_text:
@@ -262,6 +288,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
262288
finished=True,
263289
),
264290
partial=False,
291+
model_version=self._model_version,
265292
)
266293
self._output_transcription_text = ''
267294
if message.server_content.turn_complete:
@@ -271,9 +298,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
271298
yield LlmResponse(
272299
turn_complete=True,
273300
interrupted=message.server_content.interrupted,
301+
grounding_metadata=message.server_content.grounding_metadata,
302+
model_version=self._model_version,
274303
)
275304
break
276-
# in case of empty content or parts, we sill surface it
305+
# in case of empty content or parts, we still surface it
277306
# in case it's an interrupted message, we merge the previous partial
278307
# text. Other we don't merge. because content can be none when model
279308
# safety threshold is triggered
@@ -282,7 +311,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
282311
yield self.__build_full_text_response(text)
283312
text = ''
284313
else:
285-
yield LlmResponse(interrupted=message.server_content.interrupted)
314+
yield LlmResponse(
315+
interrupted=message.server_content.interrupted,
316+
model_version=self._model_version,
317+
)
286318
if message.tool_call:
287319
if text:
288320
yield self.__build_full_text_response(text)
@@ -291,12 +323,16 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
291323
types.Part(function_call=function_call)
292324
for function_call in message.tool_call.function_calls
293325
]
294-
yield LlmResponse(content=types.Content(role='model', parts=parts))
326+
yield LlmResponse(
327+
content=types.Content(role='model', parts=parts),
328+
model_version=self._model_version,
329+
)
295330
if message.session_resumption_update:
296331
logger.debug('Received session resumption message: %s', message)
297332
yield (
298333
LlmResponse(
299-
live_session_resumption_update=message.session_resumption_update
334+
live_session_resumption_update=message.session_resumption_update,
335+
model_version=self._model_version,
300336
)
301337
)
302338

src/google/adk/models/lite_llm.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,15 @@ def _message_to_generate_content_response(
15001500
)
15011501

15021502

1503+
def _finish_reason_to_error_message(
1504+
finish_reason: types.FinishReason,
1505+
) -> str:
1506+
"""Returns an error message for non-stop finish reasons."""
1507+
if finish_reason == types.FinishReason.MAX_TOKENS:
1508+
return "Maximum tokens reached"
1509+
return f"Finished with {finish_reason}"
1510+
1511+
15031512
def _enforce_strict_openai_schema(schema: dict[str, Any]) -> None:
15041513
"""Recursively transforms a JSON schema for OpenAI strict structured outputs.
15051514
@@ -2000,8 +2009,15 @@ def _finalize_tool_call_response(
20002009
*, model_version: str, finish_reason: str
20012010
) -> LlmResponse:
20022011
tool_calls = []
2012+
has_incomplete_tool_call_args = False
20032013
for index, func_data in function_calls.items():
20042014
if func_data["id"]:
2015+
if finish_reason == "length":
2016+
try:
2017+
json.loads(func_data["args"] or "{}")
2018+
except json.JSONDecodeError:
2019+
has_incomplete_tool_call_args = True
2020+
continue
20052021
tool_calls.append(
20062022
ChatCompletionMessageToolCall(
20072023
type="function",
@@ -2013,6 +2029,19 @@ def _finalize_tool_call_response(
20132029
),
20142030
)
20152031
)
2032+
2033+
if has_incomplete_tool_call_args:
2034+
return LlmResponse(
2035+
error_code=types.FinishReason.MAX_TOKENS,
2036+
error_message=(
2037+
"Tool call arguments were truncated while streaming and"
2038+
" could not be parsed as valid JSON. Increase"
2039+
" `max_output_tokens` and retry."
2040+
),
2041+
finish_reason=types.FinishReason.MAX_TOKENS,
2042+
model_version=model_version,
2043+
)
2044+
20162045
llm_response = _message_to_generate_content_response(
20172046
ChatCompletionAssistantMessage(
20182047
role="assistant",
@@ -2022,7 +2051,13 @@ def _finalize_tool_call_response(
20222051
model_version=model_version,
20232052
thought_parts=list(reasoning_parts) if reasoning_parts else None,
20242053
)
2025-
llm_response.finish_reason = _map_finish_reason(finish_reason)
2054+
mapped_finish_reason = _map_finish_reason(finish_reason)
2055+
llm_response.finish_reason = mapped_finish_reason
2056+
if mapped_finish_reason != types.FinishReason.STOP:
2057+
llm_response.error_code = mapped_finish_reason
2058+
llm_response.error_message = _finish_reason_to_error_message(
2059+
mapped_finish_reason
2060+
)
20262061
return llm_response
20272062

20282063
def _finalize_text_response(
@@ -2037,7 +2072,13 @@ def _finalize_text_response(
20372072
model_version=model_version,
20382073
thought_parts=list(reasoning_parts) if reasoning_parts else None,
20392074
)
2040-
llm_response.finish_reason = _map_finish_reason(finish_reason)
2075+
mapped_finish_reason = _map_finish_reason(finish_reason)
2076+
llm_response.finish_reason = mapped_finish_reason
2077+
if mapped_finish_reason != types.FinishReason.STOP:
2078+
llm_response.error_code = mapped_finish_reason
2079+
llm_response.error_message = _finish_reason_to_error_message(
2080+
mapped_finish_reason
2081+
)
20412082
return llm_response
20422083

20432084
def _reset_stream_buffers() -> None:
@@ -2096,10 +2137,11 @@ def _reset_stream_buffers() -> None:
20962137
)
20972138

20982139
# LiteLLM 1.81+ can set finish_reason="stop" on partial chunks. Only
2099-
# finalize tool calls on an explicit tool_calls finish_reason, or on a
2100-
# stop-only chunk (no content/tool deltas).
2140+
# finalize tool calls on an explicit tool_calls/length finish_reason,
2141+
# or on a stop-only chunk (no content/tool deltas).
21012142
if function_calls and (
21022143
finish_reason == "tool_calls"
2144+
or finish_reason == "length"
21032145
or (finish_reason == "stop" and chunk is None)
21042146
):
21052147
aggregated_llm_response_with_tool_call = (
@@ -2109,16 +2151,14 @@ def _reset_stream_buffers() -> None:
21092151
)
21102152
)
21112153
_reset_stream_buffers()
2112-
elif (
2113-
finish_reason == "stop"
2114-
and (text or reasoning_parts)
2115-
and chunk is None
2116-
and not function_calls
2154+
elif (text or reasoning_parts) and (
2155+
finish_reason == "length"
2156+
or (
2157+
finish_reason == "stop"
2158+
and chunk is None
2159+
and not function_calls
2160+
)
21172161
):
2118-
# Only aggregate text response when we have a true stop signal
2119-
# chunk is None means no content in this chunk, just finish signal.
2120-
# LiteLLM 1.81+ sets finish_reason="stop" on partial chunks with
2121-
# content.
21222162
aggregated_llm_response = _finalize_text_response(
21232163
model_version=part.model,
21242164
finish_reason=finish_reason,

src/google/adk/runners.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from .plugins.base_plugin import BasePlugin
5858
from .plugins.plugin_manager import PluginManager
5959
from .sessions.base_session_service import BaseSessionService
60+
from .sessions.base_session_service import GetSessionConfig
6061
from .sessions.in_memory_session_service import InMemorySessionService
6162
from .sessions.session import Session
6263
from .telemetry.tracing import tracer
@@ -393,7 +394,11 @@ def _format_session_not_found_message(self, session_id: str) -> str:
393394
)
394395

395396
async def _get_or_create_session(
396-
self, *, user_id: str, session_id: str
397+
self,
398+
*,
399+
user_id: str,
400+
session_id: str,
401+
get_session_config: Optional[GetSessionConfig] = None,
397402
) -> Session:
398403
"""Gets the session or creates it if auto-creation is enabled.
399404
@@ -404,6 +409,8 @@ async def _get_or_create_session(
404409
Args:
405410
user_id: The user ID of the session.
406411
session_id: The session ID of the session.
412+
get_session_config: Optional configuration for controlling which events
413+
are fetched from session storage.
407414
408415
Returns:
409416
The existing or newly created `Session`.
@@ -413,7 +420,10 @@ async def _get_or_create_session(
413420
auto_create_session is False.
414421
"""
415422
session = await self.session_service.get_session(
416-
app_name=self.app_name, user_id=user_id, session_id=session_id
423+
app_name=self.app_name,
424+
user_id=user_id,
425+
session_id=session_id,
426+
config=get_session_config,
417427
)
418428
if not session:
419429
if self.auto_create_session:
@@ -535,7 +545,9 @@ async def _run_with_trace(
535545
) -> AsyncGenerator[Event, None]:
536546
with tracer.start_as_current_span('invocation'):
537547
session = await self._get_or_create_session(
538-
user_id=user_id, session_id=session_id
548+
user_id=user_id,
549+
session_id=session_id,
550+
get_session_config=run_config.get_session_config,
539551
)
540552

541553
if not invocation_id and not new_message:
@@ -626,10 +638,14 @@ async def rewind_async(
626638
user_id: str,
627639
session_id: str,
628640
rewind_before_invocation_id: str,
641+
run_config: Optional[RunConfig] = None,
629642
) -> None:
630643
"""Rewinds the session to before the specified invocation."""
644+
run_config = run_config or RunConfig()
631645
session = await self._get_or_create_session(
632-
user_id=user_id, session_id=session_id
646+
user_id=user_id,
647+
session_id=session_id,
648+
get_session_config=run_config.get_session_config,
633649
)
634650
rewind_event_index = -1
635651
for i, event in enumerate(session.events):
@@ -1060,7 +1076,9 @@ async def run_live(
10601076
)
10611077
if not session:
10621078
session = await self._get_or_create_session(
1063-
user_id=user_id, session_id=session_id
1079+
user_id=user_id,
1080+
session_id=session_id,
1081+
get_session_config=run_config.get_session_config,
10641082
)
10651083
invocation_context = self._new_invocation_context_for_live(
10661084
session,
@@ -1231,8 +1249,12 @@ async def run_debug(
12311249
- Performance optimization
12321250
Please use run_async() with proper configuration.
12331251
"""
1252+
run_config = run_config or RunConfig()
12341253
session = await self.session_service.get_session(
1235-
app_name=self.app_name, user_id=user_id, session_id=session_id
1254+
app_name=self.app_name,
1255+
user_id=user_id,
1256+
session_id=session_id,
1257+
config=run_config.get_session_config,
12361258
)
12371259
if not session:
12381260
session = await self.session_service.create_session(

src/google/adk/tools/bigquery/search_tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def search_catalog(
4848
dataset_ids_filter: list[str] | None = None,
4949
types_filter: list[str] | None = None,
5050
) -> dict[str, Any]:
51-
"""Searches for BigQuery assets within Dataplex.
51+
"""Finds BigQuery datasets and tables using natural language semantic search via Dataplex.
52+
53+
Use this tool to discover BigQuery assets when you don't know the exact names.
54+
It's ideal for searching based on topics, descriptions, or questions about the data.
5255
5356
Args:
5457
prompt: The base search query (natural language or keywords).

0 commit comments

Comments
 (0)