Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict

from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Annotation, Content, Message, SupportsGetEmbeddings
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
from agent_framework._settings import SecretString, load_settings
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
Expand Down Expand Up @@ -47,8 +47,12 @@
from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageImageContent,
KnowledgeBaseMessageImageContentImage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseReference,
KnowledgeBaseRetrievalRequest,
KnowledgeBaseRetrievalResponse,
KnowledgeRetrievalIntent,
KnowledgeRetrievalSemanticIntent,
)
Expand Down Expand Up @@ -78,8 +82,12 @@
from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageImageContent,
KnowledgeBaseMessageImageContentImage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseReference,
KnowledgeBaseRetrievalRequest,
KnowledgeBaseRetrievalResponse,
KnowledgeRetrievalIntent,
KnowledgeRetrievalSemanticIntent,
)
Expand Down Expand Up @@ -154,7 +162,9 @@ def __init__(
top_k: int = 5,
semantic_configuration_name: str | None = None,
vector_field_name: str | None = None,
embedding_function: Callable[[str], Awaitable[list[float]]] | None = None,
embedding_function: Callable[[str], Awaitable[list[float]]]
| SupportsGetEmbeddings[str, list[float], Any]
| None = None,
context_prompt: str | None = None,
azure_openai_resource_url: str | None = None,
model_deployment_name: str | None = None,
Expand All @@ -181,7 +191,7 @@ def __init__(
top_k: Maximum number of documents to retrieve. Default: 5.
semantic_configuration_name: Name of semantic configuration in the index.
vector_field_name: Name of the vector field in the index.
embedding_function: Async function to generate embeddings.
embedding_function: Async function to generate embeddings or a SupportsGetEmbeddings instance.
context_prompt: Custom prompt to prepend to retrieved context.
azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base.
model_deployment_name: Model deployment name in Azure OpenAI.
Expand Down Expand Up @@ -309,9 +319,20 @@ async def __aexit__(
exc_tb: Any,
) -> None:
"""Async context manager exit - cleanup clients."""
await self.close()

async def close(self) -> None:
"""Close all the open clients."""
if self._retrieval_client is not None:
await self._retrieval_client.close()
self._retrieval_client = None
self._knowledge_base_initialized = False
if self._search_client is not None:
await self._search_client.close()
self._search_client = None
if self._index_client is not None:
await self._index_client.close()
self._index_client = None

# -- Hooks pattern ---------------------------------------------------------

Expand All @@ -326,32 +347,23 @@ async def before_run(
"""Retrieve relevant context from Azure AI Search and add to session context."""
messages_list = list(context.input_messages)

def get_role_value(role: str | Any) -> str:
return role.value if hasattr(role, "value") else str(role)

filtered_messages = [
msg
for msg in messages_list
if msg and msg.text and msg.text.strip() and get_role_value(msg.role) in ["user", "assistant"]
msg for msg in messages_list if msg and msg.text and msg.text.strip() and msg.role in ["user", "assistant"]
]
if not filtered_messages:
return

if self.mode == "semantic":
query = "\n".join(msg.text for msg in filtered_messages)
search_result_parts = await self._semantic_search(query)
result_messages = await self._semantic_search(query)
else:
recent_messages = filtered_messages[-self.agentic_message_history_count :]
search_result_parts = await self._agentic_search(recent_messages)
result_messages = await self._agentic_search(recent_messages)

if not search_result_parts:
if not result_messages:
return

context_messages = [Message(role="user", text=self.context_prompt)]
context_messages.extend([Message(role="user", text=part) for part in search_result_parts])
context.extend_messages(self.source_id, context_messages)

# -- Internal methods (ported from AzureAISearchContextProvider) -----------
context.extend_messages(self.source_id, [Message(role="user", text=self.context_prompt), *result_messages])

def _find_vector_fields(self, index: Any) -> list[str]:
"""Find all fields that can store vectors."""
Expand Down Expand Up @@ -432,22 +444,22 @@ async def _auto_discover_vector_field(self) -> None:

self._auto_discovered_vector_field = True

async def _semantic_search(self, query: str) -> list[str]:
async def _semantic_search(self, query: str) -> list[Message]:
"""Perform semantic hybrid search."""
await self._auto_discover_vector_field()

vector_queries: list[VectorizableTextQuery | VectorizedQuery] = []
if self.vector_field_name:
vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k
if self._use_vectorizable_query:
vector_queries = [
VectorizableTextQuery(text=query, k_nearest_neighbors=vector_k, fields=self.vector_field_name)
]
vector_queries = [VectorizableTextQuery(text=query, k=vector_k, fields=self.vector_field_name)]
elif self.embedding_function:
query_vector = await self.embedding_function(query)
vector_queries = [
VectorizedQuery(vector=query_vector, k_nearest_neighbors=vector_k, fields=self.vector_field_name)
]
if isinstance(self.embedding_function, SupportsGetEmbeddings):
embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType]
query_vector: list[float] = embeddings[0].vector # type: ignore[reportUnknownVariableType]
else:
query_vector = await self.embedding_function(query)
vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)]

