diff --git a/contributing/samples/hello_world_gemma/__init__.py b/contributing/samples/hello_world_gemma/__init__.py new file mode 100644 index 0000000000..7d5bb0b1c6 --- /dev/null +++ b/contributing/samples/hello_world_gemma/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_gemma/agent.py b/contributing/samples/hello_world_gemma/agent.py new file mode 100644 index 0000000000..3407d721d3 --- /dev/null +++ b/contributing/samples/hello_world_gemma/agent.py @@ -0,0 +1,95 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random + +from google.adk.agents.llm_agent import Agent +from google.adk.models.gemma_llm import Gemma +from google.genai.types import GenerateContentConfig + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = number + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model=Gemma(model="gemma-3-27b-it"), + name="data_processing_agent", + description=( + "hello world agent that can roll many-sided dice and check if numbers" + " are prime." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After the user reports a response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + generate_content_config=GenerateContentConfig( + temperature=1.0, + top_p=0.95, + ), +) diff --git a/contributing/samples/hello_world_gemma/main.py b/contributing/samples/hello_world_gemma/main.py new file mode 100644 index 0000000000..f177064b68 --- /dev/null +++ b/contributing/samples/hello_world_gemma/main.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import time + +import agent +from dotenv import load_dotenv +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder(level=logging.INFO) + + +async def main(): + app_name = 'my_gemma_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt( + session_11, 'Roll a die with 100 sides and check if it is prime' + ) + await run_prompt(session_11, 'Roll it again.') + await run_prompt(session_11, 'What numbers did I get?') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index fc86c197ca..c08570a96c 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -15,6 +15,7 @@ """Defines the interface to support a model.""" from .base_llm import BaseLlm +from .gemma_llm import Gemma from .google_llm import Gemini from .llm_request import LlmRequest from .llm_response import LlmResponse @@ -23,9 +24,10 @@ __all__ = [ 'BaseLlm', 'Gemini', + 'Gemma', 'LLMRegistry', ] -for regex in Gemini.supported_models(): - LLMRegistry.register(Gemini) +LLMRegistry.register(Gemini) +LLMRegistry.register(Gemma) diff --git a/src/google/adk/models/gemma_llm.py b/src/google/adk/models/gemma_llm.py new file mode 100644 index 0000000000..3233d66f99 --- /dev/null +++ b/src/google/adk/models/gemma_llm.py @@ -0,0 +1,331 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import cached_property +import json +import logging +import re +from typing import Any +from typing import AsyncGenerator + +from google.adk.models.google_llm import Gemini +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.utils.variant_utils import GoogleLLMVariant +from google.genai import types +from google.genai.types import Content +from google.genai.types import FunctionDeclaration +from google.genai.types import Part +from pydantic import AliasChoices +from pydantic import BaseModel +from pydantic import Field +from pydantic import ValidationError +from typing_extensions import override + +logger = logging.getLogger('google_adk.' + __name__) + + +class GemmaFunctionCallModel(BaseModel): + """Flexible Pydantic model for parsing inline Gemma function call responses.""" + + name: str = Field(validation_alias=AliasChoices('name', 'function')) + parameters: dict[str, Any] = Field( + validation_alias=AliasChoices('parameters', 'args') + ) + + +class Gemma(Gemini): + """Integration for Gemma models exposed via the Gemini API. + + Only Gemma 3 models are supported at this time. For agentic use cases, + use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended. + + For full documentation, see: https://ai.google.dev/gemma/docs/core/ + + NOTE: Gemma does **NOT** support system instructions. Any system instructions + will be replaced with an initial *user* prompt in the LLM request. If system + instructions change over the course of agent execution, the initial content + **SHOULD** be replaced. Special care is warranted here. + See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions + + NOTE: Gemma's function calling support is limited. It does not have full access to the + same built-in tools as Gemini. It also does not have special API support for tools and + functions. Rather, tools must be passed in via a `user` prompt, and extracted from model + responses based on approximate shape. + + NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports + usage via the Gemini API. + """ + + model: str = ( + 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] + ) + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r'gemma-3.*', + ] + + @cached_property + def _api_backend(self) -> GoogleLLMVariant: + return GoogleLLMVariant.GEMINI_API + + def _move_function_calls_into_system_instruction( + self, llm_request: LlmRequest + ): + if llm_request.model is None or not llm_request.model.startswith('gemma-3'): + return + + # Iterate through the existing contents to find and convert function calls and responses + # from text parts, as Gemma models don't directly support function calling. + new_contents: list[Content] = [] + for content_item in llm_request.contents: + ( + new_parts_for_content, + has_function_response_part, + has_function_call_part, + ) = _convert_content_parts_for_gemma(content_item) + + if has_function_response_part: + if new_parts_for_content: + new_contents.append(Content(role='user', parts=new_parts_for_content)) + elif has_function_call_part: + if new_parts_for_content: + new_contents.append( + Content(role='model', parts=new_parts_for_content) + ) + else: + new_contents.append(content_item) + + llm_request.contents = new_contents + + if not llm_request.config.tools: + return + + all_function_declarations: list[FunctionDeclaration] = [] + for tool_item in llm_request.config.tools: + if isinstance(tool_item, types.Tool) and tool_item.function_declarations: + all_function_declarations.extend(tool_item.function_declarations) + + if all_function_declarations: + system_instruction = _build_gemma_function_system_instruction( + all_function_declarations + ) + llm_request.append_instructions([system_instruction]) + + llm_request.config.tools = [] + + def _extract_function_calls_from_response(self, llm_response: LlmResponse): + if llm_response.partial or (llm_response.turn_complete is True): + return + + if not llm_response.content: + return + + if not llm_response.content.parts: + return + + if len(llm_response.content.parts) > 1: + return + + response_text = llm_response.content.parts[0].text + if not response_text: + return + + try: + json_candidate = None + + markdown_code_block_pattern = re.compile( + r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL + ) + block_match = markdown_code_block_pattern.search(response_text) + + if block_match: + json_candidate = block_match.group(2).strip() + else: + found, json_text = _get_last_valid_json_substring(response_text) + if found: + json_candidate = json_text + + if not json_candidate: + return + + function_call_parsed = GemmaFunctionCallModel.model_validate_json( + json_candidate + ) + function_call = types.FunctionCall( + name=function_call_parsed.name, + args=function_call_parsed.parameters, + ) + function_call_part = Part(function_call=function_call) + llm_response.content.parts = [function_call_part] + except (json.JSONDecodeError, ValidationError) as e: + logger.debug( + f'Error attempting to parse JSON into function call. Leaving as text' + f' response. %s', + e, + ) + except Exception as e: + logger.warning('Error processing Gemma function call response: %s', e) + + @override + async def _preprocess_request(self, llm_request: LlmRequest) -> None: + self._move_function_calls_into_system_instruction(llm_request=llm_request) + + if system_instruction := llm_request.config.system_instruction: + contents = llm_request.contents + instruction_content = Content( + role='user', parts=[Part.from_text(text=system_instruction)] + ) + + # NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning + # of any chain of contents. + if contents: + if contents[0] != instruction_content: + # only prepend if it hasn't already been done + llm_request.contents = [instruction_content] + contents + + llm_request.config.system_instruction = None + + return await super()._preprocess_request(llm_request) + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to the Gemma model. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + # print(f'{llm_request=}') + assert llm_request.model.startswith('gemma-'), ( + f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM' + ' is not supported.' + ) + + async for response in super().generate_content_async(llm_request, stream): + self._extract_function_calls_from_response(response) + yield response + + +def _convert_content_parts_for_gemma( + content_item: Content, +) -> tuple[list[Part], bool, bool]: + """Converts function call/response parts within a content item to text parts. + + Args: + content_item: The original Content item. + + Returns: + A tuple containing: + - A list of new Part objects with function calls/responses converted to text. + - A boolean indicating if any function response parts were found. + - A boolean indicating if any function call parts were found. + """ + new_parts: list[Part] = [] + has_function_response_part = False + has_function_call_part = False + + for part in content_item.parts: + if func_response := part.function_response: + has_function_response_part = True + response_text = ( + f'Invoking tool `{func_response.name}` produced:' + f' `{json.dumps(func_response.response)}`.' + ) + new_parts.append(Part.from_text(text=response_text)) + elif func_call := part.function_call: + has_function_call_part = True + new_parts.append( + Part.from_text(text=func_call.model_dump_json(exclude_none=True)) + ) + else: + new_parts.append(part) + return new_parts, has_function_response_part, has_function_call_part + + +def _build_gemma_function_system_instruction( + function_declarations: list[FunctionDeclaration], +) -> str: + """Constructs the system instruction string for Gemma function calling.""" + if not function_declarations: + return '' + + system_instruction_prefix = 'You have access to the following functions:\n[' + instruction_parts = [] + for func in function_declarations: + instruction_parts.append(func.model_dump_json(exclude_none=True)) + + separator = ',\n' + system_instruction = ( + f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n' + ) + + system_instruction += ( + 'When you call a function, you MUST respond in the format of: ' + """{"name": function name, "parameters": dictionary of argument name and its value}\n""" + 'When you call a function, you MUST NOT include any other text in the' + ' response.\n' + ) + return system_instruction + + +def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: + """Attempts to find and return the last valid JSON object in a string. + + This function is designed to extract JSON that might be embedded in a larger + text, potentially with introductory or concluding remarks. It will always chose + the last block of valid json found within the supplied text (if it exists). + + Args: + text: The input string to search for JSON objects. + + Returns: + A tuple: + - bool: True if a valid JSON substring was found, False otherwise. + - str | None: The last valid JSON substring found, or None if none was + found. + """ + decoder = json.JSONDecoder() + last_json_str = None + start_pos = 0 + while start_pos < len(text): + try: + first_brace_index = text.index('{', start_pos) + _, end_index = decoder.raw_decode(text[first_brace_index:]) + last_json_str = text[first_brace_index : first_brace_index + end_index] + start_pos = first_brace_index + end_index + except json.JSONDecodeError: + start_pos = first_brace_index + 1 + except ValueError: + break + + if last_json_str: + return True, last_json_str + return False, None diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 6dc1f3d1bb..45e720a579 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -114,6 +114,6 @@ def pytest_generate_tests(metafunc: Metafunc): def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool: if hasattr(metafunc.function, 'pytestmark'): for mark in metafunc.function.pytestmark: - if mark.name == 'parametrize' and mark.args[0] == mark_name: + if mark.name == 'parametrize' and mark_name in mark.args[0]: return True return False diff --git a/tests/integration/models/test_gemma_llm.py b/tests/integration/models/test_gemma_llm.py new file mode 100644 index 0000000000..81b9672a18 --- /dev/null +++ b/tests/integration/models/test_gemma_llm.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models.gemma_llm import Gemma +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + +DEFAULT_GEMMA_MODEL = "gemma-3-1b-it" + + +@pytest.fixture +def gemma_llm(): + return Gemma(model=DEFAULT_GEMMA_MODEL) + + +@pytest.fixture +def gemma_request(): + return LlmRequest( + model=DEFAULT_GEMMA_MODEL, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="You are a helpful assistant."), + Part.from_text(text="Hello!"), + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction="Talk like a pirate.", + ), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_backend", ["GOOGLE_AI"]) +async def test_generate_content_async(gemma_llm, gemma_request): + async for response in gemma_llm.generate_content_async(gemma_request): + assert isinstance(response, LlmResponse) + assert response.content.parts[0].text diff --git a/tests/unittests/models/test_gemma_llm.py b/tests/unittests/models/test_gemma_llm.py new file mode 100644 index 0000000000..2cf98306b9 --- /dev/null +++ b/tests/unittests/models/test_gemma_llm.py @@ -0,0 +1,506 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models.gemma_llm import Gemma +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model="gemma-3-4b-it", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction="You are a helpful assistant", + ), + ) + + +@pytest.fixture +def llm_request_with_duplicate_instruction(): + return LlmRequest( + model="gemma-3-1b-it", + contents=[ + Content( + role="user", + parts=[Part.from_text(text="Talk like a pirate.")], + ), + Content(role="user", parts=[Part.from_text(text="Hello")]), + ], + config=types.GenerateContentConfig( + response_modalities=[types.Modality.TEXT], + system_instruction="Talk like a pirate.", + ), + ) + + +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model="gemma-3-1b-it", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="search_web", + description="Search the web for a query.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema(type=types.Type.STRING) + }, + required=["query"], + ), + ), + types.FunctionDeclaration( + name="get_current_time", + description="Gets the current time.", + parameters=types.Schema( + type=types.Type.OBJECT, properties={} + ), + ), + ] + ) + ], + ), + ) + + +@pytest.mark.asyncio +async def test_not_gemma_model(): + llm = Gemma() + llm_request_bad_model = LlmRequest( + model="not-a-gemma-model", + ) + with pytest.raises(AssertionError, match=r".*model.*"): + async for _ in llm.generate_content_async(llm_request_bad_model): + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "llm_request", + ["llm_request", "llm_request_with_duplicate_instruction"], + indirect=True, +) +async def test_preprocess_request(llm_request): + llm = Gemma() + want_content_text = llm_request.config.system_instruction + + await llm._preprocess_request(llm_request) + + # system instruction should be cleared + assert not llm_request.config.system_instruction + # should be two content bits now (deduped, if needed) + assert len(llm_request.contents) == 2 + # first message in contents should be "user": + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts[0].text == want_content_text + + +@pytest.mark.asyncio +async def test_preprocess_request_with_tools(llm_request_with_tools): + + gemma = Gemma() + await gemma._preprocess_request(llm_request_with_tools) + + assert not llm_request_with_tools.config.tools + + # The original user content should now be the second item + assert llm_request_with_tools.contents[1].role == "user" + assert llm_request_with_tools.contents[1].parts[0].text == "Hello" + + sys_instruct_text = llm_request_with_tools.contents[0].parts[0].text + assert sys_instruct_text is not None + assert "You have access to the following functions" in sys_instruct_text + assert ( + """{"description":"Search the web for a query.","name":"search_web",""" + in sys_instruct_text + ) + assert ( + """{"description":"Gets the current time.","name":"get_current_time","parameters":{"properties":{}""" + in sys_instruct_text + ) + + +@pytest.mark.asyncio +async def test_preprocess_request_with_function_response(): + # Simulate an LlmRequest with a function response + func_response_data = types.FunctionResponse( + name="search_web", response={"results": [{"title": "ADK"}]} + ) + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="model", + parts=[types.Part(function_response=func_response_data)], + ) + ], + config=types.GenerateContentConfig(), + ) + + gemma = Gemma() + await gemma._preprocess_request(llm_request) + + # Assertions: function response converted to user role text content + assert llm_request.contents + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts + assert ( + llm_request.contents[0].parts[0].text + == 'Invoking tool `search_web` produced: `{"results": [{"title":' + ' "ADK"}]}`.' + ) + assert llm_request.contents[0].parts[0].function_response is None + assert llm_request.contents[0].parts[0].function_call is None + + +@pytest.mark.asyncio +async def test_preprocess_request_with_function_call(): + func_call_data = types.FunctionCall(name="get_current_time", args={}) + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", parts=[types.Part(function_call=func_call_data)] + ) + ], + ) + + gemma = Gemma() + await gemma._preprocess_request(llm_request) + + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "model" + expected_text = func_call_data.model_dump_json(exclude_none=True) + assert llm_request.contents[0].parts + got_part = llm_request.contents[0].parts[0] + assert got_part.text == expected_text + assert got_part.function_call is None + assert got_part.function_response is None + + +@pytest.mark.asyncio +async def test_preprocess_request_with_mixed_content(): + func_call = types.FunctionCall(name="get_weather", args={"city": "London"}) + func_response = types.FunctionResponse( + name="get_weather", response={"temp": "15C"} + ) + + llm_request = LlmRequest( + model="gemma-3-1b-it", + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello!")] + ), + types.Content( + role="model", parts=[types.Part(function_call=func_call)] + ), + types.Content( + role="some_function", + parts=[types.Part(function_response=func_response)], + ), + types.Content( + role="user", parts=[types.Part.from_text(text="How are you?")] + ), + ], + ) + + gemma = Gemma() + await gemma._preprocess_request(llm_request) + + # Assertions + assert len(llm_request.contents) == 4 + + # First part: original user text + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts + assert llm_request.contents[0].parts[0].text == "Hello!" + + # Second part: function call converted to model text + assert llm_request.contents[1].role == "model" + assert llm_request.contents[1].parts + assert llm_request.contents[1].parts[0].text == func_call.model_dump_json( + exclude_none=True + ) + + # Third part: function response converted to user text + assert llm_request.contents[2].role == "user" + assert llm_request.contents[2].parts + assert ( + llm_request.contents[2].parts[0].text + == 'Invoking tool `get_weather` produced: `{"temp": "15C"}`.' + ) + + # Fourth part: original user text + assert llm_request.contents[3].role == "user" + assert llm_request.contents[3].parts + assert llm_request.contents[3].parts[0].text == "How are you?" + + +def test_process_response(): + # Simulate a response from Gemma that should be converted to a FunctionCall + json_function_call_str = ( + '{"name": "search_web", "parameters": {"query": "latest news"}}' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response=llm_response) + + # Assert that the content was transformed into a FunctionCall + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "latest news"} + # Assert that the entire part matches the expected structure + expected_function_call = types.FunctionCall( + name="search_web", args={"query": "latest news"} + ) + expected_part = Part(function_call=expected_function_call) + assert part == expected_part + assert part.text is None # Ensure text part is cleared + + +def test_process_response_invalid_json_text(): + # Simulate a response with plain text that is not JSON + original_text = "This is a regular text response." + llm_response = LlmResponse( + content=Content(role="model", parts=[Part.from_text(text=original_text)]) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response=llm_response) + + # Assert that the content remains unchanged + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].text == original_text + assert llm_response.content.parts[0].function_call is None + + +def test_process_response_malformed_json(): + # Simulate a response with valid JSON but not in the function call format + malformed_json_str = '{"not_a_function": "value", "another_field": 123}' + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=malformed_json_str)] + ) + ) + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response=llm_response) + + # Assert that the content remains unchanged because it doesn't match the expected schema + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].text == malformed_json_str + assert llm_response.content.parts[0].function_call is None + + +def test_process_response_empty_content_or_multiple_parts(): + gemma = Gemma() + + # Test case 1: LlmResponse with no content + llm_response_no_content = LlmResponse(content=None) + gemma._extract_function_calls_from_response( + llm_response=llm_response_no_content + ) + assert llm_response_no_content.content is None + + # Test case 2: LlmResponse with empty parts list + llm_response_empty_parts = LlmResponse( + content=Content(role="model", parts=[]) + ) + gemma._extract_function_calls_from_response( + llm_response=llm_response_empty_parts + ) + assert llm_response_empty_parts.content + assert not llm_response_empty_parts.content.parts + + # Test case 3: LlmResponse with multiple parts + llm_response_multiple_parts = LlmResponse( + content=Content( + role="model", + parts=[ + Part.from_text(text="part one"), + Part.from_text(text="part two"), + ], + ) + ) + original_parts = list( + llm_response_multiple_parts.content.parts + ) # Copy for comparison + gemma._extract_function_calls_from_response( + llm_response=llm_response_multiple_parts + ) + assert llm_response_multiple_parts.content + assert ( + llm_response_multiple_parts.content.parts == original_parts + ) # Should remain unchanged + + # Test case 4: LlmResponse with one part, but empty text + llm_response_empty_text_part = LlmResponse( + content=Content(role="model", parts=[Part.from_text(text="")]) + ) + gemma._extract_function_calls_from_response( + llm_response=llm_response_empty_text_part + ) + assert llm_response_empty_text_part.content + assert llm_response_empty_text_part.content.parts + assert llm_response_empty_text_part.content.parts[0].text == "" + assert llm_response_empty_text_part.content.parts[0].function_call is None + + +def test_process_response_with_markdown_json_block(): + # Simulate a response from Gemma with a JSON function call in a markdown block + json_function_call_str = """ +```json +{"name": "search_web", "parameters": {"query": "latest news"}} +```""" + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "latest news"} + assert part.text is None + + +def test_process_response_with_markdown_tool_code_block(): + # Simulate a response from Gemma with a JSON function call in a 'tool_code' markdown block + json_function_call_str = """ +Some text before. +```tool_code +{"name": "get_current_time", "parameters": {}} +``` +And some text after.""" + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=json_function_call_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "get_current_time" + assert part.function_call.args == {} + assert part.text is None + + +def test_process_response_with_embedded_json(): + # Simulate a response with valid JSON embedded in text + embedded_json_str = ( + 'Please call the tool: {"name": "search_web", "parameters": {"query":' + ' "new features"}} thanks!' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=embedded_json_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "search_web" + assert part.function_call.args == {"query": "new features"} + assert part.text is None + + +def test_process_response_flexible_parsing(): + # Test with "function" and "args" keys as supported by GemmaFunctionCallModel + flexible_json_str = '{"function": "do_something", "args": {"value": 123}}' + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=flexible_json_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "do_something" + assert part.function_call.args == {"value": 123} + assert part.text is None + + +def test_process_response_last_json_object(): + # Simulate a response with multiple JSON objects, ensuring the last valid one is picked + multiple_json_str = ( + 'I thought about {"name": "first_call", "parameters": {"a": 1}} but then' + ' decided to call: {"name": "second_call", "parameters": {"b": 2}}' + ) + llm_response = LlmResponse( + content=Content( + role="model", parts=[Part.from_text(text=multiple_json_str)] + ) + ) + + gemma = Gemma() + gemma._extract_function_calls_from_response(llm_response) + + assert llm_response.content + assert llm_response.content.parts + assert len(llm_response.content.parts) == 1 + part = llm_response.content.parts[0] + assert part.function_call is not None + assert part.function_call.name == "second_call" + assert part.function_call.args == {"b": 2} + assert part.text is None