From 6dfabe415c3750df5680ff276890d36def9fb8ae Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 24 Feb 2026 16:21:39 +0100 Subject: [PATCH 1/3] small updates and improvements in the azure AISearch provider --- .../_context_provider.py | 220 +++++++--- .../tests/test_aisearch_context_provider.py | 379 +++++++++++++++++- .../azure_ai_with_search_context_agentic.py | 7 +- .../azure_ai_with_search_context_semantic.py | 24 +- 4 files changed, 568 insertions(+), 62 deletions(-) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 4445de8015..4be5ad4d74 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -13,7 +13,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict -from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message +from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Annotation, Content, Message, SupportsGetEmbeddings from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext from agent_framework._settings import SecretString, load_settings from agent_framework.azure._entra_id_authentication import AzureCredentialTypes @@ -47,8 +47,12 @@ from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient from azure.search.documents.knowledgebases.models import ( KnowledgeBaseMessage, + KnowledgeBaseMessageImageContent, + KnowledgeBaseMessageImageContentImage, KnowledgeBaseMessageTextContent, + KnowledgeBaseReference, KnowledgeBaseRetrievalRequest, + KnowledgeBaseRetrievalResponse, KnowledgeRetrievalIntent, KnowledgeRetrievalSemanticIntent, ) @@ -78,8 +82,12 @@ from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient from azure.search.documents.knowledgebases.models import ( KnowledgeBaseMessage, + KnowledgeBaseMessageImageContent, + KnowledgeBaseMessageImageContentImage, KnowledgeBaseMessageTextContent, + KnowledgeBaseReference, KnowledgeBaseRetrievalRequest, + KnowledgeBaseRetrievalResponse, KnowledgeRetrievalIntent, KnowledgeRetrievalSemanticIntent, ) @@ -154,7 +162,9 @@ def __init__( top_k: int = 5, semantic_configuration_name: str | None = None, vector_field_name: str | None = None, - embedding_function: Callable[[str], Awaitable[list[float]]] | None = None, + embedding_function: Callable[[str], Awaitable[list[float]]] + | SupportsGetEmbeddings[str, list[float], Any] + | None = None, context_prompt: str | None = None, azure_openai_resource_url: str | None = None, model_deployment_name: str | None = None, @@ -181,7 +191,7 @@ def __init__( top_k: Maximum number of documents to retrieve. Default: 5. semantic_configuration_name: Name of semantic configuration in the index. vector_field_name: Name of the vector field in the index. - embedding_function: Async function to generate embeddings. + embedding_function: Async function to generate embeddings or a SupportsGetEmbeddings instance. context_prompt: Custom prompt to prepend to retrieved context. azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base. model_deployment_name: Model deployment name in Azure OpenAI. @@ -309,9 +319,20 @@ async def __aexit__( exc_tb: Any, ) -> None: """Async context manager exit - cleanup clients.""" + await self.close() + + async def close(self) -> None: + """Close all the open clients.""" if self._retrieval_client is not None: await self._retrieval_client.close() self._retrieval_client = None + self._knowledge_base_initialized = False + if self._search_client is not None: + await self._search_client.close() + self._search_client = None + if self._index_client is not None: + await self._index_client.close() + self._index_client = None # -- Hooks pattern --------------------------------------------------------- @@ -326,32 +347,23 @@ async def before_run( """Retrieve relevant context from Azure AI Search and add to session context.""" messages_list = list(context.input_messages) - def get_role_value(role: str | Any) -> str: - return role.value if hasattr(role, "value") else str(role) - filtered_messages = [ - msg - for msg in messages_list - if msg and msg.text and msg.text.strip() and get_role_value(msg.role) in ["user", "assistant"] + msg for msg in messages_list if msg and msg.text and msg.text.strip() and msg.role in ["user", "assistant"] ] if not filtered_messages: return if self.mode == "semantic": query = "\n".join(msg.text for msg in filtered_messages) - search_result_parts = await self._semantic_search(query) + result_messages = await self._semantic_search(query) else: recent_messages = filtered_messages[-self.agentic_message_history_count :] - search_result_parts = await self._agentic_search(recent_messages) + result_messages = await self._agentic_search(recent_messages) - if not search_result_parts: + if not result_messages: return - context_messages = [Message(role="user", text=self.context_prompt)] - context_messages.extend([Message(role="user", text=part) for part in search_result_parts]) - context.extend_messages(self.source_id, context_messages) - - # -- Internal methods (ported from AzureAISearchContextProvider) ----------- + context.extend_messages(self.source_id, [Message(role="user", text=self.context_prompt), *result_messages]) def _find_vector_fields(self, index: Any) -> list[str]: """Find all fields that can store vectors.""" @@ -432,7 +444,7 @@ async def _auto_discover_vector_field(self) -> None: self._auto_discovered_vector_field = True - async def _semantic_search(self, query: str) -> list[str]: + async def _semantic_search(self, query: str) -> list[Message]: """Perform semantic hybrid search.""" await self._auto_discover_vector_field() @@ -440,14 +452,14 @@ async def _semantic_search(self, query: str) -> list[str]: if self.vector_field_name: vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k if self._use_vectorizable_query: - vector_queries = [ - VectorizableTextQuery(text=query, k_nearest_neighbors=vector_k, fields=self.vector_field_name) - ] + vector_queries = [VectorizableTextQuery(text=query, k=vector_k, fields=self.vector_field_name)] elif self.embedding_function: - query_vector = await self.embedding_function(query) - vector_queries = [ - VectorizedQuery(vector=query_vector, k_nearest_neighbors=vector_k, fields=self.vector_field_name) - ] + if isinstance(self.embedding_function, SupportsGetEmbeddings): + query_vector = await self.embedding_function.get_embeddings(query) # type: ignore[reportUnknownVariableType] + query_vector = query_vector[0].vector # type: ignore[reportUnknownVariableType] + else: + query_vector = await self.embedding_function(query) + vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)] search_params: dict[str, Any] = {"search_text": query, "top": self.top_k} if vector_queries: @@ -461,13 +473,13 @@ async def _semantic_search(self, query: str) -> list[str]: raise RuntimeError("Search client is not initialized.") results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType] - formatted_results: list[str] = [] + result_messages: list[Message] = [] async for doc in results: # type: ignore[reportUnknownVariableType] doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType] doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType] if doc_text: - formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType] - return formatted_results + result_messages.append(Message(role="user", text=doc_text)) # type: ignore[reportUnknownArgumentType] + return result_messages async def _ensure_knowledge_base(self) -> None: """Ensure Knowledge Base and knowledge source are created or use existing KB.""" @@ -550,7 +562,7 @@ async def _ensure_knowledge_base(self) -> None: user_agent=AGENT_FRAMEWORK_USER_AGENT, ) - async def _agentic_search(self, messages: list[Message]) -> list[str]: + async def _agentic_search(self, messages: list[Message]) -> list[Message]: """Perform agentic retrieval with multi-hop reasoning.""" await self._ensure_knowledge_base() @@ -577,14 +589,7 @@ async def _agentic_search(self, messages: list[Message]) -> list[str]: include_activity=True, ) else: - kb_messages = [ - KnowledgeBaseMessage( - role=msg.role if hasattr(msg.role, "value") else str(msg.role), - content=[KnowledgeBaseMessageTextContent(text=msg.text)], - ) - for msg in messages - if msg.text - ] + kb_messages = self._prepare_messages_for_kb_search(messages) retrieval_request = KnowledgeBaseRetrievalRequest( messages=kb_messages, retrieval_reasoning_effort=reasoning_effort, @@ -596,17 +601,136 @@ async def _agentic_search(self, messages: list[Message]) -> list[str]: raise RuntimeError("Retrieval client not initialized.") retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request) - if retrieval_result.response and len(retrieval_result.response) > 0: - assistant_message = retrieval_result.response[-1] - if assistant_message.content: - answer_parts: list[str] = [] - for content_item in assistant_message.content: - if isinstance(content_item, KnowledgeBaseMessageTextContent) and content_item.text: - answer_parts.append(content_item.text) - if answer_parts: - return answer_parts - - return ["No results found from Knowledge Base."] + return self._parse_messages_from_kb_response(retrieval_result) + + @staticmethod + def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBaseMessage]: + """Convert framework Messages to KnowledgeBaseMessages for agentic retrieval. + + Handles text and image content types. Other content types (function calls, + errors, etc.) are skipped. + + Args: + messages: Framework messages to convert. + + Returns: + List of KnowledgeBaseMessage objects suitable for retrieval requests. + """ + kb_messages: list[KnowledgeBaseMessage] = [] + for msg in messages: + kb_content: list[KnowledgeBaseMessageTextContent | KnowledgeBaseMessageImageContent] = [] + if msg.contents: + for content in msg.contents: + match content.type: + case "text" if content.text: + kb_content.append(KnowledgeBaseMessageTextContent(text=content.text)) + case "uri" | "data" if ( + content.uri and content.media_type and content.media_type.startswith("image/") + ): + kb_content.append( + KnowledgeBaseMessageImageContent( + image=KnowledgeBaseMessageImageContentImage(url=content.uri), + ) + ) + elif msg.text: + kb_content.append(KnowledgeBaseMessageTextContent(text=msg.text)) + if kb_content: + kb_messages.append(KnowledgeBaseMessage(role=msg.role, content=kb_content)) # type: ignore[arg-type] + return kb_messages + + @staticmethod + def _parse_references_to_annotations(references: list[KnowledgeBaseReference] | None) -> list[Annotation]: + """Convert Knowledge Base references to framework Annotations. + + Captures all available fields from each reference subtype: URLs, doc keys, + reranker scores, source data, and the raw reference object itself. + + Args: + references: The references from a Knowledge Base retrieval response. + + Returns: + List of citation Annotations. + """ + if not references: + return [] + annotations: list[Annotation] = [] + for ref in references: + url: str | None = None + for attr in ("url", "blob_url", "doc_url", "web_url"): + url = getattr(ref, attr, None) + if url: + break + + annotation = Annotation( + type="citation", + url=url or "", + title=getattr(ref, "title", None) or ref.id, + ) + + extra: dict[str, Any] = { + "reference_id": ref.id, + "reference_type": getattr(ref, "type", None), + "activity_source": ref.activity_source, + } + if ref.reranker_score is not None: + extra["reranker_score"] = ref.reranker_score + if ref.source_data: + extra["source_data"] = ref.source_data + doc_key = getattr(ref, "doc_key", None) + if doc_key: + extra["doc_key"] = doc_key + if ref.additional_properties: + extra["sdk_additional_properties"] = ref.additional_properties + sensitivity_info = getattr(ref, "search_sensitivity_label_info", None) + if sensitivity_info: + extra["sensitivity_label"] = { + "display_name": sensitivity_info.display_name, + "sensitivity_label_id": sensitivity_info.sensitivity_label_id, + "is_encrypted": sensitivity_info.is_encrypted, + } + + annotation["additional_properties"] = extra + annotation["raw_representation"] = ref + annotations.append(annotation) + return annotations + + @staticmethod + def _parse_messages_from_kb_response(retrieval_result: KnowledgeBaseRetrievalResponse) -> list[Message]: + """Convert a Knowledge Base retrieval response to framework Messages. + + Each KnowledgeBaseMessage becomes a Message. References from the response + are converted to Annotations and attached to content items. + + Args: + retrieval_result: The full retrieval response including messages and references. + + Returns: + List of Messages, or a single default Message if no results found. + """ + if not retrieval_result.response: + return [Message(role="assistant", text="No results found from Knowledge Base.")] + + annotations = AzureAISearchContextProvider._parse_references_to_annotations(retrieval_result.references) + + result_messages: list[Message] = [] + for kb_msg in retrieval_result.response: + if not kb_msg.content: + continue + contents: list[Content] = [] + for item in kb_msg.content: + if isinstance(item, KnowledgeBaseMessageTextContent) and item.text: + contents.append(Content.from_text(item.text)) + elif isinstance(item, KnowledgeBaseMessageImageContent) and item.image and item.image.url: + contents.append(Content.from_uri(uri=item.image.url, media_type="image/png")) + if contents: + if annotations: + for c in contents: + c.annotations = annotations + result_messages.append(Message(role=kb_msg.role or "assistant", contents=contents)) + + if not result_messages: + return [Message(role="assistant", text="No results found from Knowledge Base.")] + return result_messages def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str: """Extract readable text from a search document with optional citation.""" diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 3a83015872..0b13a5599b 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from agent_framework import Message +from agent_framework import Content, Message from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import SettingNotFoundError from azure.core.credentials import AzureKeyCredential @@ -720,7 +720,7 @@ async def _search(**kwargs): results = await provider._semantic_search("test query") assert len(results) == 1 - assert "result text" in results[0] + assert "result text" in results[0].text call_kwargs = mock_client.search.call_args[1] assert call_kwargs["search_text"] == "test query" @@ -1100,9 +1100,11 @@ async def test_minimal_reasoning_returns_results(self) -> None: mock_content = Mock() mock_content.text = "Answer text" mock_message = Mock() + mock_message.role = "assistant" mock_message.content = [mock_content] mock_result = Mock() mock_result.response = [mock_message] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) @@ -1115,7 +1117,9 @@ async def test_minimal_reasoning_returns_results(self) -> None: ): results = await provider._agentic_search([Message(role="user", contents=["test query"])]) - assert results == ["Answer text"] + assert len(results) == 1 + assert results[0].text == "Answer text" + assert results[0].role == "assistant" async def test_non_minimal_reasoning_uses_messages(self) -> None: provider = _make_provider() @@ -1126,9 +1130,11 @@ async def test_non_minimal_reasoning_uses_messages(self) -> None: mock_content = Mock() mock_content.text = "Medium answer" mock_message = Mock() + mock_message.role = "assistant" mock_message.content = [mock_content] mock_result = Mock() mock_result.response = [mock_message] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) @@ -1143,7 +1149,8 @@ async def test_non_minimal_reasoning_uses_messages(self) -> None: Message(role="assistant", contents=["answer"]), ]) - assert results == ["Medium answer"] + assert len(results) == 1 + assert results[0].text == "Medium answer" mock_retrieval.retrieve.assert_awaited_once() async def test_no_response_returns_default_message(self) -> None: @@ -1154,13 +1161,15 @@ async def test_no_response_returns_default_message(self) -> None: mock_result = Mock() mock_result.response = [] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) provider._retrieval_client = mock_retrieval results = await provider._agentic_search([Message(role="user", contents=["query"])]) - assert results == ["No results found from Knowledge Base."] + assert len(results) == 1 + assert results[0].text == "No results found from Knowledge Base." async def test_empty_content_returns_default_message(self) -> None: provider = _make_provider() @@ -1172,13 +1181,15 @@ async def test_empty_content_returns_default_message(self) -> None: mock_message.content = None mock_result = Mock() mock_result.response = [mock_message] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) provider._retrieval_client = mock_retrieval results = await provider._agentic_search([Message(role="user", contents=["query"])]) - assert results == ["No results found from Knowledge Base."] + assert len(results) == 1 + assert results[0].text == "No results found from Knowledge Base." async def test_answer_synthesis_output_mode(self) -> None: provider = _make_provider() @@ -1190,9 +1201,11 @@ async def test_answer_synthesis_output_mode(self) -> None: mock_content = Mock() mock_content.text = "Synthesized answer" mock_message = Mock() + mock_message.role = "assistant" mock_message.content = [mock_content] mock_result = Mock() mock_result.response = [mock_message] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) @@ -1204,7 +1217,8 @@ async def test_answer_synthesis_output_mode(self) -> None: ): results = await provider._agentic_search([Message(role="user", contents=["query"])]) - assert results == ["Synthesized answer"] + assert len(results) == 1 + assert results[0].text == "Synthesized answer" async def test_content_without_text_excluded(self) -> None: provider = _make_provider() @@ -1217,9 +1231,11 @@ async def test_content_without_text_excluded(self) -> None: mock_content_no_text = Mock() mock_content_no_text.text = None mock_message = Mock() + mock_message.role = "assistant" mock_message.content = [mock_content_no_text, mock_content_with_text] mock_result = Mock() mock_result.response = [mock_message] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) @@ -1231,7 +1247,8 @@ async def test_content_without_text_excluded(self) -> None: ): results = await provider._agentic_search([Message(role="user", contents=["query"])]) - assert results == ["Good content"] + assert len(results) == 1 + assert results[0].text == "Good content" async def test_none_response_returns_default_message(self) -> None: provider = _make_provider() @@ -1241,13 +1258,355 @@ async def test_none_response_returns_default_message(self) -> None: mock_result = Mock() mock_result.response = None + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) provider._retrieval_client = mock_retrieval results = await provider._agentic_search([Message(role="user", contents=["query"])]) - assert results == ["No results found from Knowledge Base."] + assert len(results) == 1 + assert results[0].text == "No results found from Knowledge Base." + + +# -- before_run: agentic mode -------------------------------------------------- + + +# -- _prepare_messages_for_kb_search / _parse_content_from_kb_response -------- + + +class TestPrepareMessagesForKbSearch: + """Tests for _prepare_messages_for_kb_search.""" + + def test_text_only_messages(self) -> None: + messages = [ + Message(role="user", contents=["hello"]), + Message(role="assistant", contents=["world"]), + ] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + # Verify content is KnowledgeBaseMessageTextContent + from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageTextContent + + assert isinstance(result[0].content[0], KnowledgeBaseMessageTextContent) + assert result[0].content[0].text == "hello" + + def test_image_uri_content(self) -> None: + from agent_framework import Content + + img = Content.from_uri(uri="https://example.com/photo.png", media_type="image/png") + messages = [Message(role="user", contents=[img])] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 1 + from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageImageContent + + assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent) + assert result[0].content[0].image.url == "https://example.com/photo.png" + + def test_mixed_text_and_image_content(self) -> None: + from agent_framework import Content + + text = Content.from_text("describe this image") + img = Content.from_uri(uri="https://example.com/img.jpg", media_type="image/jpeg") + messages = [Message(role="user", contents=[text, img])] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 1 + assert len(result[0].content) == 2 + + def test_skips_non_text_non_image_content(self) -> None: + from agent_framework import Content + + error = Content.from_error(message="oops") + messages = [Message(role="user", contents=[error])] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 0 # message had no usable content + + def test_skips_empty_text(self) -> None: + from agent_framework import Content + + empty = Content.from_text("") + messages = [Message(role="user", contents=[empty])] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 0 + + def test_fallback_to_msg_text_when_no_contents(self) -> None: + msg = Message(role="user", text="fallback text") + result = AzureAISearchContextProvider._prepare_messages_for_kb_search([msg]) + assert len(result) == 1 + assert result[0].content[0].text == "fallback text" + + def test_data_uri_image(self) -> None: + from agent_framework import Content + + img = Content.from_data(data=b"\x89PNG", media_type="image/png") + messages = [Message(role="user", contents=[img])] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 1 + from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageImageContent + + assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent) + + def test_non_image_uri_skipped(self) -> None: + from agent_framework import Content + + pdf = Content.from_uri(uri="https://example.com/doc.pdf", media_type="application/pdf") + messages = [Message(role="user", contents=[pdf])] + result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages) + assert len(result) == 0 + + +class TestParseReferencesToAnnotations: + """Tests for _parse_references_to_annotations.""" + + def test_none_references(self) -> None: + result = AzureAISearchContextProvider._parse_references_to_annotations(None) + assert result == [] + + def test_empty_references(self) -> None: + result = AzureAISearchContextProvider._parse_references_to_annotations([]) + assert result == [] + + def test_search_index_reference_captures_doc_key(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference + + ref = KnowledgeBaseSearchIndexReference(id="ref-1", activity_source=0, doc_key="doc-1") + result = AzureAISearchContextProvider._parse_references_to_annotations([ref]) + assert len(result) == 1 + assert result[0]["type"] == "citation" + assert result[0]["title"] == "ref-1" + extra = result[0]["additional_properties"] + assert extra["reference_id"] == "ref-1" + assert extra["reference_type"] == "searchIndex" + assert extra["activity_source"] == 0 + assert extra["doc_key"] == "doc-1" + + def test_web_reference_with_url_and_title(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseWebReference + + ref = KnowledgeBaseWebReference( + id="ref-2", activity_source=0, url="https://example.com/page", title="Example Page" + ) + result = AzureAISearchContextProvider._parse_references_to_annotations([ref]) + assert len(result) == 1 + assert result[0]["url"] == "https://example.com/page" + assert result[0]["title"] == "Example Page" + assert result[0]["additional_properties"]["reference_type"] == "web" + + def test_blob_reference_extracts_blob_url(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseAzureBlobReference + + ref = KnowledgeBaseAzureBlobReference( + id="ref-3", activity_source=0, blob_url="https://storage.blob.core.windows.net/doc.pdf" + ) + result = AzureAISearchContextProvider._parse_references_to_annotations([ref]) + assert result[0]["url"] == "https://storage.blob.core.windows.net/doc.pdf" + assert result[0]["additional_properties"]["reference_type"] == "azureBlob" + + def test_source_data_and_reranker_score(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference + + ref = KnowledgeBaseSearchIndexReference( + id="ref-4", activity_source=0, source_data={"chunk": "some text"}, reranker_score=0.95 + ) + result = AzureAISearchContextProvider._parse_references_to_annotations([ref]) + extra = result[0]["additional_properties"] + assert extra["source_data"] == {"chunk": "some text"} + assert extra["reranker_score"] == 0.95 + + def test_raw_representation_stores_original_ref(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference + + ref = KnowledgeBaseSearchIndexReference(id="ref-5", activity_source=0) + result = AzureAISearchContextProvider._parse_references_to_annotations([ref]) + assert result[0]["raw_representation"] is ref + + def test_remote_sharepoint_captures_sensitivity_label(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseRemoteSharePointReference, + SharePointSensitivityLabelInfo, + ) + + label = SharePointSensitivityLabelInfo( + display_name="Confidential", sensitivity_label_id="lbl-1", is_encrypted=True + ) + ref = KnowledgeBaseRemoteSharePointReference( + id="ref-6", activity_source=0, web_url="https://sp.example.com/doc", search_sensitivity_label_info=label + ) + result = AzureAISearchContextProvider._parse_references_to_annotations([ref]) + assert result[0]["url"] == "https://sp.example.com/doc" + sl = result[0]["additional_properties"]["sensitivity_label"] + assert sl["display_name"] == "Confidential" + assert sl["sensitivity_label_id"] == "lbl-1" + assert sl["is_encrypted"] is True + + def test_multiple_references(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseSearchIndexReference, + KnowledgeBaseWebReference, + ) + + refs = [ + KnowledgeBaseSearchIndexReference(id="ref-a", activity_source=0), + KnowledgeBaseWebReference(id="ref-b", activity_source=1, url="https://example.com"), + ] + result = AzureAISearchContextProvider._parse_references_to_annotations(refs) + assert len(result) == 2 + assert result[0]["additional_properties"]["activity_source"] == 0 + assert result[1]["additional_properties"]["activity_source"] == 1 + + +class TestParseMessagesFromKbResponse: + """Tests for _parse_messages_from_kb_response.""" + + def test_converts_all_messages(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageTextContent, + KnowledgeBaseRetrievalResponse, + ) + + response = KnowledgeBaseRetrievalResponse( + response=[ + KnowledgeBaseMessage(role="user", content=[KnowledgeBaseMessageTextContent(text="q")]), + KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]), + ], + references=None, + ) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 2 + assert result[0].role == "user" + assert result[0].text == "q" + assert result[1].role == "assistant" + assert result[1].text == "answer" + + def test_none_response_returns_default(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseRetrievalResponse + + response = KnowledgeBaseRetrievalResponse(response=None, references=None) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 1 + assert result[0].text == "No results found from Knowledge Base." + + def test_empty_response_returns_default(self) -> None: + from azure.search.documents.knowledgebases.models import KnowledgeBaseRetrievalResponse + + response = KnowledgeBaseRetrievalResponse(response=[], references=None) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 1 + assert result[0].text == "No results found from Knowledge Base." + + def test_image_content(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageImageContent, + KnowledgeBaseMessageImageContentImage, + KnowledgeBaseRetrievalResponse, + ) + + response = KnowledgeBaseRetrievalResponse( + response=[ + KnowledgeBaseMessage( + role="assistant", + content=[ + KnowledgeBaseMessageImageContent( + image=KnowledgeBaseMessageImageContentImage(url="https://img.example.com/a.png") + ) + ], + ), + ], + references=None, + ) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 1 + assert result[0].contents[0].type == "uri" + assert result[0].contents[0].uri == "https://img.example.com/a.png" + + def test_mixed_text_and_image_content(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageImageContent, + KnowledgeBaseMessageImageContentImage, + KnowledgeBaseMessageTextContent, + KnowledgeBaseRetrievalResponse, + ) + + response = KnowledgeBaseRetrievalResponse( + response=[ + KnowledgeBaseMessage( + role="assistant", + content=[ + KnowledgeBaseMessageTextContent(text="description"), + KnowledgeBaseMessageImageContent( + image=KnowledgeBaseMessageImageContentImage(url="https://img.example.com/b.png") + ), + ], + ), + ], + references=None, + ) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 1 + assert len(result[0].contents) == 2 + assert result[0].contents[0].type == "text" + assert result[0].contents[1].type == "uri" + + def test_references_become_annotations(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageTextContent, + KnowledgeBaseRetrievalResponse, + KnowledgeBaseWebReference, + ) + + response = KnowledgeBaseRetrievalResponse( + response=[ + KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]), + ], + references=[ + KnowledgeBaseWebReference( + id="ref-1", activity_source=0, url="https://example.com", title="Example" + ), + ], + ) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 1 + annotations = result[0].contents[0].annotations + assert annotations is not None + assert len(annotations) == 1 + assert annotations[0]["type"] == "citation" + assert annotations[0]["url"] == "https://example.com" + assert annotations[0]["title"] == "Example" + + def test_multiple_messages_with_references(self) -> None: + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageTextContent, + KnowledgeBaseRetrievalResponse, + KnowledgeBaseSearchIndexReference, + ) + + response = KnowledgeBaseRetrievalResponse( + response=[ + KnowledgeBaseMessage(role="user", content=[KnowledgeBaseMessageTextContent(text="q")]), + KnowledgeBaseMessage( + role="assistant", + content=[ + KnowledgeBaseMessageTextContent(text="part1"), + KnowledgeBaseMessageTextContent(text="part2"), + ], + ), + ], + references=[KnowledgeBaseSearchIndexReference(id="doc-1", activity_source=0)], + ) + result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) + assert len(result) == 2 + # All content items get annotations + for msg in result: + for c in msg.contents: + assert c.annotations is not None + assert len(c.annotations) == 1 # -- before_run: agentic mode -------------------------------------------------- @@ -1266,9 +1625,11 @@ async def test_agentic_mode_calls_agentic_search(self) -> None: mock_content = Mock() mock_content.text = "agentic result" mock_message = Mock() + mock_message.role = "assistant" mock_message.content = [mock_content] mock_result = Mock() mock_result.response = [mock_message] + mock_result.references = None mock_retrieval = AsyncMock() mock_retrieval.retrieve = AsyncMock(return_value=mock_result) diff --git a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py index 02bac618df..80c35f9e1f 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py +++ b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py @@ -82,8 +82,8 @@ async def main() -> None: mode="agentic", knowledge_base_name=knowledge_base_name, # Optional: Configure retrieval behavior - knowledge_base_output_mode="extractive_data", # or "answer_synthesis" - retrieval_reasoning_effort="minimal", # or "medium", "low" + knowledge_base_output_mode="answer_synthesis", # or "answer_synthesis" + retrieval_reasoning_effort="medium", # or "medium", "low" ) else: # Auto-create Knowledge Base from index @@ -135,6 +135,9 @@ async def main() -> None: async for chunk in agent.run(user_input, stream=True): if chunk.text: print(chunk.text, end="", flush=True) + for content in chunk.contents: + if content.annotations: + print(f"\n[Sources: {content.annotations}]", end="", flush=True) print("\n") diff --git a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py index e9763531fb..2217cefd23 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py +++ b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py @@ -4,7 +4,7 @@ import os from agent_framework import Agent -from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider +from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider, AzureOpenAIEmbeddingClient from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv @@ -30,6 +30,8 @@ - AZURE_SEARCH_INDEX_NAME: Your search index name - AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint - AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o") + - AZURE_OPENAI_EMBEDDING_MODEL_ID: (Optional) Your embedding model for hybrid search (e.g., "text-embedding-3-small") + - AZURE_OPENAI_ENDPOINT: (Optional) Your Azure OpenAI resource URL, required if using an OpenAI embedding model for hybrid search """ # Sample queries to demonstrate RAG @@ -43,12 +45,24 @@ async def main() -> None: """Main function demonstrating Azure AI Search semantic mode.""" + credential = AzureCliCredential() + # Get configuration from environment search_endpoint = os.environ["AZURE_SEARCH_ENDPOINT"] search_key = os.environ.get("AZURE_SEARCH_API_KEY") index_name = os.environ["AZURE_SEARCH_INDEX_NAME"] project_endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] model_deployment = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o") + openai_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + embedding_model = os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL_ID", "text-embedding-3-small") + + embedding_client = None + if openai_endpoint and embedding_model: + embedding_client = AzureOpenAIEmbeddingClient( + endpoint=openai_endpoint, + deployment_name=embedding_model, + credential=credential, + ) # Create Azure AI Search context provider with semantic mode (recommended, fast) print("Using SEMANTIC mode (hybrid search + semantic ranking, fast)\n") @@ -57,9 +71,13 @@ async def main() -> None: endpoint=search_endpoint, index_name=index_name, api_key=search_key, # Use api_key for API key auth, or credential for managed identity - credential=AzureCliCredential() if not search_key else None, + credential=credential if not search_key else None, mode="semantic", # Default mode top_k=3, # Retrieve top 3 most relevant documents + embedding_function=embedding_client, # Provide embedding function for hybrid search + vector_field_name="DescriptionVector" + if embedding_client + else None, # Set vector field for hybrid search if using embeddings ) # Create agent with search context provider @@ -68,7 +86,7 @@ async def main() -> None: AzureAIAgentClient( project_endpoint=project_endpoint, model_deployment_name=model_deployment, - credential=AzureCliCredential(), + credential=credential, ) as client, Agent( client=client, From 3bb8c8d494903900cf554b04a5cda6a825c70f39 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 24 Feb 2026 16:30:14 +0100 Subject: [PATCH 2/3] Fix mypy errors and embedding function test - Use separate variable for embeddings result to avoid mypy type reassignment error - Fix test_vectorized_query_with_embedding_function: use real async function instead of AsyncMock which falsely matches SupportsGetEmbeddings protocol Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_azure_ai_search/_context_provider.py | 4 ++-- .../tests/test_aisearch_context_provider.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 4be5ad4d74..cd7ec3448e 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -455,8 +455,8 @@ async def _semantic_search(self, query: str) -> list[Message]: vector_queries = [VectorizableTextQuery(text=query, k=vector_k, fields=self.vector_field_name)] elif self.embedding_function: if isinstance(self.embedding_function, SupportsGetEmbeddings): - query_vector = await self.embedding_function.get_embeddings(query) # type: ignore[reportUnknownVariableType] - query_vector = query_vector[0].vector # type: ignore[reportUnknownVariableType] + embeddings = await self.embedding_function.get_embeddings(query) # type: ignore[reportUnknownVariableType] + query_vector: list[float] = embeddings[0].vector # type: ignore[reportUnknownVariableType] else: query_vector = await self.embedding_function(query) vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)] diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 0b13a5599b..7cc51e4930 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -746,7 +746,11 @@ async def test_vectorized_query_with_embedding_function(self) -> None: provider = _make_provider() provider._use_vectorizable_query = False provider.vector_field_name = "embedding" - provider.embedding_function = AsyncMock(return_value=[0.1, 0.2, 0.3]) + + async def _embed(query: str) -> list[float]: + return [0.1, 0.2, 0.3] + + provider.embedding_function = _embed mock_client = AsyncMock() async def _search(**kwargs): @@ -757,7 +761,6 @@ async def _search(**kwargs): results = await provider._semantic_search("embed query") assert len(results) == 1 - provider.embedding_function.assert_awaited_once_with("embed query") call_kwargs = mock_client.search.call_args[1] assert "vector_queries" in call_kwargs From e73908c744f6954f79848f8e73819ea70ea4e155 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 24 Feb 2026 17:01:56 +0100 Subject: [PATCH 3/3] fixes from feedback --- .../agent_framework_azure_ai_search/_context_provider.py | 2 +- .../azure_ai_search/azure_ai_with_search_context_agentic.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index cd7ec3448e..ff245817b7 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -455,7 +455,7 @@ async def _semantic_search(self, query: str) -> list[Message]: vector_queries = [VectorizableTextQuery(text=query, k=vector_k, fields=self.vector_field_name)] elif self.embedding_function: if isinstance(self.embedding_function, SupportsGetEmbeddings): - embeddings = await self.embedding_function.get_embeddings(query) # type: ignore[reportUnknownVariableType] + embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType] query_vector: list[float] = embeddings[0].vector # type: ignore[reportUnknownVariableType] else: query_vector = await self.embedding_function(query) diff --git a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py index 80c35f9e1f..d3c618d8a8 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py +++ b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py @@ -82,8 +82,8 @@ async def main() -> None: mode="agentic", knowledge_base_name=knowledge_base_name, # Optional: Configure retrieval behavior - knowledge_base_output_mode="answer_synthesis", # or "answer_synthesis" - retrieval_reasoning_effort="medium", # or "medium", "low" + knowledge_base_output_mode="extractive_data", # or "answer_synthesis" + retrieval_reasoning_effort="minimal", # or "medium", "low" ) else: # Auto-create Knowledge Base from index