diff --git a/python/.env.example b/python/.env.example index c09300d775..ad173125e2 100644 --- a/python/.env.example +++ b/python/.env.example @@ -29,6 +29,9 @@ COPILOTSTUDIOAGENT__AGENTAPPID="" # Anthropic ANTHROPIC_API_KEY="" ANTHROPIC_MODEL="" +# Google Gemini +GEMINI_API_KEY="" +GEMINI_MODEL="" # Ollama OLLAMA_ENDPOINT="" OLLAMA_MODEL="" diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index caeebc5319..9722d2791d 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -36,12 +36,11 @@ all = [ "agent-framework-ag-ui", "agent-framework-azure-ai-search", "agent-framework-anthropic", - "agent-framework-openai", - "agent-framework-claude", "agent-framework-azure-ai", "agent-framework-azurefunctions", "agent-framework-bedrock", "agent-framework-chatkit", + "agent-framework-claude", "agent-framework-copilotstudio", "agent-framework-declarative", "agent-framework-devui", @@ -52,6 +51,7 @@ all = [ "agent-framework-lab", "agent-framework-mem0", "agent-framework-ollama", + "agent-framework-openai", "agent-framework-orchestrations", "agent-framework-purview", "agent-framework-redis", diff --git a/python/packages/gemini/AGENTS.md b/python/packages/gemini/AGENTS.md new file mode 100644 index 0000000000..e50dc06d66 --- /dev/null +++ b/python/packages/gemini/AGENTS.md @@ -0,0 +1,28 @@ +# Gemini Package (agent-framework-gemini) + +Integration with Google's Gemini API via the `google-genai` SDK. + +## Core Classes + +- **`RawGeminiChatClient`** - Lightweight chat client without any layers, for custom pipeline composition +- **`GeminiChatClient`** - Full-featured chat client with function invocation, middleware, and telemetry +- **`GeminiChatOptions`** - Options TypedDict for Gemini-specific parameters +- **`GeminiSettings`** - Settings loaded from environment variables +- **`ThinkingConfig`** - Configuration for extended thinking + +## Gemini Options + +- **`thinking_config`** - Enable extended thinking via `ThinkingConfig` +- **`code_execution`** - Let the model write and run code in a sandboxed environment +- **`google_search_grounding`** - Responses with live Google Search results +- **`google_maps_grounding`** - Responses with Google Maps data + +## Usage + +```python +from agent_framework import Content, Message +from agent_framework_gemini import GeminiChatClient + +client = GeminiChatClient(model="gemini-2.5-flash") +response = await client.get_response([Message(role="user", contents=[Content.from_text("Hello")])]) +``` diff --git a/python/packages/gemini/LICENSE b/python/packages/gemini/LICENSE new file mode 100644 index 0000000000..9e841e7a26 --- /dev/null +++ b/python/packages/gemini/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/python/packages/gemini/README.md b/python/packages/gemini/README.md new file mode 100644 index 0000000000..807dd470f4 --- /dev/null +++ b/python/packages/gemini/README.md @@ -0,0 +1,29 @@ +# Get Started with Microsoft Agent Framework Gemini + +Install the provider package: + +```bash +pip install agent-framework-gemini --pre +``` + +## Gemini Integration + +The Gemini integration enables Microsoft Agent Framework applications to call Google Gemini models with familiar chat abstractions, including streaming, tool/function calling, and structured output. + +## Authentication + +Obtain an API key from [Google AI Studio](https://aistudio.google.com/apikey) and set it via environment variable: + +```bash +export GEMINI_API_KEY="your-api-key" +export GEMINI_CHAT_MODEL_ID="gemini-2.5-flash" +``` + +## Examples + +See the [Google Gemini samples](../../samples/02-agents/providers/google/) for runnable end-to-end scripts covering: + +- Basic agent with tool calling and streaming +- Extended thinking with `ThinkingConfig` +- Google Search grounding +- Built-in code execution diff --git a/python/packages/gemini/agent_framework_gemini/__init__.py b/python/packages/gemini/agent_framework_gemini/__init__.py new file mode 100644 index 0000000000..42099ae0b1 --- /dev/null +++ b/python/packages/gemini/agent_framework_gemini/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft. All rights reserved. + +import importlib.metadata + +from ._chat_client import GeminiChatClient, GeminiChatOptions, GeminiSettings, RawGeminiChatClient, ThinkingConfig + +try: + __version__ = importlib.metadata.version(__name__) +except importlib.metadata.PackageNotFoundError: + __version__ = "0.0.0" + +__all__ = [ + "GeminiChatClient", + "GeminiChatOptions", + "GeminiSettings", + "RawGeminiChatClient", + "ThinkingConfig", + "__version__", +] diff --git a/python/packages/gemini/agent_framework_gemini/_chat_client.py b/python/packages/gemini/agent_framework_gemini/_chat_client.py new file mode 100644 index 0000000000..d09f7556b5 --- /dev/null +++ b/python/packages/gemini/agent_framework_gemini/_chat_client.py @@ -0,0 +1,792 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import json +import logging +import sys +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence +from typing import Any, ClassVar, Generic, cast +from uuid import uuid4 + +from agent_framework import ( + AGENT_FRAMEWORK_USER_AGENT, + BaseChatClient, + ChatAndFunctionMiddlewareTypes, + ChatMiddlewareLayer, + ChatOptions, + ChatResponse, + ChatResponseUpdate, + Content, + FinishReasonLiteral, + FunctionInvocationConfiguration, + FunctionInvocationLayer, + FunctionTool, + Message, + ResponseStream, + UsageDetails, + validate_tool_mode, +) +from agent_framework._settings import SecretString, load_settings +from agent_framework.observability import ChatTelemetryLayer +from google import genai +from google.genai import types +from pydantic import BaseModel + +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore # pragma: no cover + +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + +logger = logging.getLogger("agent_framework.gemini") + +__all__ = [ + "GeminiChatClient", + "GeminiChatOptions", + "GeminiSettings", + "RawGeminiChatClient", + "ThinkingConfig", +] + +ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None) + + +# region Options & Settings + + +class ThinkingConfig(TypedDict, total=False): + """Extended thinking configuration for Gemini models. + + Attributes: + include_thoughts: Whether to include the model's reasoning thoughts in the response. + thinking_budget: Token budget for Gemini 2.5 models. Set to ``0`` to disable + thinking or ``-1`` to enable a dynamic budget. + thinking_level: Thinking level for Gemini 3.x models. One of + ``ThinkingLevel.THINKING_LEVEL_UNSPECIFIED`` (default), ``ThinkingLevel.MINIMAL``, + ``ThinkingLevel.LOW``, ``ThinkingLevel.MEDIUM``, or ``ThinkingLevel.HIGH``. + """ + + include_thoughts: bool + thinking_budget: int + thinking_level: types.ThinkingLevel + + +class GeminiChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], total=False): + """Google Gemini API-specific chat options. + + Extends ``ChatOptions`` with Gemini-specific fields. Standard options are mapped to their + ``GenerateContentConfig`` equivalents; Gemini-specific fields are declared below. + + Only text output is supported for now. Other modalities may be added later. + + See: https://ai.google.dev/api/generate-content#generationconfig + + Inherited fields from ``ChatOptions``: + model: Model to use for this call (e.g. ``"gemini-2.5-flash"``). + temperature: Controls randomness. Higher values produce more varied output. + max_tokens: Maximum number of tokens to generate (``maxOutputTokens``). + top_p: Nucleus sampling cutoff. Only tokens within the top-p probability mass are considered. + stop: One or more sequences that stop generation when encountered (``stopSequences``). + seed: Fixed seed for reproducible outputs. + frequency_penalty: Reduces repetition by penalising tokens that appear frequently. + presence_penalty: Reduces repetition by penalising tokens that have already appeared. + tools: Function tools the model may call. Accepts ``FunctionTool`` instances or plain callables. + tool_choice: How the model picks a tool. One of ``'auto'``, ``'none'``, or ``'required'``. + response_format: Pydantic model type for structured JSON output. The response text is + parsed into the model and exposed via ``ChatResponse.value``. + instructions: Extra system-level instructions prepended to the system message. + + Not supported, and passing these raises a type error: + - ``logit_bias`` + - ``allow_multiple_tool_calls`` + - ``store`` + - ``user`` + - ``metadata`` + - ``conversation_id`` + """ + + # Gemini's GenerationConfig options + response_schema: dict[str, Any] + """Raw JSON schema dict for structured output (alternative to ``response_format``). + Sets ``response_mime_type`` to ``'application/json'`` and passes the schema directly.""" + + top_k: int + """Top-K sampling: limits token selection to the K most probable tokens.""" + + thinking_config: ThinkingConfig + """Extended thinking configuration. See ``ThinkingConfig`` for available fields.""" + + # Tool options + code_execution: bool + """Allow the model to write and run code in a sandboxed environment.""" + + google_search_grounding: bool | types.GoogleSearch + """Ground responses with live Google Search results. Pass ``True`` to use default settings, + or a ``types.GoogleSearch`` instance for full control (e.g. ``time_range_filter``, + ``search_types``, ``exclude_domains``).""" + + google_maps_grounding: bool | types.GoogleMaps + """Ground responses with Google Maps data. Pass ``True`` to use default settings, + or a ``types.GoogleMaps`` instance for full control (e.g. ``enable_widget``).""" + + # Unsupported base options. Override with None to indicate not supported + logit_bias: None # type: ignore[misc] + """Not supported in the Gemini API.""" + + allow_multiple_tool_calls: None # type: ignore[misc] + """Not supported. Gemini handles parallel tool calls automatically.""" + + store: None # type: ignore[misc] + """Not supported in the Gemini API.""" + + user: None # type: ignore[misc] + """Not supported in the Gemini API.""" + + metadata: None # type: ignore[misc] + """Not supported in the Gemini API.""" + + conversation_id: None # type: ignore[misc] + """Not supported in the Gemini API.""" + + +GeminiChatOptionsT = TypeVar("GeminiChatOptionsT", bound=TypedDict, default="GeminiChatOptions", covariant=True) # type: ignore[valid-type] + + +class GeminiSettings(TypedDict, total=False): + """Gemini configuration settings loaded from environment or .env files.""" + + api_key: SecretString | None + model: str | None + + +# endregion + + +_GEMINI_SERVICE_URL = "https://generativelanguage.googleapis.com" + +_FINISH_REASON_MAP: dict[str, FinishReasonLiteral] = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "RECITATION": "content_filter", + "LANGUAGE": "content_filter", + "BLOCKLIST": "content_filter", + "PROHIBITED_CONTENT": "content_filter", + "SPII": "content_filter", + "IMAGE_SAFETY": "content_filter", + "IMAGE_PROHIBITED_CONTENT": "content_filter", + "IMAGE_RECITATION": "content_filter", + "MALFORMED_FUNCTION_CALL": "tool_calls", + "UNEXPECTED_TOOL_CALL": "tool_calls", +} + + +class RawGeminiChatClient( + BaseChatClient[GeminiChatOptionsT], + Generic[GeminiChatOptionsT], +): + """A raw Gemini chat client for the Google Gemini API without function invocation, middleware or telemetry. + + Use this when you want full control over the request pipeline. For instance, to opt out of + telemetry, use custom middleware, or compose your own layers. If you want the full-featured + client with batteries included, use `GeminiChatClient` instead. + """ + + OTEL_PROVIDER_NAME: ClassVar[str] = "gcp.gemini" # type: ignore[reportIncompatibleVariableOverride, misc] + + def __init__( + self, + *, + api_key: str | None = None, + model: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + client: genai.Client | None = None, + additional_properties: dict[str, Any] | None = None, + ) -> None: + """Create a raw Gemini chat client. + + Args: + api_key: Google AI Studio API key. Falls back to ``GEMINI_API_KEY`` environment variable. + model: Default model identifier. Falls back to ``GEMINI_MODEL`` environment variable. + env_file_path: Path to a ``.env`` file for credential loading. + env_file_encoding: Encoding for the ``.env`` file. + client: Pre-built ``genai.Client`` instance. When provided, ``api_key`` is not required. + additional_properties: Extra properties stored on the client instance. + """ + settings = load_settings( + GeminiSettings, + env_prefix="GEMINI_", + api_key=api_key, + model=model, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + + if client: + self._genai_client = client + else: + resolved_key = settings.get("api_key") + if not resolved_key: + raise ValueError( + "Gemini API key is required. Set via api_key parameter or GEMINI_API_KEY environment variable." + ) + self._genai_client = genai.Client( + api_key=resolved_key.get_secret_value(), + http_options={"headers": {"x-goog-api-client": AGENT_FRAMEWORK_USER_AGENT}}, + ) + + self.model = settings.get("model") + + super().__init__(additional_properties=additional_properties) + + @override + def _inner_get_response( + self, + *, + messages: Sequence[Message], + options: Mapping[str, Any], + stream: bool = False, + **kwargs: Any, + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + if stream: + + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + validated = await self._validate_options(options) + model, contents, config = self._prepare_request(messages, validated) + async for chunk in await self._genai_client.aio.models.generate_content_stream( + model=model, + contents=contents, # type: ignore[arg-type] + config=config, + ): + yield self._process_chunk(chunk) + + return self._build_response_stream(_stream(), response_format=options.get("response_format")) + + async def _get_response() -> ChatResponse: + validated = await self._validate_options(options) + model, contents, config = self._prepare_request(messages, validated) + raw = await self._genai_client.aio.models.generate_content(model=model, contents=contents, config=config) # type: ignore[arg-type] + return self._process_generate_response(raw, response_format=validated.get("response_format")) + + return _get_response() + + @override + def service_url(self) -> str: + """Return the base URL of the Gemini API service. + + Returns: + The Gemini API base URL. + """ + return _GEMINI_SERVICE_URL + + # region Request preparation + + def _prepare_request( + self, + messages: Sequence[Message], + options: Mapping[str, Any], + ) -> tuple[str, list[types.Content], types.GenerateContentConfig]: + """Resolve the model ID, convert messages to Gemini contents, and build the generation config. + + Call this after awaiting ``_validate_options`` so that tools and other options are + fully normalized before the request is assembled. + + Args: + messages: The conversation history as framework Message objects. + options: Validated and normalized chat options. + + Returns: + A tuple of the resolved model, the Gemini contents list, and the generation config. + + Raises: + ValueError: If no model is set on the options or the client instance. + """ + model = options.get("model") or self.model + if not model: + raise ValueError("Gemini model is required. Set via model parameter or GEMINI_MODEL environment variable.") + + system_instruction, contents = self._prepare_gemini_messages(messages) + if call_instructions := options.get("instructions"): + system_instruction = ( + f"{call_instructions}\n{system_instruction}" if system_instruction else call_instructions + ) + + return model, contents, self._prepare_config(options, system_instruction) + + def _prepare_gemini_messages(self, messages: Sequence[Message]) -> tuple[str | None, list[types.Content]]: + """Convert framework messages to Gemini contents and extract system instruction. + + Args: + messages: The full conversation history as framework Message objects. + + Returns: + A tuple of (system_instruction_text, contents_list). System messages are extracted + into the instruction string; tool results are grouped into user-role content blocks. + """ + system_parts: list[str] = [] + contents: list[types.Content] = [] + # Maps call_id to function name so function_result parts can include the required name field. + call_id_to_name: dict[str, str] = {} + # Accumulated functionResponse parts from consecutive tool messages. + pending_tool_parts: list[types.Part] = [] + + def flush_pending_tool_parts() -> None: + if pending_tool_parts: + contents.append(types.Content(role="user", parts=list(pending_tool_parts))) + pending_tool_parts.clear() + + for message in messages: + if message.role == "system": + if message.text: + system_parts.append(message.text) + continue + + if message.role == "tool": + for content in message.contents: + part = self._convert_function_result(content, call_id_to_name) + if part is not None: + pending_tool_parts.append(part) + continue + + # Non-tool message — flush any accumulated tool parts first. + flush_pending_tool_parts() + + parts = self._convert_message_contents(message.contents, call_id_to_name) + if not parts: + continue + + role = "model" if message.role == "assistant" else "user" + contents.append(types.Content(role=role, parts=parts)) + + flush_pending_tool_parts() + + system_instruction = "\n".join(system_parts) if system_parts else None + return system_instruction, contents + + def _convert_message_contents( + self, + message_contents: Sequence[Content], + call_id_to_name: dict[str, str], + ) -> list[types.Part]: + """Convert framework Content objects to Gemini Part objects, tracking function call IDs. + + Args: + message_contents: The content items of a single framework message. + call_id_to_name: Mutable mapping updated with any function call ID-to-name pairs found. + + Returns: + A list of Gemini Part objects representing the message contents. + """ + parts: list[types.Part] = [] + for content in message_contents: + match content.type: + case "text": + parts.append(types.Part(text=content.text or "")) + case "function_call": + call_id = content.call_id or self._generate_tool_call_id() + if content.name: + call_id_to_name[call_id] = content.name + parts.append( + types.Part( + function_call=types.FunctionCall( + id=call_id, + name=content.name or "", + args=content.parse_arguments() or {}, + ) + ) + ) + case _: + logger.debug("Skipping unsupported content type for Gemini: %s", content.type) + return parts + + def _convert_function_result( + self, + content: Content, + call_id_to_name: dict[str, str], + ) -> types.Part | None: + """Convert a function_result Content to a Gemini FunctionResponse Part. + + Args: + content: The framework Content object, expected to be of type ``function_result``. + call_id_to_name: Mapping of call IDs to function names, used to resolve the required name field. + + Returns: + A Gemini Part containing a FunctionResponse, or None if the content type is not + ``function_result`` or the call ID cannot be resolved. + """ + if content.type != "function_result": + return None + + name = call_id_to_name.get(content.call_id or "") + if not name: + logger.warning( + "Skipping function_result: no matching function_call found for call_id=%r", + content.call_id, + ) + return None + + response = self._coerce_to_dict(content.result) + return types.Part( + function_response=types.FunctionResponse( + id=content.call_id, + name=name, + response=response, + ) + ) + + @staticmethod + def _coerce_to_dict(value: Any) -> dict[str, Any]: + """Ensure a tool result value is a dict as required by Gemini's FunctionResponse. + + Args: + value: The raw tool result. May be a dict, JSON string, plain string, None, or any other value. + + Returns: + A dict representation of the value. JSON strings are parsed; all other non-dict values + are wrapped as ``{"result": }``. + """ + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return cast(dict[str, Any], parsed) + except (json.JSONDecodeError, ValueError): + pass + return {"result": value} + if value is None: + return {"result": ""} + return {"result": str(value)} + + def _prepare_config( + self, + options: Mapping[str, Any], + system_instruction: str | None, + ) -> types.GenerateContentConfig: + """Build a ``types.GenerateContentConfig`` from ``ChatOptions``. + + Args: + options: Resolved chat options mapping, typically a ``GeminiChatOptions`` dict. + system_instruction: Combined system instruction text, or None if absent. + + Returns: + A fully populated ``GenerateContentConfig`` ready to pass to the Gemini API. + """ + kwargs: dict[str, Any] = {} + + # Base ChatOptions fields + if system_instruction: + kwargs["system_instruction"] = system_instruction + if (v := options.get("temperature")) is not None: + kwargs["temperature"] = v + if (v := options.get("top_p")) is not None: + kwargs["top_p"] = v + if (v := options.get("max_tokens")) is not None: + kwargs["max_output_tokens"] = v + if (v := options.get("stop")) is not None: + kwargs["stop_sequences"] = v + if (v := options.get("seed")) is not None: + kwargs["seed"] = v + if (v := options.get("frequency_penalty")) is not None: + kwargs["frequency_penalty"] = v + if (v := options.get("presence_penalty")) is not None: + kwargs["presence_penalty"] = v + if options.get("response_format"): + kwargs["response_mime_type"] = "application/json" + if tools := self._prepare_tools(options): + kwargs["tools"] = tools + if tool_config := self._prepare_tool_config(options.get("tool_choice")): + kwargs["tool_config"] = tool_config + # Gemini-specific fields + if schema := options.get("response_schema"): + kwargs["response_mime_type"] = "application/json" + kwargs["response_schema"] = schema + if (v := options.get("top_k")) is not None: + kwargs["top_k"] = v + if thinking_config := options.get("thinking_config"): + thinking_config_kwargs = {k: v for k, v in thinking_config.items() if v is not None} + if thinking_config_kwargs: + kwargs["thinking_config"] = types.ThinkingConfig(**thinking_config_kwargs) + + return types.GenerateContentConfig(**kwargs) + + def _prepare_tools(self, options: Mapping[str, Any]) -> list[types.Tool] | None: + """Build the Gemini tool list from options, combining function declarations and built-in tools. + + Args: + options: Resolved chat options containing ``tools``, ``google_search_grounding`` + (``bool`` or ``types.GoogleSearch``), ``google_maps_grounding`` + (``bool`` or ``types.GoogleMaps``), and ``code_execution`` flag. + + Returns: + A list of ``types.Tool`` objects, or None if no tools are configured. + """ + function_tools: list[Any] = options.get("tools") or [] + search_option = options.get("google_search_grounding", False) + maps_option = options.get("google_maps_grounding", False) + include_code_exec = options.get("code_execution", False) + + result: list[types.Tool] = [] + + declarations = [ + types.FunctionDeclaration( + name=tool.name, + description=tool.description or "", + parameters=tool.parameters(), # type: ignore[arg-type] + ) + for tool in function_tools + if isinstance(tool, FunctionTool) + ] + if declarations: + result.append(types.Tool(function_declarations=declarations)) + if search_option: + google_search = search_option if isinstance(search_option, types.GoogleSearch) else types.GoogleSearch() + result.append(types.Tool(google_search=google_search)) + if maps_option: + google_maps = maps_option if isinstance(maps_option, types.GoogleMaps) else types.GoogleMaps() + result.append(types.Tool(google_maps=google_maps)) + if include_code_exec: + result.append(types.Tool(code_execution=types.ToolCodeExecution())) + + return result or None + + def _prepare_tool_config(self, tool_choice: Any) -> types.ToolConfig | None: + """Build a Gemini ``ToolConfig`` from the framework tool_choice value. + + Args: + tool_choice: Raw tool_choice value from options (string, dict, or None). + + Returns: + A ``types.ToolConfig`` with the appropriate ``FunctionCallingConfig``, or None + if no tool_choice is set or the mode is unsupported. + """ + tool_mode = validate_tool_mode(tool_choice) + if not tool_mode: + return None + + match tool_mode.get("mode"): + case "auto": + function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None + case "none": + function_calling_mode, allowed_names = types.FunctionCallingConfigMode.NONE, None + case "required": + function_calling_mode = types.FunctionCallingConfigMode.ANY + name = tool_mode.get("required_function_name") + allowed_names = [name] if name else None + case unknown_mode: + logger.warning("Unsupported tool_choice mode for Gemini: %s", unknown_mode) + return None + + function_calling_kwargs: dict[str, Any] = {"mode": function_calling_mode} + if allowed_names: + function_calling_kwargs["allowed_function_names"] = allowed_names + + return types.ToolConfig(function_calling_config=types.FunctionCallingConfig(**function_calling_kwargs)) + + # endregion + + # region Response parsing + + def _process_generate_response( + self, + response: types.GenerateContentResponse, + *, + response_format: type[BaseModel] | None = None, + ) -> ChatResponse: + """Convert a Gemini generate_content response to a framework ChatResponse. + + Args: + response: The raw ``GenerateContentResponse`` from the Gemini API. + response_format: Optional Pydantic model type for structured output parsing. + When provided, the response text is parsed into the given model and + made available via ``ChatResponse.value``. + + Returns: + A ``ChatResponse`` with parsed messages, usage details, finish reason, and model ID. + """ + candidate = response.candidates[0] if response.candidates else None + parts: list[types.Part] = (candidate.content.parts or []) if candidate and candidate.content else [] + contents = self._parse_parts(parts) + return ChatResponse( + response_id=None, + messages=[Message(role="assistant", contents=contents, raw_representation=candidate)], + usage_details=self._parse_usage(response.usage_metadata), + model_id=response.model_version or self.model, + finish_reason=self._map_finish_reason( + candidate.finish_reason.name if candidate and candidate.finish_reason else None + ), + response_format=response_format, + raw_representation=response, + ) + + def _process_chunk(self, chunk: types.GenerateContentResponse) -> ChatResponseUpdate: + """Convert a single streaming chunk to a framework ChatResponseUpdate. + + Usage details are attached only to the final chunk, identified by a non-None finish reason. + + Args: + chunk: A streaming ``GenerateContentResponse`` chunk from the Gemini API. + + Returns: + A ``ChatResponseUpdate`` with parsed contents, finish reason, and model ID. + """ + candidate = chunk.candidates[0] if chunk.candidates else None + parts: list[types.Part] = (candidate.content.parts or []) if candidate and candidate.content else [] + contents = self._parse_parts(parts) + + finish_reason = self._map_finish_reason( + candidate.finish_reason.name if candidate and candidate.finish_reason else None + ) + + # Attach usage to the final chunk only (when finish_reason is set). + if finish_reason and (usage := self._parse_usage(chunk.usage_metadata)): + contents.append(Content.from_usage(usage_details=usage)) + + return ChatResponseUpdate( + contents=contents, + model_id=chunk.model_version, + finish_reason=finish_reason, + raw_representation=chunk, + ) + + def _parse_parts(self, parts: Sequence[types.Part]) -> list[Content]: + """Convert Gemini response parts to framework Content objects, skipping thought/reasoning parts. + + Args: + parts: Sequence of ``types.Part`` objects from a Gemini response candidate. + + Returns: + A list of framework ``Content`` objects (text, function_call, or function_result). + """ + contents: list[Content] = [] + for part in parts: + if part.thought: + continue + if part.text is not None: + contents.append(Content.from_text(text=part.text, raw_representation=part)) + elif part.function_call is not None: + function_call = part.function_call + if function_call.id: + call_id = function_call.id + else: + call_id = self._generate_tool_call_id() + logger.debug("function_call missing id; generated fallback call_id=%r", call_id) + contents.append( + Content.from_function_call( + call_id=call_id, + name=function_call.name or "", + arguments=function_call.args or {}, + raw_representation=part, + ) + ) + elif part.function_response is not None: + function_response = part.function_response + contents.append( + Content.from_function_result( + call_id=function_response.id or self._generate_tool_call_id(), + result=function_response.response, + raw_representation=part, + ) + ) + return contents + + def _parse_usage(self, usage: types.GenerateContentResponseUsageMetadata | None) -> UsageDetails | None: + """Extract token usage counts from Gemini usage metadata. + + Args: + usage: The ``GenerateContentResponseUsageMetadata`` from the API response, or None. + + Returns: + A ``UsageDetails`` dict with available token counts, or None if no usage data is present. + """ + if not usage: + return None + details: UsageDetails = {} + if (v := usage.prompt_token_count) is not None: + details["input_token_count"] = v + if (v := usage.candidates_token_count) is not None: + details["output_token_count"] = v + if (v := usage.total_token_count) is not None: + details["total_token_count"] = v + return details or None + + def _map_finish_reason(self, reason: str | None) -> FinishReasonLiteral | None: + """Map a Gemini finish reason string to the framework's FinishReasonLiteral. + + Args: + reason: The finish reason name from the Gemini API (e.g. ``"STOP"``), or None. + + Returns: + The corresponding ``FinishReasonLiteral``, or None if the reason is absent or unmapped. + """ + if not reason: + return None + return _FINISH_REASON_MAP.get(reason) + + # endregion + + @staticmethod + def _generate_tool_call_id() -> str: + """Generate a unique fallback ID for tool calls that lack one. + + Returns: + A unique string in the format ``tool-call-``. + """ + return f"tool-call-{uuid4().hex}" + + +class GeminiChatClient( + FunctionInvocationLayer[GeminiChatOptionsT], + ChatMiddlewareLayer[GeminiChatOptionsT], + ChatTelemetryLayer[GeminiChatOptionsT], + RawGeminiChatClient[GeminiChatOptionsT], + Generic[GeminiChatOptionsT], +): + """Gemini chat client for the Google Gemini API with function invocation, middleware and telemetry.""" + + def __init__( + self, + *, + api_key: str | None = None, + model: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + client: genai.Client | None = None, + additional_properties: dict[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + function_invocation_configuration: FunctionInvocationConfiguration | None = None, + ) -> None: + """Create a Gemini chat client. + + Args: + api_key: The Google AI Studio API key. Falls back to ``GEMINI_API_KEY`` environment variable. + model: Default model identifier. Falls back to ``GEMINI_MODEL`` environment variable. + env_file_path: Path to a ``.env`` file for credential loading. + env_file_encoding: Encoding for the ``.env`` file. + client: Pre-built ``genai.Client`` instance. When provided, ``api_key`` is not required. + additional_properties: Extra properties stored on the client instance. + middleware: Optional middleware chain applied to every call. + function_invocation_configuration: Optional configuration for the function invocation loop. + """ + super().__init__( + api_key=api_key, + model=model, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + client=client, + additional_properties=additional_properties, + middleware=middleware, + function_invocation_configuration=function_invocation_configuration, + ) diff --git a/python/packages/gemini/agent_framework_gemini/py.typed b/python/packages/gemini/agent_framework_gemini/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/gemini/pyproject.toml b/python/packages/gemini/pyproject.toml new file mode 100644 index 0000000000..b8fb764c48 --- /dev/null +++ b/python/packages/gemini/pyproject.toml @@ -0,0 +1,102 @@ +[project] +name = "agent-framework-gemini" +description = "Google Gemini integration for Microsoft Agent Framework." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.10" +version = "1.0.0b260319" +license-files = ["LICENSE"] +urls.homepage = "https://aka.ms/agent-framework" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Framework :: Pydantic :: 2", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core>=1.0.0rc5", + "google-genai>=1.0.0,<2.0.0", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" + +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [] +markers = [ + "integration: marks tests as integration tests that require external services", + "flaky: marks tests as flaky and eligible for automatic retry", +] +timeout = 120 + +[tool.ruff] +extend = "../../pyproject.toml" + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extends = "../../pyproject.toml" +include = ["agent_framework_gemini"] +exclude = ['tests'] + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true + +[tool.bandit] +targets = ["agent_framework_gemini"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" + +[tool.poe.tasks.mypy] +help = "Run MyPy for this package." +cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_gemini" + +[tool.poe.tasks.test] +help = "Run the default unit test suite for this package." +cmd = 'pytest -m "not integration" --cov=agent_framework_gemini --cov-report=term-missing:skip-covered tests' + +[tool.uv.build-backend] +module-name = "agent_framework_gemini" +module-root = "" + +[build-system] +requires = ["uv_build>=0.8.2,<0.9.0"] +build-backend = "uv_build" diff --git a/python/packages/gemini/tests/test_gemini_client.py b/python/packages/gemini/tests/test_gemini_client.py new file mode 100644 index 0000000000..07c511232e --- /dev/null +++ b/python/packages/gemini/tests/test_gemini_client.py @@ -0,0 +1,1253 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import logging +import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import Content, FunctionTool, Message +from google.genai import types +from pydantic import BaseModel + +from agent_framework_gemini import GeminiChatClient, GeminiChatOptions, ThinkingConfig + +skip_if_no_api_key = pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY"), + reason="GEMINI_API_KEY not set; skipping integration tests.", +) + +_TEST_MODEL = "gemini-2.5-flash" + +# stub helpers + + +def _make_part( + *, + text: str | None = None, + thought: bool = False, + function_call: tuple[str, str, dict[str, Any]] | None = None, +) -> MagicMock: + """Build a mock types.Part. + + Args: + text: Text content of the part. + thought: Whether this is a thinking/reasoning part. + function_call: Tuple of (id, name, args) if this is a function call part. + """ + part = MagicMock() + part.text = text + part.thought = thought + part.function_response = None + + if function_call: + mock_function_call = MagicMock() + mock_function_call.id, mock_function_call.name, mock_function_call.args = function_call + part.function_call = mock_function_call + else: + part.function_call = None + + return part + + +def _make_response( + parts: list[MagicMock], + *, + finish_reason: str | None = "STOP", + model_version: str = "gemini-2.5-flash-001", + prompt_tokens: int | None = 10, + output_tokens: int | None = 5, + total_tokens: int | None = 15, +) -> MagicMock: + """Build a mock types.GenerateContentResponse.""" + response = MagicMock() + candidate = MagicMock() + candidate.content.parts = parts + + if finish_reason: + candidate.finish_reason.name = finish_reason + else: + candidate.finish_reason = None + + response.candidates = [candidate] + response.model_version = model_version + + if prompt_tokens is not None or output_tokens is not None: + usage = MagicMock() + usage.prompt_token_count = prompt_tokens + usage.candidates_token_count = output_tokens + usage.total_token_count = total_tokens + response.usage_metadata = usage + else: + response.usage_metadata = None + + return response + + +async def _async_iter(items: list[Any]): + """Async generator used to simulate generate_content_stream results.""" + for item in items: + yield item + + +def _make_gemini_client( + model: str = "gemini-2.5-flash", + mock_client: MagicMock | None = None, +) -> tuple[GeminiChatClient, MagicMock]: + """Return a (GeminiChatClient, mock_genai_client) pair.""" + mock = mock_client or MagicMock() + client = GeminiChatClient(client=mock, model=model) + return client, mock + + +# settings & initialisation + + +def test_model_stored_on_instance() -> None: + client, _ = _make_gemini_client(model="gemini-2.5-pro") + assert client.model == "gemini-2.5-pro" + + +def test_client_created_from_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("GEMINI_API_KEY", "test-key-123") + client = GeminiChatClient(model="gemini-2.5-flash") + assert client.model == "gemini-2.5-flash" + + +def test_missing_api_key_raises_when_no_client_injected(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GEMINI_MODEL", raising=False) + + with pytest.raises(ValueError, match="GEMINI_API_KEY"): + GeminiChatClient(model="gemini-2.5-flash") + + +async def test_missing_model_raises_on_get_response() -> None: + client, mock = _make_gemini_client(model=None) # type: ignore[arg-type] + mock.aio.models.generate_content = AsyncMock() + + with pytest.raises(ValueError, match="model"): + await client.get_response(messages=[Message(role="user", contents=[Content.from_text("hi")])]) + + +# text response + + +async def test_get_response_returns_text() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hello!")])) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert response.messages[0].text == "Hello!" + + +async def test_get_response_model_id_from_response() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response([_make_part(text="Hi")], model_version="gemini-2.5-pro-002") + ) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert response.model_id == "gemini-2.5-pro-002" + + +async def test_get_response_uses_model_from_options() -> None: + client, mock = _make_gemini_client(model="gemini-2.5-flash") + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"model": "gemini-2.5-pro"}, + ) + + call_kwargs = mock.aio.models.generate_content.call_args.kwargs + assert call_kwargs["model"] == "gemini-2.5-pro" + + +async def test_get_response_usage_details() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response( + [_make_part(text="Hi")], + prompt_tokens=20, + output_tokens=8, + total_tokens=28, + ) + ) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert response.usage_details is not None + assert response.usage_details["input_token_count"] == 20 + assert response.usage_details["output_token_count"] == 8 + assert response.usage_details["total_token_count"] == 28 + + +async def test_get_response_no_usage_when_metadata_absent() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response([_make_part(text="Hi")], prompt_tokens=None, output_tokens=None) + ) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert not response.usage_details + + +# finish reasons + + +@pytest.mark.parametrize( + ("gemini_reason", "expected"), + [ + ("STOP", "stop"), + ("MAX_TOKENS", "length"), + ("SAFETY", "content_filter"), + ("RECITATION", "content_filter"), + ("BLOCKLIST", "content_filter"), + ("PROHIBITED_CONTENT", "content_filter"), + ("SPII", "content_filter"), + ("MALFORMED_FUNCTION_CALL", "tool_calls"), + ("OTHER", None), + ], +) +async def test_finish_reason_mapping(gemini_reason: str, expected: str | None) -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response([_make_part(text="Hi")], finish_reason=gemini_reason) + ) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert response.finish_reason == expected + + +# message conversion + + +async def test_system_message_extracted_to_system_instruction() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[ + Message(role="system", contents=[Content.from_text("You are concise.")]), + Message(role="user", contents=[Content.from_text("Hi")]), + ] + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.system_instruction == "You are concise." + + +async def test_multiple_system_messages_concatenated() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[ + Message(role="system", contents=[Content.from_text("Be concise.")]), + Message(role="system", contents=[Content.from_text("Use bullet points.")]), + Message(role="user", contents=[Content.from_text("Hi")]), + ] + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert "Be concise." in config.system_instruction + assert "Use bullet points." in config.system_instruction + + +async def test_instructions_option_merged_with_system_instruction() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[ + Message(role="system", contents=[Content.from_text("Be concise.")]), + Message(role="user", contents=[Content.from_text("Hi")]), + ], + options={"instructions": "Always respond in French."}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert "Always respond in French." in config.system_instruction + assert "Be concise." in config.system_instruction + + +async def test_instructions_option_without_system_message() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"instructions": "Be helpful."}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.system_instruction == "Be helpful." + + +async def test_assistant_role_mapped_to_model() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Sure")])) + + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_text("Hello")]), + Message(role="assistant", contents=[Content.from_text("Hi there")]), + Message(role="user", contents=[Content.from_text("Follow up")]), + ] + ) + + contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] + roles = [c.role for c in contents] + assert roles == ["user", "model", "user"] + + +async def test_tool_messages_collapsed_into_single_user_message() -> None: + """Consecutive tool messages must be collapsed into one role='user' message + with multiple functionResponse parts (parallel tool call pattern).""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_text("Run both")]), + Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="c1", name="tool_a", arguments={}), + Content.from_function_call(call_id="c2", name="tool_b", arguments={}), + ], + ), + Message(role="tool", contents=[Content.from_function_result(call_id="c1", result="res_a")]), + Message(role="tool", contents=[Content.from_function_result(call_id="c2", result="res_b")]), + ] + ) + + contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] + # user, model (with 2 function calls), user (with 2 function responses) + assert contents[-1].role == "user" + assert len(contents[-1].parts) == 2 + + +async def test_function_result_name_resolved_from_call_history() -> None: + """function_result name must come from the matching function_call in history.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_text("Go")]), + Message( + role="assistant", + contents=[Content.from_function_call(call_id="call-42", name="get_weather", arguments={})], + ), + Message(role="tool", contents=[Content.from_function_result(call_id="call-42", result="sunny")]), + ] + ) + + contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] + tool_user_msg = contents[-1] + assert tool_user_msg.role == "user" + function_response = tool_user_msg.parts[0].function_response + assert function_response.name == "get_weather" + assert function_response.id == "call-42" + + +async def test_function_result_resolved_when_call_id_was_generated() -> None: + """When a function_call has no call_id and a fallback is generated, the subsequent + function_result referencing that generated ID must still resolve the function name.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + generated_id = "tool-call-generated-123" + with patch.object(client, "_generate_tool_call_id", return_value=generated_id): + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_text("Go")]), + Message( + role="assistant", + contents=[Content.from_function_call(call_id=None, name="get_weather", arguments={})], # type: ignore[arg-type] + ), + Message( + role="tool", + contents=[Content.from_function_result(call_id=generated_id, result="sunny")], + ), + ] + ) + + contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] + tool_turn = next(c for c in contents if c.role == "user" and any(p.function_response for p in c.parts)) + assert tool_turn.parts[0].function_response.name == "get_weather" + assert tool_turn.parts[0].function_response.id == generated_id + + +async def test_function_result_without_matching_call_is_skipped(caplog: pytest.LogCaptureFixture) -> None: + """A function_result with no prior function_call in history should be skipped with a warning.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + with caplog.at_level(logging.WARNING, logger="agent_framework.gemini"): + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_text("Go")]), + Message( + role="tool", + contents=[Content.from_function_result(call_id="unknown-id", result="oops")], + ), + Message(role="user", contents=[Content.from_text("What happened?")]), + ] + ) + + assert any("unknown-id" in r.message or "function_result" in r.message.lower() for r in caplog.records) + + +async def test_message_with_only_unsupported_content_type_is_skipped() -> None: + """A user message whose contents produce no convertible parts is dropped from the request.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_function_result(call_id="x", result="y")]), + Message(role="user", contents=[Content.from_text("Follow up")]), + ] + ) + + contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] + assert len(contents) == 1 + assert contents[0].parts[0].text == "Follow up" + + +async def test_non_function_result_content_in_tool_message_is_skipped() -> None: + """Unexpected content types inside a tool message are silently ignored.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + await client.get_response( + messages=[ + Message(role="user", contents=[Content.from_text("Hi")]), + Message(role="tool", contents=[Content.from_text("unexpected")]), + ] + ) + + contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] + assert len(contents) == 1 + + +# thinking parts + + +async def test_thinking_parts_are_silently_skipped() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response([ + _make_part(text="I should think first...", thought=True), + _make_part(text="The answer is 42."), + ]) + ) + + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("What is the answer?")])] + ) + + assert len(response.messages[0].contents) == 1 + assert response.messages[0].text == "The answer is 42." + + +# generation config options + + +async def test_prepare_config_temperature() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"temperature": 0.3}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.temperature == 0.3 + + +async def test_prepare_config_max_tokens() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"max_tokens": 512}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.max_output_tokens == 512 + + +async def test_prepare_config_top_p_and_top_k() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"top_p": 0.9, "top_k": 40}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.top_p == 0.9 + assert config.top_k == 40 + + +async def test_prepare_config_stop_sequences() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"stop": ["END", "STOP"]}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.stop_sequences == ["END", "STOP"] + + +async def test_prepare_config_seed() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"seed": 42}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.seed == 42 + + +async def test_prepare_config_frequency_and_presence_penalty() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"frequency_penalty": 0.5, "presence_penalty": 0.2}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.frequency_penalty == 0.5 + assert config.presence_penalty == 0.2 + + +# thinking config + + +async def test_thinking_config_budget() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + tc: ThinkingConfig = {"thinking_budget": 1024} + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"thinking_config": tc}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert isinstance(config.thinking_config, types.ThinkingConfig) + assert config.thinking_config.thinking_budget == 1024 + + +async def test_thinking_config_level() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + tc: ThinkingConfig = {"thinking_level": types.ThinkingLevel.HIGH} + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"thinking_config": tc}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert isinstance(config.thinking_config, types.ThinkingConfig) + assert config.thinking_config.thinking_level == types.ThinkingLevel.HIGH + + +# structured output + + +async def test_response_format_sets_json_mime_type() -> None: + from pydantic import BaseModel + + class Reply(BaseModel): + text: str + + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="{}")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"response_format": Reply}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.response_mime_type == "application/json" + + +async def test_response_format_populates_value_on_chat_response() -> None: + """When response_format is a Pydantic model, ChatResponse.value must be parsed from the response text.""" + from pydantic import BaseModel + + class Reply(BaseModel): + text: str + + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text='{"text": "hello"}')])) + + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"response_format": Reply}, + ) + + assert response.value == Reply(text="hello") + + +async def test_response_schema_added_to_config() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="{}")])) + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"response_schema": schema}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.response_mime_type == "application/json" + assert config.response_schema == schema + + +async def test_streaming_response_format_passed_to_build_response_stream() -> None: + """Verifies that response_format is forwarded to _build_response_stream when streaming + so that structured output parsing works correctly on the final assembled response.""" + from unittest.mock import patch + + from pydantic import BaseModel + + class Reply(BaseModel): + text: str + + client, mock = _make_gemini_client() + chunks = [_make_response([_make_part(text='{"text": "hello"}')], finish_reason="STOP")] + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks)) + + with patch.object(client, "_build_response_stream", wraps=client._build_response_stream) as spy: + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"response_format": Reply}, + stream=True, + ) + async for _ in stream: + pass + + _, kwargs = spy.call_args + assert kwargs.get("response_format") is Reply + + +# tool calling + + +async def test_function_call_in_response_mapped_to_content() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response([_make_part(function_call=("call-1", "get_weather", {"city": "Berlin"}))]) + ) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Weather?")])]) + + fc = response.messages[0].contents[0] + assert fc.type == "function_call" + assert fc.name == "get_weather" + assert fc.call_id == "call-1" + + +async def test_function_call_missing_id_gets_fallback() -> None: + """Older Gemini models may omit function_call.id — a UUID fallback must be generated.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock( + return_value=_make_response([ + _make_part(function_call=(None, "search", {"q": "test"})) # id is None + ]) + ) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Search")])]) + + fc = response.messages[0].contents[0] + assert fc.call_id is not None + assert len(fc.call_id) > 0 + + +async def test_function_tool_converted_to_function_declaration() -> None: + def get_weather(city: str) -> str: + """Get the weather for a city.""" + return "sunny" + + tool = FunctionTool(name="get_weather", func=get_weather) + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Weather?")])], + options={"tools": [tool]}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + assert len(config.tools) == 1 + function_declaration = config.tools[0].function_declarations[0] + assert function_declaration.name == "get_weather" + + +async def test_callable_tool_resolved_via_validate_options() -> None: + """Raw callables passed as tools must be normalized by _validate_options into FunctionTools + and reach the Gemini config as function declarations.""" + + def get_weather(city: str) -> str: + """Get the weather for a city.""" + return "sunny" + + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Weather?")])], + options={"tools": [get_weather]}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + function_declaration = config.tools[0].function_declarations[0] + assert function_declaration.name == "get_weather" + + +# _coerce_to_dict + + +def test_coerce_to_dict_with_dict_input() -> None: + assert GeminiChatClient._coerce_to_dict({"key": "value"}) == {"key": "value"} + + +def test_coerce_to_dict_with_json_string() -> None: + assert GeminiChatClient._coerce_to_dict('{"key": "value"}') == {"key": "value"} + + +def test_coerce_to_dict_with_plain_string() -> None: + assert GeminiChatClient._coerce_to_dict("some text") == {"result": "some text"} + + +def test_coerce_to_dict_with_none() -> None: + assert GeminiChatClient._coerce_to_dict(None) == {"result": ""} + + +def test_coerce_to_dict_with_numeric_value() -> None: + assert GeminiChatClient._coerce_to_dict(42) == {"result": "42"} + + +def test_coerce_to_dict_with_json_array_string() -> None: + assert GeminiChatClient._coerce_to_dict("[1, 2, 3]") == {"result": "[1, 2, 3]"} + + +def test_coerce_to_dict_with_json_string_literal() -> None: + assert GeminiChatClient._coerce_to_dict('"hello"') == {"result": '"hello"'} + + +# tool choice + + +def _get_function_calling_mode(config: types.GenerateContentConfig) -> str: + return config.tool_config.function_calling_config.mode + + +def _make_dummy_tool() -> FunctionTool: + def dummy(x: int) -> int: + """Dummy.""" + return x + + return FunctionTool(name="dummy", func=dummy) + + +async def _get_config_for_tool_choice(tool_choice: str) -> types.GenerateContentConfig: + tool = _make_dummy_tool() + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"tools": [tool], "tool_choice": tool_choice}, + ) + + return mock.aio.models.generate_content.call_args.kwargs["config"] + + +async def test_tool_choice_auto_maps_to_AUTO() -> None: + config = await _get_config_for_tool_choice("auto") + assert _get_function_calling_mode(config) == "AUTO" + + +async def test_tool_choice_none_maps_to_NONE() -> None: + config = await _get_config_for_tool_choice("none") + assert _get_function_calling_mode(config) == "NONE" + + +async def test_tool_choice_required_maps_to_ANY() -> None: + config = await _get_config_for_tool_choice("required") + assert _get_function_calling_mode(config) == "ANY" + + +async def test_tool_choice_required_with_name_sets_allowed_function_names() -> None: + tool = _make_dummy_tool() + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={ + "tools": [tool], + "tool_choice": {"mode": "required", "required_function_name": "dummy"}, + }, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + function_calling_config = config.tool_config.function_calling_config + assert function_calling_config.mode == "ANY" + assert "dummy" in function_calling_config.allowed_function_names + + +async def test_unknown_tool_choice_mode_is_ignored() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")])) + + with patch("agent_framework_gemini._chat_client.validate_tool_mode", return_value={"mode": "unsupported"}): + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + options={"tool_choice": "auto"}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert not hasattr(config, "tool_config") or config.tool_config is None + + +# built-in tools + + +async def test_google_search_grounding_injects_tool() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Result")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Search")])], + options={"google_search_grounding": True}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + assert any(t.google_search for t in config.tools) + + +async def test_google_maps_grounding_injects_tool() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Result")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Map")])], + options={"google_maps_grounding": True}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + assert any(t.google_maps for t in config.tools) + + +async def test_google_search_grounding_with_config_uses_provided_instance() -> None: + """Passing a types.GoogleSearch instance forwards it directly rather than constructing a default.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Result")])) + search_config = types.GoogleSearch(exclude_domains=["example.com"]) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Search")])], + options={"google_search_grounding": search_config}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + injected = next((t.google_search for t in config.tools if t.google_search is not None), None) # type: ignore[union-attr] + assert injected is search_config + + +async def test_google_maps_grounding_with_config_uses_provided_instance() -> None: + """Passing a types.GoogleMaps instance forwards it directly rather than constructing a default.""" + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Result")])) + maps_config = types.GoogleMaps(enable_widget=True) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Map")])], + options={"google_maps_grounding": maps_config}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + injected = next((t.google_maps for t in config.tools if t.google_maps is not None), None) # type: ignore[union-attr] + assert injected is maps_config + + +async def test_code_execution_injects_tool() -> None: + client, mock = _make_gemini_client() + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Result")])) + + await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Run code")])], + options={"code_execution": True}, + ) + + config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert config.tools is not None + assert any(t.code_execution for t in config.tools) + + +async def test_function_response_part_in_response_mapped_to_content() -> None: + """A function_response part echoed back in a model response is mapped to a function_result Content.""" + client, mock = _make_gemini_client() + part = MagicMock() + part.text = None + part.thought = False + part.function_call = None + part.function_response = MagicMock() + part.function_response.id = "call-99" + part.function_response.response = {"result": "done"} + mock.aio.models.generate_content = AsyncMock(return_value=_make_response([part])) + + response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert response.messages[0].contents[0].type == "function_result" + + +# streaming + + +async def test_streaming_yields_text_chunks() -> None: + client, mock = _make_gemini_client() + chunks = [ + _make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None), + _make_response([_make_part(text="world!")], finish_reason="STOP"), + ] + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks)) + + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + stream=True, + ) + + updates = [update async for update in stream] + text = "".join(u.text or "" for u in updates) + assert "Hello" in text + assert "world" in text + + +async def test_streaming_function_call_emitted_immediately() -> None: + """Function calls in streaming chunks must be emitted as they arrive, not deferred.""" + client, mock = _make_gemini_client() + chunks = [ + _make_response( + [_make_part(function_call=("call-1", "search", {"q": "test"}))], + finish_reason=None, + prompt_tokens=None, + output_tokens=None, + ), + _make_response([_make_part(text="Done")], finish_reason="STOP"), + ] + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks)) + + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Search")])], + stream=True, + ) + + all_contents = [] + async for update in stream: + all_contents.extend(update.contents) + + function_calls = [c for c in all_contents if c.type == "function_call"] + assert len(function_calls) == 1 + assert function_calls[0].name == "search" + + +async def test_streaming_finish_reason_only_on_last_chunk() -> None: + client, mock = _make_gemini_client() + chunks = [ + _make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None), + _make_response([_make_part(text="world!")], finish_reason="STOP"), + ] + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks)) + + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + stream=True, + ) + + updates = [update async for update in stream] + assert updates[0].finish_reason is None + assert updates[-1].finish_reason == "stop" + + +async def test_streaming_usage_only_on_final_chunk() -> None: + client, mock = _make_gemini_client() + chunks = [ + _make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None), + _make_response([_make_part(text="world!")], finish_reason="STOP", prompt_tokens=10, output_tokens=5), + ] + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks)) + + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + stream=True, + ) + + updates = [update async for update in stream] + assert not any(c.type == "usage" for c in updates[0].contents) + assert any(c.type == "usage" for c in updates[-1].contents) + + +async def test_streaming_get_final_response() -> None: + """get_final_response() must return a fully assembled ChatResponse after the stream is exhausted.""" + client, mock = _make_gemini_client() + chunks = [ + _make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None), + _make_response([_make_part(text="world!")], finish_reason="STOP", prompt_tokens=10, output_tokens=5), + ] + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks)) + + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + stream=True, + ) + + async for _ in stream: + pass + + final = await stream.get_final_response() + + assert final.messages[0].text == "Hello world!" + assert final.finish_reason == "stop" + assert final.usage_details is not None + assert final.usage_details["input_token_count"] == 10 + assert final.usage_details["output_token_count"] == 5 + + +# The Gemini API returns a list of candidates, each representing a possible response from the model. +# In practice only one candidate is returned, but the list can be empty or None if the request +# was blocked by safety filters or the API returned an unexpected response. + + +@pytest.mark.parametrize("candidates", [None, []]) +async def test_empty_candidates_returns_empty_message(candidates: list | None) -> None: + """An API response with no candidates must not raise and must return an empty assistant message.""" + client, mock = _make_gemini_client() + response = _make_response([]) + response.candidates = candidates + mock.aio.models.generate_content = AsyncMock(return_value=response) + + result = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])]) + + assert result.messages[0].role == "assistant" + assert result.messages[0].contents == [] + assert result.finish_reason is None + + +@pytest.mark.parametrize("candidates", [None, []]) +async def test_empty_candidates_in_stream_does_not_raise(candidates: list | None) -> None: + """A streaming chunk with no candidates must not raise and must yield an empty update.""" + client, mock = _make_gemini_client() + chunk = _make_response([], finish_reason=None, prompt_tokens=None, output_tokens=None) + chunk.candidates = candidates + mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter([chunk])) + + updates = [ + update + async for update in client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Hi")])], + stream=True, + ) + ] + + assert len(updates) == 1 + assert updates[0].contents == [] + assert updates[0].finish_reason is None + + +# service_url + + +def test_service_url() -> None: + client, _ = _make_gemini_client() + assert client.service_url() == "https://generativelanguage.googleapis.com" + + +# integration tests + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_basic_chat() -> None: + """Basic request/response round-trip returns a non-empty text reply.""" + client = GeminiChatClient(model=_TEST_MODEL) + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Reply with the single word: hello")])] + ) + + assert response.messages + assert response.messages[0].text + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_streaming() -> None: + """Streaming yields multiple chunks that together form a non-empty response.""" + client = GeminiChatClient(model=_TEST_MODEL) + stream = client.get_response( + messages=[Message(role="user", contents=[Content.from_text("Count from 1 to 5.")])], + stream=True, + ) + + chunks = [update async for update in stream] + assert len(chunks) > 0 + full_text = "".join(u.text or "" for u in chunks) + assert full_text + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_structured_output() -> None: + """Structured output with a Pydantic response_format returns a parsed value via response.value.""" + + class Answer(BaseModel): + answer: str + + client = GeminiChatClient(model=_TEST_MODEL) + + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("What is the capital of Germany?")])], + options={"response_format": Answer}, + ) + + assert response.value is not None + assert isinstance(response.value, Answer) + assert response.value.answer + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_tool_calling() -> None: + """Model invokes the registered tool when asked a question that requires it.""" + + def get_temperature(city: str) -> str: + """Return the current temperature for a city.""" + return f"22°C in {city}" + + tool = FunctionTool(name="get_temperature", func=get_temperature) + client = GeminiChatClient(model=_TEST_MODEL) + + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("What is the temperature in Berlin?")])], + options={"tools": [tool], "tool_choice": "required"}, + ) + + function_calls = [c for c in response.messages[0].contents if c.type == "function_call"] + assert len(function_calls) >= 1 + assert function_calls[0].name == "get_temperature" + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_thinking_config() -> None: + """Model accepts a thinking budget and returns a non-empty text reply.""" + options: GeminiChatOptions = {"thinking_config": ThinkingConfig(thinking_budget=512)} + client = GeminiChatClient(model=_TEST_MODEL) + + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("What is 17 * 34?")])], + options=options, + ) + + assert response.messages + assert response.messages[0].text + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_google_search_grounding() -> None: + """Google Search grounding returns a non-empty response for a current-events question.""" + options: GeminiChatOptions = {"google_search_grounding": True} + client = GeminiChatClient(model=_TEST_MODEL) + + response = await client.get_response( + messages=[Message(role="user", contents=[Content.from_text("What is the latest stable version of Python?")])], + options=options, + ) + + assert response.messages + assert response.messages[0].text + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_google_maps_grounding() -> None: + """Google Maps grounding returns a non-empty response for a location-based question.""" + options: GeminiChatOptions = {"google_maps_grounding": True} + client = GeminiChatClient(model=_TEST_MODEL) + + response = await client.get_response( + messages=[ + Message( + role="user", + contents=[Content.from_text("What are some highly rated restaurants in Karlsruhe city center?")], + ) + ], + options=options, + ) + + assert response.messages + assert response.messages[0].text + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_no_api_key +async def test_integration_code_execution() -> None: + """Code execution tool produces a non-empty response for a computation request.""" + options: GeminiChatOptions = {"code_execution": True} + client = GeminiChatClient(model=_TEST_MODEL) + + response = await client.get_response( + messages=[ + Message( + role="user", + contents=[Content.from_text("Compute the sum of the first 100 natural numbers using code.")], + ) + ], + options=options, + ) + + assert response.messages + assert response.messages[0].text diff --git a/python/pyproject.toml b/python/pyproject.toml index c19a6024d0..78464136eb 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -74,21 +74,22 @@ agent-framework-azure-ai = { workspace = true } agent-framework-azurefunctions = { workspace = true } agent-framework-bedrock = { workspace = true } agent-framework-chatkit = { workspace = true } +agent-framework-claude = { workspace = true } agent-framework-copilotstudio = { workspace = true } agent-framework-declarative = { workspace = true } agent-framework-devui = { workspace = true } agent-framework-durabletask = { workspace = true } agent-framework-foundry = { workspace = true } agent-framework-foundry-local = { workspace = true } +agent-framework-gemini = { workspace = true } +agent-framework-github-copilot = { workspace = true } agent-framework-lab = { workspace = true } agent-framework-mem0 = { workspace = true } agent-framework-ollama = { workspace = true } agent-framework-openai = { workspace = true } +agent-framework-orchestrations = { workspace = true } agent-framework-purview = { workspace = true } agent-framework-redis = { workspace = true } -agent-framework-github-copilot = { workspace = true } -agent-framework-claude = { workspace = true } -agent-framework-orchestrations = { workspace = true } litellm = { url = "https://files.pythonhosted.org/packages/57/77/0c6eca2cb049793ddf8ce9cdcd5123a35666c4962514788c4fc90edf1d3b/litellm-1.82.1-py3-none-any.whl" } [tool.ruff] diff --git a/python/samples/02-agents/providers/google/README.md b/python/samples/02-agents/providers/google/README.md new file mode 100644 index 0000000000..28fb05abeb --- /dev/null +++ b/python/samples/02-agents/providers/google/README.md @@ -0,0 +1,18 @@ +# Google Gemini Examples + +This folder contains examples demonstrating how to use Google Gemini models with the Agent Framework. + +## Examples + +| File | Description | +|------|-------------| +| [`gemini_basic.py`](gemini_basic.py) | Basic agent with a weather tool, demonstrating both streaming and non-streaming responses. | +| [`gemini_advanced.py`](gemini_advanced.py) | Extended thinking via `ThinkingConfig` for reasoning-heavy questions (Gemini 2.5+). | +| [`gemini_with_google_search.py`](gemini_with_google_search.py) | Google Search grounding for up-to-date answers. | +| [`gemini_with_google_maps.py`](gemini_with_google_maps.py) | Google Maps grounding for location and mapping information. | +| [`gemini_with_code_execution.py`](gemini_with_code_execution.py) | Built-in code execution tool for computing precise answers in a sandboxed environment. | + +## Environment Variables + +- `GEMINI_API_KEY`: Your Google AI Studio API key (get one from [Google AI Studio](https://aistudio.google.com/apikey)) +- `GEMINI_MODEL`: The Gemini model to use (e.g., `gemini-2.5-flash`, `gemini-2.5-pro`) diff --git a/python/samples/02-agents/providers/google/gemini_advanced.py b/python/samples/02-agents/providers/google/gemini_advanced.py new file mode 100644 index 0000000000..03ab9ad80d --- /dev/null +++ b/python/samples/02-agents/providers/google/gemini_advanced.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Shows how to enable extended thinking with ThinkingConfig so the model can +reason through complex problems before responding. + +Requires the following environment variables to be set: +- GEMINI_API_KEY +- GEMINI_MODEL +""" + +import asyncio + +from agent_framework import Agent +from agent_framework_gemini import GeminiChatClient, GeminiChatOptions, ThinkingConfig +from dotenv import load_dotenv + +load_dotenv() + + +async def main() -> None: + """Example of extended thinking with a Python version comparison question.""" + print("=== Extended thinking ===") + + options: GeminiChatOptions = { + "thinking_config": ThinkingConfig(thinking_budget=2048), + } + + agent = Agent( + client=GeminiChatClient(), + name="PythonAgent", + instructions="You are a helpful Python expert.", + default_options=options, + ) + + query = "What new language features were introduced in Python between 3.10 and 3.14?" + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/02-agents/providers/google/gemini_basic.py b/python/samples/02-agents/providers/google/gemini_basic.py new file mode 100644 index 0000000000..d3dc253c14 --- /dev/null +++ b/python/samples/02-agents/providers/google/gemini_basic.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Shows how to use GeminiChatClient with an agent and a custom tool, covering both +non-streaming and streaming responses. + +Requires the following environment variables to be set: +- GEMINI_API_KEY +- GEMINI_MODEL +""" + +import asyncio +from random import randint +from typing import Annotated + +from agent_framework import Agent, tool +from agent_framework_gemini import GeminiChatClient +from dotenv import load_dotenv + +load_dotenv() + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production +@tool(approval_mode="never_require") +def get_weather( + location: Annotated[str, "The location to get the weather for."], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def non_streaming_example() -> None: + """Runs the agent and waits for the complete response before printing it.""" + print("=== Non-streaming ===") + + agent = Agent( + client=GeminiChatClient(), + name="WeatherAgent", + instructions="You are a helpful weather agent.", + tools=[get_weather], + ) + + query = "What's the weather like in Karlsruhe, Germany?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Result: {result}\n") + + +async def streaming_example() -> None: + """Runs the agent and prints each chunk as it is received.""" + print("=== Streaming ===") + + agent = Agent( + client=GeminiChatClient(), + name="WeatherAgent", + instructions="You are a helpful weather agent.", + tools=[get_weather], + ) + + query = "What's the weather like in Portland and in Paris?" + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n") + + +async def main() -> None: + await non_streaming_example() + await streaming_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/02-agents/providers/google/gemini_with_code_execution.py b/python/samples/02-agents/providers/google/gemini_with_code_execution.py new file mode 100644 index 0000000000..dd73ad6c75 --- /dev/null +++ b/python/samples/02-agents/providers/google/gemini_with_code_execution.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Shows how to enable Gemini's built-in code execution tool so the model can write +and run code in a sandboxed environment to answer questions. + +Requires the following environment variables to be set: +- GEMINI_API_KEY +- GEMINI_MODEL +""" + +import asyncio + +from agent_framework import Agent +from agent_framework_gemini import GeminiChatClient, GeminiChatOptions +from dotenv import load_dotenv + +load_dotenv() + + +async def main() -> None: + print("=== Code execution ===") + + options: GeminiChatOptions = { + "code_execution": True, + } + + agent = Agent( + client=GeminiChatClient(), + name="CodeAgent", + instructions="You are a helpful assistant. Use code execution to compute precise answers.", + default_options=options, + ) + + query = "What are the first 20 prime numbers? Compute them in code." + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/02-agents/providers/google/gemini_with_google_maps.py b/python/samples/02-agents/providers/google/gemini_with_google_maps.py new file mode 100644 index 0000000000..375bd23732 --- /dev/null +++ b/python/samples/02-agents/providers/google/gemini_with_google_maps.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Shows how to enable Google Maps grounding so Gemini can retrieve location and +mapping information before responding. + +Requires the following environment variables to be set: +- GEMINI_API_KEY +- GEMINI_MODEL +""" + +import asyncio + +from agent_framework import Agent +from agent_framework_gemini import GeminiChatClient, GeminiChatOptions +from dotenv import load_dotenv + +load_dotenv() + + +async def main() -> None: + print("=== Google Maps grounding ===") + + options: GeminiChatOptions = { + "google_maps_grounding": True, + } + + agent = Agent( + client=GeminiChatClient(), + name="MapsAgent", + instructions="You are a helpful travel assistant. Use Google Maps to provide accurate location information.", + default_options=options, + ) + + query = "What are some highly rated restaurants in the city center of Karlsruhe, Germany?" + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/02-agents/providers/google/gemini_with_google_search.py b/python/samples/02-agents/providers/google/gemini_with_google_search.py new file mode 100644 index 0000000000..aed53fc8fd --- /dev/null +++ b/python/samples/02-agents/providers/google/gemini_with_google_search.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Shows how to enable Google Search grounding so Gemini can retrieve up-to-date +information from the web before responding. + +Requires the following environment variables to be set: +- GEMINI_API_KEY +- GEMINI_MODEL +""" + +import asyncio + +from agent_framework import Agent +from agent_framework_gemini import GeminiChatClient, GeminiChatOptions +from dotenv import load_dotenv + +load_dotenv() + + +async def main() -> None: + print("=== Google Search grounding ===") + + options: GeminiChatOptions = { + "google_search_grounding": True, + } + + agent = Agent( + client=GeminiChatClient(), + name="SearchAgent", + instructions="You are a helpful assistant. Use Google Search to provide accurate, up-to-date answers.", + default_options=options, + ) + + query = "What is the latest stable release of the .NET SDK?" + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/uv.lock b/python/uv.lock index 317113161f..18f0cc9464 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -44,6 +44,7 @@ members = [ "agent-framework-durabletask", "agent-framework-foundry", "agent-framework-foundry-local", + "agent-framework-gemini", "agent-framework-github-copilot", "agent-framework-lab", "agent-framework-mem0", @@ -538,6 +539,21 @@ requires-dist = [ { name = "foundry-local-sdk", specifier = ">=0.5.1,<0.5.2" }, ] +[[package]] +name = "agent-framework-gemini" +version = "1.0.0b260319" +source = { editable = "packages/gemini" } +dependencies = [ + { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "google-genai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "google-genai", specifier = ">=1.0.0,<2.0.0" }, +] + [[package]] name = "agent-framework-github-copilot" version = "1.0.0b260319" @@ -2374,6 +2390,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, ] +[package.optional-dependencies] +requests = [ + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[[package]] +name = "google-genai" +version = "1.68.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "google-auth", extra = ["requests"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "tenacity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "websockets", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" }, +] + [[package]] name = "googleapis-common-protos" version = "1.73.0"