Skip to content

Commit e0b9be7

Browse files
authored
Python: fix(declarative): Fix MCP tool connection not passed from YAML to Azure AI agent creation API (#3248)
* fix(declarative): Fix MCP tool connection not passed from YAML * Add samples to README * Fix mypy * Fix mypy again * Address PR comments
1 parent 83e8965 commit e0b9be7

7 files changed

Lines changed: 781 additions & 112 deletions

File tree

python/packages/azure-ai/agent_framework_azure_ai/_client.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from azure.core.exceptions import ResourceNotFoundError
3434
from pydantic import ValidationError
3535

36-
from ._shared import AzureAISettings, create_text_format_config
36+
from ._shared import AzureAISettings, _extract_project_connection_id, create_text_format_config
3737

3838
if sys.version_info >= (3, 13):
3939
from typing import TypeVar # type: ignore # pragma: no cover
@@ -510,6 +510,17 @@ def _prepare_mcp_tool(tool: HostedMCPTool) -> MCPTool: # type: ignore[override]
510510
"""Get MCP tool from HostedMCPTool."""
511511
mcp = MCPTool(server_label=tool.name.replace(" ", "_"), server_url=str(tool.url))
512512

513+
if tool.description:
514+
mcp["server_description"] = tool.description
515+
516+
# Check for project_connection_id in additional_properties (for Azure AI Foundry connections)
517+
project_connection_id = _extract_project_connection_id(tool.additional_properties)
518+
if project_connection_id:
519+
mcp["project_connection_id"] = project_connection_id
520+
elif tool.headers:
521+
# Only use headers if no project_connection_id is available
522+
mcp["headers"] = tool.headers
523+
513524
if tool.allowed_tools:
514525
mcp["allowed_tools"] = list(tool.allowed_tools)
515526

python/packages/azure-ai/agent_framework_azure_ai/_shared.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,37 @@ class AzureAISettings(AFBaseSettings):
8787
model_deployment_name: str | None = None
8888

8989

90+
def _extract_project_connection_id(additional_properties: dict[str, Any] | None) -> str | None:
91+
"""Extract project_connection_id from HostedMCPTool additional_properties.
92+
93+
Checks for both direct 'project_connection_id' key (programmatic usage)
94+
and 'connection.name' structure (declarative/YAML usage).
95+
96+
Args:
97+
additional_properties: The additional_properties dict from a HostedMCPTool.
98+
99+
Returns:
100+
The project_connection_id if found, None otherwise.
101+
"""
102+
if not additional_properties:
103+
return None
104+
105+
# Check for direct project_connection_id (programmatic usage)
106+
project_connection_id = additional_properties.get("project_connection_id")
107+
if isinstance(project_connection_id, str):
108+
return project_connection_id
109+
110+
# Check for connection.name structure (declarative/YAML usage)
111+
if "connection" in additional_properties:
112+
conn = additional_properties["connection"]
113+
if isinstance(conn, dict):
114+
name = conn.get("name")
115+
if isinstance(name, str):
116+
return name
117+
118+
return None
119+
120+
90121
def to_azure_ai_agent_tools(
91122
tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None,
92123
run_options: dict[str, Any] | None = None,
@@ -322,6 +353,11 @@ def from_azure_ai_tools(tools: Sequence[Tool | dict[str, Any]] | None) -> list[T
322353
if "never" in require_approval:
323354
approval_mode["never_require_approval"] = set(require_approval["never"].get("tool_names", [])) # type: ignore
324355

356+
# Preserve project_connection_id in additional_properties
357+
additional_props: dict[str, Any] | None = None
358+
if project_connection_id := mcp_tool.get("project_connection_id"):
359+
additional_props = {"connection": {"name": project_connection_id}}
360+
325361
agent_tools.append(
326362
HostedMCPTool(
327363
name=mcp_tool.get("server_label", "").replace("_", " "),
@@ -330,6 +366,7 @@ def from_azure_ai_tools(tools: Sequence[Tool | dict[str, Any]] | None) -> list[T
330366
headers=mcp_tool.get("headers"),
331367
allowed_tools=mcp_tool.get("allowed_tools"),
332368
approval_mode=approval_mode, # type: ignore
369+
additional_properties=additional_props,
333370
)
334371
)
335372
elif tool_type == "code_interpreter":
@@ -466,7 +503,13 @@ def _prepare_mcp_tool_for_azure_ai(tool: HostedMCPTool) -> MCPTool:
466503
if tool.description:
467504
mcp["server_description"] = tool.description
468505

469-
if tool.headers:
506+
# Check for project_connection_id in additional_properties (for Azure AI Foundry connections)
507+
project_connection_id = _extract_project_connection_id(tool.additional_properties)
508+
if project_connection_id:
509+
mcp["project_connection_id"] = project_connection_id
510+
elif tool.headers:
511+
# Only use headers if no project_connection_id is available
512+
# Note: Azure AI Agent Service may reject headers with sensitive info
470513
mcp["headers"] = tool.headers
471514

472515
if tool.allowed_tools:

python/packages/declarative/agent_framework_declarative/_loader.py

Lines changed: 209 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable, Mapping
44
from pathlib import Path
5-
from typing import Any, Literal, TypedDict
5+
from typing import Any, Literal, TypedDict, cast
66

77
import yaml
88
from agent_framework import (
@@ -89,6 +89,11 @@ class ProviderTypeMapping(TypedDict, total=True):
8989
"name": "AzureAIClient",
9090
"model_id_field": "model_deployment_name",
9191
},
92+
"AzureAI.ProjectProvider": {
93+
"package": "agent_framework.azure",
94+
"name": "AzureAIProjectAgentProvider",
95+
"model_id_field": "model",
96+
},
9297
"Anthropic.Chat": {
9398
"package": "agent_framework.anthropic",
9499
"name": "AnthropicChatClient",
@@ -448,6 +453,175 @@ def create_agent_from_dict(self, agent_def: dict[str, Any]) -> ChatAgent:
448453
**chat_options,
449454
)
450455

456+
async def create_agent_from_yaml_path_async(self, yaml_path: str | Path) -> ChatAgent:
457+
"""Async version: Create a ChatAgent from a YAML file path.
458+
459+
Use this method when the provider requires async initialization, such as
460+
AzureAI.ProjectProvider which creates agents on the Azure AI Agent Service.
461+
462+
Args:
463+
yaml_path: Path to the YAML file representation of a PromptAgent.
464+
465+
Returns:
466+
The ``ChatAgent`` instance created from the YAML file.
467+
468+
Examples:
469+
.. code-block:: python
470+
471+
from agent_framework_declarative import AgentFactory
472+
473+
factory = AgentFactory(
474+
client_kwargs={"credential": credential},
475+
default_provider="AzureAI.ProjectProvider",
476+
)
477+
agent = await factory.create_agent_from_yaml_path_async("agent.yaml")
478+
"""
479+
if not isinstance(yaml_path, Path):
480+
yaml_path = Path(yaml_path)
481+
if not yaml_path.exists():
482+
raise DeclarativeLoaderError(f"YAML file not found at path: {yaml_path}")
483+
yaml_str = yaml_path.read_text()
484+
return await self.create_agent_from_yaml_async(yaml_str)
485+
486+
async def create_agent_from_yaml_async(self, yaml_str: str) -> ChatAgent:
487+
"""Async version: Create a ChatAgent from a YAML string.
488+
489+
Use this method when the provider requires async initialization, such as
490+
AzureAI.ProjectProvider which creates agents on the Azure AI Agent Service.
491+
492+
Args:
493+
yaml_str: YAML string representation of a PromptAgent.
494+
495+
Returns:
496+
The ``ChatAgent`` instance created from the YAML string.
497+
498+
Examples:
499+
.. code-block:: python
500+
501+
from agent_framework_declarative import AgentFactory
502+
503+
yaml_content = '''
504+
kind: Prompt
505+
name: MyAgent
506+
instructions: You are a helpful assistant.
507+
model:
508+
id: gpt-4o
509+
provider: AzureAI.ProjectProvider
510+
'''
511+
512+
factory = AgentFactory(client_kwargs={"credential": credential})
513+
agent = await factory.create_agent_from_yaml_async(yaml_content)
514+
"""
515+
return await self.create_agent_from_dict_async(yaml.safe_load(yaml_str))
516+
517+
async def create_agent_from_dict_async(self, agent_def: dict[str, Any]) -> ChatAgent:
518+
"""Async version: Create a ChatAgent from a dictionary definition.
519+
520+
Use this method when the provider requires async initialization, such as
521+
AzureAI.ProjectProvider which creates agents on the Azure AI Agent Service.
522+
523+
Args:
524+
agent_def: Dictionary representation of a PromptAgent.
525+
526+
Returns:
527+
The ``ChatAgent`` instance created from the dictionary.
528+
529+
Examples:
530+
.. code-block:: python
531+
532+
from agent_framework_declarative import AgentFactory
533+
534+
agent_def = {
535+
"kind": "Prompt",
536+
"name": "MyAgent",
537+
"instructions": "You are a helpful assistant.",
538+
"model": {
539+
"id": "gpt-4o",
540+
"provider": "AzureAI.ProjectProvider",
541+
},
542+
}
543+
544+
factory = AgentFactory(client_kwargs={"credential": credential})
545+
agent = await factory.create_agent_from_dict_async(agent_def)
546+
"""
547+
# Set safe_mode context before parsing YAML to control PowerFx environment variable access
548+
_safe_mode_context.set(self.safe_mode)
549+
prompt_agent = agent_schema_dispatch(agent_def)
550+
if not isinstance(prompt_agent, PromptAgent):
551+
raise DeclarativeLoaderError("Only definitions for a PromptAgent are supported for agent creation.")
552+
553+
# Check if we're using a provider-based approach (like AzureAIProjectAgentProvider)
554+
mapping = self._retrieve_provider_configuration(prompt_agent.model) if prompt_agent.model else None
555+
if mapping and mapping["name"] == "AzureAIProjectAgentProvider":
556+
return await self._create_agent_with_provider(prompt_agent, mapping)
557+
558+
# Fall back to standard ChatClient approach
559+
client = self._get_client(prompt_agent)
560+
chat_options = self._parse_chat_options(prompt_agent.model)
561+
if tools := self._parse_tools(prompt_agent.tools):
562+
chat_options["tools"] = tools
563+
if output_schema := prompt_agent.outputSchema:
564+
chat_options["response_format"] = _create_model_from_json_schema("agent", output_schema.to_json_schema())
565+
return ChatAgent(
566+
chat_client=client,
567+
name=prompt_agent.name,
568+
description=prompt_agent.description,
569+
instructions=prompt_agent.instructions,
570+
**chat_options,
571+
)
572+
573+
async def _create_agent_with_provider(self, prompt_agent: PromptAgent, mapping: ProviderTypeMapping) -> ChatAgent:
574+
"""Create a ChatAgent using AzureAIProjectAgentProvider.
575+
576+
This method handles the special case where we use a provider that creates
577+
agents on a remote service (like Azure AI Agent Service) and returns
578+
ChatAgent instances directly.
579+
"""
580+
# Import the provider class
581+
module_name = mapping["package"]
582+
class_name = mapping["name"]
583+
module = __import__(module_name, fromlist=[class_name])
584+
provider_class = getattr(module, class_name)
585+
586+
# Build provider kwargs from client_kwargs and connection info
587+
provider_kwargs: dict[str, Any] = {}
588+
provider_kwargs.update(self.client_kwargs)
589+
590+
# Handle connection settings for the model
591+
if prompt_agent.model and prompt_agent.model.connection:
592+
match prompt_agent.model.connection:
593+
case RemoteConnection() | AnonymousConnection():
594+
if prompt_agent.model.connection.endpoint:
595+
provider_kwargs["project_endpoint"] = prompt_agent.model.connection.endpoint
596+
case ApiKeyConnection():
597+
if prompt_agent.model.connection.endpoint:
598+
provider_kwargs["project_endpoint"] = prompt_agent.model.connection.endpoint
599+
600+
# Create the provider and use it to create the agent
601+
provider = provider_class(**provider_kwargs)
602+
603+
# Parse tools
604+
tools = self._parse_tools(prompt_agent.tools) if prompt_agent.tools else None
605+
606+
# Parse response format
607+
response_format = None
608+
if prompt_agent.outputSchema:
609+
response_format = _create_model_from_json_schema("agent", prompt_agent.outputSchema.to_json_schema())
610+
611+
# Create the agent using the provider
612+
# The provider's create_agent returns a ChatAgent directly
613+
return cast(
614+
ChatAgent,
615+
await provider.create_agent(
616+
name=prompt_agent.name,
617+
model=prompt_agent.model.id if prompt_agent.model else None,
618+
instructions=prompt_agent.instructions,
619+
description=prompt_agent.description,
620+
tools=tools,
621+
response_format=response_format,
622+
),
623+
)
624+
451625
def _get_client(self, prompt_agent: PromptAgent) -> ChatClientProtocol:
452626
"""Create the ChatClientProtocol instance based on the PromptAgent model."""
453627
if not prompt_agent.model:
@@ -594,12 +768,46 @@ def _parse_tool(self, tool_resource: Tool) -> ToolProtocol:
594768
)
595769
if not approval_mode:
596770
approval_mode = None
771+
772+
# Handle connection settings
773+
headers: dict[str, str] | None = None
774+
additional_properties: dict[str, Any] | None = None
775+
776+
if tool_resource.connection is not None:
777+
match tool_resource.connection:
778+
case ApiKeyConnection():
779+
if tool_resource.connection.apiKey:
780+
headers = {"Authorization": f"Bearer {tool_resource.connection.apiKey}"}
781+
case RemoteConnection():
782+
additional_properties = {
783+
"connection": {
784+
"kind": tool_resource.connection.kind,
785+
"name": tool_resource.connection.name,
786+
"authenticationMode": tool_resource.connection.authenticationMode,
787+
"endpoint": tool_resource.connection.endpoint,
788+
}
789+
}
790+
case ReferenceConnection():
791+
additional_properties = {
792+
"connection": {
793+
"kind": tool_resource.connection.kind,
794+
"name": tool_resource.connection.name,
795+
"authenticationMode": tool_resource.connection.authenticationMode,
796+
}
797+
}
798+
case AnonymousConnection():
799+
pass
800+
case _:
801+
raise ValueError(f"Unsupported connection kind: {tool_resource.connection.kind}")
802+
597803
return HostedMCPTool(
598804
name=tool_resource.name, # type: ignore
599805
description=tool_resource.description,
600806
url=tool_resource.url, # type: ignore
601807
allowed_tools=tool_resource.allowedTools,
602808
approval_mode=approval_mode,
809+
headers=headers,
810+
additional_properties=additional_properties,
603811
)
604812
case _:
605813
raise ValueError(f"Unsupported tool kind: {tool_resource.kind}")

0 commit comments

Comments
 (0)