search_params: dict[str, Any] = {"search_text": query, "top": self.top_k}
if vector_queries:
Expand All @@ -461,13 +473,13 @@ async def _semantic_search(self, query: str) -> list[str]:
raise RuntimeError("Search client is not initialized.")
results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType]

formatted_results: list[str] = []
result_messages: list[Message] = []
async for doc in results: # type: ignore[reportUnknownVariableType]
doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType]
doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType]
if doc_text:
formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType]
return formatted_results
result_messages.append(Message(role="user", text=doc_text)) # type: ignore[reportUnknownArgumentType]
return result_messages

async def _ensure_knowledge_base(self) -> None:
"""Ensure Knowledge Base and knowledge source are created or use existing KB."""
Expand Down Expand Up @@ -550,7 +562,7 @@ async def _ensure_knowledge_base(self) -> None:
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)

async def _agentic_search(self, messages: list[Message]) -> list[str]:
async def _agentic_search(self, messages: list[Message]) -> list[Message]:
"""Perform agentic retrieval with multi-hop reasoning."""
await self._ensure_knowledge_base()

Expand All @@ -577,14 +589,7 @@ async def _agentic_search(self, messages: list[Message]) -> list[str]:
include_activity=True,
)
else:
kb_messages = [
KnowledgeBaseMessage(
role=msg.role if hasattr(msg.role, "value") else str(msg.role),
content=[KnowledgeBaseMessageTextContent(text=msg.text)],
)
for msg in messages
if msg.text
]
kb_messages = self._prepare_messages_for_kb_search(messages)
retrieval_request = KnowledgeBaseRetrievalRequest(
messages=kb_messages,
retrieval_reasoning_effort=reasoning_effort,
Expand All @@ -596,17 +601,136 @@ async def _agentic_search(self, messages: list[Message]) -> list[str]:
raise RuntimeError("Retrieval client not initialized.")
retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request)

if retrieval_result.response and len(retrieval_result.response) > 0:
assistant_message = retrieval_result.response[-1]
if assistant_message.content:
answer_parts: list[str] = []
for content_item in assistant_message.content:
if isinstance(content_item, KnowledgeBaseMessageTextContent) and content_item.text:
answer_parts.append(content_item.text)
if answer_parts:
return answer_parts

return ["No results found from Knowledge Base."]
return self._parse_messages_from_kb_response(retrieval_result)

