Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 219 additions & 1 deletion python/packages/azure-ai/agent_framework_azure_ai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,29 @@

import json
import logging
import re
import sys
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Mapping, Sequence
from contextlib import suppress
from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast

from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
Agent,
Annotation,
BaseContextProvider,
ChatAndFunctionMiddlewareTypes,
ChatMiddlewareLayer,
ChatResponse,
ChatResponseUpdate,
Content,
FunctionInvocationConfiguration,
FunctionInvocationLayer,
FunctionTool,
Message,
MiddlewareTypes,
ResponseStream,
TextSpanRegion,
)
from agent_framework._settings import load_settings
from agent_framework._tools import ToolTypes
Expand Down Expand Up @@ -616,6 +623,217 @@ def _update_agent_name_and_description(self, agent_name: str | None, description
if description and not self.agent_description:
self.agent_description = description

# region Azure AI Search Citation Enhancement

def _extract_azure_search_urls(self, output_items: Any) -> list[str]:
"""Extract document URLs from azure_ai_search_call_output items.

Args:
output_items: The response output items to scan.

Returns:
A flat list of get_urls from all azure_ai_search_call_output items.
"""
get_urls: list[str] = []
for item in output_items:
item_type = getattr(item, "type", None)
if isinstance(item, dict):
item_type = item.get("type")
if item_type == "azure_ai_search_call_output":
output = getattr(item, "output", None)
if isinstance(item, dict):
output = item.get("output")
if isinstance(output, str):
try:
output = json.loads(output)
except (json.JSONDecodeError, TypeError):
continue
if output is not None:
urls = getattr(output, "get_urls", None)
if isinstance(output, dict):
urls = output.get("get_urls")
if urls and isinstance(urls, list):
get_urls.extend(urls)
return get_urls

def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) -> str | None:
"""Map a citation title like 'doc_0' to its corresponding get_url.

Args:
citation_title: The annotation title (e.g., "doc_0").
get_urls: The list of document URLs from azure_ai_search_call_output.

Returns:
The matching document URL if found, otherwise None.
"""
if not citation_title or not get_urls:
return None
match = re.search(r"doc_(\d+)", citation_title)
if not match:
return None
doc_index = int(match.group(1))
if 0 <= doc_index < len(get_urls):
return str(get_urls[doc_index])
return None

def _enrich_annotations_with_search_urls(self, contents: list[Content], get_urls: list[str]) -> None:
"""Enrich citation annotations in contents with real document URLs from Azure AI Search.

Looks for annotations with ``type == "citation"`` and a ``title`` matching ``doc_N``,
then adds the corresponding document URL from *get_urls* to ``additional_properties["get_url"]``.

Args:
contents: The parsed content list from a ChatResponse or ChatResponseUpdate.
get_urls: Document URLs extracted from azure_ai_search_call_output.
"""
if not get_urls:
return
for content in contents:
if not content.annotations:
continue
for annotation in content.annotations:
if not isinstance(annotation, dict):
continue
if annotation.get("type") != "citation":
continue
title = annotation.get("title")
doc_url = self._get_search_doc_url(title, get_urls)
if doc_url:
annotation.setdefault("additional_properties", {})["get_url"] = doc_url

def _build_url_citation_content(self, annotation_data: Any, get_urls: list[str], raw_event: Any) -> Content:
"""Build a Content with a citation Annotation from a url_citation streaming event.

The base class does not handle ``url_citation`` annotations in streaming, so this
method creates the appropriate framework content for them.

Args:
annotation_data: The raw annotation object/dict from the streaming event.
get_urls: Captured document URLs for enrichment.
raw_event: The raw streaming event for raw_representation.

Returns:
A Content object containing the citation annotation.
"""

def _val(key: str) -> Any:
if isinstance(annotation_data, dict):
return annotation_data.get(key)
return getattr(annotation_data, key, None)

ann_title = str(_val("title") or "")
ann_url = str(_val("url") or "")
ann_start = _val("start_index")
ann_end = _val("end_index")

additional_props: dict[str, Any] = {
"annotation_index": getattr(raw_event, "annotation_index", None),
}
doc_url = self._get_search_doc_url(ann_title, get_urls)
if doc_url:
additional_props["get_url"] = doc_url

annotation_obj = Annotation(
type="citation",
title=ann_title,
url=ann_url,
additional_properties=additional_props,
raw_representation=annotation_data,
)
if ann_start is not None and ann_end is not None:
annotation_obj["annotated_regions"] = [
TextSpanRegion(type="text_span", start_index=ann_start, end_index=ann_end)
]

return Content.from_text(text="", annotations=[annotation_obj], raw_representation=raw_event)

@override
def _inner_get_response(
self,
*,
messages: Sequence[Message],
options: Mapping[str, Any],
stream: bool = False,
**kwargs: Any,
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
"""Wrap base response to enrich Azure AI Search citation annotations.

For non-streaming responses, the ``ChatResponse.raw_representation`` carries the
full response including ``azure_ai_search_call_output`` items. After the base class
parses the response, ``url_citation`` annotations are enriched with per-document URLs.

For streaming responses, a transform hook is registered on the ``ResponseStream`` to
capture ``get_urls`` from search output events and enrich ``url_citation`` annotations
as they arrive. The captured URL state is local to the stream closure, so concurrent
streams do not interfere.
"""
result = super()._inner_get_response(messages=messages, options=options, stream=stream, **kwargs)

if not stream:
original_awaitable: Awaitable[ChatResponse] = result # type: ignore[assignment]

async def _enrich_response() -> ChatResponse:
response = await original_awaitable
raw = getattr(response, "raw_representation", None)
get_urls = self._extract_azure_search_urls(getattr(raw, "output", []))
if get_urls:
for msg in response.messages:
self._enrich_annotations_with_search_urls(list(msg.contents or []), get_urls)
return response

return _enrich_response()

# Streaming: use a closure-local list so concurrent streams don't interfere
stream_result: ResponseStream[ChatResponseUpdate, ChatResponse] = result # type: ignore[assignment]
search_get_urls: list[str] = []

def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate:
raw = getattr(update, "raw_representation", None)
if raw is None:
return update
event_type = getattr(raw, "type", None)

# Capture get_urls from azure_ai_search_call_output items.
# Check both "added" and "done" events because the output data (including
# get_urls) may only be fully populated in the "done" event.
if event_type in ("response.output_item.added", "response.output_item.done"):
item = getattr(raw, "item", None)
if item is not None:
urls = self._extract_azure_search_urls([item])
if urls:
search_get_urls.extend(urls)

# Handle url_citation annotations (not handled by the base class in streaming)
if event_type == "response.output_text.annotation.added":
ann = getattr(raw, "annotation", None)
if ann is not None:
ann_type = ann.get("type") if isinstance(ann, dict) else getattr(ann, "type", None)
if ann_type == "url_citation":
citation_content = self._build_url_citation_content(ann, search_get_urls, raw)
contents_list = list(update.contents or [])
contents_list.append(citation_content)
return ChatResponseUpdate(
contents=contents_list,
conversation_id=update.conversation_id,
response_id=update.response_id,
role=update.role,
model_id=update.model_id,
continuation_token=update.continuation_token,
additional_properties=update.additional_properties,
raw_representation=update.raw_representation,
)

# Enrich any citation annotations already parsed by the base class
if update.contents and search_get_urls:
self._enrich_annotations_with_search_urls(list(update.contents), search_get_urls)

return update

stream_result.with_transform_hook(_enrich_update)
return stream_result

# endregion

# region Hosted Tool Factory Methods (Azure-specific overrides)

@staticmethod
Expand Down
Loading