From dc7b01186a062798407e60506adaf244052cbe0d Mon Sep 17 00:00:00 2001 From: Farzad Sunavala Date: Tue, 28 Oct 2025 23:05:30 -0500 Subject: [PATCH 1/6] Python: Fix pyright errors and move search provider to core (#1546) --- python/.env.example | 8 + .../core/agent_framework/azure/__init__.py | 1 + .../agent_framework/azure/_search_provider.py | 538 ++++++++++++++++++ python/packages/core/pyproject.toml | 1 + .../core/tests/azure/test_search_provider.py | 338 +++++++++++ .../azure_ai/azure_ai_with_search_context.py | 137 +++++ python/uv.lock | 26 + 7 files changed, 1049 insertions(+) create mode 100644 python/packages/core/agent_framework/azure/_search_provider.py create mode 100644 python/packages/core/tests/azure/test_search_provider.py create mode 100644 python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py diff --git a/python/.env.example b/python/.env.example index 82458a3fda..f864f18f72 100644 --- a/python/.env.example +++ b/python/.env.example @@ -3,6 +3,14 @@ AZURE_AI_PROJECT_ENDPOINT="" AZURE_AI_MODEL_DEPLOYMENT_NAME="" # Bing connection for web search (optional, used by samples with web search) BING_CONNECTION_ID="" +# Azure AI Search (optional, used by AzureAISearchContextProvider samples) +AZURE_SEARCH_ENDPOINT="" +AZURE_SEARCH_API_KEY="" +AZURE_SEARCH_INDEX_NAME="" +AZURE_SEARCH_SEMANTIC_CONFIG="" +AZURE_SEARCH_KNOWLEDGE_BASE_NAME="" +# Note: For agentic mode Knowledge Bases, also set AZURE_OPENAI_ENDPOINT below +# (different from AZURE_AI_PROJECT_ENDPOINT - Knowledge Base needs OpenAI endpoint for model calls) # OpenAI OPENAI_API_KEY="" OPENAI_CHAT_MODEL_ID="" diff --git a/python/packages/core/agent_framework/azure/__init__.py b/python/packages/core/agent_framework/azure/__init__.py index c1b64f2117..80c6ece5a4 100644 --- a/python/packages/core/agent_framework/azure/__init__.py +++ b/python/packages/core/agent_framework/azure/__init__.py @@ -10,6 +10,7 @@ "AgentResponseCallbackProtocol": ("agent_framework_azurefunctions", "azurefunctions"), "AzureAIAgentClient": ("agent_framework_azure_ai", "azure-ai"), "AzureAIClient": ("agent_framework_azure_ai", "azure-ai"), + "AzureAISearchContextProvider": ("agent_framework.azure._search_provider", "core"), "AzureOpenAIAssistantsClient": ("agent_framework.azure._assistants_client", "core"), "AzureOpenAIChatClient": ("agent_framework.azure._chat_client", "core"), "AzureAISettings": ("agent_framework_azure_ai", "azure-ai"), diff --git a/python/packages/core/agent_framework/azure/_search_provider.py b/python/packages/core/agent_framework/azure/_search_provider.py new file mode 100644 index 0000000000..2a95592d26 --- /dev/null +++ b/python/packages/core/agent_framework/azure/_search_provider.py @@ -0,0 +1,538 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure AI Search Context Provider for Agent Framework. + +This module provides context providers for Azure AI Search integration with two modes: +- Semantic: Fast hybrid search (vector + keyword) with semantic ranker +- Agentic: Slower multi-hop reasoning using Knowledge Bases for complex queries + +Use semantic mode for most cases. Use agentic mode only when you need multi-hop +reasoning across documents with Knowledge Bases. +""" + +import sys +from collections.abc import MutableSequence +from typing import TYPE_CHECKING, Any, Literal, cast + +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.aio import SearchClient +from azure.search.documents.indexes.aio import SearchIndexClient +from azure.search.documents.indexes.models import ( + AzureOpenAIVectorizerParameters, + KnowledgeAgent, + KnowledgeAgentAzureOpenAIModel, + KnowledgeAgentOutputConfiguration, + KnowledgeAgentOutputConfigurationModality, + KnowledgeAgentRequestLimits, + KnowledgeSourceReference, + SearchIndexKnowledgeSource, + SearchIndexKnowledgeSourceParameters, +) +from azure.search.documents.models import ( + QueryCaptionType, + QueryType, + VectorizedQuery, +) + +from agent_framework import ChatMessage, Context, ContextProvider + +# Type checking imports for optional agentic mode dependencies +if TYPE_CHECKING: + from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient + from azure.search.documents.agent.models import ( + KnowledgeAgentMessage, + KnowledgeAgentMessageTextContent, + KnowledgeAgentRetrievalRequest, + ) + +# Runtime imports for agentic mode (optional dependency) +try: + from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient + from azure.search.documents.agent.models import ( + KnowledgeAgentMessage, + KnowledgeAgentMessageTextContent, + KnowledgeAgentRetrievalRequest, + ) + + _agentic_retrieval_available = True +except ImportError: + _agentic_retrieval_available = False + +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + + +class AzureAISearchContextProvider(ContextProvider): + """Azure AI Search Context Provider with hybrid search and semantic ranking. + + This provider retrieves relevant documents from Azure AI Search to provide context + to the AI agent. It supports two modes: + + - **semantic** (default): Fast hybrid search combining vector and keyword search + with semantic reranking. Suitable for most RAG use cases. + - **agentic**: Slower multi-hop reasoning across documents using Knowledge Bases. + Use only for complex queries requiring cross-document reasoning. + + Examples: + Semantic hybrid search (recommended for most cases): + + .. code-block:: python + + from agent_framework import ChatAgent + from agent_framework_azure_ai import AzureAIAgentClient, AzureAISearchContextProvider + from azure.identity.aio import DefaultAzureCredential + + # Create context provider with semantic hybrid search + search_provider = AzureAISearchContextProvider( + endpoint="https://mysearch.search.windows.net", + index_name="my-index", + credential=DefaultAzureCredential(), + mode="semantic", # Fast hybrid + semantic ranker (default) + ) + + # Use with agent + async with ( + AzureAIAgentClient() as client, + ChatAgent( + chat_client=client, + context_providers=[search_provider], + ) as agent, + ): + response = await agent.run("What is in the documents?") + + Agentic retrieval for complex queries: + + .. code-block:: python + + # Use agentic mode for multi-hop reasoning (slower) + search_provider = AzureAISearchContextProvider( + endpoint="https://mysearch.search.windows.net", + index_name="my-index", + credential=DefaultAzureCredential(), + mode="agentic", # Multi-hop reasoning + azure_ai_project_endpoint="https://myproject.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="my-knowledge-base", # Required for agentic mode + ) + """ + + def __init__( + self, + endpoint: str, + index_name: str, + credential: AzureKeyCredential | AsyncTokenCredential, + mode: Literal["semantic", "agentic"] = "semantic", + top_k: int = 5, + semantic_configuration_name: str | None = None, + vector_field_name: str | None = None, + embedding_function: Any | None = None, + context_prompt: str | None = None, + # Agentic mode parameters (Knowledge Base) + azure_ai_project_endpoint: str | None = None, + model_deployment_name: str | None = None, + model_name: str | None = None, + knowledge_base_name: str | None = None, + retrieval_instructions: str | None = None, + azure_openai_api_key: str | None = None, + azure_openai_resource_url: str | None = None, + # Deprecated parameters (for backwards compatibility) + azure_openai_endpoint: str | None = None, + azure_openai_deployment_name: str | None = None, + azure_openai_api_version: str | None = None, + ) -> None: + """Initialize Azure AI Search Context Provider. + + Args: + endpoint: Azure AI Search endpoint URL. + index_name: Name of the search index to query. + credential: Azure credential (API key or DefaultAzureCredential). + mode: Search mode - "semantic" for hybrid search with semantic ranking (fast) + or "agentic" for multi-hop reasoning (slower). Default: "semantic". + top_k: Maximum number of documents to retrieve. Default: 5. + semantic_configuration_name: Name of semantic configuration in the index. + Required for semantic ranking. If None, uses index default. + vector_field_name: Name of the vector field in the index for hybrid search. + Required if using vector search. Default: None (keyword search only). + embedding_function: Async function to generate embeddings for vector search. + Signature: async def embed(text: str) -> list[float] + Required if vector_field_name is specified. + context_prompt: Custom prompt to prepend to retrieved context. + Default: Uses DEFAULT_CONTEXT_PROMPT. + azure_ai_project_endpoint: Azure AI Foundry project endpoint URL. + Required for agentic mode. Example: "https://myproject.services.ai.azure.com" + model_deployment_name: Model deployment name in the Azure AI project. + Required for agentic mode. + model_name: The underlying model name (e.g., "gpt-4o", "gpt-4o-mini"). + If not provided, defaults to model_deployment_name. Used for Knowledge Base configuration. + knowledge_base_name: Name for the Knowledge Base. Required for agentic mode. + retrieval_instructions: Custom instructions for the Knowledge Base's + retrieval planning. Only used in agentic mode. + azure_openai_api_key: Azure OpenAI API key for Knowledge Base to call the model. + Only needed when using API key authentication instead of managed identity. + azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base model calls. + Required for agentic mode. Example: "https://myresource.openai.azure.com" + This is different from azure_ai_project_endpoint (which is Foundry-specific). + azure_openai_endpoint: (Deprecated) Use azure_ai_project_endpoint instead. + azure_openai_deployment_name: (Deprecated) Use model_deployment_name instead. + azure_openai_api_version: (Deprecated) No longer used. + """ + self.endpoint = endpoint + self.index_name = index_name + self.credential = credential + self.mode = mode + self.top_k = top_k + self.semantic_configuration_name = semantic_configuration_name + self.vector_field_name = vector_field_name + self.embedding_function = embedding_function + self.context_prompt = context_prompt or self.DEFAULT_CONTEXT_PROMPT + + # Agentic mode parameters (Knowledge Base) + # azure_openai_resource_url: The actual Azure OpenAI endpoint for model calls + # azure_openai_endpoint (deprecated): Fall back to this if resource_url not provided + self.azure_openai_resource_url = azure_openai_resource_url or azure_openai_endpoint + + self.azure_openai_deployment_name = model_deployment_name or azure_openai_deployment_name + # If model_name not provided, default to deployment name for backwards compatibility + self.model_name = model_name or self.azure_openai_deployment_name + self.knowledge_base_name = knowledge_base_name + self.retrieval_instructions = retrieval_instructions + self.azure_openai_api_key = azure_openai_api_key + self.azure_ai_project_endpoint = azure_ai_project_endpoint + + # Validation + if vector_field_name and not embedding_function: + raise ValueError("embedding_function is required when vector_field_name is specified") + + if mode == "agentic": + if not _agentic_retrieval_available: + raise ImportError( + "Agentic retrieval requires azure-search-documents >= 11.7.0b1 with Knowledge Base support. " + "Please upgrade: pip install azure-search-documents>=11.7.0b1" + ) + if not self.azure_openai_resource_url: + raise ValueError( + "azure_openai_resource_url (or deprecated azure_openai_endpoint) is required for agentic mode. " + "This should be your Azure OpenAI endpoint (e.g., 'https://myresource.openai.azure.com')" + ) + if not self.azure_openai_deployment_name: + raise ValueError( + "model_deployment_name (or deprecated azure_openai_deployment_name) is required for agentic mode" + ) + if not knowledge_base_name: + raise ValueError("knowledge_base_name is required for agentic mode") + + # Create search client for semantic mode + self._search_client = SearchClient( + endpoint=endpoint, + index_name=index_name, + credential=credential, + ) + + # Create index client for agentic mode (Knowledge Base) + # Note: Retrieval client is created fresh for each query to avoid transport issues + self._index_client: SearchIndexClient | None = None + if mode == "agentic": + self._index_client = SearchIndexClient( + endpoint=endpoint, + credential=credential, + ) + + self._knowledge_base_initialized = False + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit - cleanup handled by client destructors. + + Args: + exc_type: Exception type if an error occurred. + exc_val: Exception value if an error occurred. + exc_tb: Exception traceback if an error occurred. + + Note: + We don't explicitly close the Azure SDK clients here because doing so + can cause "transport already closed" errors on subsequent uses within + the same context. The clients will clean up their resources when they + are garbage collected. + """ + pass + + @override + async def invoking( + self, + messages: ChatMessage | MutableSequence[ChatMessage], + **kwargs: Any, + ) -> Context: + """Retrieve relevant context from Azure AI Search before model invocation. + + Args: + messages: User messages to use for context retrieval. + **kwargs: Additional arguments (unused). + + Returns: + Context object with retrieved documents as messages. + """ + # Extract query from messages + messages_list = [messages] if isinstance(messages, ChatMessage) else list(messages) + query = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip()) + + if not query: + return Context() + + # Perform search based on mode + if self.mode == "semantic": + search_results = await self._semantic_search(query) + else: # agentic + search_results = await self._agentic_search(query) + + # Format results as context + if not search_results: + return Context() + + context_text = f"{self.context_prompt}\n\n{search_results}" + + return Context(messages=[ChatMessage(role="system", text=context_text)]) + + async def _semantic_search(self, query: str) -> str: + """Perform semantic hybrid search with semantic ranking. + + This is the recommended mode for most use cases. It combines: + - Vector search (if embedding_function provided) + - Keyword search (BM25) + - Semantic reranking (if semantic_configuration_name provided) + + Args: + query: Search query text. + + Returns: + Formatted search results as string. + """ + vector_queries = [] + + # Generate vector query if embedding function provided + if self.embedding_function and self.vector_field_name: + query_vector = await self.embedding_function(query) + vector_queries = [ + VectorizedQuery( + vector=query_vector, + k_nearest_neighbors=self.top_k, + fields=self.vector_field_name, + ) + ] + + # Build search parameters + search_params: dict[str, Any] = { + "search_text": query, + "top": self.top_k, + } + + if vector_queries: + search_params["vector_queries"] = vector_queries + + # Add semantic ranking if configured + if self.semantic_configuration_name: + search_params["query_type"] = QueryType.SEMANTIC + search_params["semantic_configuration_name"] = self.semantic_configuration_name + search_params["query_caption"] = QueryCaptionType.EXTRACTIVE + + # Execute search + results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType] + + # Format results + formatted_results: list[str] = [] + async for doc in results: # type: ignore[reportUnknownVariableType] + # Extract semantic captions if available + caption: str | None = None + if hasattr(doc, "@search.captions"): # type: ignore[reportUnknownArgumentType] + captions: Any = doc.get("@search.captions", []) # type: ignore[reportUnknownVariableType] + if captions: + caption = captions[0].text if hasattr(captions[0], "text") else str(captions[0]) # type: ignore[reportUnknownArgumentType, reportUnknownMemberType] + + # Build document text + doc_text: str = caption if caption else self._extract_document_text(doc) # type: ignore[reportUnknownArgumentType] + if doc_text: + formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType] + + return "\n\n".join(formatted_results) + + async def _ensure_knowledge_base(self) -> None: + """Ensure Knowledge Base and knowledge source are created. + + This method is idempotent - it will only create resources if they don't exist. + + Note: Azure SDK uses KnowledgeAgent classes internally, but the feature + is marketed as "Knowledge Bases" in Azure AI Search. + """ + if self._knowledge_base_initialized or not self._index_client: + return + + # Type narrowing: these are validated as non-None in __init__ for agentic mode + # Using cast() for type checker - actual validation happens in __init__ + knowledge_base_name = cast(str, self.knowledge_base_name) + azure_openai_resource_url = cast(str, self.azure_openai_resource_url) + azure_openai_deployment_name = cast(str, self.azure_openai_deployment_name) + + # Step 1: Create or get knowledge source + knowledge_source_name = f"{self.index_name}-source" + + try: + # Try to get existing knowledge source + await self._index_client.get_knowledge_source(knowledge_source_name) + except ResourceNotFoundError: + # Create new knowledge source if it doesn't exist + knowledge_source = SearchIndexKnowledgeSource( + name=knowledge_source_name, + description=f"Knowledge source for {self.index_name} search index", + search_index_parameters=SearchIndexKnowledgeSourceParameters( + search_index_name=self.index_name, + ), + ) + await self._index_client.create_knowledge_source(knowledge_source) + + # Step 2: Create or get Knowledge Base (using KnowledgeAgent SDK class) + try: + # Try to get existing Knowledge Base + await self._index_client.get_agent(knowledge_base_name) + except ResourceNotFoundError: + # Create new Knowledge Base if it doesn't exist + aoai_params = AzureOpenAIVectorizerParameters( + resource_url=azure_openai_resource_url, + deployment_name=azure_openai_deployment_name, + model_name=self.model_name, # Underlying model name (e.g., "gpt-4o") + api_key=self.azure_openai_api_key, # Optional: for API key auth instead of managed identity + ) + + # Note: SDK uses KnowledgeAgent class name, but this represents a Knowledge Base + knowledge_base = KnowledgeAgent( + name=knowledge_base_name, + description=f"Knowledge Base for multi-hop retrieval across {self.index_name}", + models=[KnowledgeAgentAzureOpenAIModel(azure_open_ai_parameters=aoai_params)], + knowledge_sources=[ + KnowledgeSourceReference( + name=knowledge_source_name, + include_references=True, + include_reference_source_data=True, + ) + ], + output_configuration=KnowledgeAgentOutputConfiguration( + modality=KnowledgeAgentOutputConfigurationModality.ANSWER_SYNTHESIS, + attempt_fast_path=True, + ), + request_limits=KnowledgeAgentRequestLimits( + max_output_size=10000, + max_runtime_in_seconds=60, + ), + retrieval_instructions=self.retrieval_instructions, + ) + await self._index_client.create_agent(knowledge_base) + + self._knowledge_base_initialized = True + + async def _agentic_search(self, query: str) -> str: + """Perform agentic retrieval with multi-hop reasoning using Knowledge Bases. + + NOTE: This mode is significantly slower than semantic search and should + only be used for complex queries requiring cross-document reasoning. + + This method uses Azure AI Search Knowledge Bases which: + 1. Analyze the query and plan sub-queries + 2. Retrieve relevant documents across multiple sources + 3. Perform multi-hop reasoning with an LLM + 4. Synthesize a comprehensive answer with references + + Args: + query: Search query text. + + Returns: + Synthesized answer from the Knowledge Base. + """ + # Ensure Knowledge Base is initialized + await self._ensure_knowledge_base() + + # Type narrowing: knowledge_base_name is validated in __init__ for agentic mode + # Using cast() for type checker - actual validation happens in __init__ + knowledge_base_name = cast(str, self.knowledge_base_name) + + # Create retrieval request with query as a conversation message + # Note: SDK uses KnowledgeAgent class names, but represents Knowledge Base operations + retrieval_request = KnowledgeAgentRetrievalRequest( + messages=[ + KnowledgeAgentMessage( + role="user", + content=[KnowledgeAgentMessageTextContent(text=query)], + ) + ] + ) + + # Create a fresh retrieval client for each query to avoid transport closure issues + if not _agentic_retrieval_available: + raise ImportError("KnowledgeAgentRetrievalClient not available") + + retrieval_client = KnowledgeAgentRetrievalClient( + endpoint=self.endpoint, + agent_name=knowledge_base_name, + credential=self.credential, + ) + + try: + # Perform retrieval via Knowledge Base + retrieval_result = await retrieval_client.retrieve(retrieval_request=retrieval_request) + finally: + # Ensure client is closed after use + await retrieval_client.close() + + # Extract synthesized answer from response + if retrieval_result.response and len(retrieval_result.response) > 0: + # Get the assistant's response (last message) + assistant_message = retrieval_result.response[-1] + if assistant_message.content: + # Combine all text content + answer_parts: list[str] = [] + for content_item in assistant_message.content: + # Check if this is a text content item + if isinstance(content_item, KnowledgeAgentMessageTextContent) and content_item.text: + answer_parts.append(content_item.text) + + if answer_parts: + return "\n".join(answer_parts) + + # Fallback if no answer generated + return "No results found from Knowledge Base." + + def _extract_document_text(self, doc: dict[str, Any]) -> str: + """Extract readable text from a search document. + + Args: + doc: Search result document. + + Returns: + Formatted document text. + """ + # Try common text field names + for field in ["content", "text", "description", "body", "chunk"]: + if doc.get(field): + return str(doc[field])[:500] # Limit to 500 chars + + # Fallback: concatenate all string fields + text_parts: list[str] = [] + for key, value in doc.items(): + if isinstance(value, str) and not key.startswith("@") and key != "id": + text_parts.append(f"{key}: {value}") + + return " | ".join(text_parts)[:500] if text_parts else "" diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index 8f19719c59..45a0e44fa3 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "azure-identity>=1,<2", "mcp[ws]>=1.13", "packaging>=24.1", + "azure-search-documents>=11.7.0b1", ] [project.optional-dependencies] diff --git a/python/packages/core/tests/azure/test_search_provider.py b/python/packages/core/tests/azure/test_search_provider.py new file mode 100644 index 0000000000..149ad12524 --- /dev/null +++ b/python/packages/core/tests/azure/test_search_provider.py @@ -0,0 +1,338 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportPrivateUsage=false + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ResourceNotFoundError + +from agent_framework import ChatMessage, Context, Role +from agent_framework.azure import AzureAISearchContextProvider + + +@pytest.fixture +def mock_search_client() -> AsyncMock: + """Create a mock SearchClient.""" + mock_client = AsyncMock() + mock_client.search = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + return mock_client + + +@pytest.fixture +def mock_index_client() -> AsyncMock: + """Create a mock SearchIndexClient.""" + mock_client = AsyncMock() + mock_client.get_knowledge_source = AsyncMock() + mock_client.create_knowledge_source = AsyncMock() + mock_client.get_agent = AsyncMock() + mock_client.create_agent = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + return mock_client + + +@pytest.fixture +def sample_messages() -> list[ChatMessage]: + """Create sample chat messages for testing.""" + return [ + ChatMessage(role=Role.USER, text="What is in the documents?"), + ] + + +class TestSearchProviderInitialization: + """Test initialization and configuration of AzureAISearchContextProvider.""" + + def test_init_semantic_mode_minimal(self) -> None: + """Test initialization with minimal semantic mode parameters.""" + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + assert provider.endpoint == "https://test.search.windows.net" + assert provider.index_name == "test-index" + assert provider.mode == "semantic" + assert provider.top_k == 5 + + def test_init_semantic_mode_with_vector_field_requires_embedding_function(self) -> None: + """Test that vector_field_name requires embedding_function.""" + with pytest.raises(ValueError, match="embedding_function is required"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + vector_field_name="embedding", + ) + + def test_init_agentic_mode_requires_parameters(self) -> None: + """Test that agentic mode requires additional parameters.""" + with pytest.raises(ValueError, match="azure_openai_resource_url"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + ) + + def test_init_agentic_mode_requires_model_deployment_name(self) -> None: + """Test that agentic mode requires model_deployment_name.""" + with pytest.raises(ValueError, match="model_deployment_name"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + def test_init_agentic_mode_requires_knowledge_base_name(self) -> None: + """Test that agentic mode requires knowledge_base_name.""" + with pytest.raises(ValueError, match="knowledge_base_name"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + def test_init_agentic_mode_with_all_params(self) -> None: + """Test initialization with all agentic mode parameters.""" + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="my-gpt-4o-deployment", + model_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + assert provider.mode == "agentic" + assert provider.azure_ai_project_endpoint == "https://test.services.ai.azure.com" + assert provider.azure_openai_resource_url == "https://test.openai.azure.com" + assert provider.azure_openai_deployment_name == "my-gpt-4o-deployment" + assert provider.model_name == "gpt-4o" + assert provider.knowledge_base_name == "test-kb" + + def test_init_model_name_defaults_to_deployment_name(self) -> None: + """Test that model_name defaults to deployment_name if not provided.""" + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + assert provider.model_name == "gpt-4o" + + def test_init_with_custom_context_prompt(self) -> None: + """Test initialization with custom context prompt.""" + custom_prompt = "Use the following information:" + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + context_prompt=custom_prompt, + ) + assert provider.context_prompt == custom_prompt + + def test_init_uses_default_context_prompt(self) -> None: + """Test that default context prompt is used when not provided.""" + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + assert provider.context_prompt == provider.DEFAULT_CONTEXT_PROMPT + + +class TestSemanticSearch: + """Test semantic search functionality.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_semantic_search_basic( + self, mock_search_class: MagicMock, sample_messages: list[ChatMessage] + ) -> None: + """Test basic semantic search without vector search.""" + # Setup mock + mock_search_client = AsyncMock() + mock_results = AsyncMock() + mock_results.__aiter__.return_value = iter([{"content": "Test document content"}]) + mock_search_client.search.return_value = mock_results + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + context = await provider.invoking(sample_messages) + + assert isinstance(context, Context) + assert len(context.messages) > 0 + assert "Test document content" in context.messages[0].text + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_semantic_search_empty_query(self, mock_search_class: MagicMock) -> None: + """Test that empty queries return empty context.""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Empty message + context = await provider.invoking([ChatMessage(role=Role.USER, text="")]) + + assert isinstance(context, Context) + assert len(context.messages) == 0 + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_semantic_search_with_vector_query( + self, mock_search_class: MagicMock, sample_messages: list[ChatMessage] + ) -> None: + """Test semantic search with vector query.""" + # Setup mock + mock_search_client = AsyncMock() + mock_results = AsyncMock() + mock_results.__aiter__.return_value = iter([{"content": "Vector search result"}]) + mock_search_client.search.return_value = mock_results + mock_search_class.return_value = mock_search_client + + # Mock embedding function + async def mock_embed(text: str) -> list[float]: + return [0.1, 0.2, 0.3] + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + vector_field_name="embedding", + embedding_function=mock_embed, + ) + + context = await provider.invoking(sample_messages) + + assert isinstance(context, Context) + assert len(context.messages) > 0 + # Verify that search was called + mock_search_client.search.assert_called_once() + + +class TestKnowledgeBaseSetup: + """Test Knowledge Base setup for agentic mode.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_ensure_knowledge_base_creates_when_not_exists( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that Knowledge Base is created when it doesn't exist.""" + # Setup mocks + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.side_effect = ResourceNotFoundError("Not found") + mock_index_client.create_knowledge_source = AsyncMock() + mock_index_client.get_agent.side_effect = ResourceNotFoundError("Not found") + mock_index_client.create_agent = AsyncMock() + mock_index_class.return_value = mock_index_client + + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + model_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + await provider._ensure_knowledge_base() + + # Verify knowledge source was created + mock_index_client.create_knowledge_source.assert_called_once() + # Verify agent (Knowledge Base) was created + mock_index_client.create_agent.assert_called_once() + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_ensure_knowledge_base_skips_when_exists( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that Knowledge Base setup is skipped when already exists.""" + # Setup mocks + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.return_value = MagicMock() # Exists + mock_index_client.get_agent.return_value = MagicMock() # Exists + mock_index_class.return_value = mock_index_client + + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + await provider._ensure_knowledge_base() + + # Verify nothing was created + mock_index_client.create_knowledge_source.assert_not_called() + mock_index_client.create_agent.assert_not_called() + + +class TestContextProviderLifecycle: + """Test context provider lifecycle methods.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_context_manager(self, mock_search_class: MagicMock) -> None: + """Test that provider can be used as async context manager.""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + async with AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) as provider: + assert provider is not None + assert isinstance(provider, AzureAISearchContextProvider) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py new file mode 100644 index 0000000000..2f7f308944 --- /dev/null +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +from agent_framework import ChatAgent +from agent_framework.azure import AzureAISearchContextProvider +from agent_framework_azure_ai import AzureAIAgentClient +from azure.core.credentials import AzureKeyCredential +from azure.identity.aio import DefaultAzureCredential + +""" +The following sample demonstrates how to use Azure AI Search as a context provider +for RAG (Retrieval Augmented Generation) with Azure AI agents. + +AzureAISearchContextProvider supports two modes: + +1. **Semantic mode** (default, recommended): + - Fast hybrid search combining vector and keyword search + - Uses semantic ranking for improved relevance + - Returns raw search results as context + - Best for most RAG use cases + +2. **Agentic mode** (slower, advanced): + - Uses Knowledge Bases in Azure AI Search + - Performs multi-hop reasoning across documents + - Uses an LLM to synthesize answers + - Best for complex queries requiring cross-document reasoning + - Significantly slower (order of magnitude) + +Prerequisites: +1. An Azure AI Search service with a search index +2. An Azure AI Foundry project with a model deployment +3. Ensure the model deployment name exists in your Azure AI Foundry project +4. Set the following environment variables: + + For both modes: + - AZURE_SEARCH_ENDPOINT: Your Azure AI Search endpoint + - AZURE_SEARCH_API_KEY: Your search API key (or use Azure AD) + - AZURE_SEARCH_INDEX_NAME: Your search index name + - AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint + - AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o") + + Additional for agentic mode (Knowledge Bases): + - USE_AGENTIC_MODE: Set to "true" to use agentic retrieval + - AZURE_SEARCH_KNOWLEDGE_BASE_NAME: Your Knowledge Base name + - AZURE_OPENAI_ENDPOINT: Your Azure OpenAI endpoint (e.g., "https://myresource.openai.azure.com") + (This is different from AZURE_AI_PROJECT_ENDPOINT - Knowledge Base needs the OpenAI endpoint for model calls) +""" + +# Sample queries to demonstrate RAG +USER_INPUTS = [ + "What information is available in the knowledge base?", + "Summarize the main topics from the documents", + "Find specific details about the content", +] + + +async def main() -> None: + """Main function demonstrating Azure AI Search context provider.""" + + # Get configuration from environment + search_endpoint = os.environ["AZURE_SEARCH_ENDPOINT"] + search_key = os.environ.get("AZURE_SEARCH_API_KEY") + index_name = os.environ["AZURE_SEARCH_INDEX_NAME"] + project_endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] + model_deployment = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o") + + # Check if agentic mode is requested + use_agentic = os.environ.get("USE_AGENTIC_MODE", "false").lower() == "true" + + # Create credential + search_credential = AzureKeyCredential(search_key) if search_key else DefaultAzureCredential() + + # Create Azure AI Search context provider + if use_agentic: + # Agentic mode: Multi-hop reasoning with Knowledge Bases (slower) + print("Using AGENTIC mode (Knowledge Bases with multi-hop reasoning, slower)\n") + knowledge_base_name = os.environ["AZURE_SEARCH_KNOWLEDGE_BASE_NAME"] + azure_openai_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] + search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + index_name=index_name, + credential=search_credential, + mode="agentic", + # Agentic mode uses Azure AI Foundry project for model inference + azure_ai_project_endpoint=project_endpoint, + model_deployment_name=model_deployment, + knowledge_base_name=knowledge_base_name, + azure_openai_resource_url=azure_openai_endpoint, + top_k=3, + ) + else: + # Semantic mode: Fast hybrid search + semantic ranking (recommended) + print("Using SEMANTIC mode (hybrid search + semantic ranking, fast)\n") + search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + index_name=index_name, + credential=search_credential, + mode="semantic", + top_k=3, # Retrieve top 3 most relevant documents + ) + + # Create agent with search context provider + async with ( + search_provider, + AzureAIAgentClient( + project_endpoint=project_endpoint, + model_deployment_name=model_deployment, + async_credential=DefaultAzureCredential(), + ) as client, + ChatAgent( + chat_client=client, + name="SearchAgent", + instructions=( + "You are a helpful assistant. Use the provided context from the " + "knowledge base to answer questions accurately." + ), + context_providers=[search_provider], + ) as agent, + ): + print("=== Azure AI Agent with Search Context ===\n") + + for user_input in USER_INPUTS: + print(f"User: {user_input}") + print("Agent: ", end="", flush=True) + + # Stream response + async for chunk in agent.run_stream(user_input): + if chunk.text: + print(chunk.text, end="", flush=True) + + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/uv.lock b/python/uv.lock index f0657d5f68..6481f9036f 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -305,6 +305,7 @@ version = "1.0.0b251114" source = { editable = "packages/core" } dependencies = [ { name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "azure-search-documents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "mcp", extra = ["ws"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -342,6 +343,7 @@ requires-dist = [ { name = "agent-framework-purview", marker = "extra == 'all'", editable = "packages/purview" }, { name = "agent-framework-redis", marker = "extra == 'all'", editable = "packages/redis" }, { name = "azure-identity", specifier = ">=1,<2" }, + { name = "azure-search-documents", specifier = ">=11.7.0b1" }, { name = "mcp", extras = ["ws"], specifier = ">=1.13" }, { name = "openai", specifier = ">=1.99.0" }, { name = "opentelemetry-api", specifier = ">=1.24" }, @@ -912,6 +914,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/82/69/b0f93d40bf24ecf3785097e601e389f9d8b5d33b979db7eee6ec5b92df40/azure_ai_projects-2.0.0b1-py3-none-any.whl", hash = "sha256:88eca153c63c8d93bc1776858ae036f1796ef0b9552b77736c446e575e3b132a", size = 221008, upload-time = "2025-11-12T04:00:01.281Z" }, ] +[[package]] +name = "azure-common" +version = "1.1.28" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3e/71/f6f71a276e2e69264a97ad39ef850dca0a04fce67b12570730cb38d0ccac/azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3", size = 20914, upload-time = "2022-02-03T19:39:44.373Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/55/7f118b9c1b23ec15ca05d15a578d8207aa1706bc6f7c87218efffbbf875d/azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad", size = 14462, upload-time = "2022-02-03T19:39:42.417Z" }, +] + [[package]] name = "azure-core" version = "1.36.0" @@ -971,6 +982,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/83/7b/5652771e24fff12da9dde4c20ecf4682e606b104f26419d139758cc935a6/azure_identity-1.25.1-py3-none-any.whl", hash = "sha256:e9edd720af03dff020223cd269fa3a61e8f345ea75443858273bcb44844ab651", size = 191317, upload-time = "2025-10-06T20:30:04.251Z" }, ] +[[package]] +name = "azure-search-documents" +version = "11.7.0b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-common", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "isodate", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bc/9a/819cb9375e431eac8f337777526a88e5f8e10cbf2efd877d39c470f49aa2/azure_search_documents-11.7.0b1.tar.gz", hash = "sha256:0324bd6732dd79c2bc4b6f6429d5a1b96129077d4706620af6f74af4b952fe0c", size = 395057, upload-time = "2025-09-04T23:30:32.571Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/14/8e9bc312fffc21e5220c8ce87f064009e7f577fef69bade2a9713032b646/azure_search_documents-11.7.0b1-py3-none-any.whl", hash = "sha256:aae83b501246438eff77aa3930de9fdd3532c0d219018342fd9d85c682814376", size = 413982, upload-time = "2025-09-04T23:30:34.252Z" }, +] + [[package]] name = "azure-storage-blob" version = "12.27.1" From 84615628dbee456e0ad98d406c6dc22860d0a04a Mon Sep 17 00:00:00 2001 From: Farzad Sunavala Date: Sun, 2 Nov 2025 18:04:38 +0000 Subject: [PATCH 2/6] address pablo coments --- .../agent_framework/azure/_search_provider.py | 235 ++++++++++++------ .../core/tests/azure/test_search_provider.py | 206 +++++++++++++++ .../getting_started/agents/azure_ai/README.md | 10 +- .../azure_ai/azure_ai_with_search_context.py | 2 +- 4 files changed, 380 insertions(+), 73 deletions(-) diff --git a/python/packages/core/agent_framework/azure/_search_provider.py b/python/packages/core/agent_framework/azure/_search_provider.py index 2a95592d26..267002b6fa 100644 --- a/python/packages/core/agent_framework/azure/_search_provider.py +++ b/python/packages/core/agent_framework/azure/_search_provider.py @@ -12,7 +12,7 @@ import sys from collections.abc import MutableSequence -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential @@ -125,11 +125,14 @@ class AzureAISearchContextProvider(ContextProvider): ) """ + DEFAULT_CONTEXT_PROMPT = "Use the following context to answer the question:" + def __init__( self, endpoint: str, index_name: str, credential: AzureKeyCredential | AsyncTokenCredential, + *, mode: Literal["semantic", "agentic"] = "semantic", top_k: int = 5, semantic_configuration_name: str | None = None, @@ -208,6 +211,13 @@ def __init__( self.azure_openai_api_key = azure_openai_api_key self.azure_ai_project_endpoint = azure_ai_project_endpoint + # Auto-discover vector field if not specified + self._auto_discovered_vector_field = False + if not vector_field_name and mode == "semantic": + # Attempt to auto-discover vector field from index schema + # This will be done lazily on first search to avoid blocking initialization + pass + # Validation if vector_field_name and not embedding_function: raise ValueError("embedding_function is required when vector_field_name is specified") @@ -237,14 +247,15 @@ def __init__( credential=credential, ) - # Create index client for agentic mode (Knowledge Base) - # Note: Retrieval client is created fresh for each query to avoid transport issues + # Create index client and retrieval client for agentic mode (Knowledge Base) self._index_client: SearchIndexClient | None = None + self._retrieval_client: KnowledgeAgentRetrievalClient | None = None if mode == "agentic": self._index_client = SearchIndexClient( endpoint=endpoint, credential=credential, ) + # Retrieval client will be created after Knowledge Base initialization self._knowledge_base_initialized = False @@ -258,20 +269,17 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: Any, ) -> None: - """Async context manager exit - cleanup handled by client destructors. + """Async context manager exit - cleanup clients. Args: exc_type: Exception type if an error occurred. exc_val: Exception value if an error occurred. exc_tb: Exception traceback if an error occurred. - - Note: - We don't explicitly close the Azure SDK clients here because doing so - can cause "transport already closed" errors on subsequent uses within - the same context. The clients will clean up their resources when they - are garbage collected. """ - pass + # Close retrieval client if it was created + if self._retrieval_client is not None: + await self._retrieval_client.close() + self._retrieval_client = None @override async def invoking( @@ -288,18 +296,28 @@ async def invoking( Returns: Context object with retrieved documents as messages. """ - # Extract query from messages + # Convert to list and filter to USER/ASSISTANT messages with text only messages_list = [messages] if isinstance(messages, ChatMessage) else list(messages) - query = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip()) + from agent_framework import Role - if not query: + filtered_messages = [ + msg + for msg in messages_list + if msg and msg.text and msg.text.strip() and msg.role in [Role.USER, Role.ASSISTANT] + ] + + if not filtered_messages: return Context() # Perform search based on mode if self.mode == "semantic": + # Semantic mode: flatten messages to single query + query = "\n".join(msg.text for msg in filtered_messages) search_results = await self._semantic_search(query) else: # agentic - search_results = await self._agentic_search(query) + # Agentic mode: pass last 10 messages as conversation history + recent_messages = filtered_messages[-10:] + search_results = await self._agentic_search(recent_messages) # Format results as context if not search_results: @@ -309,6 +327,72 @@ async def invoking( return Context(messages=[ChatMessage(role="system", text=context_text)]) + async def _auto_discover_vector_field(self) -> None: + """Auto-discover vector field from index schema. + + Attempts to find vector fields in the index and use the vectorizer configuration. + If successful, sets self.vector_field_name. If embedding function is not provided, + logs a warning that vector search is recommended for RAG. + """ + if self._auto_discovered_vector_field or self.vector_field_name: + return # Already discovered or manually specified + + try: + # Need index client to get schema + if not self._index_client: + from azure.search.documents.indexes.aio import SearchIndexClient + + index_client = SearchIndexClient(endpoint=self.endpoint, credential=self.credential) + else: + index_client = self._index_client + + # Get index schema + index = await index_client.get_index(self.index_name) + + # Find vector fields - must have vector_search_dimensions set (not None) + vector_fields = [ + field + for field in index.fields + if field.vector_search_dimensions is not None and field.vector_search_dimensions > 0 + ] + + if len(vector_fields) == 1: + # Exactly one vector field found - auto-select it + self.vector_field_name = vector_fields[0].name + self._auto_discovered_vector_field = True + + # Warn if no embedding function provided + if not self.embedding_function: + import logging + + logging.warning( + f"Auto-discovered vector field '{self.vector_field_name}' but no embedding_function provided. " + "Vector search is recommended for RAG use cases. Falling back to keyword-only search." + ) + self.vector_field_name = None # Clear it since we can't use it + elif len(vector_fields) > 1: + # Multiple vector fields - warn and continue with keyword search + import logging + + logging.warning( + f"Multiple vector fields found in index '{self.index_name}': " + f"{[f.name for f in vector_fields]}. " + "Please specify vector_field_name explicitly. Using keyword-only search." + ) + # If no vector fields found, silently continue with keyword search + + # Close index client if we created it + if not self._index_client: + await index_client.close() + + except Exception as e: + # Log warning but continue with keyword search + import logging + + logging.warning(f"Failed to auto-discover vector field: {e}. Using keyword-only search.") + + self._auto_discovered_vector_field = True # Mark as attempted + async def _semantic_search(self, query: str) -> str: """Perform semantic hybrid search with semantic ranking. @@ -323,15 +407,20 @@ async def _semantic_search(self, query: str) -> str: Returns: Formatted search results as string. """ + # Auto-discover vector field if not already done + await self._auto_discover_vector_field() + vector_queries = [] # Generate vector query if embedding function provided if self.embedding_function and self.vector_field_name: query_vector = await self.embedding_function(query) + # Use larger k for vector query when semantic reranker is enabled for better ranking quality + vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k vector_queries = [ VectorizedQuery( vector=query_vector, - k_nearest_neighbors=self.top_k, + k_nearest_neighbors=vector_k, fields=self.vector_field_name, ) ] @@ -354,18 +443,14 @@ async def _semantic_search(self, query: str) -> str: # Execute search results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType] - # Format results + # Format results with citations formatted_results: list[str] = [] async for doc in results: # type: ignore[reportUnknownVariableType] - # Extract semantic captions if available - caption: str | None = None - if hasattr(doc, "@search.captions"): # type: ignore[reportUnknownArgumentType] - captions: Any = doc.get("@search.captions", []) # type: ignore[reportUnknownVariableType] - if captions: - caption = captions[0].text if hasattr(captions[0], "text") else str(captions[0]) # type: ignore[reportUnknownArgumentType, reportUnknownMemberType] - - # Build document text - doc_text: str = caption if caption else self._extract_document_text(doc) # type: ignore[reportUnknownArgumentType] + # Extract document ID for citation + doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType] + + # Use full document chunks with citation + doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType] if doc_text: formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType] @@ -382,11 +467,17 @@ async def _ensure_knowledge_base(self) -> None: if self._knowledge_base_initialized or not self._index_client: return - # Type narrowing: these are validated as non-None in __init__ for agentic mode - # Using cast() for type checker - actual validation happens in __init__ - knowledge_base_name = cast(str, self.knowledge_base_name) - azure_openai_resource_url = cast(str, self.azure_openai_resource_url) - azure_openai_deployment_name = cast(str, self.azure_openai_deployment_name) + # Runtime validation for agentic mode parameters + if not self.knowledge_base_name: + raise ValueError("knowledge_base_name is required for agentic mode") + if not self.azure_openai_resource_url: + raise ValueError("azure_openai_resource_url is required for agentic mode") + if not self.azure_openai_deployment_name: + raise ValueError("model_deployment_name is required for agentic mode") + + knowledge_base_name = self.knowledge_base_name + azure_openai_resource_url = self.azure_openai_resource_url + azure_openai_deployment_name = self.azure_openai_deployment_name # Step 1: Create or get knowledge source knowledge_source_name = f"{self.index_name}-source" @@ -444,7 +535,15 @@ async def _ensure_knowledge_base(self) -> None: self._knowledge_base_initialized = True - async def _agentic_search(self, query: str) -> str: + # Create retrieval client now that Knowledge Base is initialized + if _agentic_retrieval_available and self._retrieval_client is None: + self._retrieval_client = KnowledgeAgentRetrievalClient( + endpoint=self.endpoint, + agent_name=knowledge_base_name, + credential=self.credential, + ) + + async def _agentic_search(self, messages: list[ChatMessage]) -> str: """Perform agentic retrieval with multi-hop reasoning using Knowledge Bases. NOTE: This mode is significantly slower than semantic search and should @@ -457,7 +556,7 @@ async def _agentic_search(self, query: str) -> str: 4. Synthesize a comprehensive answer with references Args: - query: Search query text. + messages: Conversation history (last 10 messages) to use for retrieval context. Returns: Synthesized answer from the Knowledge Base. @@ -465,37 +564,25 @@ async def _agentic_search(self, query: str) -> str: # Ensure Knowledge Base is initialized await self._ensure_knowledge_base() - # Type narrowing: knowledge_base_name is validated in __init__ for agentic mode - # Using cast() for type checker - actual validation happens in __init__ - knowledge_base_name = cast(str, self.knowledge_base_name) - - # Create retrieval request with query as a conversation message + # Convert ChatMessage list to KnowledgeAgent message format # Note: SDK uses KnowledgeAgent class names, but represents Knowledge Base operations - retrieval_request = KnowledgeAgentRetrievalRequest( - messages=[ - KnowledgeAgentMessage( - role="user", - content=[KnowledgeAgentMessageTextContent(text=query)], - ) - ] - ) + kb_messages = [ + KnowledgeAgentMessage( + role=msg.role.value if hasattr(msg.role, "value") else str(msg.role), + content=[KnowledgeAgentMessageTextContent(text=msg.text)], + ) + for msg in messages + if msg.text + ] - # Create a fresh retrieval client for each query to avoid transport closure issues - if not _agentic_retrieval_available: - raise ImportError("KnowledgeAgentRetrievalClient not available") + retrieval_request = KnowledgeAgentRetrievalRequest(messages=kb_messages) - retrieval_client = KnowledgeAgentRetrievalClient( - endpoint=self.endpoint, - agent_name=knowledge_base_name, - credential=self.credential, - ) + # Use reusable retrieval client + if not self._retrieval_client: + raise RuntimeError("Retrieval client not initialized. Ensure Knowledge Base is set up correctly.") - try: - # Perform retrieval via Knowledge Base - retrieval_result = await retrieval_client.retrieve(retrieval_request=retrieval_request) - finally: - # Ensure client is closed after use - await retrieval_client.close() + # Perform retrieval via Knowledge Base + retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request) # Extract synthesized answer from response if retrieval_result.response and len(retrieval_result.response) > 0: @@ -515,24 +602,32 @@ async def _agentic_search(self, query: str) -> str: # Fallback if no answer generated return "No results found from Knowledge Base." - def _extract_document_text(self, doc: dict[str, Any]) -> str: - """Extract readable text from a search document. + def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str: + """Extract readable text from a search document with optional citation. Args: doc: Search result document. + doc_id: Optional document ID for citation. Returns: - Formatted document text. + Formatted document text with citation if doc_id provided. """ # Try common text field names + text = "" for field in ["content", "text", "description", "body", "chunk"]: if doc.get(field): - return str(doc[field])[:500] # Limit to 500 chars + text = str(doc[field]) + break # Fallback: concatenate all string fields - text_parts: list[str] = [] - for key, value in doc.items(): - if isinstance(value, str) and not key.startswith("@") and key != "id": - text_parts.append(f"{key}: {value}") - - return " | ".join(text_parts)[:500] if text_parts else "" + if not text: + text_parts: list[str] = [] + for key, value in doc.items(): + if isinstance(value, str) and not key.startswith("@") and key != "id": + text_parts.append(f"{key}: {value}") + text = " | ".join(text_parts) if text_parts else "" + + # Add citation if document ID provided + if doc_id and text: + return f"[Source: {doc_id}] {text}" + return text diff --git a/python/packages/core/tests/azure/test_search_provider.py b/python/packages/core/tests/azure/test_search_provider.py index 149ad12524..29e86f9352 100644 --- a/python/packages/core/tests/azure/test_search_provider.py +++ b/python/packages/core/tests/azure/test_search_provider.py @@ -336,3 +336,209 @@ async def test_context_manager(self, mock_search_class: MagicMock) -> None: ) as provider: assert provider is not None assert isinstance(provider, AzureAISearchContextProvider) + + +class TestMessageFiltering: + """Test message filtering functionality.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_filters_non_user_assistant_messages(self, mock_search_class: MagicMock) -> None: + """Test that only USER and ASSISTANT messages are processed.""" + # Setup mock + mock_search_client = AsyncMock() + mock_results = AsyncMock() + mock_results.__aiter__.return_value = iter([{"content": "Test result"}]) + mock_search_client.search.return_value = mock_results + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Mix of message types + messages = [ + ChatMessage(role=Role.SYSTEM, text="System message"), + ChatMessage(role=Role.USER, text="User message"), + ChatMessage(role=Role.ASSISTANT, text="Assistant message"), + ChatMessage(role=Role.TOOL, text="Tool message"), + ] + + context = await provider.invoking(messages) + + # Should have processed only USER and ASSISTANT messages + assert isinstance(context, Context) + mock_search_client.search.assert_called_once() + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_filters_empty_messages(self, mock_search_class: MagicMock) -> None: + """Test that empty/whitespace messages are filtered out.""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Messages with empty/whitespace text + messages = [ + ChatMessage(role=Role.USER, text=""), + ChatMessage(role=Role.USER, text=" "), + ChatMessage(role=Role.USER, text=None), + ] + + context = await provider.invoking(messages) + + # Should return empty context + assert len(context.messages) == 0 + + +class TestCitations: + """Test citation functionality.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_citations_included_in_semantic_search(self, mock_search_class: MagicMock) -> None: + """Test that citations are included in semantic search results.""" + # Setup mock with document ID + mock_search_client = AsyncMock() + mock_results = AsyncMock() + mock_doc = {"id": "doc123", "content": "Test document content"} + mock_results.__aiter__.return_value = iter([mock_doc]) + mock_search_client.search.return_value = mock_results + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + context = await provider.invoking([ChatMessage(role=Role.USER, text="test query")]) + + # Check that citation is included + assert isinstance(context, Context) + assert len(context.messages) > 0 + assert "[Source: doc123]" in context.messages[0].text + assert "Test document content" in context.messages[0].text + + +class TestVectorFieldAutoDiscovery: + """Test vector field auto-discovery functionality.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_auto_discovers_single_vector_field( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that single vector field is auto-discovered.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock + mock_index_client = AsyncMock() + mock_index = MagicMock() + + # Create mock field with vector_search_dimensions attribute + mock_vector_field = MagicMock() + mock_vector_field.name = "embedding_vector" + mock_vector_field.vector_search_dimensions = 1536 + + mock_index.fields = [mock_vector_field] + mock_index_client.get_index.return_value = mock_index + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider without specifying vector_field_name + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Trigger auto-discovery + await provider._auto_discover_vector_field() + + # Vector field should be auto-discovered but not used without embedding function + assert provider._auto_discovered_vector_field is True + # Should be cleared since no embedding function + assert provider.vector_field_name is None + + @pytest.mark.asyncio + async def test_vector_detection_accuracy(self) -> None: + """Test that vector field detection logic correctly identifies vector fields.""" + from azure.search.documents.indexes.models import SearchField + + # Create real SearchField objects to test the detection logic + vector_field = SearchField( + name="embedding_vector", type="Collection(Edm.Single)", vector_search_dimensions=1536, searchable=True + ) + + string_field = SearchField(name="content", type="Edm.String", searchable=True) + + number_field = SearchField(name="price", type="Edm.Double", filterable=True) + + # Test detection logic directly + is_vector_1 = vector_field.vector_search_dimensions is not None and vector_field.vector_search_dimensions > 0 + is_vector_2 = string_field.vector_search_dimensions is not None and string_field.vector_search_dimensions > 0 + is_vector_3 = number_field.vector_search_dimensions is not None and number_field.vector_search_dimensions > 0 + + # Only the vector field should be detected + assert is_vector_1 is True + assert is_vector_2 is False + assert is_vector_3 is False + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_no_false_positives_on_string_fields( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that regular string fields are not detected as vector fields.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index with only string fields (no vectors) + mock_index_client = AsyncMock() + mock_index = MagicMock() + + # All fields have vector_search_dimensions = None + mock_fields = [] + for name in ["id", "title", "content", "category"]: + field = MagicMock() + field.name = name + field.vector_search_dimensions = None + field.vector_search_profile_name = None + mock_fields.append(field) + + mock_index.fields = mock_fields + mock_index_client.get_index.return_value = mock_index + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Trigger auto-discovery + await provider._auto_discover_vector_field() + + # Should NOT detect any vector fields + assert provider.vector_field_name is None + assert provider._auto_discovered_vector_field is True diff --git a/python/samples/getting_started/agents/azure_ai/README.md b/python/samples/getting_started/agents/azure_ai/README.md index 475dad5a1a..7a3aa71327 100644 --- a/python/samples/getting_started/agents/azure_ai/README.md +++ b/python/samples/getting_started/agents/azure_ai/README.md @@ -15,13 +15,19 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_code_interpreter.py`](azure_ai_with_code_interpreter.py) | Shows how to use the `HostedCodeInterpreterTool` with Azure AI agents to write and execute Python code for mathematical problem solving and data analysis. | | [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with a pre-existing agent by providing the agent name and version to the Azure AI client. Demonstrates agent reuse patterns for production scenarios. | | [`azure_ai_with_existing_conversation.py`](azure_ai_with_existing_conversation.py) | Demonstrates how to use an existing conversation created on the service side with Azure AI agents. Shows two approaches: specifying conversation ID at the client level and using AgentThread with an existing conversation ID. | +| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing thread by providing the thread ID to the Azure AI chat client. This example also demonstrates proper cleanup of manually created threads. | | [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured `AzureAIClient` settings, including project endpoint, model deployment, and credentials rather than relying on environment variable defaults. | | [`azure_ai_with_file_search.py`](azure_ai_with_file_search.py) | Shows how to use the `HostedFileSearchTool` with Azure AI agents to upload files, create vector stores, and enable agents to search through uploaded documents to answer user questions. | +| [`azure_ai_with_function_tools.py`](azure_ai_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`azure_ai_with_hosted_mcp.py`](azure_ai_with_hosted_mcp.py) | Shows how to integrate hosted Model Context Protocol (MCP) tools with Azure AI Agent. | -| [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Shows how to use structured outputs (response format) with Azure AI agents using Pydantic models to enforce specific response schemas. | -| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates thread management with Azure AI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | | [`azure_ai_with_image_generation.py`](azure_ai_with_image_generation.py) | Shows how to use the `ImageGenTool` with Azure AI agents to generate images based on text prompts. | +| [`azure_ai_with_local_mcp.py`](azure_ai_with_local_mcp.py) | Shows how to integrate Azure AI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. Demonstrates both agent-level and run-level tool configuration. | | [`azure_ai_with_microsoft_fabric.py`](azure_ai_with_microsoft_fabric.py) | Shows how to use Microsoft Fabric with Azure AI agents to query Fabric data sources and provide responses based on data analysis. Requires a Microsoft Fabric connection configured in your Azure AI project. | +| [`azure_ai_with_multiple_tools.py`](azure_ai_with_multiple_tools.py) | Demonstrates how to use multiple tools together with Azure AI agents, including web search, MCP servers, and function tools. Shows coordinated multi-tool interactions and approval workflows. | +| [`azure_ai_with_openapi_tools.py`](azure_ai_with_openapi_tools.py) | Demonstrates how to use OpenAPI tools with Azure AI agents to integrate external REST APIs. Shows OpenAPI specification loading, anonymous authentication, thread context management, and coordinated multi-API conversations using weather and countries APIs. | +| [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Shows how to use structured outputs (response format) with Azure AI agents using Pydantic models to enforce specific response schemas. | +| [`azure_ai_with_search_context.py`](azure_ai_with_search_context.py) | Shows how to use AzureAISearchContextProvider for RAG with hybrid search and semantic ranking. Demonstrates semantic (fast) and agentic (slower, uses Knowledge Bases for complex multi-hop reasoning) retrieval modes. Provides context from Azure AI Search to enhance agent responses. | +| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates thread management with Azure AI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | | [`azure_ai_with_web_search.py`](azure_ai_with_web_search.py) | Shows how to use the `HostedWebSearchTool` with Azure AI agents to perform web searches and retrieve up-to-date information from the internet. | ## Environment Variables diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py index 2f7f308944..c9f05dba2b 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_search_context.py @@ -36,7 +36,7 @@ For both modes: - AZURE_SEARCH_ENDPOINT: Your Azure AI Search endpoint - - AZURE_SEARCH_API_KEY: Your search API key (or use Azure AD) + - AZURE_SEARCH_API_KEY: (Optional) Your search API key - if not provided, uses DefaultAzureCredential for Entra ID authentication - AZURE_SEARCH_INDEX_NAME: Your search index name - AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint - AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o") From 9c479942d84573b450c560bc063c587fcecb809c Mon Sep 17 00:00:00 2001 From: kishikawa-hayato <84244732+HerBest-max@users.noreply.github.com> Date: Sat, 15 Nov 2025 19:52:58 +0900 Subject: [PATCH 3/6] fix(azure): remove DEFAULT_CONTEXT_PROMPT override to fix mypy error - Fixes mypy error: Cannot assign to final name 'DEFAULT_CONTEXT_PROMPT' - Base class ContextProvider already defines this as Final - Instance variable self.context_prompt correctly set in __init__ (Line 197) Related to microsoft/agent-framework#1546 --- python/packages/core/agent_framework/azure/_search_provider.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/packages/core/agent_framework/azure/_search_provider.py b/python/packages/core/agent_framework/azure/_search_provider.py index 267002b6fa..ff63ac4f4b 100644 --- a/python/packages/core/agent_framework/azure/_search_provider.py +++ b/python/packages/core/agent_framework/azure/_search_provider.py @@ -125,8 +125,6 @@ class AzureAISearchContextProvider(ContextProvider): ) """ - DEFAULT_CONTEXT_PROMPT = "Use the following context to answer the question:" - def __init__( self, endpoint: str, From 01b662f785a0c0d041e6b8aa75d3720dfd4a3998 Mon Sep 17 00:00:00 2001 From: kishikawa-hayato <84244732+HerBest-max@users.noreply.github.com> Date: Sat, 15 Nov 2025 20:16:27 +0900 Subject: [PATCH 4/6] test(azure): improve search provider coverage to 84% MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added comprehensive test coverage for Azure AI Search provider: - Agentic mode end-to-end flow (invoking → retrieval) - Agentic mode empty response fallback handling - Auto-discovery exception handling for network errors - Semantic search with semantic_configuration_name parameter Coverage improved from 73% to 84% (target: 80%+) Test results: 25 passed, 0 failed --- .../core/tests/azure/test_search_provider.py | 191 ++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/python/packages/core/tests/azure/test_search_provider.py b/python/packages/core/tests/azure/test_search_provider.py index 29e86f9352..f440b50faa 100644 --- a/python/packages/core/tests/azure/test_search_provider.py +++ b/python/packages/core/tests/azure/test_search_provider.py @@ -542,3 +542,194 @@ async def test_no_false_positives_on_string_fields( # Should NOT detect any vector fields assert provider.vector_field_name is None assert provider._auto_discovered_vector_field is True + + +class TestAgenticMode: + """Test agentic mode functionality with Knowledge Bases.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.KnowledgeAgentRetrievalClient") + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_agentic_mode_end_to_end( + self, + mock_search_class: MagicMock, + mock_index_class: MagicMock, + mock_retrieval_class: MagicMock, + ) -> None: + """Test complete agentic mode flow from invoking to retrieval.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock (Knowledge Base exists) + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.return_value = MagicMock() + mock_index_client.get_agent.return_value = MagicMock() + mock_index_class.return_value = mock_index_client + + # Setup retrieval client mock + mock_retrieval_client = AsyncMock() + + # Import the models for mocking + from agent_framework.azure._search_provider import ( + KnowledgeAgentMessageTextContent, + ) + + # Mock retrieval response + mock_response_message = MagicMock() + mock_response_message.content = [ + KnowledgeAgentMessageTextContent(text="This is the synthesized answer from the Knowledge Base.") + ] + + mock_retrieval_result = MagicMock() + mock_retrieval_result.response = [mock_response_message] + mock_retrieval_client.retrieve = AsyncMock(return_value=mock_retrieval_result) + + mock_retrieval_class.return_value = mock_retrieval_client + + # Create provider in agentic mode + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + # Call invoking with a user message + messages = [ChatMessage(role=Role.USER, text="What information is available?")] + context = await provider.invoking(messages) + + # Verify context was created with synthesized answer + assert isinstance(context, Context) + assert len(context.messages) > 0 + assert "synthesized answer from the Knowledge Base" in context.messages[0].text + + # Verify retrieval was called + mock_retrieval_client.retrieve.assert_called_once() + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.KnowledgeAgentRetrievalClient") + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_agentic_mode_empty_response_fallback( + self, + mock_search_class: MagicMock, + mock_index_class: MagicMock, + mock_retrieval_class: MagicMock, + ) -> None: + """Test that agentic mode handles empty responses with fallback message.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock (Knowledge Base exists) + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.return_value = MagicMock() + mock_index_client.get_agent.return_value = MagicMock() + mock_index_class.return_value = mock_index_client + + # Setup retrieval client mock with empty response + mock_retrieval_client = AsyncMock() + mock_retrieval_result = MagicMock() + mock_retrieval_result.response = [] # Empty response + mock_retrieval_client.retrieve = AsyncMock(return_value=mock_retrieval_result) + mock_retrieval_class.return_value = mock_retrieval_client + + # Create provider in agentic mode + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + # Call invoking + messages = [ChatMessage(role=Role.USER, text="What is this about?")] + context = await provider.invoking(messages) + + # Should have fallback message + assert isinstance(context, Context) + assert len(context.messages) > 0 + assert "No results found from Knowledge Base" in context.messages[0].text + + +class TestErrorHandling: + """Test error handling and edge cases.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchIndexClient") + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_auto_discovery_exception_handling( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that auto-discovery gracefully handles exceptions.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client to raise an exception + mock_index_client = AsyncMock() + mock_index_client.get_index.side_effect = Exception("Network error") + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Should not raise exception, just log warning + import logging + + with patch.object(logging, "warning") as mock_warning: + await provider._auto_discover_vector_field() + # Should log warning about failure + mock_warning.assert_called_once() + + # Should mark as attempted and continue with keyword search + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name is None + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_semantic_search_with_semantic_configuration(self, mock_search_class: MagicMock) -> None: + """Test semantic search with semantic_configuration_name parameter.""" + # Setup mock + mock_search_client = AsyncMock() + mock_results = AsyncMock() + mock_results.__aiter__.return_value = iter([{"content": "Semantic search result"}]) + mock_search_client.search.return_value = mock_results + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + semantic_configuration_name="my-semantic-config", + ) + + messages = [ChatMessage(role=Role.USER, text="test query")] + context = await provider.invoking(messages) + + # Verify search was called with semantic configuration + assert mock_search_client.search.called + call_args = mock_search_client.search.call_args + assert "semantic_configuration_name" in call_args.kwargs + assert call_args.kwargs["semantic_configuration_name"] == "my-semantic-config" + + # Verify context was created + assert isinstance(context, Context) + assert len(context.messages) > 0 From e6b496ad2f4e70ee166a649f9888b936764b72c7 Mon Sep 17 00:00:00 2001 From: kishikawa-hayato <84244732+HerBest-max@users.noreply.github.com> Date: Sat, 15 Nov 2025 20:30:01 +0900 Subject: [PATCH 5/6] test(azure): add vector auto-discovery tests for 91% coverage Added 4 comprehensive tests for vector field auto-discovery (lines 351-384): - test_single_vector_field_with_embedding_function: Tests successful auto-discovery when embedding function is provided - test_single_vector_field_warning_without_embedding: Tests warning when vector field found but no embedding function - test_multiple_vector_fields_warning: Tests warning when multiple vector fields detected - test_index_client_cleanup_when_not_provided: Tests proper cleanup of internally-created index client Coverage improved from 84% to 91% (target: 90%+) Test results: 29 passed, 0 failed Key fix: Patched Azure SDK classes at source (azure.search.documents.*.aio) instead of module level to handle dynamic imports correctly --- .../core/tests/azure/test_search_provider.py | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) diff --git a/python/packages/core/tests/azure/test_search_provider.py b/python/packages/core/tests/azure/test_search_provider.py index f440b50faa..006fbc07fa 100644 --- a/python/packages/core/tests/azure/test_search_provider.py +++ b/python/packages/core/tests/azure/test_search_provider.py @@ -543,6 +543,188 @@ async def test_no_false_positives_on_string_fields( assert provider.vector_field_name is None assert provider._auto_discovered_vector_field is True + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_single_vector_field_with_embedding_function( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test auto-discovery with single vector field and embedding function provided.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock + mock_index_client = AsyncMock() + mock_index = MagicMock() + + # Create single vector field + mock_vector_field = MagicMock() + mock_vector_field.name = "embedding_vector" + mock_vector_field.vector_search_dimensions = 1536 + + mock_index.fields = [mock_vector_field] + mock_index_client.get_index.return_value = mock_index + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Mock embedding function + async def mock_embed(text: str) -> list[float]: + return [0.1] * 1536 + + # Create provider WITH embedding function + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + embedding_function=mock_embed, + ) + + # Trigger auto-discovery + await provider._auto_discover_vector_field() + + # Vector field should be auto-discovered and kept (has embedding function) + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name == "embedding_vector" + + # Verify index client was closed + mock_index_client.close.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_single_vector_field_warning_without_embedding( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that warning is logged when vector field found but no embedding function.""" + import logging + + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock + mock_index_client = AsyncMock() + mock_index = MagicMock() + + # Create single vector field + mock_vector_field = MagicMock() + mock_vector_field.name = "embedding_vector" + mock_vector_field.vector_search_dimensions = 1536 + + mock_index.fields = [mock_vector_field] + mock_index_client.get_index.return_value = mock_index + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider WITHOUT embedding function + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Trigger auto-discovery with logging captured + with patch.object(logging, "warning") as mock_warning: + await provider._auto_discover_vector_field() + # Should warn about missing embedding function + mock_warning.assert_called_once() + warning_msg = str(mock_warning.call_args) + assert "Auto-discovered vector field" in warning_msg + assert "no embedding_function provided" in warning_msg + + # Vector field should be cleared + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name is None + + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_multiple_vector_fields_warning( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that warning is logged when multiple vector fields are found.""" + import logging + + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock + mock_index_client = AsyncMock() + mock_index = MagicMock() + + # Create multiple vector fields + mock_fields = [] + for name in ["embedding_v1", "embedding_v2"]: + field = MagicMock() + field.name = name + field.vector_search_dimensions = 1536 + mock_fields.append(field) + + mock_index.fields = mock_fields + mock_index_client.get_index.return_value = mock_index + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Trigger auto-discovery with logging captured + with patch.object(logging, "warning") as mock_warning: + await provider._auto_discover_vector_field() + # Should warn about multiple vector fields + mock_warning.assert_called_once() + warning_msg = str(mock_warning.call_args) + assert "Multiple vector fields found" in warning_msg + + # Should not auto-select any field + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name is None + + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_index_client_cleanup_when_not_provided( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test that index client is closed when created internally.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Setup index client mock + mock_index_client = AsyncMock() + mock_index = MagicMock() + mock_index.fields = [] # No vector fields + mock_index_client.get_index.return_value = mock_index + mock_index_client.close = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider without index client + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Ensure no external index client + assert provider._index_client is None + + # Trigger auto-discovery + await provider._auto_discover_vector_field() + + # Index client should have been closed + mock_index_client.close.assert_called_once() + class TestAgenticMode: """Test agentic mode functionality with Knowledge Bases.""" From 968acb3179a49c96387e2afc9d36b3db5236d1c4 Mon Sep 17 00:00:00 2001 From: kishikawa-hayato <84244732+HerBest-max@users.noreply.github.com> Date: Sat, 15 Nov 2025 20:48:02 +0900 Subject: [PATCH 6/6] test(azure): achieve 98% test coverage for search provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Systematically added comprehensive tests across 4 phases: **Phase 1 (91%→94%):** Agentic mode validation + resource cleanup - 7 tests for parameter validation (init-time and runtime) - Retrieval client cleanup verification **Phase 2 (94%→95%):** Import error handling - 1 test for SDK unavailability (ImportError) **Phase 3 (95%→96%):** Edge cases and alternative paths - Empty search results handling - External index client usage - Early return optimization **Phase 4 (96%→98%):** Helper functions - Document text extraction fallback logic Coverage: 91% → 98% (186 statements, 3 missed) Tests: 29 → 41 passed (+12 tests) **Remaining 3 uncovered lines (acceptable):** - Lines 60-61: Module import exception (impossible to unit test) - Line 580: Agentic RuntimeError (requires complex SDK mocking) All tests pass. 98% coverage exceeds industry best practices. --- .../core/tests/azure/test_search_provider.py | 302 ++++++++++++++++++ 1 file changed, 302 insertions(+) diff --git a/python/packages/core/tests/azure/test_search_provider.py b/python/packages/core/tests/azure/test_search_provider.py index 006fbc07fa..b4b8a5adb2 100644 --- a/python/packages/core/tests/azure/test_search_provider.py +++ b/python/packages/core/tests/azure/test_search_provider.py @@ -915,3 +915,305 @@ async def test_semantic_search_with_semantic_configuration(self, mock_search_cla # Verify context was created assert isinstance(context, Context) assert len(context.messages) > 0 + + +class TestAgenticModeValidation: + """Test agentic mode parameter validation.""" + + def test_agentic_mode_missing_knowledge_base_name(self) -> None: + """Test ValueError when knowledge_base_name is missing in agentic mode.""" + with pytest.raises(ValueError, match="knowledge_base_name is required"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + azure_openai_resource_url="https://test.openai.azure.com", + # knowledge_base_name is missing + ) + + def test_agentic_mode_missing_azure_openai_resource_url(self) -> None: + """Test ValueError when azure_openai_resource_url is missing in agentic mode.""" + with pytest.raises(ValueError, match="azure_openai_resource_url"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + # azure_openai_resource_url is missing + ) + + def test_agentic_mode_missing_model_deployment_name(self) -> None: + """Test ValueError when model_deployment_name is missing in agentic mode.""" + with pytest.raises(ValueError, match="model_deployment_name"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + # model_deployment_name is missing + ) + + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_ensure_kb_missing_knowledge_base_name_runtime( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test ValueError when knowledge_base_name is None at runtime in _ensure_knowledge_base.""" + # Setup mocks + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + mock_index_client = AsyncMock() + mock_index_class.return_value = mock_index_client + + # Create provider with knowledge_base_name, then set to None to trigger runtime validation + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + # Set knowledge_base_name to None to trigger runtime validation + provider.knowledge_base_name = None + + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await provider._ensure_knowledge_base() + + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_ensure_kb_missing_azure_openai_url_runtime( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test ValueError when azure_openai_resource_url is None at runtime.""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + mock_index_client = AsyncMock() + mock_index_class.return_value = mock_index_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + provider.azure_openai_resource_url = None + + with pytest.raises(ValueError, match="azure_openai_resource_url"): + await provider._ensure_knowledge_base() + + @pytest.mark.asyncio + @patch("azure.search.documents.indexes.aio.SearchIndexClient") + @patch("azure.search.documents.aio.SearchClient") + async def test_ensure_kb_missing_deployment_name_runtime( + self, mock_search_class: MagicMock, mock_index_class: MagicMock + ) -> None: + """Test ValueError when model_deployment_name is None at runtime.""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + mock_index_client = AsyncMock() + mock_index_class.return_value = mock_index_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + provider.azure_openai_deployment_name = None + + with pytest.raises(ValueError, match="model_deployment_name"): + await provider._ensure_knowledge_base() + + @pytest.mark.asyncio + @patch("azure.search.documents.aio.SearchClient") + async def test_retrieval_client_cleanup_on_exit(self, mock_search_class: MagicMock) -> None: + """Test that retrieval client is properly closed in __aexit__.""" + # Setup search client mock + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + # Create provider in agentic mode + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + # Mock retrieval client + mock_retrieval_client = AsyncMock() + provider._retrieval_client = mock_retrieval_client + + # Exit context + await provider.__aexit__(None, None, None) + + # Verify cleanup + mock_retrieval_client.close.assert_called_once() + assert provider._retrieval_client is None + + +class TestImportErrorHandling: + """Test import error handling for optional dependencies.""" + + @patch("agent_framework.azure._search_provider._agentic_retrieval_available", False) + def test_agentic_mode_unavailable_sdk(self) -> None: + """Test ImportError when agentic SDK is not available.""" + with pytest.raises(ImportError, match="azure-search-documents >= 11.7.0b1"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + +class TestEdgeCasesAndPaths: + """Test edge cases and alternative code paths.""" + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_empty_search_results_returns_empty_context(self, mock_search_class: MagicMock) -> None: + """Test that empty search results return empty Context (line 322).""" + # Setup mock to return empty results + mock_search_client = AsyncMock() + mock_results = AsyncMock() + mock_results.__aiter__.return_value = iter([]) # No results + mock_search_client.search.return_value = mock_results + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + context = await provider.invoking([ChatMessage(role=Role.USER, text="test query")]) + + # Should return empty context when no results + assert isinstance(context, Context) + assert len(context.messages) == 0 + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_external_index_client_used_when_provided(self, mock_search_class: MagicMock) -> None: + """Test that external _index_client is used when provided (line 345).""" + # Setup mocks + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + external_index_client = AsyncMock() + mock_index = MagicMock() + mock_index.fields = [] + external_index_client.get_index.return_value = mock_index + external_index_client.close = AsyncMock() + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Set external index client + provider._index_client = external_index_client + + # Trigger auto-discovery which should use the external client + await provider._auto_discover_vector_field() + + # Verify external client was used (get_index called) + external_index_client.get_index.assert_called_once_with("test-index") + # Verify external client was NOT closed (we don't own it) + external_index_client.close.assert_not_called() + + @pytest.mark.asyncio + @patch("agent_framework.azure._search_provider.SearchClient") + async def test_ensure_kb_early_return_when_already_initialized(self, mock_search_class: MagicMock) -> None: + """Test early return from _ensure_knowledge_base when already initialized (line 466).""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="agentic", + azure_ai_project_endpoint="https://test.services.ai.azure.com", + model_deployment_name="gpt-4o", + knowledge_base_name="test-kb", + azure_openai_resource_url="https://test.openai.azure.com", + ) + + # Mark as already initialized + provider._knowledge_base_initialized = True + + # Call should return early without doing anything (no errors) + await provider._ensure_knowledge_base() + + +class TestHelperFunctions: + """Test helper and utility functions.""" + + @patch("agent_framework.azure._search_provider.SearchClient") + def test_extract_document_text_fallback_to_all_fields(self, mock_search_class: MagicMock) -> None: + """Test document text extraction fallback when no standard fields (lines 622-626).""" + mock_search_client = AsyncMock() + mock_search_class.return_value = mock_search_client + + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + mode="semantic", + ) + + # Document with no standard text fields, only custom fields + doc = { + "title": "Custom Title", + "author": "John Doe", + "year": 2024, # Non-string field should be ignored + "id": "doc123", # Should be ignored + "@search.score": 0.95, # @ fields should be ignored + } + + result = provider._extract_document_text(doc) + + # Should concatenate string fields (except id and @fields) + assert "title: Custom Title" in result + assert "author: John Doe" in result + assert "2024" not in result # Non-string ignored + assert "id" not in result # id field ignored + assert "@search" not in result # @ fields ignored