diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index e4bcaf2723..603018fd48 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -16,6 +16,7 @@ from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext from agent_framework._settings import SecretString, load_settings +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import ServiceInitializationError from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential @@ -148,7 +149,7 @@ def __init__( endpoint: str | None = None, index_name: str | None = None, api_key: str | AzureKeyCredential | None = None, - credential: AsyncTokenCredential | None = None, + credential: AzureCredentialTypes | None = None, *, mode: Literal["semantic", "agentic"] = "semantic", top_k: int = 5, @@ -175,7 +176,8 @@ def __init__( endpoint: Azure AI Search endpoint URL. index_name: Name of the search index to query. api_key: API key for authentication. - credential: AsyncTokenCredential for managed identity authentication. + credential: Azure credential for managed identity authentication. + Accepts a TokenCredential, AsyncTokenCredential, or a callable token provider. mode: Search mode - "semantic" or "agentic". Default: "semantic". top_k: Maximum number of documents to retrieve. Default: 5. semantic_configuration_name: Name of semantic configuration in the index. @@ -223,7 +225,7 @@ def __init__( resolved_credential: AzureKeyCredential | AsyncTokenCredential if credential: - resolved_credential = credential + resolved_credential = credential # type: ignore[assignment] elif isinstance(api_key, AzureKeyCredential): resolved_credential = api_key elif settings.get("api_key"): diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py index 436fc602cb..ec2205759d 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py @@ -17,11 +17,11 @@ from agent_framework._mcp import MCPTool from agent_framework._settings import load_settings from agent_framework._tools import ToolTypes +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import ServiceInitializationError from azure.ai.agents.aio import AgentsClient from azure.ai.agents.models import Agent as AzureAgent from azure.ai.agents.models import ResponseFormatJsonSchema, ResponseFormatJsonSchemaType -from azure.core.credentials_async import AsyncTokenCredential from pydantic import BaseModel from ._chat_client import AzureAIAgentClient, AzureAIAgentOptions @@ -93,7 +93,7 @@ def __init__( agents_client: AgentsClient | None = None, *, project_endpoint: str | None = None, - credential: AsyncTokenCredential | None = None, + credential: AzureCredentialTypes | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, ) -> None: @@ -106,7 +106,8 @@ def __init__( Keyword Args: project_endpoint: The Azure AI Project endpoint URL. Can also be set via AZURE_AI_PROJECT_ENDPOINT environment variable. - credential: Azure async credential for authentication. + credential: Azure credential for authentication. Accepts a TokenCredential, + AsyncTokenCredential, or a callable token provider. Required if agents_client is not provided. env_file_path: Path to .env file for loading settings. env_file_encoding: Encoding of the .env file. @@ -137,7 +138,7 @@ def __init__( raise ServiceInitializationError("Azure credential is required when agents_client is not provided.") self._agents_client = AgentsClient( endpoint=resolved_endpoint, - credential=credential, + credential=credential, # type: ignore[arg-type] user_agent=AGENT_FRAMEWORK_USER_AGENT, ) self._should_close_client = True diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index c6028442fc..8131bdd09a 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -35,6 +35,7 @@ ) from agent_framework._settings import load_settings from agent_framework._tools import ToolTypes +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException from agent_framework.observability import ChatTelemetryLayer from azure.ai.agents.aio import AgentsClient @@ -84,7 +85,6 @@ ToolDefinition, ToolOutput, ) -from azure.core.credentials_async import AsyncTokenCredential from pydantic import BaseModel from ._shared import AzureAISettings, to_azure_ai_agent_tools @@ -415,7 +415,7 @@ def __init__( thread_id: str | None = None, project_endpoint: str | None = None, model_deployment_name: str | None = None, - credential: AsyncTokenCredential | None = None, + credential: AzureCredentialTypes | None = None, should_cleanup_agent: bool = True, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, @@ -439,7 +439,8 @@ def __init__( Ignored when a agents_client is passed. model_deployment_name: The model deployment name to use for agent creation. Can also be set via environment variable AZURE_AI_MODEL_DEPLOYMENT_NAME. - credential: Azure async credential to use for authentication. + credential: Azure credential for authentication. Accepts a TokenCredential, + AsyncTokenCredential, or a callable token provider. should_cleanup_agent: Whether to cleanup (delete) agents created by this client when the client is closed or context is exited. Defaults to True. Only affects agents created by this client instance; existing agents passed via agent_id are never deleted. @@ -513,7 +514,7 @@ class MyOptions(AzureAIAgentOptions, total=False): raise ServiceInitializationError("Azure credential is required when agents_client is not provided.") agents_client = AgentsClient( endpoint=resolved_endpoint, - credential=credential, + credential=credential, # type: ignore[arg-type] user_agent=AGENT_FRAMEWORK_USER_AGENT, ) should_close_client = True diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index afbbf6cea3..5c869e6396 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -23,6 +23,7 @@ ) from agent_framework._settings import load_settings from agent_framework._tools import ToolTypes +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import ServiceInitializationError from agent_framework.observability import ChatTelemetryLayer from agent_framework.openai import OpenAIResponsesOptions @@ -41,7 +42,6 @@ WebSearchPreviewTool, ) from azure.ai.projects.models import FileSearchTool as ProjectsFileSearchTool -from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import ResourceNotFoundError from ._shared import AzureAISettings, create_text_format_config @@ -109,7 +109,7 @@ def __init__( conversation_id: str | None = None, project_endpoint: str | None = None, model_deployment_name: str | None = None, - credential: AsyncTokenCredential | None = None, + credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, @@ -132,7 +132,8 @@ def __init__( Ignored when a project_client is passed. model_deployment_name: The model deployment name to use for agent creation. Can also be set via environment variable AZURE_AI_MODEL_DEPLOYMENT_NAME. - credential: Azure async credential to use for authentication. + credential: Azure credential for authentication. Accepts a TokenCredential, + AsyncTokenCredential, or a callable token provider. use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. env_file_path: Path to environment file for loading settings. @@ -197,7 +198,7 @@ class MyOptions(ChatOptions, total=False): raise ServiceInitializationError("Azure credential is required when project_client is not provided.") project_client = AIProjectClient( endpoint=resolved_endpoint, - credential=credential, + credential=credential, # type: ignore[arg-type] user_agent=AGENT_FRAMEWORK_USER_AGENT, ) should_close_client = True @@ -950,7 +951,7 @@ def __init__( conversation_id: str | None = None, project_endpoint: str | None = None, model_deployment_name: str | None = None, - credential: AsyncTokenCredential | None = None, + credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, @@ -972,7 +973,8 @@ def __init__( Ignored when a project_client is passed. model_deployment_name: The model deployment name to use for agent creation. Can also be set via environment variable AZURE_AI_MODEL_DEPLOYMENT_NAME. - credential: Azure async credential to use for authentication. + credential: Azure credential for authentication. Accepts a TokenCredential + or AsyncTokenCredential. use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. middleware: Optional sequence of chat middlewares to include. diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py index b3f7b35147..7c7fa55214 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py @@ -18,6 +18,7 @@ from agent_framework._mcp import MCPTool from agent_framework._settings import load_settings from agent_framework._tools import ToolTypes +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import ServiceInitializationError from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ( @@ -29,7 +30,6 @@ from azure.ai.projects.models import ( FunctionTool as AzureFunctionTool, ) -from azure.core.credentials_async import AsyncTokenCredential from ._client import AzureAIClient, AzureAIProjectAgentOptions from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools @@ -103,7 +103,7 @@ def __init__( *, project_endpoint: str | None = None, model: str | None = None, - credential: AsyncTokenCredential | None = None, + credential: AzureCredentialTypes | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, ) -> None: @@ -116,7 +116,8 @@ def __init__( Ignored when a project_client is passed. model: The default model deployment name to use for agent creation. Can also be set via environment variable AZURE_AI_MODEL_DEPLOYMENT_NAME. - credential: Azure async credential to use for authentication. + credential: Azure credential for authentication. Accepts a TokenCredential, + AsyncTokenCredential, or a callable token provider. Required when project_client is not provided. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. @@ -149,7 +150,7 @@ def __init__( project_client = AIProjectClient( endpoint=resolved_endpoint, - credential=credential, + credential=credential, # type: ignore[arg-type] user_agent=AGENT_FRAMEWORK_USER_AGENT, ) self._should_close_client = True diff --git a/python/packages/core/agent_framework/azure/__init__.py b/python/packages/core/agent_framework/azure/__init__.py index 7ac132bf3a..a485ee7aa7 100644 --- a/python/packages/core/agent_framework/azure/__init__.py +++ b/python/packages/core/agent_framework/azure/__init__.py @@ -30,6 +30,8 @@ "AzureAISearchSettings": ("agent_framework_azure_ai_search", "agent-framework-azure-ai-search"), "AzureAISettings": ("agent_framework_azure_ai", "agent-framework-azure-ai"), "AzureAIAgentsProvider": ("agent_framework_azure_ai", "agent-framework-azure-ai"), + "AzureCredentialTypes": ("agent_framework.azure._entra_id_authentication", "agent-framework-core"), + "AzureTokenProvider": ("agent_framework.azure._entra_id_authentication", "agent-framework-core"), "AzureOpenAIAssistantsClient": ("agent_framework.azure._assistants_client", "agent-framework-core"), "AzureOpenAIAssistantsOptions": ("agent_framework.azure._assistants_client", "agent-framework-core"), "AzureOpenAIChatClient": ("agent_framework.azure._chat_client", "agent-framework-core"), @@ -42,7 +44,6 @@ "DurableAIAgentClient": ("agent_framework_durabletask", "agent-framework-durabletask"), "DurableAIAgentOrchestrationContext": ("agent_framework_durabletask", "agent-framework-durabletask"), "DurableAIAgentWorker": ("agent_framework_durabletask", "agent-framework-durabletask"), - "get_entra_auth_token": ("agent_framework.azure._entra_id_authentication", "agent-framework-core"), } diff --git a/python/packages/core/agent_framework/azure/__init__.pyi b/python/packages/core/agent_framework/azure/__init__.pyi index a819019039..4d6e3b914c 100644 --- a/python/packages/core/agent_framework/azure/__init__.pyi +++ b/python/packages/core/agent_framework/azure/__init__.pyi @@ -21,7 +21,7 @@ from agent_framework_durabletask import ( from agent_framework.azure._assistants_client import AzureOpenAIAssistantsClient from agent_framework.azure._chat_client import AzureOpenAIChatClient -from agent_framework.azure._entra_id_authentication import get_entra_auth_token +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider from agent_framework.azure._responses_client import AzureOpenAIResponsesClient from agent_framework.azure._shared import AzureOpenAISettings @@ -37,13 +37,14 @@ __all__ = [ "AzureAISearchContextProvider", "AzureAISearchSettings", "AzureAISettings", + "AzureCredentialTypes", "AzureOpenAIAssistantsClient", "AzureOpenAIChatClient", "AzureOpenAIResponsesClient", "AzureOpenAISettings", + "AzureTokenProvider", "DurableAIAgent", "DurableAIAgentClient", "DurableAIAgentOrchestrationContext", "DurableAIAgentWorker", - "get_entra_auth_token", ] diff --git a/python/packages/core/agent_framework/azure/_assistants_client.py b/python/packages/core/agent_framework/azure/_assistants_client.py index ee3e05c4ed..8f77cf3c63 100644 --- a/python/packages/core/agent_framework/azure/_assistants_client.py +++ b/python/packages/core/agent_framework/azure/_assistants_client.py @@ -4,19 +4,16 @@ import sys from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, ClassVar, Generic +from typing import Any, ClassVar, Generic -from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI +from openai.lib.azure import AsyncAzureOpenAI from .._settings import load_settings from ..exceptions import ServiceInitializationError from ..openai import OpenAIAssistantsClient from ..openai._assistants_client import OpenAIAssistantsOptions -from ._entra_id_authentication import get_entra_auth_token -from ._shared import DEFAULT_AZURE_TOKEN_ENDPOINT, AzureOpenAISettings, _apply_azure_defaults - -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential +from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider, resolve_credential_to_token_provider +from ._shared import AzureOpenAISettings, _apply_azure_defaults if sys.version_info >= (3, 13): from typing import TypeVar # type: ignore # pragma: no cover @@ -61,10 +58,8 @@ def __init__( endpoint: str | None = None, base_url: str | None = None, api_version: str | None = None, - ad_token: str | None = None, - ad_token_provider: AsyncAzureADTokenProvider | None = None, token_endpoint: str | None = None, - credential: TokenCredential | None = None, + credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, async_client: AsyncAzureOpenAI | None = None, env_file_path: str | None = None, @@ -93,11 +88,12 @@ def __init__( api_version: The deployment API version. If provided will override the value in the env vars or .env file. Can also be set via environment variable AZURE_OPENAI_API_VERSION. - ad_token: The Azure Active Directory token. - ad_token_provider: The Azure Active Directory token provider. token_endpoint: The token endpoint to request an Azure token. Can also be set via environment variable AZURE_OPENAI_TOKEN_ENDPOINT. - credential: The Azure credential to use for authentication. + credential: Azure credential or token provider for authentication. Accepts a + ``TokenCredential``, ``AsyncTokenCredential``, or a callable that returns a + bearer token string (sync or async), for example from + ``azure.identity.get_bearer_token_provider()``. default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. @@ -156,20 +152,15 @@ class MyOptions(AzureOpenAIAssistantsOptions, total=False): "or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable." ) - # Handle authentication: try API key first, then AD token, then Entra ID - if ( - not async_client - and not azure_openai_settings["api_key"] - and not ad_token - and not ad_token_provider - and azure_openai_settings["token_endpoint"] - and credential - ): - token_ep = azure_openai_settings["token_endpoint"] or DEFAULT_AZURE_TOKEN_ENDPOINT - ad_token = get_entra_auth_token(credential, token_ep) - - if not async_client and not azure_openai_settings["api_key"] and not ad_token and not ad_token_provider: - raise ServiceInitializationError("The Azure OpenAI API key, ad_token, or ad_token_provider is required.") + # Resolve credential to token provider + ad_token_provider = None + if not async_client and not azure_openai_settings["api_key"] and credential: + ad_token_provider = resolve_credential_to_token_provider( + credential, azure_openai_settings["token_endpoint"] + ) + + if not async_client and not azure_openai_settings["api_key"] and not ad_token_provider: + raise ServiceInitializationError("Please provide either api_key, credential, or a client.") # Create Azure client if not provided if not async_client: @@ -180,8 +171,6 @@ class MyOptions(AzureOpenAIAssistantsOptions, total=False): if azure_openai_settings["api_key"]: client_params["api_key"] = azure_openai_settings["api_key"].get_secret_value() - elif ad_token: - client_params["azure_ad_token"] = ad_token elif ad_token_provider: client_params["azure_ad_token_provider"] = ad_token_provider diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index b195969563..159d1a7abd 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -8,8 +8,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic -from azure.core.credentials import TokenCredential -from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI +from openai.lib.azure import AsyncAzureOpenAI from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from pydantic import BaseModel @@ -29,6 +28,7 @@ from agent_framework.openai._chat_client import RawOpenAIChatClient from .._settings import load_settings +from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, @@ -169,10 +169,8 @@ def __init__( endpoint: str | None = None, base_url: str | None = None, api_version: str | None = None, - ad_token: str | None = None, - ad_token_provider: AsyncAzureADTokenProvider | None = None, token_endpoint: str | None = None, - credential: TokenCredential | None = None, + credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, async_client: AsyncAzureOpenAI | None = None, env_file_path: str | None = None, @@ -199,11 +197,12 @@ def __init__( api_version: The deployment API version. If provided will override the value in the env vars or .env file. Can also be set via environment variable AZURE_OPENAI_API_VERSION. - ad_token: The Azure Active Directory token. - ad_token_provider: The Azure Active Directory token provider. token_endpoint: The token endpoint to request an Azure token. Can also be set via environment variable AZURE_OPENAI_TOKEN_ENDPOINT. - credential: The Azure credential for authentication. + credential: Azure credential or token provider for authentication. Accepts a + ``TokenCredential``, ``AsyncTokenCredential``, or a callable that returns a + bearer token string (sync or async), for example from + ``azure.identity.get_bearer_token_provider()``. default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. @@ -274,8 +273,6 @@ class MyOptions(AzureOpenAIChatOptions, total=False): base_url=azure_openai_settings["base_url"], api_version=azure_openai_settings["api_version"], # type: ignore api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None, - ad_token=ad_token, - ad_token_provider=ad_token_provider, token_endpoint=azure_openai_settings["token_endpoint"], credential=credential, default_headers=default_headers, diff --git a/python/packages/core/agent_framework/azure/_entra_id_authentication.py b/python/packages/core/agent_framework/azure/_entra_id_authentication.py index 229db60d31..430a767df5 100644 --- a/python/packages/core/agent_framework/azure/_entra_id_authentication.py +++ b/python/packages/core/agent_framework/azure/_entra_id_authentication.py @@ -3,80 +3,66 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from collections.abc import Awaitable, Callable +from typing import Union -from azure.core.exceptions import ClientAuthenticationError +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential from ..exceptions import ServiceInvalidAuthError -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - from azure.core.credentials_async import AsyncTokenCredential - logger: logging.Logger = logging.getLogger(__name__) +AzureTokenProvider = Callable[[], Union[str, Awaitable[str]]] +"""A callable that returns a bearer token string, either synchronously or asynchronously.""" -def get_entra_auth_token( - credential: TokenCredential, - token_endpoint: str, - **kwargs: Any, -) -> str | None: - """Retrieve a Microsoft Entra Auth Token for a given token endpoint. - - The token endpoint may be specified as an environment variable, via the .env - file or as an argument. If the token endpoint is not provided, the default is None. - - Args: - credential: The Azure credential to use for authentication. - token_endpoint: The token endpoint to use to retrieve the authentication token. - - Keyword Args: - **kwargs: Additional keyword arguments to pass to the token retrieval method. - - Returns: - The Azure token or None if the token could not be retrieved. - """ - if not token_endpoint: - raise ServiceInvalidAuthError( - "A token endpoint must be provided either in settings, as an environment variable, or as an argument." - ) - - try: - auth_token = credential.get_token(token_endpoint, **kwargs) - except ClientAuthenticationError as ex: - logger.error(f"Failed to retrieve Azure token for the specified endpoint: `{token_endpoint}`, with error: {ex}") - return None +AzureCredentialTypes = Union[TokenCredential, AsyncTokenCredential] +"""Union of Azure credential types. - return auth_token.token if auth_token else None +Accepts: +- ``TokenCredential`` — synchronous Azure credential (e.g. ``DefaultAzureCredential()``) +- ``AsyncTokenCredential`` — asynchronous Azure credential (e.g. ``azure.identity.aio.DefaultAzureCredential()``) +""" -async def get_entra_auth_token_async( - credential: AsyncTokenCredential, token_endpoint: str, **kwargs: Any -) -> str | None: - """Retrieve a async Microsoft Entra Auth Token for a given token endpoint. +def resolve_credential_to_token_provider( + credential: AzureCredentialTypes | AzureTokenProvider, + token_endpoint: str | None, +) -> AzureTokenProvider: + """Convert an Azure credential or token provider into an ``ad_token_provider`` callable. - The token endpoint may be specified as an environment variable, via the .env - file or as an argument. If the token endpoint is not provided, the default is None. + If the credential is already a callable token provider, it is returned as-is + (``token_endpoint`` is not required in this case). + If it is a ``TokenCredential`` or ``AsyncTokenCredential``, it is wrapped using + ``azure.identity.get_bearer_token_provider`` (sync or async variant) which + handles token caching and automatic refresh. Args: - credential: The async Azure credential to use for authentication. - token_endpoint: The token endpoint to use to retrieve the authentication token. - - Keyword Args: - **kwargs: Additional keyword arguments to pass to the token retrieval method. + credential: An Azure credential or token provider callable. + token_endpoint: The token scope/endpoint + (e.g. ``"https://cognitiveservices.azure.com/.default"``). + Required when ``credential`` is a ``TokenCredential`` or ``AsyncTokenCredential``. Returns: - The Azure token or None if the token could not be retrieved. + A callable that returns a bearer token string (sync or async). + + Raises: + ServiceInvalidAuthError: If the token endpoint is empty when needed for credential wrapping. """ + # Already a token provider callable (not a credential object) — use directly + if callable(credential) and not isinstance(credential, (TokenCredential, AsyncTokenCredential)): + return credential + if not token_endpoint: raise ServiceInvalidAuthError( "A token endpoint must be provided either in settings, as an environment variable, or as an argument." ) - try: - auth_token = await credential.get_token(token_endpoint, **kwargs) - except ClientAuthenticationError as ex: - logger.error(f"Failed to retrieve Azure token for the specified endpoint: `{token_endpoint}`, with error: {ex}") - return None + if isinstance(credential, AsyncTokenCredential): + from azure.identity.aio import get_bearer_token_provider as get_async_bearer_token_provider + + return get_async_bearer_token_provider(credential, token_endpoint) + + from azure.identity import get_bearer_token_provider - return auth_token.token if auth_token else None + return get_bearer_token_provider(credential, token_endpoint) # type: ignore[arg-type] diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index cf6be12af5..576236a8a3 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -8,9 +8,7 @@ from urllib.parse import urljoin, urlparse from azure.ai.projects.aio import AIProjectClient -from azure.core.credentials import TokenCredential from openai import AsyncOpenAI -from openai.lib.azure import AsyncAzureADTokenProvider from .._middleware import ChatMiddlewareLayer from .._settings import load_settings @@ -19,6 +17,7 @@ from ..exceptions import ServiceInitializationError from ..observability import ChatTelemetryLayer from ..openai._responses_client import RawOpenAIResponsesClient +from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, @@ -69,10 +68,8 @@ def __init__( endpoint: str | None = None, base_url: str | None = None, api_version: str | None = None, - ad_token: str | None = None, - ad_token_provider: AsyncAzureADTokenProvider | None = None, token_endpoint: str | None = None, - credential: TokenCredential | None = None, + credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, async_client: AsyncOpenAI | None = None, project_client: Any | None = None, @@ -109,11 +106,12 @@ def __init__( api_version: The deployment API version. If provided will override the value in the env vars or .env file. Currently, the api_version must be "preview". Can also be set via environment variable AZURE_OPENAI_API_VERSION. - ad_token: The Azure Active Directory token. - ad_token_provider: The Azure Active Directory token provider. token_endpoint: The token endpoint to request an Azure token. Can also be set via environment variable AZURE_OPENAI_TOKEN_ENDPOINT. - credential: The Azure credential for authentication. + credential: Azure credential or token provider for authentication. Accepts a + ``TokenCredential``, ``AsyncTokenCredential``, or a callable that returns a + bearer token string (sync or async), for example from + ``azure.identity.get_bearer_token_provider()``. default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. @@ -230,8 +228,6 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False): base_url=azure_openai_settings["base_url"], api_version=azure_openai_settings["api_version"], # type: ignore api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None, - ad_token=ad_token, - ad_token_provider=ad_token_provider, token_endpoint=azure_openai_settings["token_endpoint"], credential=credential, default_headers=default_headers, @@ -246,7 +242,7 @@ def _create_client_from_project( *, project_client: AIProjectClient | None, project_endpoint: str | None, - credential: TokenCredential | None, + credential: AzureCredentialTypes | AzureTokenProvider | None, ) -> AsyncOpenAI: """Create an AsyncOpenAI client from an Azure AI Foundry project. diff --git a/python/packages/core/agent_framework/azure/_shared.py b/python/packages/core/agent_framework/azure/_shared.py index 02529ee8b3..3fe10d50a5 100644 --- a/python/packages/core/agent_framework/azure/_shared.py +++ b/python/packages/core/agent_framework/azure/_shared.py @@ -4,11 +4,10 @@ import logging import sys -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Mapping from copy import copy from typing import Any, ClassVar, Final -from azure.core.credentials import TokenCredential from openai import AsyncOpenAI from openai.lib.azure import AsyncAzureOpenAI @@ -16,7 +15,7 @@ from .._telemetry import APP_INFO, prepend_agent_framework_to_user_agent from ..exceptions import ServiceInitializationError from ..openai._shared import OpenAIBase -from ._entra_id_authentication import get_entra_auth_token +from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider, resolve_credential_to_token_provider logger: logging.Logger = logging.getLogger(__name__) @@ -135,10 +134,8 @@ def __init__( base_url: str | None = None, api_version: str = DEFAULT_AZURE_API_VERSION, api_key: str | None = None, - ad_token: str | None = None, - ad_token_provider: Callable[[], str | Awaitable[str]] | None = None, token_endpoint: str | None = None, - credential: TokenCredential | None = None, + credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, client: AsyncOpenAI | None = None, instruction_role: str | None = None, @@ -155,10 +152,10 @@ def __init__( base_url: The base URL for Azure services. api_version: Azure API version. Defaults to the defined DEFAULT_AZURE_API_VERSION. api_key: API key for Azure services. - ad_token: Azure AD token for authentication. - ad_token_provider: A callable or coroutine function providing Azure AD tokens. - token_endpoint: Azure AD token endpoint use to get the token. - credential: Azure credential for authentication. + token_endpoint: Azure AD token scope used to obtain a bearer token from a credential. + credential: Azure credential or token provider for authentication. Accepts a + ``TokenCredential``, ``AsyncTokenCredential``, or a callable that returns a + bearer token string (sync or async). default_headers: Default headers for HTTP requests. client: An existing client to use. instruction_role: The role to use for 'instruction' messages, for example, summarization @@ -172,16 +169,13 @@ def __init__( merged_headers.update(APP_INFO) merged_headers = prepend_agent_framework_to_user_agent(merged_headers) if not client: - # If the client is None, the api_key is none, the ad_token is none, and the ad_token_provider is none, - # then we will attempt to get the ad_token using the default endpoint specified in the Azure OpenAI - # settings. - if not api_key and not ad_token_provider and not ad_token and token_endpoint and credential: - ad_token = get_entra_auth_token(credential, token_endpoint) + # Resolve credential to a token provider if needed + ad_token_provider = None + if not api_key and credential: + ad_token_provider = resolve_credential_to_token_provider(credential, token_endpoint) - if not api_key and not ad_token and not ad_token_provider: - raise ServiceInitializationError( - "Please provide either api_key, ad_token or ad_token_provider or a client." - ) + if not api_key and not ad_token_provider: + raise ServiceInitializationError("Please provide either api_key, credential, or a client.") if not endpoint and not base_url: raise ServiceInitializationError("Please provide an endpoint or a base_url") @@ -191,8 +185,6 @@ def __init__( } if api_version: args["api_version"] = api_version - if ad_token: - args["azure_ad_token"] = ad_token if ad_token_provider: args["azure_ad_token_provider"] = ad_token_provider if api_key: diff --git a/python/packages/core/tests/azure/test_azure_assistants_client.py b/python/packages/core/tests/azure/test_azure_assistants_client.py index c87419216b..6c9cdf501e 100644 --- a/python/packages/core/tests/azure/test_azure_assistants_client.py +++ b/python/packages/core/tests/azure/test_azure_assistants_client.py @@ -551,12 +551,16 @@ async def test_azure_assistants_client_agent_level_tool_persistence(): def test_azure_assistants_client_entra_id_authentication() -> None: - """Test Entra ID authentication path with credential.""" + """Test credential authentication path with sync credential.""" mock_credential = MagicMock() + mock_provider = MagicMock(return_value="token-string") with ( patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings, - patch("agent_framework.azure._assistants_client.get_entra_auth_token") as mock_get_token, + patch( + "agent_framework.azure._assistants_client.resolve_credential_to_token_provider", + return_value=mock_provider, + ) as mock_resolve, patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client, patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None), ): @@ -564,28 +568,26 @@ def test_azure_assistants_client_entra_id_authentication() -> None: "chat_deployment_name": "test-deployment", "responses_deployment_name": None, "api_key": None, - "token_endpoint": "https://login.microsoftonline.com/test", + "token_endpoint": "https://cognitiveservices.azure.com/.default", "api_version": "2024-05-01-preview", "endpoint": "https://test-endpoint.openai.azure.com", "base_url": None, } - mock_get_token.return_value = "entra-token-12345" client = AzureOpenAIAssistantsClient( deployment_name="test-deployment", - api_key="placeholder-key", endpoint="https://test-endpoint.openai.azure.com", credential=mock_credential, - token_endpoint="https://login.microsoftonline.com/test", + token_endpoint="https://cognitiveservices.azure.com/.default", ) - # Verify Entra ID token was requested - mock_get_token.assert_called_once_with(mock_credential, "https://login.microsoftonline.com/test") + # Verify credential was resolved to a token provider + mock_resolve.assert_called_once_with(mock_credential, "https://cognitiveservices.azure.com/.default") - # Verify client was created with the token + # Verify client was created with the token provider mock_azure_client.assert_called_once() call_args = mock_azure_client.call_args[1] - assert call_args["azure_ad_token"] == "entra-token-12345" + assert call_args["azure_ad_token_provider"] is mock_provider assert client is not None assert isinstance(client, AzureOpenAIAssistantsClient) @@ -605,7 +607,7 @@ def test_azure_assistants_client_no_authentication_error() -> None: } # Test missing authentication raises error - with pytest.raises(ServiceInitializationError, match="API key, ad_token, or ad_token_provider is required"): + with pytest.raises(ServiceInitializationError, match="api_key, credential, or a client"): AzureOpenAIAssistantsClient( deployment_name="test-deployment", endpoint="https://test-endpoint.openai.azure.com", @@ -613,46 +615,16 @@ def test_azure_assistants_client_no_authentication_error() -> None: ) -def test_azure_assistants_client_ad_token_authentication() -> None: - """Test ad_token authentication client parameter path.""" - with ( - patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings, - patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client, - patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None), - ): - mock_load_settings.return_value = { - "chat_deployment_name": "test-deployment", - "responses_deployment_name": None, - "api_key": None, - "token_endpoint": None, - "api_version": "2024-05-01-preview", - "endpoint": "https://test-endpoint.openai.azure.com", - "base_url": None, - } - - client = AzureOpenAIAssistantsClient( - deployment_name="test-deployment", - endpoint="https://test-endpoint.openai.azure.com", - ad_token="test-ad-token-12345", - ) - - # ad_token path - mock_azure_client.assert_called_once() - call_args = mock_azure_client.call_args[1] - assert call_args["azure_ad_token"] == "test-ad-token-12345" - - assert client is not None - assert isinstance(client, AzureOpenAIAssistantsClient) - - -def test_azure_assistants_client_ad_token_provider_authentication() -> None: - """Test ad_token_provider authentication client parameter path.""" - from openai.lib.azure import AsyncAzureADTokenProvider - - mock_token_provider = MagicMock(spec=AsyncAzureADTokenProvider) +def test_azure_assistants_client_callable_credential() -> None: + """Test callable token provider as credential.""" + mock_provider = MagicMock(return_value="my-token") with ( patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings, + patch( + "agent_framework.azure._assistants_client.resolve_credential_to_token_provider", + return_value=mock_provider, + ), patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client, patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None), ): @@ -660,7 +632,7 @@ def test_azure_assistants_client_ad_token_provider_authentication() -> None: "chat_deployment_name": "test-deployment", "responses_deployment_name": None, "api_key": None, - "token_endpoint": None, + "token_endpoint": "https://cognitiveservices.azure.com/.default", "api_version": "2024-05-01-preview", "endpoint": "https://test-endpoint.openai.azure.com", "base_url": None, @@ -669,13 +641,14 @@ def test_azure_assistants_client_ad_token_provider_authentication() -> None: client = AzureOpenAIAssistantsClient( deployment_name="test-deployment", endpoint="https://test-endpoint.openai.azure.com", - ad_token_provider=mock_token_provider, + credential=mock_provider, + token_endpoint="https://cognitiveservices.azure.com/.default", ) - # ad_token_provider path + # Verify client was created with the token provider mock_azure_client.assert_called_once() call_args = mock_azure_client.call_args[1] - assert call_args["azure_ad_token_provider"] is mock_token_provider + assert call_args["azure_ad_token_provider"] is mock_provider assert client is not None assert isinstance(client, AzureOpenAIAssistantsClient) diff --git a/python/packages/core/tests/azure/test_entra_id_authentication.py b/python/packages/core/tests/azure/test_entra_id_authentication.py index b8e8c543b9..6e2686bdd1 100644 --- a/python/packages/core/tests/azure/test_entra_id_authentication.py +++ b/python/packages/core/tests/azure/test_entra_id_authentication.py @@ -1,156 +1,61 @@ # Copyright (c) Microsoft. All rights reserved. -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock, patch import pytest -from azure.core.exceptions import ClientAuthenticationError +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential from agent_framework.azure._entra_id_authentication import ( - get_entra_auth_token, - get_entra_auth_token_async, + resolve_credential_to_token_provider, ) from agent_framework.exceptions import ServiceInvalidAuthError +TOKEN_ENDPOINT = "https://cognitiveservices.azure.com/.default" -@pytest.fixture -def mock_credential() -> MagicMock: - """Mock synchronous TokenCredential.""" - mock_cred = MagicMock() - # Create a mock token object with a .token attribute - mock_token = MagicMock() - mock_token.token = "test-access-token-12345" - mock_cred.get_token.return_value = mock_token - return mock_cred +def test_resolve_sync_credential_returns_provider() -> None: + """Test that a sync TokenCredential is resolved via azure.identity.get_bearer_token_provider.""" + mock_credential = MagicMock(spec=TokenCredential) + mock_provider = MagicMock(return_value="token-string") -@pytest.fixture -def mock_async_credential() -> MagicMock: - """Mock asynchronous AsyncTokenCredential.""" - mock_cred = MagicMock() - # Create a mock token object with a .token attribute - mock_token = MagicMock() - mock_token.token = "test-async-access-token-12345" - mock_cred.get_token = AsyncMock(return_value=mock_token) - return mock_cred + with patch("azure.identity.get_bearer_token_provider", return_value=mock_provider) as mock_gbtp: + result = resolve_credential_to_token_provider(mock_credential, TOKEN_ENDPOINT) + mock_gbtp.assert_called_once_with(mock_credential, TOKEN_ENDPOINT) + assert result is mock_provider -def test_get_entra_auth_token_success(mock_credential: MagicMock) -> None: - """Test successful token retrieval with sync function.""" - token_endpoint = "https://test-endpoint.com/.default" +def test_resolve_async_credential_returns_provider() -> None: + """Test that an AsyncTokenCredential is resolved via azure.identity.aio.get_bearer_token_provider.""" + mock_credential = MagicMock(spec=AsyncTokenCredential) + mock_provider = MagicMock(return_value="token-string") - result = get_entra_auth_token(mock_credential, token_endpoint) + with patch("azure.identity.aio.get_bearer_token_provider", return_value=mock_provider) as mock_gbtp: + result = resolve_credential_to_token_provider(mock_credential, TOKEN_ENDPOINT) - # Assert - check the results - assert result == "test-access-token-12345" - mock_credential.get_token.assert_called_once_with(token_endpoint) + mock_gbtp.assert_called_once_with(mock_credential, TOKEN_ENDPOINT) + assert result is mock_provider -async def test_get_entra_auth_token_async_success(mock_async_credential: MagicMock) -> None: - """Test successful token retrieval with async function.""" +def test_resolve_callable_provider_passthrough() -> None: + """Test that a callable token provider is returned as-is, without needing token_endpoint.""" + my_provider = lambda: "my-token" # noqa: E731 - token_endpoint = "https://test-endpoint.com/.default" + # Works with token_endpoint + assert resolve_credential_to_token_provider(my_provider, TOKEN_ENDPOINT) is my_provider - result = await get_entra_auth_token_async(mock_async_credential, token_endpoint) + # Also works without token_endpoint + assert resolve_credential_to_token_provider(my_provider, None) is my_provider + assert resolve_credential_to_token_provider(my_provider, "") is my_provider - # Assert - check the results - assert result == "test-async-access-token-12345" - mock_async_credential.get_token.assert_called_once_with(token_endpoint) - -def test_get_entra_auth_token_missing_endpoint(mock_credential: MagicMock) -> None: +def test_resolve_missing_endpoint_raises() -> None: """Test that missing token endpoint raises ServiceInvalidAuthError.""" - # Test with empty string - with pytest.raises(ServiceInvalidAuthError, match="A token endpoint must be provided"): - get_entra_auth_token(mock_credential, "") - - # Test with None - with pytest.raises(ServiceInvalidAuthError, match="A token endpoint must be provided"): - get_entra_auth_token(mock_credential, None) # type: ignore - + mock_credential = MagicMock(spec=TokenCredential) -async def test_get_entra_auth_token_async_missing_endpoint(mock_async_credential: MagicMock) -> None: - """Test that missing token endpoint raises ServiceInvalidAuthError in async function.""" - # Test with empty string with pytest.raises(ServiceInvalidAuthError, match="A token endpoint must be provided"): - await get_entra_auth_token_async(mock_async_credential, "") + resolve_credential_to_token_provider(mock_credential, "") - # Test with None with pytest.raises(ServiceInvalidAuthError, match="A token endpoint must be provided"): - await get_entra_auth_token_async(mock_async_credential, None) # type: ignore - - -def test_get_entra_auth_token_auth_failure(mock_credential: MagicMock) -> None: - """Test that Azure authentication failure returns None.""" - - mock_credential.get_token.side_effect = ClientAuthenticationError("Auth failed") - token_endpoint = "https://test-endpoint.com/.default" - - result = get_entra_auth_token(mock_credential, token_endpoint) - - # Assert - should return None on auth failure - assert result is None - mock_credential.get_token.assert_called_once_with(token_endpoint) - - -async def test_get_entra_auth_token_async_auth_failure(mock_async_credential: MagicMock) -> None: - """Test that Azure authentication failure returns None in async function.""" - - mock_async_credential.get_token.side_effect = ClientAuthenticationError("Auth failed") - token_endpoint = "https://test-endpoint.com/.default" - - result = await get_entra_auth_token_async(mock_async_credential, token_endpoint) - - # Assert - should return None on auth failure - assert result is None - mock_async_credential.get_token.assert_called_once_with(token_endpoint) - - -def test_get_entra_auth_token_none_token_response(mock_credential: MagicMock) -> None: - """Test that None token response returns None.""" - mock_credential.get_token.return_value = None - token_endpoint = "https://test-endpoint.com/.default" - - result = get_entra_auth_token(mock_credential, token_endpoint) - - # Assert - assert result is None - mock_credential.get_token.assert_called_once_with(token_endpoint) - - -async def test_get_entra_auth_token_async_none_token_response(mock_async_credential: MagicMock) -> None: - """Test that None token response returns None in async function.""" - mock_async_credential.get_token.return_value = None - token_endpoint = "https://test-endpoint.com/.default" - - result = await get_entra_auth_token_async(mock_async_credential, token_endpoint) - - # Assert - assert result is None - mock_async_credential.get_token.assert_called_once_with(token_endpoint) - - -def test_get_entra_auth_token_with_kwargs(mock_credential: MagicMock) -> None: - """Test that kwargs are passed through to get_token.""" - - token_endpoint = "https://test-endpoint.com/.default" - extra_kwargs = {"scopes": ["read", "write"], "tenant_id": "test-tenant"} - - result = get_entra_auth_token(mock_credential, token_endpoint, **extra_kwargs) - - # Assert - assert result == "test-access-token-12345" - mock_credential.get_token.assert_called_once_with(token_endpoint, **extra_kwargs) - - -async def test_get_entra_auth_token_async_with_kwargs(mock_async_credential: MagicMock) -> None: - """Test that kwargs are passed through to async get_token.""" - - token_endpoint = "https://test-endpoint.com/.default" - extra_kwargs = {"scopes": ["read", "write"], "tenant_id": "test-tenant"} - - result = await get_entra_auth_token_async(mock_async_credential, token_endpoint, **extra_kwargs) - - # Assert - assert result == "test-async-access-token-12345" - mock_async_credential.get_token.assert_called_once_with(token_endpoint, **extra_kwargs) + resolve_credential_to_token_provider(mock_credential, None) # type: ignore[arg-type] diff --git a/python/packages/purview/agent_framework_purview/_client.py b/python/packages/purview/agent_framework_purview/_client.py index 2de5340cb7..a1f404849b 100644 --- a/python/packages/purview/agent_framework_purview/_client.py +++ b/python/packages/purview/agent_framework_purview/_client.py @@ -11,6 +11,7 @@ import httpx from agent_framework import AGENT_FRAMEWORK_USER_AGENT +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider from agent_framework.observability import get_tracer from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential @@ -39,18 +40,19 @@ class PurviewClient: """Async client for calling Graph Purview endpoints. - Supports both synchronous TokenCredential and asynchronous AsyncTokenCredential implementations. - A sync credential will be invoked in a thread to avoid blocking the event loop. + Supports synchronous TokenCredential, asynchronous AsyncTokenCredential, + or callable token providers. A sync credential will be invoked in a thread + to avoid blocking the event loop. """ def __init__( self, - credential: TokenCredential | AsyncTokenCredential, + credential: AzureCredentialTypes | AzureTokenProvider, settings: PurviewSettings, *, timeout: float | None = 10.0, ): - self._credential: TokenCredential | AsyncTokenCredential = credential + self._credential: AzureCredentialTypes | AzureTokenProvider = credential self._settings = settings self._graph_uri = (settings.get("graph_base_uri") or "https://graph.microsoft.com/v1.0/").rstrip("/") self._timeout = timeout @@ -60,10 +62,14 @@ async def close(self) -> None: await self._client.aclose() async def _get_token(self, *, tenant_id: str | None = None) -> str: - """Acquire an access token using either async or sync credential.""" - scopes = get_purview_scopes(self._settings) + """Acquire an access token using either async or sync credential, or callable token provider.""" cred = self._credential - token = cred.get_token(*scopes, tenant_id=tenant_id) + # Callable token provider — returns a token string directly + if callable(cred) and not isinstance(cred, (TokenCredential, AsyncTokenCredential)): + result = cred() + return await result if inspect.isawaitable(result) else result # type: ignore[return-value] + scopes = get_purview_scopes(self._settings) + token = cred.get_token(*scopes, tenant_id=tenant_id) # type: ignore[union-attr] token = await token if inspect.isawaitable(token) else token return token.token diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 3a4a628294..2296122135 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -4,8 +4,7 @@ from collections.abc import Awaitable, Callable from agent_framework import AgentContext, AgentMiddleware, ChatContext, ChatMiddleware, MiddlewareTermination -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider from ._cache import CacheProvider from ._client import PurviewClient @@ -20,7 +19,7 @@ class PurviewPolicyMiddleware(AgentMiddleware): """Agent middleware that enforces Purview policies on prompt and response. - Accepts either a synchronous TokenCredential or an AsyncTokenCredential. + Accepts a TokenCredential, AsyncTokenCredential, or callable token provider. Usage: @@ -28,14 +27,14 @@ class PurviewPolicyMiddleware(AgentMiddleware): from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings from agent_framework import Agent - credential = ... # TokenCredential or AsyncTokenCredential + credential = ... # TokenCredential, AsyncTokenCredential, or callable settings = PurviewSettings(app_name="My App") agent = Agent(client=client, instructions="...", middleware=[PurviewPolicyMiddleware(credential, settings)]) """ def __init__( self, - credential: TokenCredential | AsyncTokenCredential, + credential: AzureCredentialTypes | AzureTokenProvider, settings: PurviewSettings, cache_provider: CacheProvider | None = None, ) -> None: @@ -153,14 +152,14 @@ class PurviewChatPolicyMiddleware(ChatMiddleware): from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings from agent_framework import ChatClient - credential = ... # TokenCredential or AsyncTokenCredential + credential = ... # TokenCredential, AsyncTokenCredential, or callable settings = PurviewSettings(app_name="My App") client = ChatClient(..., middleware=[PurviewChatPolicyMiddleware(credential, settings)]) """ def __init__( self, - credential: TokenCredential | AsyncTokenCredential, + credential: AzureCredentialTypes | AzureTokenProvider, settings: PurviewSettings, cache_provider: CacheProvider | None = None, ) -> None: