From ab36c7ed1f0ca13eaef9b6c60b6ca22c08cedd30 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Tue, 17 Feb 2026 20:27:28 -0800 Subject: [PATCH 1/8] Python: Enhance Azure AI Search citations with document URLs in Foundry V2 (Responses API) Override _parse_response_from_openai and _parse_chunk_from_openai in RawAzureAIClient to extract get_urls from azure_ai_search_call_output items and enrich url_citation annotations with document-specific URLs. - Non-streaming: first pass collects get_urls, post-processes annotations - Streaming: captures search output state, enriches url_citation events (also handles url_citation annotation type not handled by base class) - Updated V2 sample to demonstrate citation URL extraction - Added 14 unit tests covering extraction, enrichment, and edge cases Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_azure_ai/_client.py | 211 ++++++++++++ .../azure-ai/tests/test_azure_ai_client.py | 314 ++++++++++++++++++ .../azure_ai/azure_ai_with_azure_ai_search.py | 60 +++- 3 files changed, 582 insertions(+), 3 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index afbbf6cea3..80a54a8e12 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -4,6 +4,7 @@ import json import logging +import re import sys from collections.abc import Callable, Mapping, Sequence from contextlib import suppress @@ -12,14 +13,19 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, Agent, + Annotation, BaseContextProvider, ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, + ChatResponse, + ChatResponseUpdate, + Content, FunctionInvocationConfiguration, FunctionInvocationLayer, FunctionTool, Message, MiddlewareTypes, + TextSpanRegion, ) from agent_framework._settings import load_settings from agent_framework._tools import ToolTypes @@ -43,6 +49,12 @@ from azure.ai.projects.models import FileSearchTool as ProjectsFileSearchTool from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import ResourceNotFoundError +from openai.types.responses.parsed_response import ParsedResponse +from openai.types.responses.response import Response as OpenAIResponse +from openai.types.responses.response_stream_event import ( + ResponseStreamEvent as OpenAIResponseStreamEvent, +) +from pydantic import BaseModel from ._shared import AzureAISettings, create_text_format_config @@ -616,6 +628,205 @@ def _update_agent_name_and_description(self, agent_name: str | None, description if description and not self.agent_description: self.agent_description = description + # region Azure AI Search Citation Enhancement + + def _extract_azure_search_urls(self, output_items: Any) -> list[str]: + """Extract document URLs from azure_ai_search_call_output items. + + Args: + output_items: The response output items to scan. + + Returns: + A flat list of get_urls from all azure_ai_search_call_output items. + """ + get_urls: list[str] = [] + for item in output_items: + item_type = getattr(item, "type", None) + if isinstance(item, dict): + item_type = item.get("type") + if item_type == "azure_ai_search_call_output": + output = getattr(item, "output", None) + if isinstance(item, dict): + output = item.get("output") + if output is not None: + urls = getattr(output, "get_urls", None) + if isinstance(output, dict): + urls = output.get("get_urls") + if urls and isinstance(urls, list): + get_urls.extend(urls) + return get_urls + + def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) -> str | None: + """Map a citation title like 'doc_0' to its corresponding get_url. + + Args: + citation_title: The annotation title (e.g., "doc_0"). + get_urls: The list of document URLs from azure_ai_search_call_output. + + Returns: + The matching document URL if found, otherwise None. + """ + if not citation_title or not get_urls: + return None + match = re.search(r"doc_(\d+)", citation_title) + if not match: + return None + doc_index = int(match.group(1)) + if 0 <= doc_index < len(get_urls): + return str(get_urls[doc_index]) + return None + + def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls: list[str]) -> None: + """Enrich url_citation annotations in contents with real document URLs from Azure AI Search. + + Args: + contents: The parsed content list from a ChatResponse. + get_urls: Document URLs extracted from azure_ai_search_call_output. + """ + if not get_urls: + return + for content in contents: + annotations = getattr(content, "annotations", None) + if not annotations: + continue + for annotation in annotations: + if not isinstance(annotation, dict): + continue + if annotation.get("type") != "citation": + continue + title = annotation.get("title") + doc_url = self._get_search_doc_url(title, get_urls) + if doc_url: + props = annotation.get("additional_properties") or {} + props["get_url"] = doc_url + annotation["additional_properties"] = props + + @override + def _parse_response_from_openai( + self, + response: OpenAIResponse | ParsedResponse[BaseModel], + options: dict[str, Any], + ) -> ChatResponse: + """Parse response with Azure AI Search citation enrichment.""" + # Extract search URLs before parsing + get_urls = self._extract_azure_search_urls(response.output) + + # Let base class do the standard parsing + result = super()._parse_response_from_openai(response, options) + + # Enrich url_citation annotations with real document URLs + if get_urls and result.messages: + for msg in result.messages: + self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls) + + return result + + @override + def _parse_chunk_from_openai( + self, + event: OpenAIResponseStreamEvent, + options: dict[str, Any], + function_call_ids: dict[int, tuple[str, str]], + ) -> ChatResponseUpdate: + """Parse streaming event with Azure AI Search citation enrichment.""" + # Capture search output URLs when azure_ai_search_call_output items arrive + if event.type == "response.output_item.added": + event_item = event.item + item_type = getattr(event_item, "type", None) + if isinstance(event_item, dict): + item_type = event_item.get("type") + if item_type == "azure_ai_search_call_output": + urls = self._extract_azure_search_urls([event_item]) + if urls: + if not hasattr(self, "_streaming_search_get_urls"): + self._streaming_search_get_urls: list[str] = [] + self._streaming_search_get_urls.extend(urls) + + # Let base class parse the event + result = super()._parse_chunk_from_openai(event, options, function_call_ids) + + # Handle url_citation annotations in streaming — base class doesn't handle these, + # so we produce an Annotation with the enriched URL from captured search data. + if event.type == "response.output_text.annotation.added": + annotation_data: Any = event.annotation + ann_type = ( + annotation_data.get("type") + if isinstance(annotation_data, dict) + else getattr(annotation_data, "type", None) + ) + if ann_type == "url_citation": + ann_title = ( + annotation_data.get("title") + if isinstance(annotation_data, dict) + else getattr(annotation_data, "title", None) + ) + ann_url = ( + annotation_data.get("url") + if isinstance(annotation_data, dict) + else getattr(annotation_data, "url", None) + ) + ann_start = ( + annotation_data.get("start_index") + if isinstance(annotation_data, dict) + else getattr(annotation_data, "start_index", None) + ) + ann_end = ( + annotation_data.get("end_index") + if isinstance(annotation_data, dict) + else getattr(annotation_data, "end_index", None) + ) + + additional_props: dict[str, Any] = { + "annotation_index": getattr(event, "annotation_index", None), + } + + # Enrich with get_url from captured search data + if hasattr(self, "_streaming_search_get_urls") and self._streaming_search_get_urls: + doc_url = self._get_search_doc_url(ann_title, self._streaming_search_get_urls) + if doc_url: + additional_props["get_url"] = doc_url + + annotation_obj = Annotation( + type="citation", + title=ann_title, + url=ann_url, + additional_properties=additional_props, + raw_representation=annotation_data, + ) + if ann_start is not None and ann_end is not None: + annotation_obj["annotated_regions"] = [ + TextSpanRegion( + type="text_span", + start_index=ann_start, + end_index=ann_end, + ) + ] + + text_content = Content.from_text(text="", raw_representation=event) + text_content.annotations = [annotation_obj] + # Add to result contents + contents_list = list(result.contents or []) + contents_list.append(text_content) + result = ChatResponseUpdate( + contents=contents_list, + conversation_id=result.conversation_id, + response_id=result.response_id, + role=result.role, + model_id=result.model_id, + continuation_token=result.continuation_token, + additional_properties=result.additional_properties, + raw_representation=result.raw_representation, + ) + + # Clear streaming state when response completes + if event.type == "response.completed": + if hasattr(self, "_streaming_search_get_urls"): + del self._streaming_search_get_urls + + return result + + # endregion + # region Hosted Tool Factory Methods (Azure-specific overrides) @staticmethod diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 38839b32d0..dfc5acc269 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -1778,3 +1778,317 @@ def test_get_image_generation_tool_with_options() -> None: # endregion + + +# region Azure AI Search Citation Enhancement Tests + + +def test_extract_azure_search_urls_with_dict_items(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with dict-style output items.""" + client = create_test_azure_ai_client(mock_project_client) + output_items = [ + { + "type": "azure_ai_search_call", + "arguments": '{"query":"test"}', + }, + { + "type": "azure_ai_search_call_output", + "output": { + "documents": [{"id": "1", "url": "https://search.example.com/"}], + "get_urls": [ + "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01", + ], + }, + }, + {"type": "message", "content": [{"type": "output_text", "text": "hello"}]}, + ] + urls = client._extract_azure_search_urls(output_items) + assert len(urls) == 2 + assert urls[0] == "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01" + assert urls[1] == "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01" + + +def test_extract_azure_search_urls_with_object_items(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with object-style output items.""" + client = create_test_azure_ai_client(mock_project_client) + mock_output = MagicMock() + mock_output.get_urls = ["https://example.com/doc/1", "https://example.com/doc/2"] + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = mock_output + + urls = client._extract_azure_search_urls([mock_item]) + assert urls == ["https://example.com/doc/1", "https://example.com/doc/2"] + + +def test_extract_azure_search_urls_no_search_items(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with no search output items.""" + client = create_test_azure_ai_client(mock_project_client) + output_items = [{"type": "message", "content": []}] + urls = client._extract_azure_search_urls(output_items) + assert urls == [] + + +def test_get_search_doc_url_valid(mock_project_client: MagicMock) -> None: + """Test _get_search_doc_url with valid doc_N title.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://example.com/doc/0", "https://example.com/doc/1", "https://example.com/doc/2"] + + assert client._get_search_doc_url("doc_0", get_urls) == "https://example.com/doc/0" + assert client._get_search_doc_url("doc_1", get_urls) == "https://example.com/doc/1" + assert client._get_search_doc_url("doc_2", get_urls) == "https://example.com/doc/2" + + +def test_get_search_doc_url_out_of_range(mock_project_client: MagicMock) -> None: + """Test _get_search_doc_url with out-of-range index.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://example.com/doc/0"] + assert client._get_search_doc_url("doc_5", get_urls) is None + + +def test_get_search_doc_url_no_match(mock_project_client: MagicMock) -> None: + """Test _get_search_doc_url with non-matching title.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://example.com/doc/0"] + assert client._get_search_doc_url("some_title", get_urls) is None + assert client._get_search_doc_url(None, get_urls) is None + assert client._get_search_doc_url("doc_0", []) is None + + +def test_enrich_annotations_with_search_urls(mock_project_client: MagicMock) -> None: + """Test _enrich_annotations_with_search_urls enriches citation annotations.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", + ] + + content = Content.from_text(text="test response") + content.annotations = [ + { + "type": "citation", + "title": "doc_0", + "url": "https://search.example.com/", + }, + { + "type": "citation", + "title": "doc_1", + "url": "https://search.example.com/", + }, + ] + + client._enrich_annotations_with_search_urls([content], get_urls) + + assert content.annotations[0]["additional_properties"]["get_url"] == get_urls[0] + assert content.annotations[1]["additional_properties"]["get_url"] == get_urls[1] + + +def test_enrich_annotations_no_match(mock_project_client: MagicMock) -> None: + """Test _enrich_annotations_with_search_urls with non-matching titles.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] + + content = Content.from_text(text="test response") + content.annotations = [ + { + "type": "citation", + "title": "some_title", + "url": "https://search.example.com/", + }, + ] + + client._enrich_annotations_with_search_urls([content], get_urls) + assert "additional_properties" not in content.annotations[0] or "get_url" not in content.annotations[0].get( + "additional_properties", {} + ) + + +def test_enrich_annotations_empty_get_urls(mock_project_client: MagicMock) -> None: + """Test _enrich_annotations_with_search_urls with empty get_urls.""" + client = create_test_azure_ai_client(mock_project_client) + content = Content.from_text(text="test") + content.annotations = [{"type": "citation", "title": "doc_0", "url": "https://example.com/"}] + + # Should not raise or modify + client._enrich_annotations_with_search_urls([content], []) + assert "additional_properties" not in content.annotations[0] + + +def test_parse_response_from_openai_with_search_citations(mock_project_client: MagicMock) -> None: + """Test _parse_response_from_openai enriches url_citation annotations with search URLs.""" + client = create_test_azure_ai_client(mock_project_client) + + # Build a mock OpenAI response with azure_ai_search_call_output and message with url_citation + mock_search_output = MagicMock() + mock_search_output.type = "azure_ai_search_call_output" + mock_search_output_data = MagicMock() + mock_search_output_data.get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", + ] + mock_search_output.output = mock_search_output_data + + mock_annotation = MagicMock() + mock_annotation.type = "url_citation" + mock_annotation.title = "doc_0" + mock_annotation.url = "https://search.example.com/" + mock_annotation.start_index = 100 + mock_annotation.end_index = 112 + + mock_text_content = MagicMock() + mock_text_content.type = "output_text" + mock_text_content.text = "Here is the result【5:0†source】." + mock_text_content.annotations = [mock_annotation] + mock_text_content.logprobs = None + + mock_message = MagicMock() + mock_message.type = "message" + mock_message.content = [mock_text_content] + mock_message.role = "assistant" + mock_message.id = "msg_123" + mock_message.status = "completed" + + mock_response = MagicMock(spec=OpenAIResponse) + mock_response.output = [mock_search_output, mock_message] + mock_response.id = "resp_123" + mock_response.created_at = 1700000000 + mock_response.model = "gpt-4" + mock_response.metadata = {} + mock_response.usage = None + mock_response.status = "completed" + + result = client._parse_response_from_openai(mock_response, options={"store": False}) + + # Find the text content with annotations + assert result.messages is not None + found_citation = False + for msg in result.messages: + for content in msg.contents: + if hasattr(content, "annotations") and content.annotations: + for ann in content.annotations: + if isinstance(ann, dict) and ann.get("title") == "doc_0": + found_citation = True + assert ann["additional_properties"]["get_url"] == ( + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" + ) + assert found_citation, "Expected to find enriched citation annotation" + + +def test_parse_response_from_openai_without_search_output(mock_project_client: MagicMock) -> None: + """Test _parse_response_from_openai works normally when no search output exists.""" + client = create_test_azure_ai_client(mock_project_client) + + mock_text_content = MagicMock() + mock_text_content.type = "output_text" + mock_text_content.text = "Hello world" + mock_text_content.annotations = [] + mock_text_content.logprobs = None + + mock_message = MagicMock() + mock_message.type = "message" + mock_message.content = [mock_text_content] + mock_message.role = "assistant" + mock_message.id = "msg_456" + mock_message.status = "completed" + + mock_response = MagicMock(spec=OpenAIResponse) + mock_response.output = [mock_message] + mock_response.id = "resp_456" + mock_response.created_at = 1700000000 + mock_response.model = "gpt-4" + mock_response.metadata = {} + mock_response.usage = None + mock_response.status = "completed" + + result = client._parse_response_from_openai(mock_response, options={"store": False}) + assert result.messages is not None + + +def test_parse_chunk_from_openai_captures_search_urls(mock_project_client: MagicMock) -> None: + """Test _parse_chunk_from_openai captures search URLs from azure_ai_search_call_output events.""" + client = create_test_azure_ai_client(mock_project_client) + + # Simulate azure_ai_search_call_output item arriving via response.output_item.added + mock_output_data = MagicMock() + mock_output_data.get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + ] + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = mock_output_data + + mock_event = MagicMock() + mock_event.type = "response.output_item.added" + mock_event.item = mock_item + mock_event.output_index = 0 + + client._parse_chunk_from_openai(mock_event, options={}, function_call_ids={}) + + assert hasattr(client, "_streaming_search_get_urls") + assert len(client._streaming_search_get_urls) == 1 + assert ( + client._streaming_search_get_urls[0] == "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" + ) + + +def test_parse_chunk_from_openai_enriches_url_citation(mock_project_client: MagicMock) -> None: + """Test _parse_chunk_from_openai enriches url_citation annotations with search URLs.""" + client = create_test_azure_ai_client(mock_project_client) + + # Pre-populate search URLs (simulating earlier capture) + client._streaming_search_get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", + ] + + # Simulate url_citation annotation event + mock_annotation = MagicMock() + mock_annotation.type = "url_citation" + mock_annotation.title = "doc_0" + mock_annotation.url = "https://search.example.com/" + mock_annotation.start_index = 100 + mock_annotation.end_index = 112 + + mock_event = MagicMock() + mock_event.type = "response.output_text.annotation.added" + mock_event.annotation = mock_annotation + mock_event.annotation_index = 0 + + result = client._parse_chunk_from_openai(mock_event, options={}, function_call_ids={}) + + # Should have content with annotation containing get_url + assert result.contents is not None + found = False + for content in result.contents: + if hasattr(content, "annotations") and content.annotations: + for ann in content.annotations: + if isinstance(ann, dict) and ann.get("title") == "doc_0": + found = True + assert ann["additional_properties"]["get_url"] == ( + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" + ) + assert found, "Expected url_citation annotation with enriched get_url" + + +def test_parse_chunk_from_openai_clears_state_on_completed(mock_project_client: MagicMock) -> None: + """Test _parse_chunk_from_openai clears streaming state on response.completed.""" + client = create_test_azure_ai_client(mock_project_client) + client._streaming_search_get_urls = ["https://example.com/doc/0"] + + mock_response = MagicMock() + mock_response.id = "resp_123" + mock_response.model = "gpt-4" + mock_response.usage = None + mock_response.status = "completed" + + mock_event = MagicMock() + mock_event.type = "response.completed" + mock_event.response = mock_response + + client._parse_chunk_from_openai(mock_event, options={}, function_call_ids={}) + + assert not hasattr(client, "_streaming_search_get_urls") + + +# endregion diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py index c4ee686d87..c4a0c277ad 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py @@ -2,6 +2,7 @@ import asyncio import os +from agent_framework import Annotation from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential @@ -11,6 +12,9 @@ This sample demonstrates usage of AzureAIProjectAgentProvider with Azure AI Search to search through indexed data and answer user questions about it. +Citations from Azure AI Search are automatically enriched with document-specific +URLs (get_url) that can be used to retrieve the original documents. + Prerequisites: 1. Set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME environment variables. 2. Ensure you have an Azure AI Search connection configured in your Azure AI project @@ -25,8 +29,10 @@ async def main() -> None: ): agent = await provider.create_agent( name="MySearchAgent", - instructions="""You are a helpful assistant. You must always provide citations for - answers using the tool and render them as: `[message_idx:search_idx†source]`.""", + instructions=( + "You are a helpful agent that searches hotel information using Azure AI Search. " + "Always use the search tool and index to find hotel data and provide accurate information." + ), tools={ "type": "azure_ai_search", "azure_ai_search": { @@ -42,11 +48,59 @@ async def main() -> None: }, ) - query = "Tell me about insurance options" + query = ( + "Use Azure AI search knowledge tool to find detailed information about a winter hotel." + " Use the search tool and index." # You can modify prompt to force tool usage + ) print(f"User: {query}") + + # Non-streaming: get response with enriched citations result = await agent.run(query) print(f"Result: {result}\n") + # Display citations with document-specific URLs + if result.messages: + citations: list[Annotation] = [] + for msg in result.messages: + for content in msg.contents: + if hasattr(content, "annotations") and content.annotations: + citations.extend(content.annotations) + + if citations: + print("Citations:") + for i, citation in enumerate(citations, 1): + url = citation.get("url", "N/A") + # get_url contains the document-specific REST API URL from Azure AI Search + get_url = (citation.get("additional_properties") or {}).get("get_url") + print(f" [{i}] {citation.get('title', 'N/A')}") + print(f" URL: {url}") + if get_url: + print(f" Document URL: {get_url}") + + # Streaming: collect citations from streamed response + print("\n--- Streaming ---") + print(f"User: {query}") + print("Agent: ", end="", flush=True) + streaming_citations: list[Annotation] = [] + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + for content in getattr(chunk, "contents", []): + annotations = getattr(content, "annotations", []) + if annotations: + streaming_citations.extend(annotations) + + print() + if streaming_citations: + print("\nStreaming Citations:") + for i, citation in enumerate(streaming_citations, 1): + url = citation.get("url", "N/A") + get_url = (citation.get("additional_properties") or {}).get("get_url") + print(f" [{i}] {citation.get('title', 'N/A')}") + print(f" URL: {url}") + if get_url: + print(f" Document URL: {get_url}") + if __name__ == "__main__": asyncio.run(main()) From 37ec78e14e5525a48f21def7dfae73c99135718d Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Thu, 19 Feb 2026 12:03:08 -0800 Subject: [PATCH 2/8] refactor: rework search citation enrichment to override _inner_get_response - Remove all direct openai/pydantic imports from _client.py - Override _inner_get_response instead of _parse_response_from_openai/_parse_chunk_from_openai - Use closure-local state for streaming instead of instance-level _streaming_search_get_urls - Add _build_url_citation_content helper for streaming url_citation handling - Fix mypy errors by using str(value or '') for Annotation TypedDict fields - Fix docstring to say 'citation' instead of 'url_citation' - Update tests to match new approach Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_azure_ai/_client.py | 251 ++++++++-------- .../azure-ai/tests/test_azure_ai_client.py | 277 ++++++++++-------- 2 files changed, 290 insertions(+), 238 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 0e0656b9b5..857be639c2 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -6,7 +6,7 @@ import logging import re import sys -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import suppress from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast @@ -25,6 +25,7 @@ FunctionTool, Message, MiddlewareTypes, + ResponseStream, TextSpanRegion, ) from agent_framework._settings import load_settings @@ -48,12 +49,6 @@ ) from azure.ai.projects.models import FileSearchTool as ProjectsFileSearchTool from azure.core.exceptions import ResourceNotFoundError -from openai.types.responses.parsed_response import ParsedResponse -from openai.types.responses.response import Response as OpenAIResponse -from openai.types.responses.response_stream_event import ( - ResponseStreamEvent as OpenAIResponseStreamEvent, -) -from pydantic import BaseModel from ._shared import AzureAISettings, create_text_format_config @@ -677,10 +672,13 @@ def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) - return None def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls: list[str]) -> None: - """Enrich url_citation annotations in contents with real document URLs from Azure AI Search. + """Enrich citation annotations in contents with real document URLs from Azure AI Search. + + Looks for annotations with ``type == "citation"`` and a ``title`` matching ``doc_N``, + then adds the corresponding document URL from *get_urls* to ``additional_properties["get_url"]``. Args: - contents: The parsed content list from a ChatResponse. + contents: The parsed content list from a ChatResponse or ChatResponseUpdate. get_urls: Document URLs extracted from azure_ai_search_call_output. """ if not get_urls: @@ -701,129 +699,138 @@ def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls props["get_url"] = doc_url annotation["additional_properties"] = props - @override - def _parse_response_from_openai( - self, - response: OpenAIResponse | ParsedResponse[BaseModel], - options: dict[str, Any], - ) -> ChatResponse: - """Parse response with Azure AI Search citation enrichment.""" - # Extract search URLs before parsing - get_urls = self._extract_azure_search_urls(response.output) + def _build_url_citation_content( + self, annotation_data: Any, get_urls: list[str], raw_event: Any + ) -> Content: + """Build a Content with a citation Annotation from a url_citation streaming event. - # Let base class do the standard parsing - result = super()._parse_response_from_openai(response, options) + The base class does not handle ``url_citation`` annotations in streaming, so this + method creates the appropriate framework content for them. - # Enrich url_citation annotations with real document URLs - if get_urls and result.messages: - for msg in result.messages: - self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls) + Args: + annotation_data: The raw annotation object/dict from the streaming event. + get_urls: Captured document URLs for enrichment. + raw_event: The raw streaming event for raw_representation. + + Returns: + A Content object containing the citation annotation. + """ - return result + def _val(key: str) -> Any: + if isinstance(annotation_data, dict): + return annotation_data.get(key) + return getattr(annotation_data, key, None) + + ann_title = str(_val("title") or "") + ann_url = str(_val("url") or "") + ann_start = _val("start_index") + ann_end = _val("end_index") + + additional_props: dict[str, Any] = { + "annotation_index": getattr(raw_event, "annotation_index", None), + } + doc_url = self._get_search_doc_url(ann_title, get_urls) + if doc_url: + additional_props["get_url"] = doc_url + + annotation_obj = Annotation( + type="citation", + title=ann_title, + url=ann_url, + additional_properties=additional_props, + raw_representation=annotation_data, + ) + if ann_start is not None and ann_end is not None: + annotation_obj["annotated_regions"] = [ + TextSpanRegion(type="text_span", start_index=ann_start, end_index=ann_end) + ] + + text_content = Content.from_text(text="", raw_representation=raw_event) + text_content.annotations = [annotation_obj] + return text_content @override - def _parse_chunk_from_openai( + def _inner_get_response( self, - event: OpenAIResponseStreamEvent, - options: dict[str, Any], - function_call_ids: dict[int, tuple[str, str]], - ) -> ChatResponseUpdate: - """Parse streaming event with Azure AI Search citation enrichment.""" - # Capture search output URLs when azure_ai_search_call_output items arrive - if event.type == "response.output_item.added": - event_item = event.item - item_type = getattr(event_item, "type", None) - if isinstance(event_item, dict): - item_type = event_item.get("type") - if item_type == "azure_ai_search_call_output": - urls = self._extract_azure_search_urls([event_item]) - if urls: - if not hasattr(self, "_streaming_search_get_urls"): - self._streaming_search_get_urls: list[str] = [] - self._streaming_search_get_urls.extend(urls) - - # Let base class parse the event - result = super()._parse_chunk_from_openai(event, options, function_call_ids) - - # Handle url_citation annotations in streaming — base class doesn't handle these, - # so we produce an Annotation with the enriched URL from captured search data. - if event.type == "response.output_text.annotation.added": - annotation_data: Any = event.annotation - ann_type = ( - annotation_data.get("type") - if isinstance(annotation_data, dict) - else getattr(annotation_data, "type", None) - ) - if ann_type == "url_citation": - ann_title = ( - annotation_data.get("title") - if isinstance(annotation_data, dict) - else getattr(annotation_data, "title", None) - ) - ann_url = ( - annotation_data.get("url") - if isinstance(annotation_data, dict) - else getattr(annotation_data, "url", None) - ) - ann_start = ( - annotation_data.get("start_index") - if isinstance(annotation_data, dict) - else getattr(annotation_data, "start_index", None) - ) - ann_end = ( - annotation_data.get("end_index") - if isinstance(annotation_data, dict) - else getattr(annotation_data, "end_index", None) - ) + *, + messages: Sequence[Message], + options: Mapping[str, Any], + stream: bool = False, + **kwargs: Any, + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + """Wrap base response to enrich Azure AI Search citation annotations. - additional_props: dict[str, Any] = { - "annotation_index": getattr(event, "annotation_index", None), - } - - # Enrich with get_url from captured search data - if hasattr(self, "_streaming_search_get_urls") and self._streaming_search_get_urls: - doc_url = self._get_search_doc_url(ann_title, self._streaming_search_get_urls) - if doc_url: - additional_props["get_url"] = doc_url - - annotation_obj = Annotation( - type="citation", - title=ann_title, - url=ann_url, - additional_properties=additional_props, - raw_representation=annotation_data, - ) - if ann_start is not None and ann_end is not None: - annotation_obj["annotated_regions"] = [ - TextSpanRegion( - type="text_span", - start_index=ann_start, - end_index=ann_end, + For non-streaming responses, the ``ChatResponse.raw_representation`` carries the + full response including ``azure_ai_search_call_output`` items. After the base class + parses the response, ``url_citation`` annotations are enriched with per-document URLs. + + For streaming responses, a transform hook is registered on the ``ResponseStream`` to + capture ``get_urls`` from search output events and enrich ``url_citation`` annotations + as they arrive. The captured URL state is local to the stream closure, so concurrent + streams do not interfere. + """ + result = super()._inner_get_response(messages=messages, options=options, stream=stream, **kwargs) + + if not stream: + original_awaitable: Awaitable[ChatResponse] = result # type: ignore[assignment] + + async def _enrich_response() -> ChatResponse: + response = await original_awaitable + raw = getattr(response, "raw_representation", None) + get_urls = self._extract_azure_search_urls(getattr(raw, "output", [])) + if get_urls: + for msg in response.messages: + self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls) + return response + + return _enrich_response() + + # Streaming: use a closure-local list so concurrent streams don't interfere + stream_result: ResponseStream[ChatResponseUpdate, ChatResponse] = result # type: ignore[assignment] + search_get_urls: list[str] = [] + + def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: + raw = getattr(update, "raw_representation", None) + if raw is None: + return update + event_type = getattr(raw, "type", None) + + # Capture get_urls from azure_ai_search_call_output items + if event_type == "response.output_item.added": + item = getattr(raw, "item", None) + if item is not None: + urls = self._extract_azure_search_urls([item]) + if urls: + search_get_urls.extend(urls) + + # Handle url_citation annotations (not handled by the base class in streaming) + if event_type == "response.output_text.annotation.added" and search_get_urls: + ann = getattr(raw, "annotation", None) + if ann is not None: + ann_type = ann.get("type") if isinstance(ann, dict) else getattr(ann, "type", None) + if ann_type == "url_citation": + citation_content = self._build_url_citation_content(ann, search_get_urls, raw) + contents_list = list(update.contents or []) + contents_list.append(citation_content) + return ChatResponseUpdate( + contents=contents_list, + conversation_id=update.conversation_id, + response_id=update.response_id, + role=update.role, + model_id=update.model_id, + continuation_token=update.continuation_token, + additional_properties=update.additional_properties, + raw_representation=update.raw_representation, ) - ] - - text_content = Content.from_text(text="", raw_representation=event) - text_content.annotations = [annotation_obj] - # Add to result contents - contents_list = list(result.contents or []) - contents_list.append(text_content) - result = ChatResponseUpdate( - contents=contents_list, - conversation_id=result.conversation_id, - response_id=result.response_id, - role=result.role, - model_id=result.model_id, - continuation_token=result.continuation_token, - additional_properties=result.additional_properties, - raw_representation=result.raw_representation, - ) - # Clear streaming state when response completes - if event.type == "response.completed": - if hasattr(self, "_streaming_search_get_urls"): - del self._streaming_search_get_urls + # Enrich any citation annotations already parsed by the base class + if update.contents and search_get_urls: + self._enrich_annotations_with_search_urls(list(update.contents), search_get_urls) + + return update - return result + stream_result._transform_hooks.append(_enrich_update) + return stream_result # endregion diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 3215620e28..78633e9890 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -13,14 +13,18 @@ from agent_framework import ( Agent, AgentResponse, + Annotation, ChatOptions, ChatResponse, + ChatResponseUpdate, Content, Message, + ResponseStream, SupportsChatGetResponse, tool, ) from agent_framework._settings import load_settings +from agent_framework.openai._responses_client import RawOpenAIResponsesClient from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ( ApproximateLocation, @@ -1911,134 +1915,142 @@ def test_enrich_annotations_empty_get_urls(mock_project_client: MagicMock) -> No assert "additional_properties" not in content.annotations[0] -def test_parse_response_from_openai_with_search_citations(mock_project_client: MagicMock) -> None: - """Test _parse_response_from_openai enriches url_citation annotations with search URLs.""" +@pytest.mark.asyncio +async def test_inner_get_response_enriches_non_streaming(mock_project_client: MagicMock) -> None: + """Test _inner_get_response enriches url_citation annotations for non-streaming responses.""" client = create_test_azure_ai_client(mock_project_client) - # Build a mock OpenAI response with azure_ai_search_call_output and message with url_citation + # Build a ChatResponse with citation annotations and a raw_representation carrying search output + content = Content.from_text(text="Here is the result【5:0†source】.") + content.annotations = [ + Annotation(type="citation", title="doc_0", url="https://search.example.com/"), + ] + msg = Message(role="assistant", contents=[content]) + mock_raw = MagicMock() mock_search_output = MagicMock() mock_search_output.type = "azure_ai_search_call_output" mock_search_output_data = MagicMock() mock_search_output_data.get_urls = [ "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", - "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", ] mock_search_output.output = mock_search_output_data + mock_raw.output = [mock_search_output] - mock_annotation = MagicMock() - mock_annotation.type = "url_citation" - mock_annotation.title = "doc_0" - mock_annotation.url = "https://search.example.com/" - mock_annotation.start_index = 100 - mock_annotation.end_index = 112 + base_response = ChatResponse(messages=[msg], raw_representation=mock_raw) + + async def _fake_awaitable() -> ChatResponse: + return base_response - mock_text_content = MagicMock() - mock_text_content.type = "output_text" - mock_text_content.text = "Here is the result【5:0†source】." - mock_text_content.annotations = [mock_annotation] - mock_text_content.logprobs = None + with patch.object( + RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable() + ): + result_awaitable = client._inner_get_response(messages=[], options={}, stream=False) + result = await result_awaitable # type: ignore[misc] - mock_message = MagicMock() - mock_message.type = "message" - mock_message.content = [mock_text_content] - mock_message.role = "assistant" - mock_message.id = "msg_123" - mock_message.status = "completed" + ann = result.messages[0].contents[0].annotations[0] + assert ann["additional_properties"]["get_url"] == ( + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" + ) - mock_response = MagicMock(spec=OpenAIResponse) - mock_response.output = [mock_search_output, mock_message] - mock_response.id = "resp_123" - mock_response.created_at = 1700000000 - mock_response.model = "gpt-4" - mock_response.metadata = {} - mock_response.usage = None - mock_response.status = "completed" - - result = client._parse_response_from_openai(mock_response, options={"store": False}) - - # Find the text content with annotations - assert result.messages is not None - found_citation = False - for msg in result.messages: - for content in msg.contents: - if hasattr(content, "annotations") and content.annotations: - for ann in content.annotations: - if isinstance(ann, dict) and ann.get("title") == "doc_0": - found_citation = True - assert ann["additional_properties"]["get_url"] == ( - "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" - ) - assert found_citation, "Expected to find enriched citation annotation" - - -def test_parse_response_from_openai_without_search_output(mock_project_client: MagicMock) -> None: - """Test _parse_response_from_openai works normally when no search output exists.""" + +@pytest.mark.asyncio +async def test_inner_get_response_no_search_output_non_streaming(mock_project_client: MagicMock) -> None: + """Test _inner_get_response passes through when no search output exists.""" client = create_test_azure_ai_client(mock_project_client) - mock_text_content = MagicMock() - mock_text_content.type = "output_text" - mock_text_content.text = "Hello world" - mock_text_content.annotations = [] - mock_text_content.logprobs = None + content = Content.from_text(text="Hello world") + msg = Message(role="assistant", contents=[content]) + mock_raw = MagicMock() + mock_raw.output = [] + base_response = ChatResponse(messages=[msg], raw_representation=mock_raw) - mock_message = MagicMock() - mock_message.type = "message" - mock_message.content = [mock_text_content] - mock_message.role = "assistant" - mock_message.id = "msg_456" - mock_message.status = "completed" + async def _fake_awaitable() -> ChatResponse: + return base_response - mock_response = MagicMock(spec=OpenAIResponse) - mock_response.output = [mock_message] - mock_response.id = "resp_456" - mock_response.created_at = 1700000000 - mock_response.model = "gpt-4" - mock_response.metadata = {} - mock_response.usage = None - mock_response.status = "completed" + with patch.object( + RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable() + ): + result_awaitable = client._inner_get_response(messages=[], options={}, stream=False) + result = await result_awaitable # type: ignore[misc] - result = client._parse_response_from_openai(mock_response, options={"store": False}) - assert result.messages is not None + assert result.messages[0].contents[0].text == "Hello world" -def test_parse_chunk_from_openai_captures_search_urls(mock_project_client: MagicMock) -> None: - """Test _parse_chunk_from_openai captures search URLs from azure_ai_search_call_output events.""" +def test_inner_get_response_streaming_registers_hook(mock_project_client: MagicMock) -> None: + """Test _inner_get_response appends a transform hook to the stream for streaming responses.""" client = create_test_azure_ai_client(mock_project_client) - # Simulate azure_ai_search_call_output item arriving via response.output_item.added - mock_output_data = MagicMock() - mock_output_data.get_urls = [ - "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", - ] + mock_stream = MagicMock(spec=ResponseStream) + mock_stream._transform_hooks = [] + + with patch.object( + RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream + ): + result = client._inner_get_response(messages=[], options={}, stream=True) + + assert result is mock_stream + assert len(mock_stream._transform_hooks) == 1 + + +def test_streaming_hook_captures_search_urls(mock_project_client: MagicMock) -> None: + """Test the streaming transform hook captures get_urls from search output events.""" + client = create_test_azure_ai_client(mock_project_client) + + mock_stream = MagicMock(spec=ResponseStream) + mock_stream._transform_hooks = [] + + with patch.object( + RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream + ): + client._inner_get_response(messages=[], options={}, stream=True) + + hook = mock_stream._transform_hooks[0] + + # Simulate azure_ai_search_call_output event mock_item = MagicMock() mock_item.type = "azure_ai_search_call_output" - mock_item.output = mock_output_data - - mock_event = MagicMock() - mock_event.type = "response.output_item.added" - mock_event.item = mock_item - mock_event.output_index = 0 + mock_item.output = MagicMock() + mock_item.output.get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + ] - client._parse_chunk_from_openai(mock_event, options={}, function_call_ids={}) + raw_event = MagicMock() + raw_event.type = "response.output_item.added" + raw_event.item = mock_item - assert hasattr(client, "_streaming_search_get_urls") - assert len(client._streaming_search_get_urls) == 1 - assert ( - client._streaming_search_get_urls[0] == "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" - ) + update = ChatResponseUpdate(raw_representation=raw_event) + result = hook(update) + assert result is update # passes through (no annotations to enrich) -def test_parse_chunk_from_openai_enriches_url_citation(mock_project_client: MagicMock) -> None: - """Test _parse_chunk_from_openai enriches url_citation annotations with search URLs.""" +def test_streaming_hook_enriches_url_citation(mock_project_client: MagicMock) -> None: + """Test the streaming transform hook enriches url_citation annotations with get_urls.""" client = create_test_azure_ai_client(mock_project_client) - # Pre-populate search URLs (simulating earlier capture) - client._streaming_search_get_urls = [ + mock_stream = MagicMock(spec=ResponseStream) + mock_stream._transform_hooks = [] + + with patch.object( + RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream + ): + client._inner_get_response(messages=[], options={}, stream=True) + + hook = mock_stream._transform_hooks[0] + + # Step 1: Feed search output event to capture URLs + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = MagicMock() + mock_item.output.get_urls = [ "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", ] + raw_output_event = MagicMock() + raw_output_event.type = "response.output_item.added" + raw_output_event.item = mock_item + hook(ChatResponseUpdate(raw_representation=raw_output_event)) - # Simulate url_citation annotation event + # Step 2: Feed url_citation annotation event mock_annotation = MagicMock() mock_annotation.type = "url_citation" mock_annotation.title = "doc_0" @@ -2046,19 +2058,19 @@ def test_parse_chunk_from_openai_enriches_url_citation(mock_project_client: Magi mock_annotation.start_index = 100 mock_annotation.end_index = 112 - mock_event = MagicMock() - mock_event.type = "response.output_text.annotation.added" - mock_event.annotation = mock_annotation - mock_event.annotation_index = 0 + raw_ann_event = MagicMock() + raw_ann_event.type = "response.output_text.annotation.added" + raw_ann_event.annotation = mock_annotation + raw_ann_event.annotation_index = 0 - result = client._parse_chunk_from_openai(mock_event, options={}, function_call_ids={}) + result = hook(ChatResponseUpdate(raw_representation=raw_ann_event)) - # Should have content with annotation containing get_url + # Verify the result has enriched annotation assert result.contents is not None found = False - for content in result.contents: - if hasattr(content, "annotations") and content.annotations: - for ann in content.annotations: + for content_item in result.contents: + if hasattr(content_item, "annotations") and content_item.annotations: + for ann in content_item.annotations: if isinstance(ann, dict) and ann.get("title") == "doc_0": found = True assert ann["additional_properties"]["get_url"] == ( @@ -2067,24 +2079,57 @@ def test_parse_chunk_from_openai_enriches_url_citation(mock_project_client: Magi assert found, "Expected url_citation annotation with enriched get_url" -def test_parse_chunk_from_openai_clears_state_on_completed(mock_project_client: MagicMock) -> None: - """Test _parse_chunk_from_openai clears streaming state on response.completed.""" +def test_build_url_citation_content(mock_project_client: MagicMock) -> None: + """Test _build_url_citation_content creates Content with enriched Annotation.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] + + annotation_data = MagicMock() + annotation_data.type = "url_citation" + annotation_data.title = "doc_0" + annotation_data.url = "https://search.example.com/" + annotation_data.start_index = 100 + annotation_data.end_index = 112 + + raw_event = MagicMock() + raw_event.annotation_index = 0 + + content = client._build_url_citation_content(annotation_data, get_urls, raw_event) + + assert content.annotations is not None + ann = content.annotations[0] + assert ann["type"] == "citation" + assert ann["title"] == "doc_0" + assert ann["url"] == "https://search.example.com/" + assert ann["additional_properties"]["get_url"] == get_urls[0] + assert ann["annotated_regions"][0]["start_index"] == 100 + assert ann["annotated_regions"][0]["end_index"] == 112 + + +def test_build_url_citation_content_with_dict(mock_project_client: MagicMock) -> None: + """Test _build_url_citation_content handles dict-style annotation data.""" client = create_test_azure_ai_client(mock_project_client) - client._streaming_search_get_urls = ["https://example.com/doc/0"] + get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] - mock_response = MagicMock() - mock_response.id = "resp_123" - mock_response.model = "gpt-4" - mock_response.usage = None - mock_response.status = "completed" + annotation_data = { + "type": "url_citation", + "title": "doc_1", + "url": "https://search.example.com/", + "start_index": 200, + "end_index": 215, + } - mock_event = MagicMock() - mock_event.type = "response.completed" - mock_event.response = mock_response + raw_event = MagicMock() + raw_event.annotation_index = 1 - client._parse_chunk_from_openai(mock_event, options={}, function_call_ids={}) + content = client._build_url_citation_content(annotation_data, get_urls, raw_event) - assert not hasattr(client, "_streaming_search_get_urls") + assert content.annotations is not None + ann = content.annotations[0] + assert ann["type"] == "citation" + assert ann["title"] == "doc_1" + # doc_1 is out of range for a 1-element get_urls, so no get_url + assert "get_url" not in ann.get("additional_properties", {}) # endregion From e1d1f84b31e7a5a1c62a884d72410dba09544b90 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Thu, 19 Feb 2026 13:12:45 -0800 Subject: [PATCH 3/8] fix: handle streaming search citations from output_item.done events The azure_ai_search_call_output item only has populated output data (including get_urls) in the response.output_item.done event, not in the response.output_item.added event. Also removed the search_get_urls guard on url_citation handling so annotations are always produced even if get_urls haven't been captured yet. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../packages/azure-ai/agent_framework_azure_ai/_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 857be639c2..4d94c15477 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -795,8 +795,10 @@ def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: return update event_type = getattr(raw, "type", None) - # Capture get_urls from azure_ai_search_call_output items - if event_type == "response.output_item.added": + # Capture get_urls from azure_ai_search_call_output items. + # Check both "added" and "done" events because the output data (including + # get_urls) may only be fully populated in the "done" event. + if event_type in ("response.output_item.added", "response.output_item.done"): item = getattr(raw, "item", None) if item is not None: urls = self._extract_azure_search_urls([item]) @@ -804,7 +806,7 @@ def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: search_get_urls.extend(urls) # Handle url_citation annotations (not handled by the base class in streaming) - if event_type == "response.output_text.annotation.added" and search_get_urls: + if event_type == "response.output_text.annotation.added": ann = getattr(raw, "annotation", None) if ann is not None: ann_type = ann.get("type") if isinstance(ann, dict) else getattr(ann, "type", None) From 51a07654ef426880b457171f2e63450696e0e466 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Fri, 20 Feb 2026 12:10:25 -0800 Subject: [PATCH 4/8] addressed comments --- .../agent_framework_azure_ai/_client.py | 24 ++++---- .../azure-ai/tests/test_azure_ai_client.py | 55 ++++++++++++------- 2 files changed, 45 insertions(+), 34 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 4d94c15477..566030e577 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -643,6 +643,11 @@ def _extract_azure_search_urls(self, output_items: Any) -> list[str]: output = getattr(item, "output", None) if isinstance(item, dict): output = item.get("output") + if isinstance(output, str): + try: + output = json.loads(output) + except (json.JSONDecodeError, TypeError): + continue if output is not None: urls = getattr(output, "get_urls", None) if isinstance(output, dict): @@ -684,10 +689,9 @@ def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls if not get_urls: return for content in contents: - annotations = getattr(content, "annotations", None) - if not annotations: + if not content.annotations: continue - for annotation in annotations: + for annotation in content.annotations: if not isinstance(annotation, dict): continue if annotation.get("type") != "citation": @@ -695,13 +699,9 @@ def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls title = annotation.get("title") doc_url = self._get_search_doc_url(title, get_urls) if doc_url: - props = annotation.get("additional_properties") or {} - props["get_url"] = doc_url - annotation["additional_properties"] = props + annotation.setdefault("additional_properties", {})["get_url"] = doc_url - def _build_url_citation_content( - self, annotation_data: Any, get_urls: list[str], raw_event: Any - ) -> Content: + def _build_url_citation_content(self, annotation_data: Any, get_urls: list[str], raw_event: Any) -> Content: """Build a Content with a citation Annotation from a url_citation streaming event. The base class does not handle ``url_citation`` annotations in streaming, so this @@ -745,9 +745,7 @@ def _val(key: str) -> Any: TextSpanRegion(type="text_span", start_index=ann_start, end_index=ann_end) ] - text_content = Content.from_text(text="", raw_representation=raw_event) - text_content.annotations = [annotation_obj] - return text_content + return Content.from_text(text="", annotations=[annotation_obj], raw_representation=raw_event) @override def _inner_get_response( @@ -831,7 +829,7 @@ def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: return update - stream_result._transform_hooks.append(_enrich_update) + stream_result.with_transform_hook(_enrich_update) return stream_result # endregion diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 78633e9890..9f102a70e8 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -1830,6 +1830,24 @@ def test_extract_azure_search_urls_no_search_items(mock_project_client: MagicMoc assert urls == [] +def test_extract_azure_search_urls_with_json_string_output(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with JSON string output (non-streaming pydantic extra field).""" + client = create_test_azure_ai_client(mock_project_client) + json_output = json.dumps({ + "documents": [{"id": "1"}], + "get_urls": [ + "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01", + ], + }) + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = json_output + + urls = client._extract_azure_search_urls([mock_item]) + assert len(urls) == 1 + assert urls[0] == "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01" + + def test_get_search_doc_url_valid(mock_project_client: MagicMock) -> None: """Test _get_search_doc_url with valid doc_N title.""" client = create_test_azure_ai_client(mock_project_client) @@ -1941,9 +1959,7 @@ async def test_inner_get_response_enriches_non_streaming(mock_project_client: Ma async def _fake_awaitable() -> ChatResponse: return base_response - with patch.object( - RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable() - ): + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable()): result_awaitable = client._inner_get_response(messages=[], options={}, stream=False) result = await result_awaitable # type: ignore[misc] @@ -1967,25 +1983,28 @@ async def test_inner_get_response_no_search_output_non_streaming(mock_project_cl async def _fake_awaitable() -> ChatResponse: return base_response - with patch.object( - RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable() - ): + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable()): result_awaitable = client._inner_get_response(messages=[], options={}, stream=False) result = await result_awaitable # type: ignore[misc] assert result.messages[0].contents[0].text == "Hello world" +def _create_mock_stream() -> MagicMock: + """Create a mock ResponseStream with working with_transform_hook.""" + mock_stream = MagicMock(spec=ResponseStream) + mock_stream._transform_hooks = [] + mock_stream.with_transform_hook.side_effect = lambda hook: mock_stream._transform_hooks.append(hook) or mock_stream + return mock_stream + + def test_inner_get_response_streaming_registers_hook(mock_project_client: MagicMock) -> None: """Test _inner_get_response appends a transform hook to the stream for streaming responses.""" client = create_test_azure_ai_client(mock_project_client) - mock_stream = MagicMock(spec=ResponseStream) - mock_stream._transform_hooks = [] + mock_stream = _create_mock_stream() - with patch.object( - RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream - ): + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream): result = client._inner_get_response(messages=[], options={}, stream=True) assert result is mock_stream @@ -1996,12 +2015,9 @@ def test_streaming_hook_captures_search_urls(mock_project_client: MagicMock) -> """Test the streaming transform hook captures get_urls from search output events.""" client = create_test_azure_ai_client(mock_project_client) - mock_stream = MagicMock(spec=ResponseStream) - mock_stream._transform_hooks = [] + mock_stream = _create_mock_stream() - with patch.object( - RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream - ): + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream): client._inner_get_response(messages=[], options={}, stream=True) hook = mock_stream._transform_hooks[0] @@ -2027,12 +2043,9 @@ def test_streaming_hook_enriches_url_citation(mock_project_client: MagicMock) -> """Test the streaming transform hook enriches url_citation annotations with get_urls.""" client = create_test_azure_ai_client(mock_project_client) - mock_stream = MagicMock(spec=ResponseStream) - mock_stream._transform_hooks = [] + mock_stream = _create_mock_stream() - with patch.object( - RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream - ): + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream): client._inner_get_response(messages=[], options={}, stream=True) hook = mock_stream._transform_hooks[0] From 90935cbd6e7d8aa4f128c19cd0efd67295b2c35f Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Mon, 23 Feb 2026 08:50:58 -0800 Subject: [PATCH 5/8] refactor: address PR review - eliminate type: ignore[assignment] pattern Call super()._inner_get_response() independently in each branch instead of once at the top with union type reassignment. Non-streaming uses two-arg super() in the closure; streaming uses cast() for type narrowing. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../azure-ai/agent_framework_azure_ai/_client.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 566030e577..f42b9257d1 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -767,13 +767,12 @@ def _inner_get_response( as they arrive. The captured URL state is local to the stream closure, so concurrent streams do not interfere. """ - result = super()._inner_get_response(messages=messages, options=options, stream=stream, **kwargs) - if not stream: - original_awaitable: Awaitable[ChatResponse] = result # type: ignore[assignment] async def _enrich_response() -> ChatResponse: - response = await original_awaitable + response = await super(RawAzureAIClient, self)._inner_get_response( + messages=messages, options=options, stream=False, **kwargs + ) raw = getattr(response, "raw_representation", None) get_urls = self._extract_azure_search_urls(getattr(raw, "output", [])) if get_urls: @@ -784,7 +783,10 @@ async def _enrich_response() -> ChatResponse: return _enrich_response() # Streaming: use a closure-local list so concurrent streams don't interfere - stream_result: ResponseStream[ChatResponseUpdate, ChatResponse] = result # type: ignore[assignment] + stream_result = cast( + ResponseStream[ChatResponseUpdate, ChatResponse], + super()._inner_get_response(messages=messages, options=options, stream=True, **kwargs), + ) search_get_urls: list[str] = [] def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: From d06d2387394f7a295ea94b88e008bd0628e0c787 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Mon, 23 Feb 2026 12:03:56 -0800 Subject: [PATCH 6/8] refactor: remove defensive patterns per PR review - Replace all getattr() with direct attribute access - Remove cast() for streaming branch, use type: ignore[assignment] - Simplify _build_url_citation_content to use dict access directly - Simplify _extract_azure_search_urls to use item.type/item.output - Handle empty list output from streaming 'added' events - Update tests to match actual runtime types (objects, not dicts) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_azure_ai/_client.py | 107 ++++++++---------- .../azure-ai/tests/test_azure_ai_client.py | 73 ++++++------ 2 files changed, 84 insertions(+), 96 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index f42b9257d1..4e5e101a98 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -636,24 +636,21 @@ def _extract_azure_search_urls(self, output_items: Any) -> list[str]: """ get_urls: list[str] = [] for item in output_items: - item_type = getattr(item, "type", None) - if isinstance(item, dict): - item_type = item.get("type") - if item_type == "azure_ai_search_call_output": - output = getattr(item, "output", None) - if isinstance(item, dict): - output = item.get("output") - if isinstance(output, str): - try: - output = json.loads(output) - except (json.JSONDecodeError, TypeError): - continue - if output is not None: - urls = getattr(output, "get_urls", None) - if isinstance(output, dict): - urls = output.get("get_urls") - if urls and isinstance(urls, list): - get_urls.extend(urls) + if item.type != "azure_ai_search_call_output": + continue + output = item.output + if isinstance(output, str): + try: + output = json.loads(output) + except (json.JSONDecodeError, TypeError): + continue + if isinstance(output, list): + # Streaming "added" events send output as an empty list; skip. + continue + if output is not None: + urls = output.get("get_urls") if isinstance(output, dict) else output.get_urls + if urls and isinstance(urls, list): + get_urls.extend(urls) return get_urls def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) -> str | None: @@ -701,33 +698,29 @@ def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls if doc_url: annotation.setdefault("additional_properties", {})["get_url"] = doc_url - def _build_url_citation_content(self, annotation_data: Any, get_urls: list[str], raw_event: Any) -> Content: + def _build_url_citation_content( + self, annotation_data: dict[str, Any], get_urls: list[str], raw_event: Any + ) -> Content: """Build a Content with a citation Annotation from a url_citation streaming event. The base class does not handle ``url_citation`` annotations in streaming, so this method creates the appropriate framework content for them. Args: - annotation_data: The raw annotation object/dict from the streaming event. + annotation_data: The raw annotation dict from the streaming event. get_urls: Captured document URLs for enrichment. raw_event: The raw streaming event for raw_representation. Returns: A Content object containing the citation annotation. """ - - def _val(key: str) -> Any: - if isinstance(annotation_data, dict): - return annotation_data.get(key) - return getattr(annotation_data, key, None) - - ann_title = str(_val("title") or "") - ann_url = str(_val("url") or "") - ann_start = _val("start_index") - ann_end = _val("end_index") + ann_title = str(annotation_data.get("title") or "") + ann_url = str(annotation_data.get("url") or "") + ann_start = annotation_data.get("start_index") + ann_end = annotation_data.get("end_index") additional_props: dict[str, Any] = { - "annotation_index": getattr(raw_event, "annotation_index", None), + "annotation_index": raw_event.annotation_index, } doc_url = self._get_search_doc_url(ann_title, get_urls) if doc_url: @@ -773,8 +766,7 @@ async def _enrich_response() -> ChatResponse: response = await super(RawAzureAIClient, self)._inner_get_response( messages=messages, options=options, stream=False, **kwargs ) - raw = getattr(response, "raw_representation", None) - get_urls = self._extract_azure_search_urls(getattr(raw, "output", [])) + get_urls = self._extract_azure_search_urls(response.raw_representation.output) if get_urls: for msg in response.messages: self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls) @@ -783,47 +775,42 @@ async def _enrich_response() -> ChatResponse: return _enrich_response() # Streaming: use a closure-local list so concurrent streams don't interfere - stream_result = cast( - ResponseStream[ChatResponseUpdate, ChatResponse], - super()._inner_get_response(messages=messages, options=options, stream=True, **kwargs), + stream_result = super()._inner_get_response( # type: ignore[assignment] + messages=messages, options=options, stream=True, **kwargs ) search_get_urls: list[str] = [] def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: - raw = getattr(update, "raw_representation", None) + raw = update.raw_representation if raw is None: return update - event_type = getattr(raw, "type", None) + event_type = raw.type # Capture get_urls from azure_ai_search_call_output items. # Check both "added" and "done" events because the output data (including # get_urls) may only be fully populated in the "done" event. if event_type in ("response.output_item.added", "response.output_item.done"): - item = getattr(raw, "item", None) - if item is not None: - urls = self._extract_azure_search_urls([item]) - if urls: - search_get_urls.extend(urls) + urls = self._extract_azure_search_urls([raw.item]) + if urls: + search_get_urls.extend(urls) # Handle url_citation annotations (not handled by the base class in streaming) if event_type == "response.output_text.annotation.added": - ann = getattr(raw, "annotation", None) - if ann is not None: - ann_type = ann.get("type") if isinstance(ann, dict) else getattr(ann, "type", None) - if ann_type == "url_citation": - citation_content = self._build_url_citation_content(ann, search_get_urls, raw) - contents_list = list(update.contents or []) - contents_list.append(citation_content) - return ChatResponseUpdate( - contents=contents_list, - conversation_id=update.conversation_id, - response_id=update.response_id, - role=update.role, - model_id=update.model_id, - continuation_token=update.continuation_token, - additional_properties=update.additional_properties, - raw_representation=update.raw_representation, - ) + ann = raw.annotation + if ann.get("type") == "url_citation": + citation_content = self._build_url_citation_content(ann, search_get_urls, raw) + contents_list = list(update.contents or []) + contents_list.append(citation_content) + return ChatResponseUpdate( + contents=contents_list, + conversation_id=update.conversation_id, + response_id=update.response_id, + role=update.role, + model_id=update.model_id, + continuation_token=update.continuation_token, + additional_properties=update.additional_properties, + raw_representation=update.raw_representation, + ) # Enrich any citation annotations already parsed by the base class if update.contents and search_get_urls: diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 9f102a70e8..ff5366b87c 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -1784,26 +1784,26 @@ def test_get_image_generation_tool_with_options() -> None: def test_extract_azure_search_urls_with_dict_items(mock_project_client: MagicMock) -> None: - """Test _extract_azure_search_urls with dict-style output items.""" + """Test _extract_azure_search_urls with dict-style output (after JSON parsing).""" client = create_test_azure_ai_client(mock_project_client) - output_items = [ - { - "type": "azure_ai_search_call", - "arguments": '{"query":"test"}', - }, - { - "type": "azure_ai_search_call_output", - "output": { - "documents": [{"id": "1", "url": "https://search.example.com/"}], - "get_urls": [ - "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01", - "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01", - ], - }, - }, - {"type": "message", "content": [{"type": "output_text", "text": "hello"}]}, - ] - urls = client._extract_azure_search_urls(output_items) + mock_output = { + "documents": [{"id": "1", "url": "https://search.example.com/"}], + "get_urls": [ + "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01", + ], + } + mock_search_item = MagicMock() + mock_search_item.type = "azure_ai_search_call_output" + mock_search_item.output = mock_output + + mock_call_item = MagicMock() + mock_call_item.type = "azure_ai_search_call" + + mock_msg_item = MagicMock() + mock_msg_item.type = "message" + + urls = client._extract_azure_search_urls([mock_call_item, mock_search_item, mock_msg_item]) assert len(urls) == 2 assert urls[0] == "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01" assert urls[1] == "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01" @@ -1825,8 +1825,9 @@ def test_extract_azure_search_urls_with_object_items(mock_project_client: MagicM def test_extract_azure_search_urls_no_search_items(mock_project_client: MagicMock) -> None: """Test _extract_azure_search_urls with no search output items.""" client = create_test_azure_ai_client(mock_project_client) - output_items = [{"type": "message", "content": []}] - urls = client._extract_azure_search_urls(output_items) + mock_item = MagicMock() + mock_item.type = "message" + urls = client._extract_azure_search_urls([mock_item]) assert urls == [] @@ -2063,17 +2064,16 @@ def test_streaming_hook_enriches_url_citation(mock_project_client: MagicMock) -> raw_output_event.item = mock_item hook(ChatResponseUpdate(raw_representation=raw_output_event)) - # Step 2: Feed url_citation annotation event - mock_annotation = MagicMock() - mock_annotation.type = "url_citation" - mock_annotation.title = "doc_0" - mock_annotation.url = "https://search.example.com/" - mock_annotation.start_index = 100 - mock_annotation.end_index = 112 - + # Step 2: Feed url_citation annotation event (annotation is always a dict in streaming) raw_ann_event = MagicMock() raw_ann_event.type = "response.output_text.annotation.added" - raw_ann_event.annotation = mock_annotation + raw_ann_event.annotation = { + "type": "url_citation", + "title": "doc_0", + "url": "https://search.example.com/", + "start_index": 100, + "end_index": 112, + } raw_ann_event.annotation_index = 0 result = hook(ChatResponseUpdate(raw_representation=raw_ann_event)) @@ -2097,12 +2097,13 @@ def test_build_url_citation_content(mock_project_client: MagicMock) -> None: client = create_test_azure_ai_client(mock_project_client) get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] - annotation_data = MagicMock() - annotation_data.type = "url_citation" - annotation_data.title = "doc_0" - annotation_data.url = "https://search.example.com/" - annotation_data.start_index = 100 - annotation_data.end_index = 112 + annotation_data = { + "type": "url_citation", + "title": "doc_0", + "url": "https://search.example.com/", + "start_index": 100, + "end_index": 112, + } raw_event = MagicMock() raw_event.annotation_index = 0 From e4f5ef5e769780b99311d77fd4513d7b9435e898 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Mon, 23 Feb 2026 12:13:26 -0800 Subject: [PATCH 7/8] mypy fix --- python/packages/azure-ai/agent_framework_azure_ai/_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 4e5e101a98..db6f4dfec6 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -766,7 +766,7 @@ async def _enrich_response() -> ChatResponse: response = await super(RawAzureAIClient, self)._inner_get_response( messages=messages, options=options, stream=False, **kwargs ) - get_urls = self._extract_azure_search_urls(response.raw_representation.output) + get_urls = self._extract_azure_search_urls(response.raw_representation.output) # type: ignore[union-attr] if get_urls: for msg in response.messages: self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls) @@ -818,7 +818,7 @@ def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: return update - stream_result.with_transform_hook(_enrich_update) + stream_result.with_transform_hook(_enrich_update) # type: ignore[union-attr] return stream_result # endregion From bc736afd4abc62856c62b76526086fce8896bf37 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Mon, 23 Feb 2026 14:03:43 -0800 Subject: [PATCH 8/8] small fixes --- python/packages/azure-ai/agent_framework_azure_ai/_client.py | 4 +++- python/packages/azure-ai/tests/test_azure_ai_client.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index db6f4dfec6..7c698847cc 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -86,6 +86,8 @@ class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False): covariant=True, ) +_DOC_INDEX_PATTERN = re.compile(r"doc_(\d+)") + class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[AzureAIClientOptionsT]): """Raw Azure AI client without middleware, telemetry, or function invocation layers. @@ -665,7 +667,7 @@ def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) - """ if not citation_title or not get_urls: return None - match = re.search(r"doc_(\d+)", citation_title) + match = _DOC_INDEX_PATTERN.search(citation_title) if not match: return None doc_index = int(match.group(1)) diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index ff5366b87c..5e08ddc066 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -1934,7 +1934,6 @@ def test_enrich_annotations_empty_get_urls(mock_project_client: MagicMock) -> No assert "additional_properties" not in content.annotations[0] -@pytest.mark.asyncio async def test_inner_get_response_enriches_non_streaming(mock_project_client: MagicMock) -> None: """Test _inner_get_response enriches url_citation annotations for non-streaming responses.""" client = create_test_azure_ai_client(mock_project_client) @@ -1970,7 +1969,6 @@ async def _fake_awaitable() -> ChatResponse: ) -@pytest.mark.asyncio async def test_inner_get_response_no_search_output_non_streaming(mock_project_client: MagicMock) -> None: """Test _inner_get_response passes through when no search output exists.""" client = create_test_azure_ai_client(mock_project_client)