Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 100 additions & 21 deletions python/packages/azure-ai/agent_framework_azure_ai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

import json
import sys
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from contextlib import suppress
from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast

from agent_framework import (
Expand Down Expand Up @@ -218,6 +220,10 @@ class MyOptions(ChatOptions, total=False):
self._is_application_endpoint = "/applications/" in project_client._config.endpoint # type: ignore
# Track whether we should close client connection
self._should_close_client = should_close_client
# Track creation-time agent configuration for runtime mismatch warnings.
self.warn_runtime_tools_and_structure_changed = False
self._created_agent_tool_names: set[str] = set()
self._created_agent_structured_output_signature: str | None = None

async def configure_azure_monitor(
self,
Expand Down Expand Up @@ -341,18 +347,18 @@ async def _get_agent_reference_or_create(
"Agent name is required. Provide 'agent_name' when initializing AzureAIClient "
"or 'name' when initializing Agent."
)
# If the agent exists and we do not want to track agent configuration, return early
if self.agent_version is not None and not self.warn_runtime_tools_and_structure_changed:
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}

# If no agent_version is provided, either use latest version or create a new agent:
if self.agent_version is None:
# Try to use latest version if requested and agent exists
if self.use_latest_version:
try:
with suppress(ResourceNotFoundError):
existing_agent = await self.project_client.agents.get(self.agent_name)
self.agent_version = existing_agent.versions.latest.version
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}
except ResourceNotFoundError:
# Agent doesn't exist, fall through to creation logic
pass

if "model" not in run_options or not run_options["model"]:
raise ServiceInitializationError(
Expand Down Expand Up @@ -395,14 +401,101 @@ async def _get_agent_reference_or_create(
)

self.agent_version = created_agent.version

self.warn_runtime_tools_and_structure_changed = True
self._created_agent_tool_names = self._extract_tool_names(run_options.get("tools"))
self._created_agent_structured_output_signature = self._get_structured_output_signature(chat_options)
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}

async def _close_client_if_needed(self) -> None:
"""Close project_client session if we created it."""
if self._should_close_client:
await self.project_client.close()

def _extract_tool_names(self, tools: Any) -> set[str]:
"""Extract comparable tool names from runtime tool payloads."""
if not isinstance(tools, Sequence) or isinstance(tools, str | bytes):
return set()
return {self._get_tool_name(tool) for tool in tools}

def _get_tool_name(self, tool: Any) -> str:
"""Get a stable name for a tool for runtime comparison."""
if isinstance(tool, FunctionTool):
return tool.name
if isinstance(tool, Mapping):
tool_type = tool.get("type")
if tool_type == "function":
if isinstance(function_data := tool.get("function"), Mapping) and function_data.get("name"):
return str(function_data["name"])
if tool.get("name"):
return str(tool["name"])
if tool.get("name"):
return str(tool["name"])
if tool.get("server_label"):
return f"mcp:{tool['server_label']}"
if tool_type:
return str(tool_type)
if getattr(tool, "name", None):
return str(tool.name)
if getattr(tool, "server_label", None):
return f"mcp:{tool.server_label}"
if getattr(tool, "type", None):
return str(tool.type)
return type(tool).__name__

def _get_structured_output_signature(self, chat_options: Mapping[str, Any] | None) -> str | None:
"""Build a stable signature for structured_output/response_format values."""
if not chat_options:
return None
response_format = chat_options.get("response_format")
if response_format is None:
return None
if isinstance(response_format, type):
return f"{response_format.__module__}.{response_format.__qualname__}"
if isinstance(response_format, Mapping):
return json.dumps(response_format, sort_keys=True, default=str)
return str(response_format)

def _remove_agent_level_run_options(
self,
run_options: dict[str, Any],
chat_options: Mapping[str, Any] | None = None,
) -> None:
"""Remove request-level options that Azure AI only supports at agent creation time."""
runtime_tools = run_options.get("tools")
runtime_structured_output = self._get_structured_output_signature(chat_options)

if runtime_tools is not None or runtime_structured_output is not None:
tools_changed = runtime_tools is not None
structured_output_changed = runtime_structured_output is not None

if self.warn_runtime_tools_and_structure_changed:
if runtime_tools is not None:
tools_changed = self._extract_tool_names(runtime_tools) != self._created_agent_tool_names
if runtime_structured_output is not None:
structured_output_changed = (
runtime_structured_output != self._created_agent_structured_output_signature
)

if tools_changed or structured_output_changed:
logger.warning(
"AzureAIClient does not support runtime tools or structured_output overrides after agent creation. "
"Use AzureOpenAIResponsesClient instead."
)

agent_level_option_to_run_keys = {
"model_id": ("model",),
"tools": ("tools",),
"response_format": ("response_format", "text", "text_format"),
"rai_config": ("rai_config",),
"temperature": ("temperature",),
"top_p": ("top_p",),
"reasoning": ("reasoning",),
}

for run_keys in agent_level_option_to_run_keys.values():
for run_key in run_keys:
run_options.pop(run_key, None)

