diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index de9df819ec..7c698847cc 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -4,22 +4,29 @@ import json import logging +import re import sys -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import suppress from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, Agent, + Annotation, BaseContextProvider, ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, + ChatResponse, + ChatResponseUpdate, + Content, FunctionInvocationConfiguration, FunctionInvocationLayer, FunctionTool, Message, MiddlewareTypes, + ResponseStream, + TextSpanRegion, ) from agent_framework._settings import load_settings from agent_framework._tools import ToolTypes @@ -79,6 +86,8 @@ class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False): covariant=True, ) +_DOC_INDEX_PATTERN = re.compile(r"doc_(\d+)") + class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[AzureAIClientOptionsT]): """Raw Azure AI client without middleware, telemetry, or function invocation layers. @@ -616,6 +625,206 @@ def _update_agent_name_and_description(self, agent_name: str | None, description if description and not self.agent_description: self.agent_description = description + # region Azure AI Search Citation Enhancement + + def _extract_azure_search_urls(self, output_items: Any) -> list[str]: + """Extract document URLs from azure_ai_search_call_output items. + + Args: + output_items: The response output items to scan. + + Returns: + A flat list of get_urls from all azure_ai_search_call_output items. + """ + get_urls: list[str] = [] + for item in output_items: + if item.type != "azure_ai_search_call_output": + continue + output = item.output + if isinstance(output, str): + try: + output = json.loads(output) + except (json.JSONDecodeError, TypeError): + continue + if isinstance(output, list): + # Streaming "added" events send output as an empty list; skip. + continue + if output is not None: + urls = output.get("get_urls") if isinstance(output, dict) else output.get_urls + if urls and isinstance(urls, list): + get_urls.extend(urls) + return get_urls + + def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) -> str | None: + """Map a citation title like 'doc_0' to its corresponding get_url. + + Args: + citation_title: The annotation title (e.g., "doc_0"). + get_urls: The list of document URLs from azure_ai_search_call_output. + + Returns: + The matching document URL if found, otherwise None. + """ + if not citation_title or not get_urls: + return None + match = _DOC_INDEX_PATTERN.search(citation_title) + if not match: + return None + doc_index = int(match.group(1)) + if 0 <= doc_index < len(get_urls): + return str(get_urls[doc_index]) + return None + + def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls: list[str]) -> None: + """Enrich citation annotations in contents with real document URLs from Azure AI Search. + + Looks for annotations with ``type == "citation"`` and a ``title`` matching ``doc_N``, + then adds the corresponding document URL from *get_urls* to ``additional_properties["get_url"]``. + + Args: + contents: The parsed content list from a ChatResponse or ChatResponseUpdate. + get_urls: Document URLs extracted from azure_ai_search_call_output. + """ + if not get_urls: + return + for content in contents: + if not content.annotations: + continue + for annotation in content.annotations: + if not isinstance(annotation, dict): + continue + if annotation.get("type") != "citation": + continue + title = annotation.get("title") + doc_url = self._get_search_doc_url(title, get_urls) + if doc_url: + annotation.setdefault("additional_properties", {})["get_url"] = doc_url + + def _build_url_citation_content( + self, annotation_data: dict[str, Any], get_urls: list[str], raw_event: Any + ) -> Content: + """Build a Content with a citation Annotation from a url_citation streaming event. + + The base class does not handle ``url_citation`` annotations in streaming, so this + method creates the appropriate framework content for them. + + Args: + annotation_data: The raw annotation dict from the streaming event. + get_urls: Captured document URLs for enrichment. + raw_event: The raw streaming event for raw_representation. + + Returns: + A Content object containing the citation annotation. + """ + ann_title = str(annotation_data.get("title") or "") + ann_url = str(annotation_data.get("url") or "") + ann_start = annotation_data.get("start_index") + ann_end = annotation_data.get("end_index") + + additional_props: dict[str, Any] = { + "annotation_index": raw_event.annotation_index, + } + doc_url = self._get_search_doc_url(ann_title, get_urls) + if doc_url: + additional_props["get_url"] = doc_url + + annotation_obj = Annotation( + type="citation", + title=ann_title, + url=ann_url, + additional_properties=additional_props, + raw_representation=annotation_data, + ) + if ann_start is not None and ann_end is not None: + annotation_obj["annotated_regions"] = [ + TextSpanRegion(type="text_span", start_index=ann_start, end_index=ann_end) + ] + + return Content.from_text(text="", annotations=[annotation_obj], raw_representation=raw_event) + + @override + def _inner_get_response( + self, + *, + messages: Sequence[Message], + options: Mapping[str, Any], + stream: bool = False, + **kwargs: Any, + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + """Wrap base response to enrich Azure AI Search citation annotations. + + For non-streaming responses, the ``ChatResponse.raw_representation`` carries the + full response including ``azure_ai_search_call_output`` items. After the base class + parses the response, ``url_citation`` annotations are enriched with per-document URLs. + + For streaming responses, a transform hook is registered on the ``ResponseStream`` to + capture ``get_urls`` from search output events and enrich ``url_citation`` annotations + as they arrive. The captured URL state is local to the stream closure, so concurrent + streams do not interfere. + """ + if not stream: + + async def _enrich_response() -> ChatResponse: + response = await super(RawAzureAIClient, self)._inner_get_response( + messages=messages, options=options, stream=False, **kwargs + ) + get_urls = self._extract_azure_search_urls(response.raw_representation.output) # type: ignore[union-attr] + if get_urls: + for msg in response.messages: + self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls) + return response + + return _enrich_response() + + # Streaming: use a closure-local list so concurrent streams don't interfere + stream_result = super()._inner_get_response( # type: ignore[assignment] + messages=messages, options=options, stream=True, **kwargs + ) + search_get_urls: list[str] = [] + + def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: + raw = update.raw_representation + if raw is None: + return update + event_type = raw.type + + # Capture get_urls from azure_ai_search_call_output items. + # Check both "added" and "done" events because the output data (including + # get_urls) may only be fully populated in the "done" event. + if event_type in ("response.output_item.added", "response.output_item.done"): + urls = self._extract_azure_search_urls([raw.item]) + if urls: + search_get_urls.extend(urls) + + # Handle url_citation annotations (not handled by the base class in streaming) + if event_type == "response.output_text.annotation.added": + ann = raw.annotation + if ann.get("type") == "url_citation": + citation_content = self._build_url_citation_content(ann, search_get_urls, raw) + contents_list = list(update.contents or []) + contents_list.append(citation_content) + return ChatResponseUpdate( + contents=contents_list, + conversation_id=update.conversation_id, + response_id=update.response_id, + role=update.role, + model_id=update.model_id, + continuation_token=update.continuation_token, + additional_properties=update.additional_properties, + raw_representation=update.raw_representation, + ) + + # Enrich any citation annotations already parsed by the base class + if update.contents and search_get_urls: + self._enrich_annotations_with_search_urls(list(update.contents), search_get_urls) + + return update + + stream_result.with_transform_hook(_enrich_update) # type: ignore[union-attr] + return stream_result + + # endregion + # region Hosted Tool Factory Methods (Azure-specific overrides) @staticmethod diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 5824b668f1..5e08ddc066 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -13,14 +13,18 @@ from agent_framework import ( Agent, AgentResponse, + Annotation, ChatOptions, ChatResponse, + ChatResponseUpdate, Content, Message, + ResponseStream, SupportsChatGetResponse, tool, ) from agent_framework._settings import load_settings +from agent_framework.openai._responses_client import RawOpenAIResponsesClient from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ( ApproximateLocation, @@ -1774,3 +1778,370 @@ def test_get_image_generation_tool_with_options() -> None: # endregion + + +# region Azure AI Search Citation Enhancement Tests + + +def test_extract_azure_search_urls_with_dict_items(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with dict-style output (after JSON parsing).""" + client = create_test_azure_ai_client(mock_project_client) + mock_output = { + "documents": [{"id": "1", "url": "https://search.example.com/"}], + "get_urls": [ + "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01", + ], + } + mock_search_item = MagicMock() + mock_search_item.type = "azure_ai_search_call_output" + mock_search_item.output = mock_output + + mock_call_item = MagicMock() + mock_call_item.type = "azure_ai_search_call" + + mock_msg_item = MagicMock() + mock_msg_item.type = "message" + + urls = client._extract_azure_search_urls([mock_call_item, mock_search_item, mock_msg_item]) + assert len(urls) == 2 + assert urls[0] == "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01" + assert urls[1] == "https://search.example.com/indexes/idx/docs/2?api-version=2024-07-01" + + +def test_extract_azure_search_urls_with_object_items(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with object-style output items.""" + client = create_test_azure_ai_client(mock_project_client) + mock_output = MagicMock() + mock_output.get_urls = ["https://example.com/doc/1", "https://example.com/doc/2"] + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = mock_output + + urls = client._extract_azure_search_urls([mock_item]) + assert urls == ["https://example.com/doc/1", "https://example.com/doc/2"] + + +def test_extract_azure_search_urls_no_search_items(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with no search output items.""" + client = create_test_azure_ai_client(mock_project_client) + mock_item = MagicMock() + mock_item.type = "message" + urls = client._extract_azure_search_urls([mock_item]) + assert urls == [] + + +def test_extract_azure_search_urls_with_json_string_output(mock_project_client: MagicMock) -> None: + """Test _extract_azure_search_urls with JSON string output (non-streaming pydantic extra field).""" + client = create_test_azure_ai_client(mock_project_client) + json_output = json.dumps({ + "documents": [{"id": "1"}], + "get_urls": [ + "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01", + ], + }) + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = json_output + + urls = client._extract_azure_search_urls([mock_item]) + assert len(urls) == 1 + assert urls[0] == "https://search.example.com/indexes/idx/docs/1?api-version=2024-07-01" + + +def test_get_search_doc_url_valid(mock_project_client: MagicMock) -> None: + """Test _get_search_doc_url with valid doc_N title.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://example.com/doc/0", "https://example.com/doc/1", "https://example.com/doc/2"] + + assert client._get_search_doc_url("doc_0", get_urls) == "https://example.com/doc/0" + assert client._get_search_doc_url("doc_1", get_urls) == "https://example.com/doc/1" + assert client._get_search_doc_url("doc_2", get_urls) == "https://example.com/doc/2" + + +def test_get_search_doc_url_out_of_range(mock_project_client: MagicMock) -> None: + """Test _get_search_doc_url with out-of-range index.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://example.com/doc/0"] + assert client._get_search_doc_url("doc_5", get_urls) is None + + +def test_get_search_doc_url_no_match(mock_project_client: MagicMock) -> None: + """Test _get_search_doc_url with non-matching title.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://example.com/doc/0"] + assert client._get_search_doc_url("some_title", get_urls) is None + assert client._get_search_doc_url(None, get_urls) is None + assert client._get_search_doc_url("doc_0", []) is None + + +def test_enrich_annotations_with_search_urls(mock_project_client: MagicMock) -> None: + """Test _enrich_annotations_with_search_urls enriches citation annotations.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", + ] + + content = Content.from_text(text="test response") + content.annotations = [ + { + "type": "citation", + "title": "doc_0", + "url": "https://search.example.com/", + }, + { + "type": "citation", + "title": "doc_1", + "url": "https://search.example.com/", + }, + ] + + client._enrich_annotations_with_search_urls([content], get_urls) + + assert content.annotations[0]["additional_properties"]["get_url"] == get_urls[0] + assert content.annotations[1]["additional_properties"]["get_url"] == get_urls[1] + + +def test_enrich_annotations_no_match(mock_project_client: MagicMock) -> None: + """Test _enrich_annotations_with_search_urls with non-matching titles.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] + + content = Content.from_text(text="test response") + content.annotations = [ + { + "type": "citation", + "title": "some_title", + "url": "https://search.example.com/", + }, + ] + + client._enrich_annotations_with_search_urls([content], get_urls) + assert "additional_properties" not in content.annotations[0] or "get_url" not in content.annotations[0].get( + "additional_properties", {} + ) + + +def test_enrich_annotations_empty_get_urls(mock_project_client: MagicMock) -> None: + """Test _enrich_annotations_with_search_urls with empty get_urls.""" + client = create_test_azure_ai_client(mock_project_client) + content = Content.from_text(text="test") + content.annotations = [{"type": "citation", "title": "doc_0", "url": "https://example.com/"}] + + # Should not raise or modify + client._enrich_annotations_with_search_urls([content], []) + assert "additional_properties" not in content.annotations[0] + + +async def test_inner_get_response_enriches_non_streaming(mock_project_client: MagicMock) -> None: + """Test _inner_get_response enriches url_citation annotations for non-streaming responses.""" + client = create_test_azure_ai_client(mock_project_client) + + # Build a ChatResponse with citation annotations and a raw_representation carrying search output + content = Content.from_text(text="Here is the result【5:0†source】.") + content.annotations = [ + Annotation(type="citation", title="doc_0", url="https://search.example.com/"), + ] + msg = Message(role="assistant", contents=[content]) + mock_raw = MagicMock() + mock_search_output = MagicMock() + mock_search_output.type = "azure_ai_search_call_output" + mock_search_output_data = MagicMock() + mock_search_output_data.get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + ] + mock_search_output.output = mock_search_output_data + mock_raw.output = [mock_search_output] + + base_response = ChatResponse(messages=[msg], raw_representation=mock_raw) + + async def _fake_awaitable() -> ChatResponse: + return base_response + + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable()): + result_awaitable = client._inner_get_response(messages=[], options={}, stream=False) + result = await result_awaitable # type: ignore[misc] + + ann = result.messages[0].contents[0].annotations[0] + assert ann["additional_properties"]["get_url"] == ( + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" + ) + + +async def test_inner_get_response_no_search_output_non_streaming(mock_project_client: MagicMock) -> None: + """Test _inner_get_response passes through when no search output exists.""" + client = create_test_azure_ai_client(mock_project_client) + + content = Content.from_text(text="Hello world") + msg = Message(role="assistant", contents=[content]) + mock_raw = MagicMock() + mock_raw.output = [] + base_response = ChatResponse(messages=[msg], raw_representation=mock_raw) + + async def _fake_awaitable() -> ChatResponse: + return base_response + + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=_fake_awaitable()): + result_awaitable = client._inner_get_response(messages=[], options={}, stream=False) + result = await result_awaitable # type: ignore[misc] + + assert result.messages[0].contents[0].text == "Hello world" + + +def _create_mock_stream() -> MagicMock: + """Create a mock ResponseStream with working with_transform_hook.""" + mock_stream = MagicMock(spec=ResponseStream) + mock_stream._transform_hooks = [] + mock_stream.with_transform_hook.side_effect = lambda hook: mock_stream._transform_hooks.append(hook) or mock_stream + return mock_stream + + +def test_inner_get_response_streaming_registers_hook(mock_project_client: MagicMock) -> None: + """Test _inner_get_response appends a transform hook to the stream for streaming responses.""" + client = create_test_azure_ai_client(mock_project_client) + + mock_stream = _create_mock_stream() + + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream): + result = client._inner_get_response(messages=[], options={}, stream=True) + + assert result is mock_stream + assert len(mock_stream._transform_hooks) == 1 + + +def test_streaming_hook_captures_search_urls(mock_project_client: MagicMock) -> None: + """Test the streaming transform hook captures get_urls from search output events.""" + client = create_test_azure_ai_client(mock_project_client) + + mock_stream = _create_mock_stream() + + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream): + client._inner_get_response(messages=[], options={}, stream=True) + + hook = mock_stream._transform_hooks[0] + + # Simulate azure_ai_search_call_output event + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = MagicMock() + mock_item.output.get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + ] + + raw_event = MagicMock() + raw_event.type = "response.output_item.added" + raw_event.item = mock_item + + update = ChatResponseUpdate(raw_representation=raw_event) + result = hook(update) + assert result is update # passes through (no annotations to enrich) + + +def test_streaming_hook_enriches_url_citation(mock_project_client: MagicMock) -> None: + """Test the streaming transform hook enriches url_citation annotations with get_urls.""" + client = create_test_azure_ai_client(mock_project_client) + + mock_stream = _create_mock_stream() + + with patch.object(RawOpenAIResponsesClient, "_inner_get_response", return_value=mock_stream): + client._inner_get_response(messages=[], options={}, stream=True) + + hook = mock_stream._transform_hooks[0] + + # Step 1: Feed search output event to capture URLs + mock_item = MagicMock() + mock_item.type = "azure_ai_search_call_output" + mock_item.output = MagicMock() + mock_item.output.get_urls = [ + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01", + "https://search.example.com/indexes/idx/docs/41?api-version=2024-07-01", + ] + raw_output_event = MagicMock() + raw_output_event.type = "response.output_item.added" + raw_output_event.item = mock_item + hook(ChatResponseUpdate(raw_representation=raw_output_event)) + + # Step 2: Feed url_citation annotation event (annotation is always a dict in streaming) + raw_ann_event = MagicMock() + raw_ann_event.type = "response.output_text.annotation.added" + raw_ann_event.annotation = { + "type": "url_citation", + "title": "doc_0", + "url": "https://search.example.com/", + "start_index": 100, + "end_index": 112, + } + raw_ann_event.annotation_index = 0 + + result = hook(ChatResponseUpdate(raw_representation=raw_ann_event)) + + # Verify the result has enriched annotation + assert result.contents is not None + found = False + for content_item in result.contents: + if hasattr(content_item, "annotations") and content_item.annotations: + for ann in content_item.annotations: + if isinstance(ann, dict) and ann.get("title") == "doc_0": + found = True + assert ann["additional_properties"]["get_url"] == ( + "https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01" + ) + assert found, "Expected url_citation annotation with enriched get_url" + + +def test_build_url_citation_content(mock_project_client: MagicMock) -> None: + """Test _build_url_citation_content creates Content with enriched Annotation.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] + + annotation_data = { + "type": "url_citation", + "title": "doc_0", + "url": "https://search.example.com/", + "start_index": 100, + "end_index": 112, + } + + raw_event = MagicMock() + raw_event.annotation_index = 0 + + content = client._build_url_citation_content(annotation_data, get_urls, raw_event) + + assert content.annotations is not None + ann = content.annotations[0] + assert ann["type"] == "citation" + assert ann["title"] == "doc_0" + assert ann["url"] == "https://search.example.com/" + assert ann["additional_properties"]["get_url"] == get_urls[0] + assert ann["annotated_regions"][0]["start_index"] == 100 + assert ann["annotated_regions"][0]["end_index"] == 112 + + +def test_build_url_citation_content_with_dict(mock_project_client: MagicMock) -> None: + """Test _build_url_citation_content handles dict-style annotation data.""" + client = create_test_azure_ai_client(mock_project_client) + get_urls = ["https://search.example.com/indexes/idx/docs/16?api-version=2024-07-01"] + + annotation_data = { + "type": "url_citation", + "title": "doc_1", + "url": "https://search.example.com/", + "start_index": 200, + "end_index": 215, + } + + raw_event = MagicMock() + raw_event.annotation_index = 1 + + content = client._build_url_citation_content(annotation_data, get_urls, raw_event) + + assert content.annotations is not None + ann = content.annotations[0] + assert ann["type"] == "citation" + assert ann["title"] == "doc_1" + # doc_1 is out of range for a 1-element get_urls, so no get_url + assert "get_url" not in ann.get("additional_properties", {}) + + +# endregion diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py index 25648751d8..3e7ce71096 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_azure_ai_search.py @@ -2,6 +2,7 @@ import asyncio import os +from agent_framework import Annotation from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv @@ -15,6 +16,9 @@ This sample demonstrates usage of AzureAIProjectAgentProvider with Azure AI Search to search through indexed data and answer user questions about it. +Citations from Azure AI Search are automatically enriched with document-specific +URLs (get_url) that can be used to retrieve the original documents. + Prerequisites: 1. Set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME environment variables. 2. Ensure you have an Azure AI Search connection configured in your Azure AI project @@ -29,8 +33,10 @@ async def main() -> None: ): agent = await provider.create_agent( name="MySearchAgent", - instructions="""You are a helpful assistant. You must always provide citations for - answers using the tool and render them as: `[message_idx:search_idx†source]`.""", + instructions=( + "You are a helpful agent that searches hotel information using Azure AI Search. " + "Always use the search tool and index to find hotel data and provide accurate information." + ), tools={ "type": "azure_ai_search", "azure_ai_search": { @@ -46,11 +52,59 @@ async def main() -> None: }, ) - query = "Tell me about insurance options" + query = ( + "Use Azure AI search knowledge tool to find detailed information about a winter hotel." + " Use the search tool and index." # You can modify prompt to force tool usage + ) print(f"User: {query}") + + # Non-streaming: get response with enriched citations result = await agent.run(query) print(f"Result: {result}\n") + # Display citations with document-specific URLs + if result.messages: + citations: list[Annotation] = [] + for msg in result.messages: + for content in msg.contents: + if hasattr(content, "annotations") and content.annotations: + citations.extend(content.annotations) + + if citations: + print("Citations:") + for i, citation in enumerate(citations, 1): + url = citation.get("url", "N/A") + # get_url contains the document-specific REST API URL from Azure AI Search + get_url = (citation.get("additional_properties") or {}).get("get_url") + print(f" [{i}] {citation.get('title', 'N/A')}") + print(f" URL: {url}") + if get_url: + print(f" Document URL: {get_url}") + + # Streaming: collect citations from streamed response + print("\n--- Streaming ---") + print(f"User: {query}") + print("Agent: ", end="", flush=True) + streaming_citations: list[Annotation] = [] + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + for content in getattr(chunk, "contents", []): + annotations = getattr(content, "annotations", []) + if annotations: + streaming_citations.extend(annotations) + + print() + if streaming_citations: + print("\nStreaming Citations:") + for i, citation in enumerate(streaming_citations, 1): + url = citation.get("url", "N/A") + get_url = (citation.get("additional_properties") or {}).get("get_url") + print(f" [{i}] {citation.get('title', 'N/A')}") + print(f" URL: {url}") + if get_url: + print(f" Document URL: {get_url}") + if __name__ == "__main__": asyncio.run(main())