@staticmethod
def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBaseMessage]:
"""Convert framework Messages to KnowledgeBaseMessages for agentic retrieval.

Handles text and image content types. Other content types (function calls,
errors, etc.) are skipped.

Args:
messages: Framework messages to convert.

Returns:
List of KnowledgeBaseMessage objects suitable for retrieval requests.
"""
kb_messages: list[KnowledgeBaseMessage] = []
for msg in messages:
kb_content: list[KnowledgeBaseMessageTextContent | KnowledgeBaseMessageImageContent] = []
if msg.contents:
for content in msg.contents:
match content.type:
case "text" if content.text:
kb_content.append(KnowledgeBaseMessageTextContent(text=content.text))
case "uri" | "data" if (
content.uri and content.media_type and content.media_type.startswith("image/")
):
kb_content.append(
KnowledgeBaseMessageImageContent(
image=KnowledgeBaseMessageImageContentImage(url=content.uri),
)
)
elif msg.text:
kb_content.append(KnowledgeBaseMessageTextContent(text=msg.text))
if kb_content:
kb_messages.append(KnowledgeBaseMessage(role=msg.role, content=kb_content)) # type: ignore[arg-type]
return kb_messages

@staticmethod
def _parse_references_to_annotations(references: list[KnowledgeBaseReference] | None) -> list[Annotation]:
"""Convert Knowledge Base references to framework Annotations.

Captures all available fields from each reference subtype: URLs, doc keys,
reranker scores, source data, and the raw reference object itself.

Args:
references: The references from a Knowledge Base retrieval response.

Returns:
List of citation Annotations.
"""
if not references:
return []
annotations: list[Annotation] = []
for ref in references:
url: str | None = None
for attr in ("url", "blob_url", "doc_url", "web_url"):
url = getattr(ref, attr, None)
if url:
break

annotation = Annotation(
type="citation",
url=url or "",
title=getattr(ref, "title", None) or ref.id,
)

extra: dict[str, Any] = {
"reference_id": ref.id,
"reference_type": getattr(ref, "type", None),
"activity_source": ref.activity_source,
}
if ref.reranker_score is not None:
extra["reranker_score"] = ref.reranker_score
if ref.source_data:
extra["source_data"] = ref.source_data
doc_key = getattr(ref, "doc_key", None)
if doc_key:
extra["doc_key"] = doc_key
if ref.additional_properties:
extra["sdk_additional_properties"] = ref.additional_properties
sensitivity_info = getattr(ref, "search_sensitivity_label_info", None)
if sensitivity_info:
extra["sensitivity_label"] = {
"display_name": sensitivity_info.display_name,
"sensitivity_label_id": sensitivity_info.sensitivity_label_id,
"is_encrypted": sensitivity_info.is_encrypted,
}

annotation["additional_properties"] = extra
annotation["raw_representation"] = ref
annotations.append(annotation)
return annotations

@staticmethod
def _parse_messages_from_kb_response(retrieval_result: KnowledgeBaseRetrievalResponse) -> list[Message]:
"""Convert a Knowledge Base retrieval response to framework Messages.

Each KnowledgeBaseMessage becomes a Message. References from the response
are converted to Annotations and attached to content items.

Args:
retrieval_result: The full retrieval response including messages and references.

Returns:
List of Messages, or a single default Message if no results found.
"""
if not retrieval_result.response:
return [Message(role="assistant", text="No results found from Knowledge Base.")]

annotations = AzureAISearchContextProvider._parse_references_to_annotations(retrieval_result.references)

result_messages: list[Message] = []
for kb_msg in retrieval_result.response:
if not kb_msg.content:
continue
contents: list[Content] = []
for item in kb_msg.content:
if isinstance(item, KnowledgeBaseMessageTextContent) and item.text:
contents.append(Content.from_text(item.text))
elif isinstance(item, KnowledgeBaseMessageImageContent) and item.image and item.image.url:
contents.append(Content.from_uri(uri=item.image.url, media_type="image/png"))
if contents:
if annotations:
for c in contents:
c.annotations = annotations
result_messages.append(Message(role=kb_msg.role or "assistant", contents=contents))

if not result_messages:
return [Message(role="assistant", text="No results found from Knowledge Base.")]
return result_messages

def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str:
"""Extract readable text from a search document with optional citation."""
Expand Down
Loading