From 8cf15914900f87cfb63b152ace1002c9dff9ffac Mon Sep 17 00:00:00 2001 From: Doug Reid <21148125+douglas-reid@users.noreply.github.com> Date: Wed, 3 Sep 2025 21:48:36 -0700 Subject: [PATCH 1/4] feat[models]: add support for gemma model via gemini api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds support for invoking Gemma models via the Gemini API endpoint. To support agentic function, callbacks are added which can extract and transform function calls and responses into user and model messages in the history. This change is intended to allow developers to explore the use of Gemma models for agentic purposes without requiring local deployment of the models. This should ease the burden of experimentation and testing for developers. A basic "hello world" style agent example is provided to demonstrate proper functioning of Gemma 3 models inside an Agent container, using the dice roll + prime check framework of similar examples for other models. Testing Plan: - add integration and unit tests - manual run of hello_world_gemma agent - manual run of example multi_tool_agent from quickstart using Gemma model Testing Results: | Test Command | Results | | pytest ./tests/unittests | 4386 passed, 2849 warnings in 58.43s | | pytest ./tests/unittests/models/test_google_llm.py | 100 passed, 4 warnings in 5.83s | | pytest ./tests/integration/models/test_google_llm.py | 5 passed, 2 warnings in 3.73s | Manual Testing: Here is a log of `multi_tool_agent` run with locally-built wheel and using Gemma model. ``` ❯ adk run multi_tool_agent Log setup complete: /var/folders/bg/_133c0ds2kb7cn699cpmmh_h0061bp/T/agents_log/agent.20250904_152617.log To access latest log: tail -F /var/folders/bg/_133c0ds2kb7cn699cpmmh_h0061bp/T/agents_log/agent.latest.log /Users//venvs/adk-quickstart/lib/python3.11/site-packages/google/adk/cli/cli.py:143: UserWarning: [EXPERIMENTAL] InMemoryCredentialService: This feature is experimental and may change or be removed in future versions without notice. It may introduce breaking changes at any time. credential_service = InMemoryCredentialService() /Users//venvs/adk-quickstart/lib/python3.11/site-packages/google/adk/auth/credential_service/in_memory_credential_service.py:33: UserWarning: [EXPERIMENTAL] BaseCredentialService: This feature is experimental and may change or be removed in future versions without notice. It may introduce breaking changes at any time. super().__init__() Running agent weather_time_agent, type exit to exit. [user]: what's the weather like today? [weather_time_agent]: Which city are you asking about? [user]: new york [weather_time_agent]: OK. The weather in New York is sunny with a temperature of 25 degrees Celsius (77 degrees Fahrenheit). ``` And here is a snippet of a log generated with DEBUG level logging of the `hello_world_gemma` sample. It demonstrates how function calls are extracted and inserted based on Gemma model interactions: ``` ... 2025-09-04 15:32:41,708 - DEBUG - google_llm.py:138 - LLM Request: ----------------------------------------------------------- System Instruction: None ----------------------------------------------------------- Contents: {"parts":[{"text":"\n You roll dice and answer questions about the outcome of the dice rolls.\n You can roll dice of different sizes...\n"}],"role":"user"} {"parts":[{"text":"Hi, introduce yourself."}],"role":"user"} {"parts":[{"text":"Hello! I am data_processing_agent, a hello world agent that can roll many-sided dice and check if numbers are prime. I'm ready to assist you with those tasks. Let's begin!\n\n\n\n"}],"role":"model"} {"parts":[{"text":"Roll a die with 100 sides and check if it is prime"}],"role":"user"} {"parts":[{"text":"{\"args\":{\"sides\":100},\"name\":\"roll_die\"}"}],"role":"model"} {"parts":[{"text":"Invoking tool `roll_die` produced: `{\"result\": 82}`."}],"role":"user"} {"parts":[{"text":"{\"args\":{\"nums\":[82]},\"name\":\"check_prime\"}"}],"role":"model"} {"parts":[{"text":"Invoking tool `check_prime` produced: `{\"result\": \"No prime numbers found.\"}`."}],"role":"user"} {"parts":[{"text":"The die roll was 82, and it is not a prime number.\n\n\n\n"}],"role":"model"} {"parts":[{"text":"Roll it again."}],"role":"user"} ----------------------------------------------------------- Functions: ----------------------------------------------------------- 2025-09-04 15:32:41,708 - INFO - models.py:8165 - AFC is enabled with max remote calls: 10. 2025-09-04 15:32:42,693 - INFO - google_llm.py:180 - Response received from the model. 2025-09-04 15:32:42,693 - DEBUG - google_llm.py:181 - LLM Response: ----------------------------------------------------------- Text: {"args":{"sides":100},"name":"roll_die"} ----------------------------------------------------------- ... ``` --- .../samples/hello_world_gemma/__init__.py | 16 + .../samples/hello_world_gemma/agent.py | 99 ++++ .../samples/hello_world_gemma/main.py | 77 +++ src/google/adk/models/__init__.py | 6 +- src/google/adk/models/google_llm.py | 315 +++++++++++ tests/integration/conftest.py | 2 +- tests/integration/models/test_google_llm.py | 63 ++- tests/unittests/models/test_google_llm.py | 521 ++++++++++++++++++ 8 files changed, 1089 insertions(+), 10 deletions(-) create mode 100644 contributing/samples/hello_world_gemma/__init__.py create mode 100644 contributing/samples/hello_world_gemma/agent.py create mode 100644 contributing/samples/hello_world_gemma/main.py diff --git a/contributing/samples/hello_world_gemma/__init__.py b/contributing/samples/hello_world_gemma/__init__.py new file mode 100644 index 0000000000..7d5bb0b1c6 --- /dev/null +++ b/contributing/samples/hello_world_gemma/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_gemma/agent.py b/contributing/samples/hello_world_gemma/agent.py new file mode 100644 index 0000000000..d8f3e2ef68 --- /dev/null +++ b/contributing/samples/hello_world_gemma/agent.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random + +from google.adk.agents.llm_agent import Agent +from google.adk.models.google_llm import Gemma +from google.adk.models.google_llm import gemma_functions_after_model_callback +from google.adk.models.google_llm import gemma_functions_before_model_callback +from google.genai.types import GenerateContentConfig + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model=Gemma(model="gemma-3-27b-it"), + name="data_processing_agent", + description=( + "hello world agent that can roll many-sided dice and check if numbers" + " are prime." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After the user reports a response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + before_model_callback=gemma_functions_before_model_callback, + after_model_callback=gemma_functions_after_model_callback, + generate_content_config=GenerateContentConfig( + temperature=1.0, + top_p=0.95, + ), +) diff --git a/contributing/samples/hello_world_gemma/main.py b/contributing/samples/hello_world_gemma/main.py new file mode 100644 index 0000000000..f177064b68 --- /dev/null +++ b/contributing/samples/hello_world_gemma/main.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import time + +import agent +from dotenv import load_dotenv +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder(level=logging.INFO) + + +async def main(): + app_name = 'my_gemma_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt( + session_11, 'Roll a die with 100 sides and check if it is prime' + ) + await run_prompt(session_11, 'Roll it again.') + await run_prompt(session_11, 'What numbers did I get?') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index fc86c197ca..c7c8c61adb 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -16,6 +16,7 @@ from .base_llm import BaseLlm from .google_llm import Gemini +from .google_llm import Gemma from .llm_request import LlmRequest from .llm_response import LlmResponse from .registry import LLMRegistry @@ -23,9 +24,10 @@ __all__ = [ 'BaseLlm', 'Gemini', + 'Gemma', 'LLMRegistry', ] -for regex in Gemini.supported_models(): - LLMRegistry.register(Gemini) +LLMRegistry.register(Gemini) +LLMRegistry.register(Gemma) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 57b1451784..67736be978 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -17,17 +17,27 @@ import contextlib from functools import cached_property +import json import logging import os import sys +from typing import Any from typing import AsyncGenerator from typing import cast from typing import Optional from typing import TYPE_CHECKING from typing import Union +from google.adk.agents.callback_context import CallbackContext from google.genai import Client from google.genai import types +from google.genai.types import Content +from google.genai.types import FunctionDeclaration +from google.genai.types import Part +from pydantic import BaseModel +from pydantic import ValidationError +from pydantic.aliases import AliasChoices +from pydantic.fields import Field from typing_extensions import override from .. import version @@ -44,6 +54,16 @@ logger = logging.getLogger('google_adk.' + __name__) + +class GemmaFunctionCallModel(BaseModel): + """Flexible Pydantic model for parsing inline Gemma function call responses.""" + + name: str = Field(validation_alias=AliasChoices('name', 'function')) + parameters: dict[str, Any] = Field( + validation_alias=AliasChoices('parameters', 'args') + ) + + _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} _AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' @@ -420,3 +440,298 @@ def _remove_display_name_if_present( """ if data_obj and data_obj.display_name: data_obj.display_name = None + + +class Gemma(Gemini): + """Integration for Gemma models exposed via the Gemini API. + + For full documentation, see: https://ai.google.dev/gemma/docs/core/ + + NOTE: Gemma does **NOT** support system instructions. Any system instructions + will be replaced with an initial *user* prompt in the LLM request. If system + instructions change over the course of agent execution, the initial content + **SHOULD** be replaced. Special care is warranted here. + See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions + + NOTE: Gemma's function calling support is limited. It does not have full access to the + same built-in tools as Gemini. It also does not have special API support for tools and + functions. Rather, tools must be passed in via a `user` prompt, and extracted from model + responses based on approximate shape. For agent developments, please use the provided + `gemma_functions_before_model_callback` and `gemma_functions_after_model_callback` methods. + + NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports + usage via the Gemini API. + """ + + model: str = ( + 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] + ) + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r'gemma-.*', + ] + + @cached_property + def _api_backend(self) -> GoogleLLMVariant: + return GoogleLLMVariant.GEMINI_API + + @override + async def _preprocess_request(self, llm_request: LlmRequest) -> None: + if system_instruction := llm_request.config.system_instruction: + contents = llm_request.contents + instruction_content = Content( + role='user', parts=[Part.from_text(text=system_instruction)] + ) + + # NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning + # of any chain of contents. + if len(contents) >= 1: + if contents[0] != instruction_content: + # only prepend if it hasn't already been done + llm_request.contents = [instruction_content] + contents + + llm_request.config.system_instruction = None + + return await super()._preprocess_request(llm_request) + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to the Gemma model. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + # print(f'{llm_request=}') + assert llm_request.model.startswith('gemma-'), ( + f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM' + ' is not supported.' + ) + + async for response in super().generate_content_async(llm_request, stream): + yield response + + +def _convert_content_parts_for_gemma( + content_item: Content, +) -> tuple[list[Part], bool, bool]: + """Converts function call/response parts within a content item to text parts. + + Args: + content_item: The original Content item. + + Returns: + A tuple containing: + - A list of new Part objects with function calls/responses converted to text. + - A boolean indicating if any function response parts were found. + - A boolean indicating if any function call parts were found. + """ + new_parts: list[Part] = [] + has_function_response_part = False + has_function_call_part = False + + for part in content_item.parts: + if func_response := part.function_response: + has_function_response_part = True + response_text = ( + f'Invoking tool `{func_response.name}` produced:' + f' `{json.dumps(func_response.response)}`.' + ) + new_parts.append(Part.from_text(text=response_text)) + elif func_call := part.function_call: + has_function_call_part = True + new_parts.append( + Part.from_text(text=func_call.model_dump_json(exclude_none=True)) + ) + else: + new_parts.append(part) + return new_parts, has_function_response_part, has_function_call_part + + +def _build_gemma_function_system_instruction( + function_declarations: list[FunctionDeclaration], +) -> str: + """Constructs the system instruction string for Gemma function calling.""" + if not function_declarations: + return '' + + system_instruction_prefix = 'You have access to the following functions:\n[' + instruction_parts = [] + for func in function_declarations: + instruction_parts.append(func.model_dump_json(exclude_none=True)) + + separator = ',\n' + system_instruction = ( + f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n' + ) + + system_instruction += ( + 'When you call a function, you MUST respond in the format of: ' + """{"name": function name, "parameters": dictionary of argument name and its value}\n""" + 'When you call a function, you MUST NOT include any other text in the' + ' response.\n' + ) + return system_instruction + + +def gemma_functions_before_model_callback( + callback_context: CallbackContext, llm_request: LlmRequest +): + """Translates function calls and responses to the Gemma-supported interaction model. + + NOTE: Gemma is **ONLY** able to handle external function declarations in a tool. It does NOT + have access to the internal Gemini tools (including Google and Enterprise Search, URL Context, etc.). + If the LLM Request includes those tools, they will be ignored and dropped from the request sent to + the model. + """ + + if llm_request.model is None or not llm_request.model.startswith('gemma-3'): + return + + # Iterate through the existing contents to find and convert function calls and responses + # from text parts, as Gemma models don't directly support function calling. + new_contents: list[Content] = [] + for content_item in llm_request.contents: + ( + new_parts_for_content, + has_function_response_part, + has_function_call_part, + ) = _convert_content_parts_for_gemma(content_item) + + if has_function_response_part: + if new_parts_for_content: + new_contents.append(Content(role='user', parts=new_parts_for_content)) + elif has_function_call_part: + if new_parts_for_content: + new_contents.append(Content(role='model', parts=new_parts_for_content)) + else: + new_contents.append(content_item) + + llm_request.contents = new_contents + + if not llm_request.config.tools: + return + + all_function_declarations: list[FunctionDeclaration] = [] + for tool_item in llm_request.config.tools: + if isinstance(tool_item, types.Tool) and tool_item.function_declarations: + all_function_declarations.extend(tool_item.function_declarations) + + if all_function_declarations: + system_instruction = _build_gemma_function_system_instruction( + all_function_declarations + ) + llm_request.append_instructions([system_instruction]) + + llm_request.config.tools = [] + + +def gemma_functions_after_model_callback( + callback_context: CallbackContext, llm_response: LlmResponse +): + """Translates function calls and responses to the Gemma-supported interaction model. + + Model function calls are attempted to be recognized in text responses and extracted into + the objects that can be exploited by `Agents`. Some flexibility in parsing is provided + in an attempt to improve model function in agentic systems. + """ + if not llm_response.content: + return + + if not llm_response.content.parts: + return + + if len(llm_response.content.parts) > 1: + return + + response_text = llm_response.content.parts[0].text + if not response_text: + return + + try: + import re + + json_candidate = None + + markdown_code_block_pattern = re.compile( + r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL + ) + block_match = markdown_code_block_pattern.search(response_text) + + if block_match: + json_candidate = block_match.group(2).strip() + else: + found, json_text = _get_last_valid_json_substring(response_text) + if found: + json_candidate = json_text + + if not json_candidate: + return + + function_call_parsed = GemmaFunctionCallModel.model_validate_json( + json_candidate + ) + function_call = types.FunctionCall( + name=function_call_parsed.name, + args=function_call_parsed.parameters, + ) + function_call_part = Part(function_call=function_call) + llm_response.content.parts = [function_call_part] + except (json.JSONDecodeError, ValidationError) as e: + logger.debug( + f'Error attempting to parse JSON into function call. Leaving as text' + f' response. %s', + e, + ) + except Exception as e: + logger.warning('Error processing Gemma function call response: %s', e) + + +def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: + """Attempts to find and return the last valid JSON object in a string. + + This function is designed to extract JSON that might be embedded in a larger + text, potentially with introductory or concluding remarks. It will always chose + the last block of valid json found within the supplied text (if it exists). + + Args: + text: The input string to search for JSON objects. + + Returns: + A tuple: + - bool: True if a valid JSON substring was found, False otherwise. + - str | None: The last valid JSON substring found, or None if none was + found. + """ + decoder = json.JSONDecoder() + last_json_str = None + start_pos = 0 + first_brace_index = 0 + while start_pos < len(text): + try: + first_brace_index = text.index('{', start_pos) + _, end_index = decoder.raw_decode(text[first_brace_index:]) + last_json_str = text[first_brace_index : first_brace_index + end_index] + start_pos = first_brace_index + end_index + except json.JSONDecodeError: + start_pos = first_brace_index + 1 + except ValueError: + break + + if last_json_str: + return True, last_json_str + return False, None diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 6dc1f3d1bb..45e720a579 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -114,6 +114,6 @@ def pytest_generate_tests(metafunc: Metafunc): def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool: if hasattr(metafunc.function, 'pytestmark'): for mark in metafunc.function.pytestmark: - if mark.name == 'parametrize' and mark.args[0] == mark_name: + if mark.name == 'parametrize' and mark_name in mark.args[0]: return True return False diff --git a/tests/integration/models/test_google_llm.py b/tests/integration/models/test_google_llm.py index 5574eb30ef..9edd6de7e9 100644 --- a/tests/integration/models/test_google_llm.py +++ b/tests/integration/models/test_google_llm.py @@ -13,6 +13,7 @@ # limitations under the License. from google.adk.models.google_llm import Gemini +from google.adk.models.google_llm import Gemma from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types @@ -20,16 +21,29 @@ from google.genai.types import Part import pytest +DEFAULT_GEMINI_MODEL = "gemini-1.5-flash" +DEFAULT_GEMMA_MODEL = "gemma-3-1b-it" + @pytest.fixture def gemini_llm(): - return Gemini(model="gemini-1.5-flash") + return Gemini(model=DEFAULT_GEMINI_MODEL) + + +@pytest.fixture +def gemma_llm(): + return Gemma(model=DEFAULT_GEMMA_MODEL) + + +@pytest.fixture +def llm(request): + return request.getfixturevalue(request.param) @pytest.fixture -def llm_request(): +def gemini_request(): return LlmRequest( - model="gemini-1.5-flash", + model=DEFAULT_GEMINI_MODEL, contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], config=types.GenerateContentConfig( temperature=0.1, @@ -39,19 +53,54 @@ def llm_request(): ) +@pytest.fixture +def gemma_request(): + return LlmRequest( + model=DEFAULT_GEMMA_MODEL, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="You are a helpful assistant."), + Part.from_text(text="Hello!"), + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction="Talk like a pirate.", + ), + ) + + +@pytest.fixture +def llm_request(request): + return request.getfixturevalue(request.param) + + @pytest.mark.asyncio -async def test_generate_content_async(gemini_llm, llm_request): - async for response in gemini_llm.generate_content_async(llm_request): +@pytest.mark.parametrize( + "llm, llm_request, llm_backend", + [ + ("gemini_llm", "gemini_request", "GOOGLE_AI"), + ("gemini_llm", "gemini_request", "VERTEX"), + ("gemma_llm", "gemma_request", "GOOGLE_AI"), + ], + indirect=True, +) +async def test_generate_content_async(llm, llm_request): + async for response in llm.generate_content_async(llm_request): assert isinstance(response, LlmResponse) assert response.content.parts[0].text @pytest.mark.asyncio -async def test_generate_content_async_stream(gemini_llm, llm_request): +async def test_generate_content_async_stream(gemini_llm, gemini_request): responses = [ resp async for resp in gemini_llm.generate_content_async( - llm_request, stream=True + gemini_request, stream=True ) ] text = "" diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 985d7e91cf..f9639e87d5 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys from typing import Optional @@ -25,11 +26,13 @@ from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.models.google_llm import Gemini +from google.adk.models.google_llm import Gemma from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.types import Content +from google.genai.types import GenerateContentConfigOrDict from google.genai.types import Part import pytest @@ -137,6 +140,62 @@ def llm_request_with_computer_use(): ) +@pytest.fixture +def llm_request_with_duplicate_instruction(): + return LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text="Talk like a pirate.")], + ), + types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ], + config=types.GenerateContentConfig( + response_modalities=[types.Modality.TEXT], + system_instruction="Talk like a pirate.", + ), + ) + + +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="search_web", + description="Search the web for a query.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema(type=types.Type.STRING) + }, + required=["query"], + ), + ), + types.FunctionDeclaration( + name="get_current_time", + description="Gets the current time.", + parameters=types.Schema( + type=types.Type.OBJECT, properties={} + ), + ), + ] + ) + ], + ), + ) + + @pytest.fixture def mock_os_environ(): initial_env = os.environ.copy() @@ -1726,3 +1785,465 @@ async def mock_coro(): # Verify cache metadata is preserved assert second_arg.cache_name == cache_metadata.cache_name assert second_arg.invocations_used == cache_metadata.invocations_used +async def test_not_gemma_model(): + llm = Gemma() + llm_request_bad_model = LlmRequest( + model="not-a-gemma-model", + ) + with pytest.raises(AssertionError, match=r".*model.*"): + async for _ in llm.generate_content_async(llm_request_bad_model): + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "llm_request", + ["llm_request", "llm_request_with_duplicate_instruction"], + indirect=True, +) +async def test_gemma_request_preprocess(llm_request): + llm = Gemma() + want_content_text = llm_request.config.system_instruction + + await llm._preprocess_request(llm_request=llm_request) + + # system instruction should be cleared + assert not llm_request.config.system_instruction + # should be two content bits now (deduped, if needed) + assert len(llm_request.contents) == 2 + # first message in contents should be "user": + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts[0].text == want_content_text + + +def test_gemma_functions_before_model_callback(llm_request_with_tools): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_before_model_callback + + gemma_functions_before_model_callback( + mock.MagicMock(spec=CallbackContext), llm_request_with_tools + ) + + assert not llm_request_with_tools.config.tools + + # The original user content should still be the first item + assert llm_request_with_tools.contents[0].role == "user" + assert llm_request_with_tools.contents[0].parts[0].text == "Hello" + + sys_instruct_text = llm_request_with_tools.config.system_instruction + assert sys_instruct_text is not None + assert "You have access to the following functions" in sys_instruct_text + assert ( + """{"description":"Search the web for a query.","name":"search_web",""" + in sys_instruct_text + ) + assert ( + """{"description":"Gets the current time.","name":"get_current_time","parameters":{"properties":{}""" + in sys_instruct_text + ) + + +def test_gemma_functions_after_model_callback_valid_json_function_call(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response from Gemma that should be converted to a FunctionCall + json_function_call_str = ( + '{"name": "search_web", "parameters": {"query": "latest news"}}' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + # Assert that the content was transformed into a FunctionCall + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "latest news"} + # Assert that the entire part matches the expected structure + expected_function_call = types.FunctionCall( + name="search_web", args={"query": "latest news"} + ) + expected_part = Part(function_call=expected_function_call) + assert part == expected_part + assert part.text is None # Ensure text part is cleared + + +def test_gemma_functions_after_model_callback_invalid_json_text(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response with plain text that is not JSON + original_text = "This is a regular text response." + llm_response = LlmResponse( + content=Content(role="model", parts=[Part.from_text(text=original_text)]) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + # Assert that the content remains unchanged + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].text == original_text + assert llm_response.content.parts[0].function_call is None + + +def test_gemma_functions_after_model_callback_malformed_json(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response with valid JSON but not in the function call format + malformed_json_str = '{"not_a_function": "value", "another_field": 123}' + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=malformed_json_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + # Assert that the content remains unchanged because it doesn't match the expected schema + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].text == malformed_json_str + assert llm_response.content.parts[0].function_call is None + + +def test_gemma_functions_after_model_callback_empty_content_or_multiple_parts(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Test case 1: LlmResponse with no content + llm_response_no_content = LlmResponse(content=None) + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response_no_content + ) + assert llm_response_no_content.content is None + + # Test case 2: LlmResponse with empty parts list + llm_response_empty_parts = LlmResponse( + content=Content(role="model", parts=[]) + ) + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response_empty_parts + ) + assert llm_response_empty_parts.content + assert not llm_response_empty_parts.content.parts + + # Test case 3: LlmResponse with multiple parts + llm_response_multiple_parts = LlmResponse( + content=Content( + role="model", + parts=[ + Part.from_text(text="part one"), + Part.from_text(text="part two"), + ], + ) + ) + original_parts = list( + llm_response_multiple_parts.content.parts + ) # Copy for comparison + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response_multiple_parts + ) + assert llm_response_multiple_parts.content + assert ( + llm_response_multiple_parts.content.parts == original_parts + ) # Should remain unchanged + + # Test case 4: LlmResponse with one part, but empty text + llm_response_empty_text_part = LlmResponse( + content=Content(role="model", parts=[Part.from_text(text="")]) + ) + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response_empty_text_part + ) + assert llm_response_empty_text_part.content + assert llm_response_empty_text_part.content.parts + assert llm_response_empty_text_part.content.parts[0].text == "" + assert llm_response_empty_text_part.content.parts[0].function_call is None + + +def test_gemma_functions_before_model_callback_with_function_response(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_before_model_callback + + # Simulate an LlmRequest with a function response + func_response_data = types.FunctionResponse( + name="search_web", response={"results": [{"title": "ADK"}]} + ) + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="model", + parts=[types.Part(function_response=func_response_data)], + ) + ], + config=types.GenerateContentConfig(), + ) + + gemma_functions_before_model_callback( + mock.MagicMock(spec=CallbackContext), llm_request + ) + + # Assertions: function response converted to user role text content + assert llm_request.contents + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts + assert ( + llm_request.contents[0].parts[0].text + == 'Invoking tool `search_web` produced: `{"results": [{"title":' + ' "ADK"}]}`.' + ) + assert llm_request.contents[0].parts[0].function_response is None + assert llm_request.contents[0].parts[0].function_call is None + + +def test_gemma_functions_before_model_callback_with_function_call(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_before_model_callback + + func_call_data = types.FunctionCall(name="get_current_time", args={}) + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", parts=[types.Part(function_call=func_call_data)] + ) + ], + ) + + gemma_functions_before_model_callback( + mock.MagicMock(spec=CallbackContext), llm_request + ) + + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "model" + expected_text = func_call_data.model_dump_json(exclude_none=True) + assert llm_request.contents[0].parts + got_part = llm_request.contents[0].parts[0] + assert got_part.text == expected_text + assert got_part.function_call is None + assert got_part.function_response is None + + +def test_gemma_functions_before_model_callback_mixed_content(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_before_model_callback + + func_call = types.FunctionCall(name="get_weather", args={"city": "London"}) + func_response = types.FunctionResponse( + name="get_weather", response={"temp": "15C"} + ) + + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello!")] + ), + types.Content( + role="model", parts=[types.Part(function_call=func_call)] + ), + types.Content( + role="some_function", + parts=[types.Part(function_response=func_response)], + ), + types.Content( + role="user", parts=[types.Part.from_text(text="How are you?")] + ), + ], + ) + + gemma_functions_before_model_callback( + mock.MagicMock(spec=CallbackContext), llm_request + ) + + # Assertions + assert len(llm_request.contents) == 4 + + # First part: original user text + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts + assert llm_request.contents[0].parts[0].text == "Hello!" + + # Second part: function call converted to model text + assert llm_request.contents[1].role == "model" + assert llm_request.contents[1].parts + assert llm_request.contents[1].parts[0].text == func_call.model_dump_json( + exclude_none=True + ) + + # Third part: function response converted to user text + assert llm_request.contents[2].role == "user" + assert llm_request.contents[2].parts + assert ( + llm_request.contents[2].parts[0].text + == 'Invoking tool `get_weather` produced: `{"temp": "15C"}`.' + ) + + # Fourth part: original user text + assert llm_request.contents[3].role == "user" + assert llm_request.contents[3].parts + assert llm_request.contents[3].parts[0].text == "How are you?" + + +def test_gemma_functions_after_model_callback_markdown_json_block(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response from Gemma with a JSON function call in a markdown block + json_function_call_str = """ +```json +{"name": "search_web", "parameters": {"query": "latest news"}} +```""" + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "latest news"} + assert part.text is None + + +def test_gemma_functions_after_model_callback_markdown_tool_code_block(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response from Gemma with a JSON function call in a 'tool_code' markdown block + json_function_call_str = """ +Some text before. +```tool_code +{"name": "get_current_time", "parameters": {}} +``` +And some text after.""" + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "get_current_time" + assert part.function_call.args == {} + assert part.text is None + + +def test_gemma_functions_after_model_callback_embedded_json(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response with valid JSON embedded in text + embedded_json_str = ( + 'Please call the tool: {"name": "search_web", "parameters": {"query":' + ' "new features"}} thanks!' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=embedded_json_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "new features"} + assert part.text is None + + +def test_gemma_functions_after_model_callback_flexible_parsing(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Test with "function" and "args" keys as supported by GemmaFunctionCallModel + flexible_json_str = '{"function": "do_something", "args": {"value": 123}}' + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=flexible_json_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "do_something" + assert part.function_call.args == {"value": 123} + assert part.text is None + + +def test_gemma_functions_after_model_callback_last_json_object(): + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.google_llm import gemma_functions_after_model_callback + + # Simulate a response with multiple JSON objects, ensuring the last valid one is picked + multiple_json_str = ( + 'I thought about {"name": "first_call", "parameters": {"a": 1}} but then' + ' decided to call: {"name": "second_call", "parameters": {"b": 2}}' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=multiple_json_str)] + ) + ) + + gemma_functions_after_model_callback( + mock.MagicMock(spec=CallbackContext), llm_response + ) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "second_call" + assert part.function_call.args == {"b": 2} + assert part.text is None From 2820c395acffeda045451caff29fe4b8ee321395 Mon Sep 17 00:00:00 2001 From: Doug Reid <21148125+douglas-reid@users.noreply.github.com> Date: Thu, 4 Sep 2025 16:07:20 -0700 Subject: [PATCH 2/4] address review comments from bot --- contributing/samples/hello_world_gemma/agent.py | 2 +- src/google/adk/models/google_llm.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/contributing/samples/hello_world_gemma/agent.py b/contributing/samples/hello_world_gemma/agent.py index d8f3e2ef68..7e70005281 100644 --- a/contributing/samples/hello_world_gemma/agent.py +++ b/contributing/samples/hello_world_gemma/agent.py @@ -45,7 +45,7 @@ async def check_prime(nums: list[int]) -> str: """ primes = set() for number in nums: - number = int(number) + number = number if number <= 1: continue is_prime = True diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 67736be978..15dc5902b2 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -20,6 +20,7 @@ import json import logging import os +import re import sys from typing import Any from typing import AsyncGenerator @@ -445,6 +446,8 @@ def _remove_display_name_if_present( class Gemma(Gemini): """Integration for Gemma models exposed via the Gemini API. + Only Gemma 3 models are supported at this time. + For full documentation, see: https://ai.google.dev/gemma/docs/core/ NOTE: Gemma does **NOT** support system instructions. Any system instructions @@ -463,9 +466,7 @@ class Gemma(Gemini): usage via the Gemini API. """ - model: str = ( - 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] - ) + model: str = 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] @classmethod @override @@ -477,7 +478,7 @@ def supported_models(cls) -> list[str]: """ return [ - r'gemma-.*', + r'gemma-3.*', ] @cached_property @@ -494,7 +495,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: # NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning # of any chain of contents. - if len(contents) >= 1: + if contents: if contents[0] != instruction_content: # only prepend if it hasn't already been done llm_request.contents = [instruction_content] + contents @@ -663,8 +664,6 @@ def gemma_functions_after_model_callback( return try: - import re - json_candidate = None markdown_code_block_pattern = re.compile( @@ -720,7 +719,6 @@ def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: decoder = json.JSONDecoder() last_json_str = None start_pos = 0 - first_brace_index = 0 while start_pos < len(text): try: first_brace_index = text.index('{', start_pos) From 9435da1f4766416220f6f13ad60021ec57761e1c Mon Sep 17 00:00:00 2001 From: Doug Reid <21148125+douglas-reid@users.noreply.github.com> Date: Thu, 4 Sep 2025 16:13:00 -0700 Subject: [PATCH 3/4] apply autoformat --- src/google/adk/models/google_llm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 15dc5902b2..fe3fb3e55e 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -466,7 +466,9 @@ class Gemma(Gemini): usage via the Gemini API. """ - model: str = 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] + model: str = ( + 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] + ) @classmethod @override From 56f9e0de2d44983f61eeed168ad626ec5784746b Mon Sep 17 00:00:00 2001 From: Doug Reid <21148125+douglas-reid@users.noreply.github.com> Date: Tue, 16 Sep 2025 10:22:55 -0700 Subject: [PATCH 4/4] refactor: inline request/response processing --- .../samples/hello_world_gemma/agent.py | 6 +- src/google/adk/models/__init__.py | 2 +- src/google/adk/models/gemma_llm.py | 331 +++++++++++ src/google/adk/models/google_llm.py | 315 ----------- tests/integration/models/test_gemma_llm.py | 57 ++ tests/integration/models/test_google_llm.py | 63 +-- tests/unittests/models/test_gemma_llm.py | 506 +++++++++++++++++ tests/unittests/models/test_google_llm.py | 521 ------------------ 8 files changed, 903 insertions(+), 898 deletions(-) create mode 100644 src/google/adk/models/gemma_llm.py create mode 100644 tests/integration/models/test_gemma_llm.py create mode 100644 tests/unittests/models/test_gemma_llm.py diff --git a/contributing/samples/hello_world_gemma/agent.py b/contributing/samples/hello_world_gemma/agent.py index 7e70005281..3407d721d3 100644 --- a/contributing/samples/hello_world_gemma/agent.py +++ b/contributing/samples/hello_world_gemma/agent.py @@ -16,9 +16,7 @@ import random from google.adk.agents.llm_agent import Agent -from google.adk.models.google_llm import Gemma -from google.adk.models.google_llm import gemma_functions_after_model_callback -from google.adk.models.google_llm import gemma_functions_before_model_callback +from google.adk.models.gemma_llm import Gemma from google.genai.types import GenerateContentConfig @@ -90,8 +88,6 @@ async def check_prime(nums: list[int]) -> str: roll_die, check_prime, ], - before_model_callback=gemma_functions_before_model_callback, - after_model_callback=gemma_functions_after_model_callback, generate_content_config=GenerateContentConfig( temperature=1.0, top_p=0.95, diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index c7c8c61adb..c08570a96c 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -15,8 +15,8 @@ """Defines the interface to support a model.""" from .base_llm import BaseLlm +from .gemma_llm import Gemma from .google_llm import Gemini -from .google_llm import Gemma from .llm_request import LlmRequest from .llm_response import LlmResponse from .registry import LLMRegistry diff --git a/src/google/adk/models/gemma_llm.py b/src/google/adk/models/gemma_llm.py new file mode 100644 index 0000000000..3233d66f99 --- /dev/null +++ b/src/google/adk/models/gemma_llm.py @@ -0,0 +1,331 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import cached_property +import json +import logging +import re +from typing import Any +from typing import AsyncGenerator + +from google.adk.models.google_llm import Gemini +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.utils.variant_utils import GoogleLLMVariant +from google.genai import types +from google.genai.types import Content +from google.genai.types import FunctionDeclaration +from google.genai.types import Part +from pydantic import AliasChoices +from pydantic import BaseModel +from pydantic import Field +from pydantic import ValidationError +from typing_extensions import override + +logger = logging.getLogger('google_adk.' + __name__) + + +class GemmaFunctionCallModel(BaseModel): + """Flexible Pydantic model for parsing inline Gemma function call responses.""" + + name: str = Field(validation_alias=AliasChoices('name', 'function')) + parameters: dict[str, Any] = Field( + validation_alias=AliasChoices('parameters', 'args') + ) + + +class Gemma(Gemini): + """Integration for Gemma models exposed via the Gemini API. + + Only Gemma 3 models are supported at this time. For agentic use cases, + use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended. + + For full documentation, see: https://ai.google.dev/gemma/docs/core/ + + NOTE: Gemma does **NOT** support system instructions. Any system instructions + will be replaced with an initial *user* prompt in the LLM request. If system + instructions change over the course of agent execution, the initial content + **SHOULD** be replaced. Special care is warranted here. + See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions + + NOTE: Gemma's function calling support is limited. It does not have full access to the + same built-in tools as Gemini. It also does not have special API support for tools and + functions. Rather, tools must be passed in via a `user` prompt, and extracted from model + responses based on approximate shape. + + NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports + usage via the Gemini API. + """ + + model: str = ( + 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] + ) + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r'gemma-3.*', + ] + + @cached_property + def _api_backend(self) -> GoogleLLMVariant: + return GoogleLLMVariant.GEMINI_API + + def _move_function_calls_into_system_instruction( + self, llm_request: LlmRequest + ): + if llm_request.model is None or not llm_request.model.startswith('gemma-3'): + return + + # Iterate through the existing contents to find and convert function calls and responses + # from text parts, as Gemma models don't directly support function calling. + new_contents: list[Content] = [] + for content_item in llm_request.contents: + ( + new_parts_for_content, + has_function_response_part, + has_function_call_part, + ) = _convert_content_parts_for_gemma(content_item) + + if has_function_response_part: + if new_parts_for_content: + new_contents.append(Content(role='user', parts=new_parts_for_content)) + elif has_function_call_part: + if new_parts_for_content: + new_contents.append( + Content(role='model', parts=new_parts_for_content) + ) + else: + new_contents.append(content_item) + + llm_request.contents = new_contents + + if not llm_request.config.tools: + return + + all_function_declarations: list[FunctionDeclaration] = [] + for tool_item in llm_request.config.tools: + if isinstance(tool_item, types.Tool) and tool_item.function_declarations: + all_function_declarations.extend(tool_item.function_declarations) + + if all_function_declarations: + system_instruction = _build_gemma_function_system_instruction( + all_function_declarations + ) + llm_request.append_instructions([system_instruction]) + + llm_request.config.tools = [] + + def _extract_function_calls_from_response(self, llm_response: LlmResponse): + if llm_response.partial or (llm_response.turn_complete is True): + return + + if not llm_response.content: + return + + if not llm_response.content.parts: + return + + if len(llm_response.content.parts) > 1: + return + + response_text = llm_response.content.parts[0].text + if not response_text: + return + + try: + json_candidate = None + + markdown_code_block_pattern = re.compile( + r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL + ) + block_match = markdown_code_block_pattern.search(response_text) + + if block_match: + json_candidate = block_match.group(2).strip() + else: + found, json_text = _get_last_valid_json_substring(response_text) + if found: + json_candidate = json_text + + if not json_candidate: + return + + function_call_parsed = GemmaFunctionCallModel.model_validate_json( + json_candidate + ) + function_call = types.FunctionCall( + name=function_call_parsed.name, + args=function_call_parsed.parameters, + ) + function_call_part = Part(function_call=function_call) + llm_response.content.parts = [function_call_part] + except (json.JSONDecodeError, ValidationError) as e: + logger.debug( + f'Error attempting to parse JSON into function call. Leaving as text' + f' response. %s', + e, + ) + except Exception as e: + logger.warning('Error processing Gemma function call response: %s', e) + + @override + async def _preprocess_request(self, llm_request: LlmRequest) -> None: + self._move_function_calls_into_system_instruction(llm_request=llm_request) + + if system_instruction := llm_request.config.system_instruction: + contents = llm_request.contents + instruction_content = Content( + role='user', parts=[Part.from_text(text=system_instruction)] + ) + + # NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning + # of any chain of contents. + if contents: + if contents[0] != instruction_content: + # only prepend if it hasn't already been done + llm_request.contents = [instruction_content] + contents + + llm_request.config.system_instruction = None + + return await super()._preprocess_request(llm_request) + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to the Gemma model. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + # print(f'{llm_request=}') + assert llm_request.model.startswith('gemma-'), ( + f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM' + ' is not supported.' + ) + + async for response in super().generate_content_async(llm_request, stream): + self._extract_function_calls_from_response(response) + yield response + + +def _convert_content_parts_for_gemma( + content_item: Content, +) -> tuple[list[Part], bool, bool]: + """Converts function call/response parts within a content item to text parts. + + Args: + content_item: The original Content item. + + Returns: + A tuple containing: + - A list of new Part objects with function calls/responses converted to text. + - A boolean indicating if any function response parts were found. + - A boolean indicating if any function call parts were found. + """ + new_parts: list[Part] = [] + has_function_response_part = False + has_function_call_part = False + + for part in content_item.parts: + if func_response := part.function_response: + has_function_response_part = True + response_text = ( + f'Invoking tool `{func_response.name}` produced:' + f' `{json.dumps(func_response.response)}`.' + ) + new_parts.append(Part.from_text(text=response_text)) + elif func_call := part.function_call: + has_function_call_part = True + new_parts.append( + Part.from_text(text=func_call.model_dump_json(exclude_none=True)) + ) + else: + new_parts.append(part) + return new_parts, has_function_response_part, has_function_call_part + + +def _build_gemma_function_system_instruction( + function_declarations: list[FunctionDeclaration], +) -> str: + """Constructs the system instruction string for Gemma function calling.""" + if not function_declarations: + return '' + + system_instruction_prefix = 'You have access to the following functions:\n[' + instruction_parts = [] + for func in function_declarations: + instruction_parts.append(func.model_dump_json(exclude_none=True)) + + separator = ',\n' + system_instruction = ( + f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n' + ) + + system_instruction += ( + 'When you call a function, you MUST respond in the format of: ' + """{"name": function name, "parameters": dictionary of argument name and its value}\n""" + 'When you call a function, you MUST NOT include any other text in the' + ' response.\n' + ) + return system_instruction + + +def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: + """Attempts to find and return the last valid JSON object in a string. + + This function is designed to extract JSON that might be embedded in a larger + text, potentially with introductory or concluding remarks. It will always chose + the last block of valid json found within the supplied text (if it exists). + + Args: + text: The input string to search for JSON objects. + + Returns: + A tuple: + - bool: True if a valid JSON substring was found, False otherwise. + - str | None: The last valid JSON substring found, or None if none was + found. + """ + decoder = json.JSONDecoder() + last_json_str = None + start_pos = 0 + while start_pos < len(text): + try: + first_brace_index = text.index('{', start_pos) + _, end_index = decoder.raw_decode(text[first_brace_index:]) + last_json_str = text[first_brace_index : first_brace_index + end_index] + start_pos = first_brace_index + end_index + except json.JSONDecodeError: + start_pos = first_brace_index + 1 + except ValueError: + break + + if last_json_str: + return True, last_json_str + return False, None diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index fe3fb3e55e..57b1451784 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -17,28 +17,17 @@ import contextlib from functools import cached_property -import json import logging import os -import re import sys -from typing import Any from typing import AsyncGenerator from typing import cast from typing import Optional from typing import TYPE_CHECKING from typing import Union -from google.adk.agents.callback_context import CallbackContext from google.genai import Client from google.genai import types -from google.genai.types import Content -from google.genai.types import FunctionDeclaration -from google.genai.types import Part -from pydantic import BaseModel -from pydantic import ValidationError -from pydantic.aliases import AliasChoices -from pydantic.fields import Field from typing_extensions import override from .. import version @@ -55,16 +44,6 @@ logger = logging.getLogger('google_adk.' + __name__) - -class GemmaFunctionCallModel(BaseModel): - """Flexible Pydantic model for parsing inline Gemma function call responses.""" - - name: str = Field(validation_alias=AliasChoices('name', 'function')) - parameters: dict[str, Any] = Field( - validation_alias=AliasChoices('parameters', 'args') - ) - - _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} _AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' @@ -441,297 +420,3 @@ def _remove_display_name_if_present( """ if data_obj and data_obj.display_name: data_obj.display_name = None - - -class Gemma(Gemini): - """Integration for Gemma models exposed via the Gemini API. - - Only Gemma 3 models are supported at this time. - - For full documentation, see: https://ai.google.dev/gemma/docs/core/ - - NOTE: Gemma does **NOT** support system instructions. Any system instructions - will be replaced with an initial *user* prompt in the LLM request. If system - instructions change over the course of agent execution, the initial content - **SHOULD** be replaced. Special care is warranted here. - See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions - - NOTE: Gemma's function calling support is limited. It does not have full access to the - same built-in tools as Gemini. It also does not have special API support for tools and - functions. Rather, tools must be passed in via a `user` prompt, and extracted from model - responses based on approximate shape. For agent developments, please use the provided - `gemma_functions_before_model_callback` and `gemma_functions_after_model_callback` methods. - - NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports - usage via the Gemini API. - """ - - model: str = ( - 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] - ) - - @classmethod - @override - def supported_models(cls) -> list[str]: - """Provides the list of supported models. - - Returns: - A list of supported models. - """ - - return [ - r'gemma-3.*', - ] - - @cached_property - def _api_backend(self) -> GoogleLLMVariant: - return GoogleLLMVariant.GEMINI_API - - @override - async def _preprocess_request(self, llm_request: LlmRequest) -> None: - if system_instruction := llm_request.config.system_instruction: - contents = llm_request.contents - instruction_content = Content( - role='user', parts=[Part.from_text(text=system_instruction)] - ) - - # NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning - # of any chain of contents. - if contents: - if contents[0] != instruction_content: - # only prepend if it hasn't already been done - llm_request.contents = [instruction_content] + contents - - llm_request.config.system_instruction = None - - return await super()._preprocess_request(llm_request) - - @override - async def generate_content_async( - self, llm_request: LlmRequest, stream: bool = False - ) -> AsyncGenerator[LlmResponse, None]: - """Sends a request to the Gemma model. - - Args: - llm_request: LlmRequest, the request to send to the Gemini model. - stream: bool = False, whether to do streaming call. - - Yields: - LlmResponse: The model response. - """ - # print(f'{llm_request=}') - assert llm_request.model.startswith('gemma-'), ( - f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM' - ' is not supported.' - ) - - async for response in super().generate_content_async(llm_request, stream): - yield response - - -def _convert_content_parts_for_gemma( - content_item: Content, -) -> tuple[list[Part], bool, bool]: - """Converts function call/response parts within a content item to text parts. - - Args: - content_item: The original Content item. - - Returns: - A tuple containing: - - A list of new Part objects with function calls/responses converted to text. - - A boolean indicating if any function response parts were found. - - A boolean indicating if any function call parts were found. - """ - new_parts: list[Part] = [] - has_function_response_part = False - has_function_call_part = False - - for part in content_item.parts: - if func_response := part.function_response: - has_function_response_part = True - response_text = ( - f'Invoking tool `{func_response.name}` produced:' - f' `{json.dumps(func_response.response)}`.' - ) - new_parts.append(Part.from_text(text=response_text)) - elif func_call := part.function_call: - has_function_call_part = True - new_parts.append( - Part.from_text(text=func_call.model_dump_json(exclude_none=True)) - ) - else: - new_parts.append(part) - return new_parts, has_function_response_part, has_function_call_part - - -def _build_gemma_function_system_instruction( - function_declarations: list[FunctionDeclaration], -) -> str: - """Constructs the system instruction string for Gemma function calling.""" - if not function_declarations: - return '' - - system_instruction_prefix = 'You have access to the following functions:\n[' - instruction_parts = [] - for func in function_declarations: - instruction_parts.append(func.model_dump_json(exclude_none=True)) - - separator = ',\n' - system_instruction = ( - f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n' - ) - - system_instruction += ( - 'When you call a function, you MUST respond in the format of: ' - """{"name": function name, "parameters": dictionary of argument name and its value}\n""" - 'When you call a function, you MUST NOT include any other text in the' - ' response.\n' - ) - return system_instruction - - -def gemma_functions_before_model_callback( - callback_context: CallbackContext, llm_request: LlmRequest -): - """Translates function calls and responses to the Gemma-supported interaction model. - - NOTE: Gemma is **ONLY** able to handle external function declarations in a tool. It does NOT - have access to the internal Gemini tools (including Google and Enterprise Search, URL Context, etc.). - If the LLM Request includes those tools, they will be ignored and dropped from the request sent to - the model. - """ - - if llm_request.model is None or not llm_request.model.startswith('gemma-3'): - return - - # Iterate through the existing contents to find and convert function calls and responses - # from text parts, as Gemma models don't directly support function calling. - new_contents: list[Content] = [] - for content_item in llm_request.contents: - ( - new_parts_for_content, - has_function_response_part, - has_function_call_part, - ) = _convert_content_parts_for_gemma(content_item) - - if has_function_response_part: - if new_parts_for_content: - new_contents.append(Content(role='user', parts=new_parts_for_content)) - elif has_function_call_part: - if new_parts_for_content: - new_contents.append(Content(role='model', parts=new_parts_for_content)) - else: - new_contents.append(content_item) - - llm_request.contents = new_contents - - if not llm_request.config.tools: - return - - all_function_declarations: list[FunctionDeclaration] = [] - for tool_item in llm_request.config.tools: - if isinstance(tool_item, types.Tool) and tool_item.function_declarations: - all_function_declarations.extend(tool_item.function_declarations) - - if all_function_declarations: - system_instruction = _build_gemma_function_system_instruction( - all_function_declarations - ) - llm_request.append_instructions([system_instruction]) - - llm_request.config.tools = [] - - -def gemma_functions_after_model_callback( - callback_context: CallbackContext, llm_response: LlmResponse -): - """Translates function calls and responses to the Gemma-supported interaction model. - - Model function calls are attempted to be recognized in text responses and extracted into - the objects that can be exploited by `Agents`. Some flexibility in parsing is provided - in an attempt to improve model function in agentic systems. - """ - if not llm_response.content: - return - - if not llm_response.content.parts: - return - - if len(llm_response.content.parts) > 1: - return - - response_text = llm_response.content.parts[0].text - if not response_text: - return - - try: - json_candidate = None - - markdown_code_block_pattern = re.compile( - r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL - ) - block_match = markdown_code_block_pattern.search(response_text) - - if block_match: - json_candidate = block_match.group(2).strip() - else: - found, json_text = _get_last_valid_json_substring(response_text) - if found: - json_candidate = json_text - - if not json_candidate: - return - - function_call_parsed = GemmaFunctionCallModel.model_validate_json( - json_candidate - ) - function_call = types.FunctionCall( - name=function_call_parsed.name, - args=function_call_parsed.parameters, - ) - function_call_part = Part(function_call=function_call) - llm_response.content.parts = [function_call_part] - except (json.JSONDecodeError, ValidationError) as e: - logger.debug( - f'Error attempting to parse JSON into function call. Leaving as text' - f' response. %s', - e, - ) - except Exception as e: - logger.warning('Error processing Gemma function call response: %s', e) - - -def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: - """Attempts to find and return the last valid JSON object in a string. - - This function is designed to extract JSON that might be embedded in a larger - text, potentially with introductory or concluding remarks. It will always chose - the last block of valid json found within the supplied text (if it exists). - - Args: - text: The input string to search for JSON objects. - - Returns: - A tuple: - - bool: True if a valid JSON substring was found, False otherwise. - - str | None: The last valid JSON substring found, or None if none was - found. - """ - decoder = json.JSONDecoder() - last_json_str = None - start_pos = 0 - while start_pos < len(text): - try: - first_brace_index = text.index('{', start_pos) - _, end_index = decoder.raw_decode(text[first_brace_index:]) - last_json_str = text[first_brace_index : first_brace_index + end_index] - start_pos = first_brace_index + end_index - except json.JSONDecodeError: - start_pos = first_brace_index + 1 - except ValueError: - break - - if last_json_str: - return True, last_json_str - return False, None diff --git a/tests/integration/models/test_gemma_llm.py b/tests/integration/models/test_gemma_llm.py new file mode 100644 index 0000000000..81b9672a18 --- /dev/null +++ b/tests/integration/models/test_gemma_llm.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models.gemma_llm import Gemma +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + +DEFAULT_GEMMA_MODEL = "gemma-3-1b-it" + + +@pytest.fixture +def gemma_llm(): + return Gemma(model=DEFAULT_GEMMA_MODEL) + + +@pytest.fixture +def gemma_request(): + return LlmRequest( + model=DEFAULT_GEMMA_MODEL, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="You are a helpful assistant."), + Part.from_text(text="Hello!"), + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction="Talk like a pirate.", + ), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_backend", ["GOOGLE_AI"]) +async def test_generate_content_async(gemma_llm, gemma_request): + async for response in gemma_llm.generate_content_async(gemma_request): + assert isinstance(response, LlmResponse) + assert response.content.parts[0].text diff --git a/tests/integration/models/test_google_llm.py b/tests/integration/models/test_google_llm.py index 9edd6de7e9..5574eb30ef 100644 --- a/tests/integration/models/test_google_llm.py +++ b/tests/integration/models/test_google_llm.py @@ -13,7 +13,6 @@ # limitations under the License. from google.adk.models.google_llm import Gemini -from google.adk.models.google_llm import Gemma from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types @@ -21,29 +20,16 @@ from google.genai.types import Part import pytest -DEFAULT_GEMINI_MODEL = "gemini-1.5-flash" -DEFAULT_GEMMA_MODEL = "gemma-3-1b-it" - @pytest.fixture def gemini_llm(): - return Gemini(model=DEFAULT_GEMINI_MODEL) - - -@pytest.fixture -def gemma_llm(): - return Gemma(model=DEFAULT_GEMMA_MODEL) - - -@pytest.fixture -def llm(request): - return request.getfixturevalue(request.param) + return Gemini(model="gemini-1.5-flash") @pytest.fixture -def gemini_request(): +def llm_request(): return LlmRequest( - model=DEFAULT_GEMINI_MODEL, + model="gemini-1.5-flash", contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], config=types.GenerateContentConfig( temperature=0.1, @@ -53,54 +39,19 @@ def gemini_request(): ) -@pytest.fixture -def gemma_request(): - return LlmRequest( - model=DEFAULT_GEMMA_MODEL, - contents=[ - Content( - role="user", - parts=[ - Part.from_text(text="You are a helpful assistant."), - Part.from_text(text="Hello!"), - ], - ) - ], - config=types.GenerateContentConfig( - temperature=0.1, - response_modalities=[types.Modality.TEXT], - system_instruction="Talk like a pirate.", - ), - ) - - -@pytest.fixture -def llm_request(request): - return request.getfixturevalue(request.param) - - @pytest.mark.asyncio -@pytest.mark.parametrize( - "llm, llm_request, llm_backend", - [ - ("gemini_llm", "gemini_request", "GOOGLE_AI"), - ("gemini_llm", "gemini_request", "VERTEX"), - ("gemma_llm", "gemma_request", "GOOGLE_AI"), - ], - indirect=True, -) -async def test_generate_content_async(llm, llm_request): - async for response in llm.generate_content_async(llm_request): +async def test_generate_content_async(gemini_llm, llm_request): + async for response in gemini_llm.generate_content_async(llm_request): assert isinstance(response, LlmResponse) assert response.content.parts[0].text @pytest.mark.asyncio -async def test_generate_content_async_stream(gemini_llm, gemini_request): +async def test_generate_content_async_stream(gemini_llm, llm_request): responses = [ resp async for resp in gemini_llm.generate_content_async( - gemini_request, stream=True + llm_request, stream=True ) ] text = "" diff --git a/tests/unittests/models/test_gemma_llm.py b/tests/unittests/models/test_gemma_llm.py new file mode 100644 index 0000000000..2cf98306b9 --- /dev/null +++ b/tests/unittests/models/test_gemma_llm.py @@ -0,0 +1,506 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models.gemma_llm import Gemma +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model="gemma-3-4b-it", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction="You are a helpful assistant", + ), + ) + + +@pytest.fixture +def llm_request_with_duplicate_instruction(): + return LlmRequest( + model="gemma-3-1b-it", + contents=[ + Content( + role="user", + parts=[Part.from_text(text="Talk like a pirate.")], + ), + Content(role="user", parts=[Part.from_text(text="Hello")]), + ], + config=types.GenerateContentConfig( + response_modalities=[types.Modality.TEXT], + system_instruction="Talk like a pirate.", + ), + ) + + +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model="gemma-3-1b-it", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="search_web", + description="Search the web for a query.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema(type=types.Type.STRING) + }, + required=["query"], + ), + ), + types.FunctionDeclaration( + name="get_current_time", + description="Gets the current time.", + parameters=types.Schema( + type=types.Type.OBJECT, properties={} + ), + ), + ] + ) + ], + ), + ) + + +@pytest.mark.asyncio +async def test_not_gemma_model(): + llm = Gemma() + llm_request_bad_model = LlmRequest( + model="not-a-gemma-model", + ) + with pytest.raises(AssertionError, match=r".*model.*"): + async for _ in llm.generate_content_async(llm_request_bad_model): + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "llm_request", + ["llm_request", "llm_request_with_duplicate_instruction"], + indirect=True, +) +async def test_preprocess_request(llm_request): + llm = Gemma() + want_content_text = llm_request.config.system_instruction + + await llm._preprocess_request(llm_request) + + # system instruction should be cleared + assert not llm_request.config.system_instruction + # should be two content bits now (deduped, if needed) + assert len(llm_request.contents) == 2 + # first message in contents should be "user": + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts[0].text == want_content_text + + +@pytest.mark.asyncio +async def test_preprocess_request_with_tools(llm_request_with_tools): + + gemma = Gemma() + await gemma._preprocess_request(llm_request_with_tools) + + assert not llm_request_with_tools.config.tools + + # The original user content should now be the second item + assert llm_request_with_tools.contents[1].role == "user" + assert llm_request_with_tools.contents[1].parts[0].text == "Hello" + + sys_instruct_text = llm_request_with_tools.contents[0].parts[0].text + assert sys_instruct_text is not None + assert "You have access to the following functions" in sys_instruct_text + assert ( + """{"description":"Search the web for a query.","name":"search_web",""" + in sys_instruct_text + ) + assert ( + """{"description":"Gets the current time.","name":"get_current_time","parameters":{"properties":{}""" + in sys_instruct_text + ) + + +@pytest.mark.asyncio +async def test_preprocess_request_with_function_response(): + # Simulate an LlmRequest with a function response + func_response_data = types.FunctionResponse( + name="search_web", response={"results": [{"title": "ADK"}]} + ) + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="model", + parts=[types.Part(function_response=func_response_data)], + ) + ], + config=types.GenerateContentConfig(), + ) + + gemma = Gemma() + await gemma._preprocess_request(llm_request) + + # Assertions: function response converted to user role text content + assert llm_request.contents + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts + assert ( + llm_request.contents[0].parts[0].text + == 'Invoking tool `search_web` produced: `{"results": [{"title":' + ' "ADK"}]}`.' + ) + assert llm_request.contents[0].parts[0].function_response is None + assert llm_request.contents[0].parts[0].function_call is None + + +@pytest.mark.asyncio +async def test_preprocess_request_with_function_call(): + func_call_data = types.FunctionCall(name="get_current_time", args={}) + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", parts=[types.Part(function_call=func_call_data)] + ) + ], + ) + + gemma = Gemma() + await gemma._preprocess_request(llm_request) + + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "model" + expected_text = func_call_data.model_dump_json(exclude_none=True) + assert llm_request.contents[0].parts + got_part = llm_request.contents[0].parts[0] + assert got_part.text == expected_text + assert got_part.function_call is None + assert got_part.function_response is None + + +@pytest.mark.asyncio +async def test_preprocess_request_with_mixed_content(): + func_call = types.FunctionCall(name="get_weather", args={"city": "London"}) + func_response = types.FunctionResponse( + name="get_weather", response={"temp": "15C"} + ) + + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello!")] + ), + types.Content( + role="model", parts=[types.Part(function_call=func_call)] + ), + types.Content( + role="some_function", + parts=[types.Part(function_response=func_response)], + ), + types.Content( + role="user", parts=[types.Part.from_text(text="How are you?")] + ), + ], + ) + + gemma = Gemma() + await gemma._preprocess_request(llm_request) + + # Assertions + assert len(llm_request.contents) == 4 + + # First part: original user text + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts + assert llm_request.contents[0].parts[0].text == "Hello!" + + # Second part: function call converted to model text + assert llm_request.contents[1].role == "model" + assert llm_request.contents[1].parts + assert llm_request.contents[1].parts[0].text == func_call.model_dump_json( + exclude_none=True + ) + + # Third part: function response converted to user text + assert llm_request.contents[2].role == "user" + assert llm_request.contents[2].parts + assert ( + llm_request.contents[2].parts[0].text + == 'Invoking tool `get_weather` produced: `{"temp": "15C"}`.' + ) + + # Fourth part: original user text + assert llm_request.contents[3].role == "user" + assert llm_request.contents[3].parts + assert llm_request.contents[3].parts[0].text == "How are you?" + + +def test_process_response(): + # Simulate a response from Gemma that should be converted to a FunctionCall + json_function_call_str = ( + '{"name": "search_web", "parameters": {"query": "latest news"}}' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response=llm_response) + + # Assert that the content was transformed into a FunctionCall + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "latest news"} + # Assert that the entire part matches the expected structure + expected_function_call = types.FunctionCall( + name="search_web", args={"query": "latest news"} + ) + expected_part = Part(function_call=expected_function_call) + assert part == expected_part + assert part.text is None # Ensure text part is cleared + + +def test_process_response_invalid_json_text(): + # Simulate a response with plain text that is not JSON + original_text = "This is a regular text response." + llm_response = LlmResponse( + content=Content(role="model", parts=[Part.from_text(text=original_text)]) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response=llm_response) + + # Assert that the content remains unchanged + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].text == original_text + assert llm_response.content.parts[0].function_call is None + + +def test_process_response_malformed_json(): + # Simulate a response with valid JSON but not in the function call format + malformed_json_str = '{"not_a_function": "value", "another_field": 123}' + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=malformed_json_str)] + ) + ) + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response=llm_response) + + # Assert that the content remains unchanged because it doesn't match the expected schema + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].text == malformed_json_str + assert llm_response.content.parts[0].function_call is None + + +def test_process_response_empty_content_or_multiple_parts(): + gemma = Gemma() + + # Test case 1: LlmResponse with no content + llm_response_no_content = LlmResponse(content=None) + gemma._extract_function_calls_from_response( + llm_response=llm_response_no_content + ) + assert llm_response_no_content.content is None + + # Test case 2: LlmResponse with empty parts list + llm_response_empty_parts = LlmResponse( + content=Content(role="model", parts=[]) + ) + gemma._extract_function_calls_from_response( + llm_response=llm_response_empty_parts + ) + assert llm_response_empty_parts.content + assert not llm_response_empty_parts.content.parts + + # Test case 3: LlmResponse with multiple parts + llm_response_multiple_parts = LlmResponse( + content=Content( + role="model", + parts=[ + Part.from_text(text="part one"), + Part.from_text(text="part two"), + ], + ) + ) + original_parts = list( + llm_response_multiple_parts.content.parts + ) # Copy for comparison + gemma._extract_function_calls_from_response( + llm_response=llm_response_multiple_parts + ) + assert llm_response_multiple_parts.content + assert ( + llm_response_multiple_parts.content.parts == original_parts + ) # Should remain unchanged + + # Test case 4: LlmResponse with one part, but empty text + llm_response_empty_text_part = LlmResponse( + content=Content(role="model", parts=[Part.from_text(text="")]) + ) + gemma._extract_function_calls_from_response( + llm_response=llm_response_empty_text_part + ) + assert llm_response_empty_text_part.content + assert llm_response_empty_text_part.content.parts + assert llm_response_empty_text_part.content.parts[0].text == "" + assert llm_response_empty_text_part.content.parts[0].function_call is None + + +def test_process_response_with_markdown_json_block(): + # Simulate a response from Gemma with a JSON function call in a markdown block + json_function_call_str = """ +```json +{"name": "search_web", "parameters": {"query": "latest news"}} +```""" + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "latest news"} + assert part.text is None + + +def test_process_response_with_markdown_tool_code_block(): + # Simulate a response from Gemma with a JSON function call in a 'tool_code' markdown block + json_function_call_str = """ +Some text before. +```tool_code +{"name": "get_current_time", "parameters": {}} +``` +And some text after.""" + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "get_current_time" + assert part.function_call.args == {} + assert part.text is None + + +def test_process_response_with_embedded_json(): + # Simulate a response with valid JSON embedded in text + embedded_json_str = ( + 'Please call the tool: {"name": "search_web", "parameters": {"query":' + ' "new features"}} thanks!' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=embedded_json_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "new features"} + assert part.text is None + + +def test_process_response_flexible_parsing(): + # Test with "function" and "args" keys as supported by GemmaFunctionCallModel + flexible_json_str = '{"function": "do_something", "args": {"value": 123}}' + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=flexible_json_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "do_something" + assert part.function_call.args == {"value": 123} + assert part.text is None + + +def test_process_response_last_json_object(): + # Simulate a response with multiple JSON objects, ensuring the last valid one is picked + multiple_json_str = ( + 'I thought about {"name": "first_call", "parameters": {"a": 1}} but then' + ' decided to call: {"name": "second_call", "parameters": {"b": 2}}' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=multiple_json_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "second_call" + assert part.function_call.args == {"b": 2} + assert part.text is None diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index f9639e87d5..985d7e91cf 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import sys from typing import Optional @@ -26,13 +25,11 @@ from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.models.google_llm import Gemini -from google.adk.models.google_llm import Gemma from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.types import Content -from google.genai.types import GenerateContentConfigOrDict from google.genai.types import Part import pytest @@ -140,62 +137,6 @@ def llm_request_with_computer_use(): ) -@pytest.fixture -def llm_request_with_duplicate_instruction(): - return LlmRequest( - model="gemma-3-1b-it", - contents=[ - types.Content( - role="user", - parts=[types.Part.from_text(text="Talk like a pirate.")], - ), - types.Content( - role="user", parts=[types.Part.from_text(text="Hello")] - ), - ], - config=types.GenerateContentConfig( - response_modalities=[types.Modality.TEXT], - system_instruction="Talk like a pirate.", - ), - ) - - -@pytest.fixture -def llm_request_with_tools(): - return LlmRequest( - model="gemma-3-1b-it", - contents=[ - types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) - ], - config=types.GenerateContentConfig( - tools=[ - types.Tool( - function_declarations=[ - types.FunctionDeclaration( - name="search_web", - description="Search the web for a query.", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - "query": types.Schema(type=types.Type.STRING) - }, - required=["query"], - ), - ), - types.FunctionDeclaration( - name="get_current_time", - description="Gets the current time.", - parameters=types.Schema( - type=types.Type.OBJECT, properties={} - ), - ), - ] - ) - ], - ), - ) - - @pytest.fixture def mock_os_environ(): initial_env = os.environ.copy() @@ -1785,465 +1726,3 @@ async def mock_coro(): # Verify cache metadata is preserved assert second_arg.cache_name == cache_metadata.cache_name assert second_arg.invocations_used == cache_metadata.invocations_used -async def test_not_gemma_model(): - llm = Gemma() - llm_request_bad_model = LlmRequest( - model="not-a-gemma-model", - ) - with pytest.raises(AssertionError, match=r".*model.*"): - async for _ in llm.generate_content_async(llm_request_bad_model): - pass - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "llm_request", - ["llm_request", "llm_request_with_duplicate_instruction"], - indirect=True, -) -async def test_gemma_request_preprocess(llm_request): - llm = Gemma() - want_content_text = llm_request.config.system_instruction - - await llm._preprocess_request(llm_request=llm_request) - - # system instruction should be cleared - assert not llm_request.config.system_instruction - # should be two content bits now (deduped, if needed) - assert len(llm_request.contents) == 2 - # first message in contents should be "user": - assert llm_request.contents[0].role == "user" - assert llm_request.contents[0].parts[0].text == want_content_text - - -def test_gemma_functions_before_model_callback(llm_request_with_tools): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_before_model_callback - - gemma_functions_before_model_callback( - mock.MagicMock(spec=CallbackContext), llm_request_with_tools - ) - - assert not llm_request_with_tools.config.tools - - # The original user content should still be the first item - assert llm_request_with_tools.contents[0].role == "user" - assert llm_request_with_tools.contents[0].parts[0].text == "Hello" - - sys_instruct_text = llm_request_with_tools.config.system_instruction - assert sys_instruct_text is not None - assert "You have access to the following functions" in sys_instruct_text - assert ( - """{"description":"Search the web for a query.","name":"search_web",""" - in sys_instruct_text - ) - assert ( - """{"description":"Gets the current time.","name":"get_current_time","parameters":{"properties":{}""" - in sys_instruct_text - ) - - -def test_gemma_functions_after_model_callback_valid_json_function_call(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response from Gemma that should be converted to a FunctionCall - json_function_call_str = ( - '{"name": "search_web", "parameters": {"query": "latest news"}}' - ) - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=json_function_call_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - # Assert that the content was transformed into a FunctionCall - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "search_web" - assert part.function_call.args == {"query": "latest news"} - # Assert that the entire part matches the expected structure - expected_function_call = types.FunctionCall( - name="search_web", args={"query": "latest news"} - ) - expected_part = Part(function_call=expected_function_call) - assert part == expected_part - assert part.text is None # Ensure text part is cleared - - -def test_gemma_functions_after_model_callback_invalid_json_text(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response with plain text that is not JSON - original_text = "This is a regular text response." - llm_response = LlmResponse( - content=Content(role="model", parts=[Part.from_text(text=original_text)]) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - # Assert that the content remains unchanged - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - assert llm_response.content.parts[0].text == original_text - assert llm_response.content.parts[0].function_call is None - - -def test_gemma_functions_after_model_callback_malformed_json(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response with valid JSON but not in the function call format - malformed_json_str = '{"not_a_function": "value", "another_field": 123}' - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=malformed_json_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - # Assert that the content remains unchanged because it doesn't match the expected schema - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - assert llm_response.content.parts[0].text == malformed_json_str - assert llm_response.content.parts[0].function_call is None - - -def test_gemma_functions_after_model_callback_empty_content_or_multiple_parts(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Test case 1: LlmResponse with no content - llm_response_no_content = LlmResponse(content=None) - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response_no_content - ) - assert llm_response_no_content.content is None - - # Test case 2: LlmResponse with empty parts list - llm_response_empty_parts = LlmResponse( - content=Content(role="model", parts=[]) - ) - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response_empty_parts - ) - assert llm_response_empty_parts.content - assert not llm_response_empty_parts.content.parts - - # Test case 3: LlmResponse with multiple parts - llm_response_multiple_parts = LlmResponse( - content=Content( - role="model", - parts=[ - Part.from_text(text="part one"), - Part.from_text(text="part two"), - ], - ) - ) - original_parts = list( - llm_response_multiple_parts.content.parts - ) # Copy for comparison - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response_multiple_parts - ) - assert llm_response_multiple_parts.content - assert ( - llm_response_multiple_parts.content.parts == original_parts - ) # Should remain unchanged - - # Test case 4: LlmResponse with one part, but empty text - llm_response_empty_text_part = LlmResponse( - content=Content(role="model", parts=[Part.from_text(text="")]) - ) - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response_empty_text_part - ) - assert llm_response_empty_text_part.content - assert llm_response_empty_text_part.content.parts - assert llm_response_empty_text_part.content.parts[0].text == "" - assert llm_response_empty_text_part.content.parts[0].function_call is None - - -def test_gemma_functions_before_model_callback_with_function_response(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_before_model_callback - - # Simulate an LlmRequest with a function response - func_response_data = types.FunctionResponse( - name="search_web", response={"results": [{"title": "ADK"}]} - ) - llm_request = LlmRequest( - model="gemma-3-1b-it", - contents=[ - types.Content( - role="model", - parts=[types.Part(function_response=func_response_data)], - ) - ], - config=types.GenerateContentConfig(), - ) - - gemma_functions_before_model_callback( - mock.MagicMock(spec=CallbackContext), llm_request - ) - - # Assertions: function response converted to user role text content - assert llm_request.contents - assert len(llm_request.contents) == 1 - assert llm_request.contents[0].role == "user" - assert llm_request.contents[0].parts - assert ( - llm_request.contents[0].parts[0].text - == 'Invoking tool `search_web` produced: `{"results": [{"title":' - ' "ADK"}]}`.' - ) - assert llm_request.contents[0].parts[0].function_response is None - assert llm_request.contents[0].parts[0].function_call is None - - -def test_gemma_functions_before_model_callback_with_function_call(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_before_model_callback - - func_call_data = types.FunctionCall(name="get_current_time", args={}) - llm_request = LlmRequest( - model="gemma-3-1b-it", - contents=[ - types.Content( - role="user", parts=[types.Part(function_call=func_call_data)] - ) - ], - ) - - gemma_functions_before_model_callback( - mock.MagicMock(spec=CallbackContext), llm_request - ) - - assert len(llm_request.contents) == 1 - assert llm_request.contents[0].role == "model" - expected_text = func_call_data.model_dump_json(exclude_none=True) - assert llm_request.contents[0].parts - got_part = llm_request.contents[0].parts[0] - assert got_part.text == expected_text - assert got_part.function_call is None - assert got_part.function_response is None - - -def test_gemma_functions_before_model_callback_mixed_content(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_before_model_callback - - func_call = types.FunctionCall(name="get_weather", args={"city": "London"}) - func_response = types.FunctionResponse( - name="get_weather", response={"temp": "15C"} - ) - - llm_request = LlmRequest( - model="gemma-3-1b-it", - contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hello!")] - ), - types.Content( - role="model", parts=[types.Part(function_call=func_call)] - ), - types.Content( - role="some_function", - parts=[types.Part(function_response=func_response)], - ), - types.Content( - role="user", parts=[types.Part.from_text(text="How are you?")] - ), - ], - ) - - gemma_functions_before_model_callback( - mock.MagicMock(spec=CallbackContext), llm_request - ) - - # Assertions - assert len(llm_request.contents) == 4 - - # First part: original user text - assert llm_request.contents[0].role == "user" - assert llm_request.contents[0].parts - assert llm_request.contents[0].parts[0].text == "Hello!" - - # Second part: function call converted to model text - assert llm_request.contents[1].role == "model" - assert llm_request.contents[1].parts - assert llm_request.contents[1].parts[0].text == func_call.model_dump_json( - exclude_none=True - ) - - # Third part: function response converted to user text - assert llm_request.contents[2].role == "user" - assert llm_request.contents[2].parts - assert ( - llm_request.contents[2].parts[0].text - == 'Invoking tool `get_weather` produced: `{"temp": "15C"}`.' - ) - - # Fourth part: original user text - assert llm_request.contents[3].role == "user" - assert llm_request.contents[3].parts - assert llm_request.contents[3].parts[0].text == "How are you?" - - -def test_gemma_functions_after_model_callback_markdown_json_block(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response from Gemma with a JSON function call in a markdown block - json_function_call_str = """ -```json -{"name": "search_web", "parameters": {"query": "latest news"}} -```""" - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=json_function_call_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "search_web" - assert part.function_call.args == {"query": "latest news"} - assert part.text is None - - -def test_gemma_functions_after_model_callback_markdown_tool_code_block(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response from Gemma with a JSON function call in a 'tool_code' markdown block - json_function_call_str = """ -Some text before. -```tool_code -{"name": "get_current_time", "parameters": {}} -``` -And some text after.""" - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=json_function_call_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "get_current_time" - assert part.function_call.args == {} - assert part.text is None - - -def test_gemma_functions_after_model_callback_embedded_json(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response with valid JSON embedded in text - embedded_json_str = ( - 'Please call the tool: {"name": "search_web", "parameters": {"query":' - ' "new features"}} thanks!' - ) - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=embedded_json_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "search_web" - assert part.function_call.args == {"query": "new features"} - assert part.text is None - - -def test_gemma_functions_after_model_callback_flexible_parsing(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Test with "function" and "args" keys as supported by GemmaFunctionCallModel - flexible_json_str = '{"function": "do_something", "args": {"value": 123}}' - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=flexible_json_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "do_something" - assert part.function_call.args == {"value": 123} - assert part.text is None - - -def test_gemma_functions_after_model_callback_last_json_object(): - from google.adk.agents.callback_context import CallbackContext - from google.adk.models.google_llm import gemma_functions_after_model_callback - - # Simulate a response with multiple JSON objects, ensuring the last valid one is picked - multiple_json_str = ( - 'I thought about {"name": "first_call", "parameters": {"a": 1}} but then' - ' decided to call: {"name": "second_call", "parameters": {"b": 2}}' - ) - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=multiple_json_str)] - ) - ) - - gemma_functions_after_model_callback( - mock.MagicMock(spec=CallbackContext), llm_response - ) - - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "second_call" - assert part.function_call.args == {"b": 2} - assert part.text is None