diff --git a/.github/workflows/python-check-coverage.py b/.github/workflows/python-check-coverage.py index 0e6811ccd0..0ad12de993 100644 --- a/.github/workflows/python-check-coverage.py +++ b/.github/workflows/python-check-coverage.py @@ -34,8 +34,8 @@ "packages.core.agent_framework._workflows", "packages.purview.agent_framework_purview", "packages.anthropic.agent_framework_anthropic", - # Add more modules here as coverage improves: - # "packages.azure-ai-search.agent_framework_azure_ai_search", + "packages.azure-ai-search.agent_framework_azure_ai_search", + # Add more modules here as coverage improves } diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index fe95dc877c..153a6f879b 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -2,12 +2,14 @@ # pyright: reportPrivateUsage=false import os -from unittest.mock import AsyncMock, patch +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch import pytest from agent_framework import Message from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import ServiceInitializationError, SettingNotFoundError +from azure.core.credentials import AzureKeyCredential from agent_framework_azure_ai_search._context_provider import AzureAISearchContextProvider @@ -32,6 +34,18 @@ async def __anext__(self): return doc +def _make_mock_index( + fields: list[SimpleNamespace] | None = None, + profiles: list[SimpleNamespace] | None = None, + has_vector_search: bool = True, +) -> SimpleNamespace: + """Create a mock search index with the given fields and vector search profiles.""" + vector_search = None + if has_vector_search: + vector_search = SimpleNamespace(profiles=profiles or []) + return SimpleNamespace(fields=fields or [], vector_search=vector_search) + + @pytest.fixture def mock_search_client() -> AsyncMock: """Create a mock SearchClient that returns one document.""" @@ -116,6 +130,62 @@ def test_env_variable_fallback(self) -> None: assert provider.endpoint == "https://env.search.windows.net" assert provider.index_name == "env-index" + def test_top_k_and_semantic_config(self) -> None: + provider = _make_provider(top_k=10, semantic_configuration_name="my-config") + assert provider.top_k == 10 + assert provider.semantic_configuration_name == "my-config" + + def test_default_context_prompt(self) -> None: + provider = _make_provider() + assert provider.context_prompt == AzureAISearchContextProvider._DEFAULT_SEARCH_CONTEXT_PROMPT + + def test_custom_context_prompt(self) -> None: + provider = _make_provider(context_prompt="Custom prompt:") + assert provider.context_prompt == "Custom prompt:" + + def test_model_name_falls_back_to_deployment_name(self) -> None: + """model_name defaults to model_deployment_name when not explicitly set.""" + provider = _make_provider(model_deployment_name="my-deploy") + assert provider.model_name == "my-deploy" + + def test_model_name_explicit(self) -> None: + provider = _make_provider(model_deployment_name="deploy", model_name="gpt-4") + assert provider.model_name == "gpt-4" + + +# -- Initialization: credential resolution ------------------------------------ + + +class TestInitCredentialResolution: + """Tests for credential resolution paths.""" + + def test_token_credential_used(self) -> None: + mock_cred = AsyncMock() + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="idx", + credential=mock_cred, + ) + provider._auto_discovered_vector_field = True + assert provider.credential is mock_cred + + def test_azure_key_credential_passed_through(self) -> None: + akc = AzureKeyCredential("my-key") + provider = AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="idx", + api_key=akc, + ) + provider._auto_discovered_vector_field = True + assert provider.credential is akc + + def test_no_credential_raises(self) -> None: + with pytest.raises(ServiceInitializationError, match="Azure credential is required"): + AzureAISearchContextProvider( + endpoint="https://test.search.windows.net", + index_name="idx", + ) + # -- Initialization: agentic mode validation ----------------------------------- @@ -166,6 +236,69 @@ def test_vector_field_without_embedding_raises(self) -> None: vector_field_name="embedding", ) + def test_agentic_missing_aoai_url_with_index_raises(self) -> None: + with pytest.raises(ValueError, match="azure_openai_resource_url"): + AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + index_name="idx", + api_key="key", + mode="agentic", + model_deployment_name="deploy", + ) + + def test_agentic_with_kb_name_sets_use_existing(self) -> None: + provider = AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + knowledge_base_name="my-kb", + api_key="key", + mode="agentic", + ) + assert provider._use_existing_knowledge_base is True + assert provider.knowledge_base_name == "my-kb" + + def test_agentic_with_index_generates_kb_name(self) -> None: + provider = AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + index_name="idx", + api_key="key", + mode="agentic", + model_deployment_name="deploy", + azure_openai_resource_url="https://aoai.openai.azure.com", + ) + assert provider._use_existing_knowledge_base is False + assert provider.knowledge_base_name == "idx-kb" + + +# -- __aenter__ / __aexit__ --------------------------------------------------- + + +class TestAsyncContextManager: + """Tests for async context manager.""" + + async def test_aenter_returns_self(self) -> None: + provider = _make_provider() + result = await provider.__aenter__() + assert result is provider + + async def test_closes_retrieval_client(self) -> None: + provider = _make_provider() + mock_retrieval = AsyncMock() + provider._retrieval_client = mock_retrieval + + await provider.__aexit__(None, None, None) + + mock_retrieval.close.assert_awaited_once() + assert provider._retrieval_client is None + + async def test_no_retrieval_client_no_error(self) -> None: + provider = _make_provider() + assert provider._retrieval_client is None + + await provider.__aexit__(None, None, None) # should not raise + # -- before_run: semantic mode ------------------------------------------------- @@ -281,25 +414,881 @@ async def test_only_system_messages_no_search(self, mock_search_client: AsyncMoc mock_search_client.search.assert_not_awaited() + async def test_whitespace_only_messages_filtered(self, mock_search_client: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client -# -- __aexit__ ----------------------------------------------------------------- + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[Message(role="user", contents=[" "])], + session_id="s1", + ) + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + mock_search_client.search.assert_not_awaited() -class TestAexit: - """Tests for async context manager cleanup.""" + async def test_assistant_messages_included(self, mock_search_client: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client - async def test_closes_retrieval_client(self) -> None: + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + Message(role="user", contents=["first question"]), + Message(role="assistant", contents=["first answer"]), + Message(role="user", contents=["follow up"]), + ], + session_id="s1", + ) + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + call_kwargs = mock_search_client.search.call_args[1] + assert "first question" in call_kwargs["search_text"] + assert "first answer" in call_kwargs["search_text"] + assert "follow up" in call_kwargs["search_text"] + + +# -- _find_vector_fields ------------------------------------------------------- + + +class TestFindVectorFields: + """Tests for _find_vector_fields helper.""" + + def test_finds_fields_with_dimensions(self) -> None: + provider = _make_provider() + index = _make_mock_index( + fields=[ + SimpleNamespace(name="embedding", vector_search_dimensions=1536), + SimpleNamespace(name="content", vector_search_dimensions=None), + SimpleNamespace(name="title", vector_search_dimensions=0), + ] + ) + result = provider._find_vector_fields(index) + assert result == ["embedding"] + + def test_returns_empty_for_no_vector_fields(self) -> None: + provider = _make_provider() + index = _make_mock_index( + fields=[ + SimpleNamespace(name="content", vector_search_dimensions=None), + SimpleNamespace(name="title", vector_search_dimensions=0), + ] + ) + result = provider._find_vector_fields(index) + assert result == [] + + def test_multiple_vector_fields(self) -> None: + provider = _make_provider() + index = _make_mock_index( + fields=[ + SimpleNamespace(name="emb1", vector_search_dimensions=768), + SimpleNamespace(name="emb2", vector_search_dimensions=1536), + ] + ) + result = provider._find_vector_fields(index) + assert result == ["emb1", "emb2"] + + +# -- _find_vectorizable_fields ------------------------------------------------ + + +class TestFindVectorizableFields: + """Tests for _find_vectorizable_fields helper.""" + + def test_finds_vectorizable_fields(self) -> None: + provider = _make_provider() + profiles = [SimpleNamespace(name="profile1", vectorizer_name="my-vectorizer")] + fields = [ + SimpleNamespace(name="embedding", vector_search_dimensions=1536, vector_search_profile_name="profile1"), + ] + index = _make_mock_index(fields=fields, profiles=profiles) + result = provider._find_vectorizable_fields(index, ["embedding"]) + assert result == ["embedding"] + + def test_returns_empty_when_no_vector_search(self) -> None: + provider = _make_provider() + index = _make_mock_index(has_vector_search=False) + result = provider._find_vectorizable_fields(index, ["embedding"]) + assert result == [] + + def test_returns_empty_when_no_profiles(self) -> None: + provider = _make_provider() + index = _make_mock_index(profiles=None) + index.vector_search = SimpleNamespace(profiles=None) + result = provider._find_vectorizable_fields(index, ["embedding"]) + assert result == [] + + def test_field_not_in_vector_fields_excluded(self) -> None: + provider = _make_provider() + profiles = [SimpleNamespace(name="profile1", vectorizer_name="my-vectorizer")] + fields = [ + SimpleNamespace(name="other_field", vector_search_dimensions=1536, vector_search_profile_name="profile1"), + ] + index = _make_mock_index(fields=fields, profiles=profiles) + result = provider._find_vectorizable_fields(index, ["embedding"]) + assert result == [] + + def test_profile_without_vectorizer_not_included(self) -> None: + provider = _make_provider() + profiles = [SimpleNamespace(name="profile1", vectorizer_name=None)] + fields = [ + SimpleNamespace(name="embedding", vector_search_dimensions=1536, vector_search_profile_name="profile1"), + ] + index = _make_mock_index(fields=fields, profiles=profiles) + result = provider._find_vectorizable_fields(index, ["embedding"]) + assert result == [] + + def test_field_without_profile_name_excluded(self) -> None: + provider = _make_provider() + profiles = [SimpleNamespace(name="profile1", vectorizer_name="my-vectorizer")] + fields = [ + SimpleNamespace(name="embedding", vector_search_dimensions=1536, vector_search_profile_name=None), + ] + index = _make_mock_index(fields=fields, profiles=profiles) + result = provider._find_vectorizable_fields(index, ["embedding"]) + assert result == [] + + +# -- _auto_discover_vector_field ----------------------------------------------- + + +class TestAutoDiscoverVectorField: + """Tests for _auto_discover_vector_field.""" + + async def test_skip_if_already_discovered(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = True + await provider._auto_discover_vector_field() + # No error, no side effects + + async def test_skip_if_vector_field_set(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + provider.vector_field_name = "my_field" + await provider._auto_discover_vector_field() + # Should return immediately + + async def test_no_index_name_warns(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + provider.index_name = None + provider._index_client = AsyncMock() + + await provider._auto_discover_vector_field() + assert provider._auto_discovered_vector_field is True + + async def test_no_vector_fields_sets_flag(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + mock_index_client = AsyncMock() + mock_index_client.get_index.return_value = _make_mock_index( + fields=[SimpleNamespace(name="content", vector_search_dimensions=None)] + ) + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name is None + + async def test_single_vectorizable_field_discovered(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + profiles = [SimpleNamespace(name="profile1", vectorizer_name="my-vectorizer")] + fields = [ + SimpleNamespace(name="embedding", vector_search_dimensions=1536, vector_search_profile_name="profile1"), + ] + mock_index_client = AsyncMock() + mock_index_client.get_index.return_value = _make_mock_index(fields=fields, profiles=profiles) + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider.vector_field_name == "embedding" + assert provider._use_vectorizable_query is True + assert provider._auto_discovered_vector_field is True + + async def test_multiple_vectorizable_fields_warns(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + profiles = [ + SimpleNamespace(name="profile1", vectorizer_name="v1"), + SimpleNamespace(name="profile2", vectorizer_name="v2"), + ] + fields = [ + SimpleNamespace(name="emb1", vector_search_dimensions=768, vector_search_profile_name="profile1"), + SimpleNamespace(name="emb2", vector_search_dimensions=1536, vector_search_profile_name="profile2"), + ] + mock_index_client = AsyncMock() + mock_index_client.get_index.return_value = _make_mock_index(fields=fields, profiles=profiles) + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider._auto_discovered_vector_field is True + # vector_field_name should not be set when multiple found + assert provider.vector_field_name is None + + async def test_single_vector_field_without_embedding_clears_field(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + provider.embedding_function = None + fields = [ + SimpleNamespace(name="embedding", vector_search_dimensions=1536, vector_search_profile_name=None), + ] + mock_index_client = AsyncMock() + mock_index_client.get_index.return_value = _make_mock_index(fields=fields, profiles=[]) + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name is None + + async def test_single_vector_field_with_embedding_function(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + provider.embedding_function = AsyncMock(return_value=[0.1] * 1536) + fields = [ + SimpleNamespace(name="embedding", vector_search_dimensions=1536, vector_search_profile_name=None), + ] + mock_index_client = AsyncMock() + mock_index_client.get_index.return_value = _make_mock_index(fields=fields, profiles=[]) + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider.vector_field_name == "embedding" + assert provider._use_vectorizable_query is False + + async def test_multiple_vector_fields_no_vectorizable_warns(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + fields = [ + SimpleNamespace(name="emb1", vector_search_dimensions=768, vector_search_profile_name=None), + SimpleNamespace(name="emb2", vector_search_dimensions=1536, vector_search_profile_name=None), + ] + mock_index_client = AsyncMock() + mock_index_client.get_index.return_value = _make_mock_index(fields=fields, profiles=[]) + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider._auto_discovered_vector_field is True + assert provider.vector_field_name is None + + async def test_exception_falls_back_to_keyword_search(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + mock_index_client = AsyncMock() + mock_index_client.get_index.side_effect = Exception("network error") + provider._index_client = mock_index_client + + await provider._auto_discover_vector_field() + assert provider._auto_discovered_vector_field is True + + async def test_creates_index_client_if_none(self) -> None: + provider = _make_provider() + provider._auto_discovered_vector_field = False + provider._index_client = None + + with patch("agent_framework_azure_ai_search._context_provider.SearchIndexClient") as mock_cls: + mock_client = AsyncMock() + mock_client.get_index.return_value = _make_mock_index( + fields=[SimpleNamespace(name="content", vector_search_dimensions=None)] + ) + mock_cls.return_value = mock_client + + await provider._auto_discover_vector_field() + mock_cls.assert_called_once() + assert provider._auto_discovered_vector_field is True + + +# -- _semantic_search ---------------------------------------------------------- + + +class TestSemanticSearch: + """Tests for _semantic_search method.""" + + async def test_basic_keyword_search(self) -> None: + provider = _make_provider() + mock_client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([{"id": "d1", "content": "result text"}]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + results = await provider._semantic_search("test query") + assert len(results) == 1 + assert "result text" in results[0] + call_kwargs = mock_client.search.call_args[1] + assert call_kwargs["search_text"] == "test query" + + async def test_vectorizable_text_query(self) -> None: + provider = _make_provider() + provider._use_vectorizable_query = True + provider.vector_field_name = "embedding" + mock_client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([{"id": "d1", "content": "vector result"}]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + results = await provider._semantic_search("vector query") + assert len(results) == 1 + call_kwargs = mock_client.search.call_args[1] + assert "vector_queries" in call_kwargs + assert len(call_kwargs["vector_queries"]) == 1 + + async def test_vectorized_query_with_embedding_function(self) -> None: + provider = _make_provider() + provider._use_vectorizable_query = False + provider.vector_field_name = "embedding" + provider.embedding_function = AsyncMock(return_value=[0.1, 0.2, 0.3]) + mock_client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([{"id": "d1", "content": "embed result"}]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + results = await provider._semantic_search("embed query") + assert len(results) == 1 + provider.embedding_function.assert_awaited_once_with("embed query") + call_kwargs = mock_client.search.call_args[1] + assert "vector_queries" in call_kwargs + + async def test_semantic_configuration_params(self) -> None: + provider = _make_provider(semantic_configuration_name="my-semantic-config") + mock_client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([{"id": "d1", "content": "semantic result"}]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + await provider._semantic_search("sem query") + call_kwargs = mock_client.search.call_args[1] + assert call_kwargs["query_type"] == "semantic" + assert call_kwargs["semantic_configuration_name"] == "my-semantic-config" + assert "query_caption" in call_kwargs + + async def test_vector_k_with_semantic_config(self) -> None: + provider = _make_provider(semantic_configuration_name="sc", top_k=3) + provider._use_vectorizable_query = True + provider.vector_field_name = "embedding" + mock_client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + await provider._semantic_search("query") + call_kwargs = mock_client.search.call_args[1] + assert "vector_queries" in call_kwargs + assert len(call_kwargs["vector_queries"]) == 1 + + async def test_no_search_client_raises(self) -> None: + provider = _make_provider() + provider._search_client = None + + with pytest.raises(RuntimeError, match="Search client is not initialized"): + await provider._semantic_search("query") + + async def test_empty_results_returns_empty_list(self) -> None: + provider = _make_provider() + mock_client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + results = await provider._semantic_search("query") + assert results == [] + + async def test_doc_without_text_excluded(self) -> None: + provider = _make_provider() + mock_client = AsyncMock() + + async def _search(**kwargs): + # doc with only @search metadata and id - no extractable text + return MockSearchResults([{"id": "d1", "@search.score": 0.9}]) + + mock_client.search = AsyncMock(side_effect=_search) + provider._search_client = mock_client + + results = await provider._semantic_search("query") + assert results == [] + + +# -- _extract_document_text ---------------------------------------------------- + + +class TestExtractDocumentText: + """Tests for _extract_document_text.""" + + def test_content_field_extracted(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"content": "Hello world"}, doc_id="d1") + assert result == "[Source: d1] Hello world" + + def test_text_field_extracted(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"text": "Some text"}, doc_id="d1") + assert result == "[Source: d1] Some text" + + def test_description_field_extracted(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"description": "A description"}, doc_id="d1") + assert result == "[Source: d1] A description" + + def test_body_field_extracted(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"body": "Body content"}, doc_id="d1") + assert result == "[Source: d1] Body content" + + def test_chunk_field_extracted(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"chunk": "Chunk data"}, doc_id="d1") + assert result == "[Source: d1] Chunk data" + + def test_content_field_priority(self) -> None: + provider = _make_provider() + result = provider._extract_document_text( + {"content": "Primary", "text": "Secondary", "description": "Tertiary"}, doc_id="d1" + ) + assert result == "[Source: d1] Primary" + + def test_fallback_to_string_fields(self) -> None: + provider = _make_provider() + result = provider._extract_document_text( + {"title": "My Title", "summary": "My Summary", "id": "skip-this", "@search.score": "skip-meta"}, + doc_id="d1", + ) + assert "title: My Title" in result + assert "summary: My Summary" in result + assert "id" not in result.split("] ")[1] # id should be excluded from fallback + assert "@search.score" not in result + + def test_empty_doc_returns_empty(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({}) + assert result == "" + + def test_no_doc_id_returns_text_only(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"content": "Hello"}, doc_id=None) + assert result == "Hello" + + def test_search_id_fallback(self) -> None: + """Test that doc results using @search.id work too (via before_run path).""" + provider = _make_provider() + result = provider._extract_document_text({"content": "data"}, doc_id="alt-id") + assert result == "[Source: alt-id] data" + + def test_only_id_and_metadata_returns_empty(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"id": "d1", "@search.score": 0.9}) + assert result == "" + + def test_non_string_values_excluded_from_fallback(self) -> None: + provider = _make_provider() + result = provider._extract_document_text({"count": 42, "tags": ["a", "b"]}, doc_id="d1") + # Non-string values should not appear in fallback + assert result == "" + + +# -- _ensure_knowledge_base --------------------------------------------------- + + +class TestEnsureKnowledgeBase: + """Tests for _ensure_knowledge_base.""" + + async def test_already_initialized_returns_early(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + + await provider._ensure_knowledge_base() # should not raise + + async def test_missing_kb_name_raises(self) -> None: provider = _make_provider() + provider._knowledge_base_initialized = False + provider.knowledge_base_name = None + + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await provider._ensure_knowledge_base() + + async def test_existing_kb_sets_initialized(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = True + provider.knowledge_base_name = "existing-kb" + + with patch("agent_framework_azure_ai_search._context_provider.KnowledgeBaseRetrievalClient") as mock_cls: + mock_cls.return_value = AsyncMock() + await provider._ensure_knowledge_base() + assert provider._knowledge_base_initialized is True + + async def test_missing_index_client_raises(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider._index_client = None + + with pytest.raises(ValueError, match="Index client is required"): + await provider._ensure_knowledge_base() + + async def test_missing_aoai_url_raises(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider._index_client = AsyncMock() + provider.azure_openai_resource_url = None + + with pytest.raises(ValueError, match="azure_openai_resource_url is required"): + await provider._ensure_knowledge_base() + + async def test_missing_deployment_name_raises(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider._index_client = AsyncMock() + provider.azure_openai_resource_url = "https://aoai.openai.azure.com" + provider.azure_openai_deployment_name = None + + with pytest.raises(ValueError, match="model_deployment_name is required"): + await provider._ensure_knowledge_base() + + async def test_missing_index_name_raises(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider._index_client = AsyncMock() + provider.azure_openai_resource_url = "https://aoai.openai.azure.com" + provider.azure_openai_deployment_name = "deploy" + provider.index_name = None + + with pytest.raises(ValueError, match="index_name is required"): + await provider._ensure_knowledge_base() + + async def test_creates_knowledge_source_when_not_found(self) -> None: + from azure.core.exceptions import ResourceNotFoundError + + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider.azure_openai_resource_url = "https://aoai.openai.azure.com" + provider.azure_openai_deployment_name = "deploy" + provider.model_name = "gpt-4" + provider.index_name = "test-index" + + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.side_effect = ResourceNotFoundError("not found") + mock_index_client.create_knowledge_source = AsyncMock() + mock_index_client.create_or_update_knowledge_base = AsyncMock() + provider._index_client = mock_index_client + + with patch("agent_framework_azure_ai_search._context_provider.KnowledgeBaseRetrievalClient") as mock_cls: + mock_cls.return_value = AsyncMock() + await provider._ensure_knowledge_base() + + mock_index_client.create_knowledge_source.assert_awaited_once() + mock_index_client.create_or_update_knowledge_base.assert_awaited_once() + assert provider._knowledge_base_initialized is True + + async def test_uses_existing_knowledge_source(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider.azure_openai_resource_url = "https://aoai.openai.azure.com" + provider.azure_openai_deployment_name = "deploy" + provider.model_name = "gpt-4" + provider.index_name = "test-index" + + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.return_value = Mock() # source already exists + mock_index_client.create_or_update_knowledge_base = AsyncMock() + provider._index_client = mock_index_client + + with patch("agent_framework_azure_ai_search._context_provider.KnowledgeBaseRetrievalClient") as mock_cls: + mock_cls.return_value = AsyncMock() + await provider._ensure_knowledge_base() + + mock_index_client.create_knowledge_source.assert_not_awaited() + mock_index_client.create_or_update_knowledge_base.assert_awaited_once() + + async def test_answer_synthesis_output_mode(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider.azure_openai_resource_url = "https://aoai.openai.azure.com" + provider.azure_openai_deployment_name = "deploy" + provider.model_name = "gpt-4" + provider.index_name = "test-index" + provider.knowledge_base_output_mode = "answer_synthesis" + + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.return_value = Mock() + mock_index_client.create_or_update_knowledge_base = AsyncMock() + provider._index_client = mock_index_client + + with patch("agent_framework_azure_ai_search._context_provider.KnowledgeBaseRetrievalClient") as mock_cls: + mock_cls.return_value = AsyncMock() + await provider._ensure_knowledge_base() + + assert provider._knowledge_base_initialized is True + + async def test_medium_reasoning_effort(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = False + provider._use_existing_knowledge_base = False + provider.knowledge_base_name = "test-kb" + provider.azure_openai_resource_url = "https://aoai.openai.azure.com" + provider.azure_openai_deployment_name = "deploy" + provider.model_name = "gpt-4" + provider.index_name = "test-index" + provider.retrieval_reasoning_effort = "medium" + + mock_index_client = AsyncMock() + mock_index_client.get_knowledge_source.return_value = Mock() + mock_index_client.create_or_update_knowledge_base = AsyncMock() + provider._index_client = mock_index_client + + with patch("agent_framework_azure_ai_search._context_provider.KnowledgeBaseRetrievalClient") as mock_cls: + mock_cls.return_value = AsyncMock() + await provider._ensure_knowledge_base() + + assert provider._knowledge_base_initialized is True + + +# -- _agentic_search ---------------------------------------------------------- + + +class TestAgenticSearch: + """Tests for _agentic_search.""" + + async def test_no_retrieval_client_raises(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider._retrieval_client = None + + with pytest.raises(RuntimeError, match="Retrieval client not initialized"): + await provider._agentic_search([Message(role="user", contents=["query"])]) + + async def test_minimal_reasoning_returns_results(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "minimal" + + mock_content = Mock() + mock_content.text = "Answer text" + mock_message = Mock() + mock_message.content = [mock_content] + mock_result = Mock() + mock_result.response = [mock_message] + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) provider._retrieval_client = mock_retrieval - await provider.__aexit__(None, None, None) + # Patch isinstance check for KnowledgeBaseMessageTextContent + with patch( + "agent_framework_azure_ai_search._context_provider.KnowledgeBaseMessageTextContent", + type(mock_content), + ): + results = await provider._agentic_search([Message(role="user", contents=["test query"])]) - mock_retrieval.close.assert_awaited_once() - assert provider._retrieval_client is None + assert results == ["Answer text"] - async def test_no_retrieval_client_no_error(self) -> None: + async def test_non_minimal_reasoning_uses_messages(self) -> None: provider = _make_provider() - assert provider._retrieval_client is None + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "medium" - await provider.__aexit__(None, None, None) # should not raise + mock_content = Mock() + mock_content.text = "Medium answer" + mock_message = Mock() + mock_message.content = [mock_content] + mock_result = Mock() + mock_result.response = [mock_message] + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + with patch( + "agent_framework_azure_ai_search._context_provider.KnowledgeBaseMessageTextContent", + type(mock_content), + ): + results = await provider._agentic_search([ + Message(role="user", contents=["question"]), + Message(role="assistant", contents=["answer"]), + ]) + + assert results == ["Medium answer"] + mock_retrieval.retrieve.assert_awaited_once() + + async def test_no_response_returns_default_message(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "minimal" + + mock_result = Mock() + mock_result.response = [] + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + results = await provider._agentic_search([Message(role="user", contents=["query"])]) + assert results == ["No results found from Knowledge Base."] + + async def test_empty_content_returns_default_message(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "minimal" + + mock_message = Mock() + mock_message.content = None + mock_result = Mock() + mock_result.response = [mock_message] + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + results = await provider._agentic_search([Message(role="user", contents=["query"])]) + assert results == ["No results found from Knowledge Base."] + + async def test_answer_synthesis_output_mode(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "low" + provider.knowledge_base_output_mode = "answer_synthesis" + + mock_content = Mock() + mock_content.text = "Synthesized answer" + mock_message = Mock() + mock_message.content = [mock_content] + mock_result = Mock() + mock_result.response = [mock_message] + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + with patch( + "agent_framework_azure_ai_search._context_provider.KnowledgeBaseMessageTextContent", + type(mock_content), + ): + results = await provider._agentic_search([Message(role="user", contents=["query"])]) + + assert results == ["Synthesized answer"] + + async def test_content_without_text_excluded(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "minimal" + + mock_content_with_text = Mock() + mock_content_with_text.text = "Good content" + mock_content_no_text = Mock() + mock_content_no_text.text = None + mock_message = Mock() + mock_message.content = [mock_content_no_text, mock_content_with_text] + mock_result = Mock() + mock_result.response = [mock_message] + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + with patch( + "agent_framework_azure_ai_search._context_provider.KnowledgeBaseMessageTextContent", + type(mock_content_with_text), + ): + results = await provider._agentic_search([Message(role="user", contents=["query"])]) + + assert results == ["Good content"] + + async def test_none_response_returns_default_message(self) -> None: + provider = _make_provider() + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + provider.retrieval_reasoning_effort = "minimal" + + mock_result = Mock() + mock_result.response = None + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + results = await provider._agentic_search([Message(role="user", contents=["query"])]) + assert results == ["No results found from Knowledge Base."] + + +# -- before_run: agentic mode -------------------------------------------------- + + +class TestBeforeRunAgentic: + """Tests for before_run in agentic mode.""" + + async def test_agentic_mode_calls_agentic_search(self) -> None: + provider = _make_provider() + provider.mode = "agentic" + provider.agentic_message_history_count = 5 + provider._knowledge_base_initialized = True + provider.knowledge_base_name = "kb" + + mock_content = Mock() + mock_content.text = "agentic result" + mock_message = Mock() + mock_message.content = [mock_content] + mock_result = Mock() + mock_result.response = [mock_message] + + mock_retrieval = AsyncMock() + mock_retrieval.retrieve = AsyncMock(return_value=mock_result) + provider._retrieval_client = mock_retrieval + + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[Message(role="user", contents=["agentic question"])], + session_id="s1", + ) + + with patch( + "agent_framework_azure_ai_search._context_provider.KnowledgeBaseMessageTextContent", + type(mock_content), + ): + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + msgs = ctx.context_messages.get(provider.source_id, []) + assert len(msgs) >= 2 + assert msgs[0].text == provider.context_prompt + assert msgs[1].text == "agentic result"