|
2 | 2 |
|
3 | 3 | from collections.abc import Callable, Mapping |
4 | 4 | from pathlib import Path |
5 | | -from typing import Any, Literal, TypedDict |
| 5 | +from typing import Any, Literal, TypedDict, cast |
6 | 6 |
|
7 | 7 | import yaml |
8 | 8 | from agent_framework import ( |
@@ -89,6 +89,11 @@ class ProviderTypeMapping(TypedDict, total=True): |
89 | 89 | "name": "AzureAIClient", |
90 | 90 | "model_id_field": "model_deployment_name", |
91 | 91 | }, |
| 92 | + "AzureAI.ProjectProvider": { |
| 93 | + "package": "agent_framework.azure", |
| 94 | + "name": "AzureAIProjectAgentProvider", |
| 95 | + "model_id_field": "model", |
| 96 | + }, |
92 | 97 | "Anthropic.Chat": { |
93 | 98 | "package": "agent_framework.anthropic", |
94 | 99 | "name": "AnthropicChatClient", |
@@ -448,6 +453,175 @@ def create_agent_from_dict(self, agent_def: dict[str, Any]) -> ChatAgent: |
448 | 453 | **chat_options, |
449 | 454 | ) |
450 | 455 |
|
| 456 | + async def create_agent_from_yaml_path_async(self, yaml_path: str | Path) -> ChatAgent: |
| 457 | + """Async version: Create a ChatAgent from a YAML file path. |
| 458 | +
|
| 459 | + Use this method when the provider requires async initialization, such as |
| 460 | + AzureAI.ProjectProvider which creates agents on the Azure AI Agent Service. |
| 461 | +
|
| 462 | + Args: |
| 463 | + yaml_path: Path to the YAML file representation of a PromptAgent. |
| 464 | +
|
| 465 | + Returns: |
| 466 | + The ``ChatAgent`` instance created from the YAML file. |
| 467 | +
|
| 468 | + Examples: |
| 469 | + .. code-block:: python |
| 470 | +
|
| 471 | + from agent_framework_declarative import AgentFactory |
| 472 | +
|
| 473 | + factory = AgentFactory( |
| 474 | + client_kwargs={"credential": credential}, |
| 475 | + default_provider="AzureAI.ProjectProvider", |
| 476 | + ) |
| 477 | + agent = await factory.create_agent_from_yaml_path_async("agent.yaml") |
| 478 | + """ |
| 479 | + if not isinstance(yaml_path, Path): |
| 480 | + yaml_path = Path(yaml_path) |
| 481 | + if not yaml_path.exists(): |
| 482 | + raise DeclarativeLoaderError(f"YAML file not found at path: {yaml_path}") |
| 483 | + yaml_str = yaml_path.read_text() |
| 484 | + return await self.create_agent_from_yaml_async(yaml_str) |
| 485 | + |
| 486 | + async def create_agent_from_yaml_async(self, yaml_str: str) -> ChatAgent: |
| 487 | + """Async version: Create a ChatAgent from a YAML string. |
| 488 | +
|
| 489 | + Use this method when the provider requires async initialization, such as |
| 490 | + AzureAI.ProjectProvider which creates agents on the Azure AI Agent Service. |
| 491 | +
|
| 492 | + Args: |
| 493 | + yaml_str: YAML string representation of a PromptAgent. |
| 494 | +
|
| 495 | + Returns: |
| 496 | + The ``ChatAgent`` instance created from the YAML string. |
| 497 | +
|
| 498 | + Examples: |
| 499 | + .. code-block:: python |
| 500 | +
|
| 501 | + from agent_framework_declarative import AgentFactory |
| 502 | +
|
| 503 | + yaml_content = ''' |
| 504 | + kind: Prompt |
| 505 | + name: MyAgent |
| 506 | + instructions: You are a helpful assistant. |
| 507 | + model: |
| 508 | + id: gpt-4o |
| 509 | + provider: AzureAI.ProjectProvider |
| 510 | + ''' |
| 511 | +
|
| 512 | + factory = AgentFactory(client_kwargs={"credential": credential}) |
| 513 | + agent = await factory.create_agent_from_yaml_async(yaml_content) |
| 514 | + """ |
| 515 | + return await self.create_agent_from_dict_async(yaml.safe_load(yaml_str)) |
| 516 | + |
| 517 | + async def create_agent_from_dict_async(self, agent_def: dict[str, Any]) -> ChatAgent: |
| 518 | + """Async version: Create a ChatAgent from a dictionary definition. |
| 519 | +
|
| 520 | + Use this method when the provider requires async initialization, such as |
| 521 | + AzureAI.ProjectProvider which creates agents on the Azure AI Agent Service. |
| 522 | +
|
| 523 | + Args: |
| 524 | + agent_def: Dictionary representation of a PromptAgent. |
| 525 | +
|
| 526 | + Returns: |
| 527 | + The ``ChatAgent`` instance created from the dictionary. |
| 528 | +
|
| 529 | + Examples: |
| 530 | + .. code-block:: python |
| 531 | +
|
| 532 | + from agent_framework_declarative import AgentFactory |
| 533 | +
|
| 534 | + agent_def = { |
| 535 | + "kind": "Prompt", |
| 536 | + "name": "MyAgent", |
| 537 | + "instructions": "You are a helpful assistant.", |
| 538 | + "model": { |
| 539 | + "id": "gpt-4o", |
| 540 | + "provider": "AzureAI.ProjectProvider", |
| 541 | + }, |
| 542 | + } |
| 543 | +
|
| 544 | + factory = AgentFactory(client_kwargs={"credential": credential}) |
| 545 | + agent = await factory.create_agent_from_dict_async(agent_def) |
| 546 | + """ |
| 547 | + # Set safe_mode context before parsing YAML to control PowerFx environment variable access |
| 548 | + _safe_mode_context.set(self.safe_mode) |
| 549 | + prompt_agent = agent_schema_dispatch(agent_def) |
| 550 | + if not isinstance(prompt_agent, PromptAgent): |
| 551 | + raise DeclarativeLoaderError("Only definitions for a PromptAgent are supported for agent creation.") |
| 552 | + |
| 553 | + # Check if we're using a provider-based approach (like AzureAIProjectAgentProvider) |
| 554 | + mapping = self._retrieve_provider_configuration(prompt_agent.model) if prompt_agent.model else None |
| 555 | + if mapping and mapping["name"] == "AzureAIProjectAgentProvider": |
| 556 | + return await self._create_agent_with_provider(prompt_agent, mapping) |
| 557 | + |
| 558 | + # Fall back to standard ChatClient approach |
| 559 | + client = self._get_client(prompt_agent) |
| 560 | + chat_options = self._parse_chat_options(prompt_agent.model) |
| 561 | + if tools := self._parse_tools(prompt_agent.tools): |
| 562 | + chat_options["tools"] = tools |
| 563 | + if output_schema := prompt_agent.outputSchema: |
| 564 | + chat_options["response_format"] = _create_model_from_json_schema("agent", output_schema.to_json_schema()) |
| 565 | + return ChatAgent( |
| 566 | + chat_client=client, |
| 567 | + name=prompt_agent.name, |
| 568 | + description=prompt_agent.description, |
| 569 | + instructions=prompt_agent.instructions, |
| 570 | + **chat_options, |
| 571 | + ) |
| 572 | + |
| 573 | + async def _create_agent_with_provider(self, prompt_agent: PromptAgent, mapping: ProviderTypeMapping) -> ChatAgent: |
| 574 | + """Create a ChatAgent using AzureAIProjectAgentProvider. |
| 575 | +
|
| 576 | + This method handles the special case where we use a provider that creates |
| 577 | + agents on a remote service (like Azure AI Agent Service) and returns |
| 578 | + ChatAgent instances directly. |
| 579 | + """ |
| 580 | + # Import the provider class |
| 581 | + module_name = mapping["package"] |
| 582 | + class_name = mapping["name"] |
| 583 | + module = __import__(module_name, fromlist=[class_name]) |
| 584 | + provider_class = getattr(module, class_name) |
| 585 | + |
| 586 | + # Build provider kwargs from client_kwargs and connection info |
| 587 | + provider_kwargs: dict[str, Any] = {} |
| 588 | + provider_kwargs.update(self.client_kwargs) |
| 589 | + |
| 590 | + # Handle connection settings for the model |
| 591 | + if prompt_agent.model and prompt_agent.model.connection: |
| 592 | + match prompt_agent.model.connection: |
| 593 | + case RemoteConnection() | AnonymousConnection(): |
| 594 | + if prompt_agent.model.connection.endpoint: |
| 595 | + provider_kwargs["project_endpoint"] = prompt_agent.model.connection.endpoint |
| 596 | + case ApiKeyConnection(): |
| 597 | + if prompt_agent.model.connection.endpoint: |
| 598 | + provider_kwargs["project_endpoint"] = prompt_agent.model.connection.endpoint |
| 599 | + |
| 600 | + # Create the provider and use it to create the agent |
| 601 | + provider = provider_class(**provider_kwargs) |
| 602 | + |
| 603 | + # Parse tools |
| 604 | + tools = self._parse_tools(prompt_agent.tools) if prompt_agent.tools else None |
| 605 | + |
| 606 | + # Parse response format |
| 607 | + response_format = None |
| 608 | + if prompt_agent.outputSchema: |
| 609 | + response_format = _create_model_from_json_schema("agent", prompt_agent.outputSchema.to_json_schema()) |
| 610 | + |
| 611 | + # Create the agent using the provider |
| 612 | + # The provider's create_agent returns a ChatAgent directly |
| 613 | + return cast( |
| 614 | + ChatAgent, |
| 615 | + await provider.create_agent( |
| 616 | + name=prompt_agent.name, |
| 617 | + model=prompt_agent.model.id if prompt_agent.model else None, |
| 618 | + instructions=prompt_agent.instructions, |
| 619 | + description=prompt_agent.description, |
| 620 | + tools=tools, |
| 621 | + response_format=response_format, |
| 622 | + ), |
| 623 | + ) |
| 624 | + |
451 | 625 | def _get_client(self, prompt_agent: PromptAgent) -> ChatClientProtocol: |
452 | 626 | """Create the ChatClientProtocol instance based on the PromptAgent model.""" |
453 | 627 | if not prompt_agent.model: |
@@ -594,12 +768,46 @@ def _parse_tool(self, tool_resource: Tool) -> ToolProtocol: |
594 | 768 | ) |
595 | 769 | if not approval_mode: |
596 | 770 | approval_mode = None |
| 771 | + |
| 772 | + # Handle connection settings |
| 773 | + headers: dict[str, str] | None = None |
| 774 | + additional_properties: dict[str, Any] | None = None |
| 775 | + |
| 776 | + if tool_resource.connection is not None: |
| 777 | + match tool_resource.connection: |
| 778 | + case ApiKeyConnection(): |
| 779 | + if tool_resource.connection.apiKey: |
| 780 | + headers = {"Authorization": f"Bearer {tool_resource.connection.apiKey}"} |
| 781 | + case RemoteConnection(): |
| 782 | + additional_properties = { |
| 783 | + "connection": { |
| 784 | + "kind": tool_resource.connection.kind, |
| 785 | + "name": tool_resource.connection.name, |
| 786 | + "authenticationMode": tool_resource.connection.authenticationMode, |
| 787 | + "endpoint": tool_resource.connection.endpoint, |
| 788 | + } |
| 789 | + } |
| 790 | + case ReferenceConnection(): |
| 791 | + additional_properties = { |
| 792 | + "connection": { |
| 793 | + "kind": tool_resource.connection.kind, |
| 794 | + "name": tool_resource.connection.name, |
| 795 | + "authenticationMode": tool_resource.connection.authenticationMode, |
| 796 | + } |
| 797 | + } |
| 798 | + case AnonymousConnection(): |
| 799 | + pass |
| 800 | + case _: |
| 801 | + raise ValueError(f"Unsupported connection kind: {tool_resource.connection.kind}") |
| 802 | + |
597 | 803 | return HostedMCPTool( |
598 | 804 | name=tool_resource.name, # type: ignore |
599 | 805 | description=tool_resource.description, |
600 | 806 | url=tool_resource.url, # type: ignore |
601 | 807 | allowed_tools=tool_resource.allowedTools, |
602 | 808 | approval_mode=approval_mode, |
| 809 | + headers=headers, |
| 810 | + additional_properties=additional_properties, |
603 | 811 | ) |
604 | 812 | case _: |
605 | 813 | raise ValueError(f"Unsupported tool kind: {tool_resource.kind}") |
|
0 commit comments