@override
async def _prepare_options(
self,
Expand All @@ -427,22 +520,8 @@ async def _prepare_options(
agent_reference = await self._get_agent_reference_or_create(run_options, instructions, options)
run_options["extra_body"] = {"agent": agent_reference}

# Remove properties that are not supported on request level
# but were configured on agent level
exclude = [
"model",
"tools",
"response_format",
"rai_config",
"temperature",
"top_p",
"text",
"text_format",
"reasoning",
]

for property in exclude:
run_options.pop(property, None)
# Remove only keys that map to this client's declared options TypedDict.
self._remove_agent_level_run_options(run_options, options)

return run_options

Expand Down
149 changes: 149 additions & 0 deletions python/packages/azure-ai/tests/test_azure_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def create_test_azure_ai_client(
client.conversation_id = conversation_id
client._is_application_endpoint = False # type: ignore
client._should_close_client = should_close_client # type: ignore
client.warn_runtime_tools_and_structure_changed = False # type: ignore
client._created_agent_tool_names = set() # type: ignore
client._created_agent_structured_output_signature = None # type: ignore
client.additional_properties = {}
client.middleware = None

Expand Down Expand Up @@ -773,6 +776,82 @@ async def test_agent_creation_with_tools(
assert call_args[1]["definition"].tools == test_tools


async def test_runtime_tools_override_logs_warning(
mock_project_client: MagicMock,
) -> None:
"""Test warning is logged when runtime tools differ from creation-time tools."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")

mock_agent = MagicMock()
mock_agent.name = "test-agent"
mock_agent.version = "1.0"
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]

with patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
):
await client._prepare_options(messages, {})

with (
patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_two"}]},
),
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
):
await client._prepare_options(messages, {})
mock_warning.assert_called_once()
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]


async def test_prepare_options_logs_warning_for_tools_with_existing_agent_version(
mock_project_client: MagicMock,
) -> None:
"""Test warning is logged when tools are supplied against an existing agent version."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]

with (
patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
),
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
):
run_options = await client._prepare_options(messages, {})

mock_warning.assert_called_once()
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]
assert "tools" not in run_options


async def test_prepare_options_logs_warning_for_tools_on_application_endpoint(
mock_project_client: MagicMock,
) -> None:
"""Test warning is logged when runtime tools are removed for application endpoints."""
client = create_test_azure_ai_client(mock_project_client)
client._is_application_endpoint = True # type: ignore
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]

with (
patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
),
patch.object(client, "_get_agent_reference_or_create", new_callable=AsyncMock) as mock_get_agent_reference,
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
):
run_options = await client._prepare_options(messages, {})

mock_get_agent_reference.assert_not_called()
mock_warning.assert_called_once()
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]
assert "tools" not in run_options
assert "extra_body" not in run_options


async def test_use_latest_version_existing_agent(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -872,6 +951,13 @@ class ResponseFormatModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class AlternateResponseFormatModel(BaseModel):
"""Alternate model for structured output warning checks."""

summary: str
confidence: float


async def test_agent_creation_with_response_format(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -964,6 +1050,36 @@ async def test_agent_creation_with_mapping_response_format(
assert format_config.strict is True


async def test_runtime_structured_output_override_logs_warning(
mock_project_client: MagicMock,
) -> None:
"""Test warning is logged when runtime structured_output differs from creation-time configuration."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")

mock_agent = MagicMock()
mock_agent.name = "test-agent"
mock_agent.version = "1.0"
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]

with patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={"model": "test-model"},
):
await client._prepare_options(messages, {"response_format": ResponseFormatModel})

with (
patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={"model": "test-model"},
),
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
):
await client._prepare_options(messages, {"response_format": AlternateResponseFormatModel})
mock_warning.assert_called_once()
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]


async def test_prepare_options_excludes_response_format(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -1001,6 +1117,39 @@ async def test_prepare_options_excludes_response_format(
assert run_options["extra_body"]["agent"]["name"] == "test-agent"


async def test_prepare_options_keeps_values_for_unsupported_option_keys(
mock_project_client: MagicMock,
) -> None:
"""Test that run_options removal only applies to known AzureAI agent-level option mappings."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]

with (
patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={
"model": "test-model",
"tools": [{"type": "function", "name": "weather"}],
"text": {"format": {"type": "json_schema", "name": "schema"}},
"text_format": ResponseFormatModel,
"custom_option": "keep-me",
},
),
patch.object(
client,
"_get_agent_reference_or_create",
return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"},
),
):
run_options = await client._prepare_options(messages, {})

assert "model" not in run_options
assert "tools" not in run_options
assert "text" not in run_options
assert "text_format" not in run_options
assert run_options["custom_option"] == "keep-me"


def test_get_conversation_id_with_store_true_and_conversation_id() -> None:
"""Test _get_conversation_id returns conversation ID when store is True and conversation exists."""
client = create_test_azure_ai_client(MagicMock())
Expand Down