From 6fe9a3bf5d534ebea91ed4894e9dfd6a5d585238 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Wed, 28 Jan 2026 08:58:11 -0800 Subject: [PATCH 1/5] Add core utilities unit tests to improve coverage (#3356) --- .../packages/core/tests/core/test_memory.py | 60 ++++ .../tests/core/test_serializable_mixin.py | 259 ++++++++++++++++++ .../packages/core/tests/core/test_threads.py | 152 ++++++++++ 3 files changed, 471 insertions(+) diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index 6cc7ba436e..cd20a2b1a2 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -91,3 +91,63 @@ async def test_invoking(self) -> None: assert context.messages is not None assert len(context.messages) == 1 assert context.messages[0].text == "Context message" + + +class TestContextProviderBaseClass: + """Tests for ContextProvider base class default implementations.""" + + async def test_base_thread_created_does_nothing(self) -> None: + """Test that base ContextProvider.thread_created does nothing by default.""" + + class MinimalContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context() + + provider = MinimalContextProvider() + # Should not raise and should do nothing + await provider.thread_created("some-thread-id") + await provider.thread_created(None) + + async def test_base_invoked_does_nothing(self) -> None: + """Test that base ContextProvider.invoked does nothing by default.""" + + class MinimalContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context() + + provider = MinimalContextProvider() + message = ChatMessage(role=Role.USER, text="Test") + # Should not raise and should do nothing + await provider.invoked(message) + await provider.invoked(message, response_messages=message) + await provider.invoked(message, invoke_exception=Exception("test")) + + async def test_base_aenter_returns_self(self) -> None: + """Test that base ContextProvider.__aenter__ returns self.""" + + class MinimalContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context() + + provider = MinimalContextProvider() + async with provider as p: + assert p is provider + + async def test_base_aexit_does_nothing(self) -> None: + """Test that base ContextProvider.__aexit__ handles exceptions gracefully.""" + + class MinimalContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context() + + provider = MinimalContextProvider() + # Test exit with no exception + await provider.__aexit__(None, None, None) + # Test exit with exception info + try: + raise ValueError("test error") + except ValueError: + import sys + + exc_info = sys.exc_info() + await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2]) diff --git a/python/packages/core/tests/core/test_serializable_mixin.py b/python/packages/core/tests/core/test_serializable_mixin.py index 0472f881cf..0c1c9c6ad6 100644 --- a/python/packages/core/tests/core/test_serializable_mixin.py +++ b/python/packages/core/tests/core/test_serializable_mixin.py @@ -190,3 +190,262 @@ def __init__(self, value: str, number: int, client: Any = None): assert restored.value == "test" assert restored.number == 42 assert restored.client == mock_client + + def test_exclude_none_in_to_dict(self): + """Test that exclude_none parameter removes None values from to_dict().""" + + class TestClass(SerializationMixin): + def __init__(self, value: str, optional: str | None = None): + self.value = value + self.optional = optional + + obj = TestClass(value="test", optional=None) + data = obj.to_dict(exclude_none=True) + + assert data["value"] == "test" + assert "optional" not in data + + def test_to_dict_with_nested_serialization_protocol(self): + """Test to_dict handles nested SerializationProtocol objects.""" + + class InnerClass(SerializationMixin): + def __init__(self, inner_value: str): + self.inner_value = inner_value + + class OuterClass(SerializationMixin): + def __init__(self, outer_value: str, inner: Any = None): + self.outer_value = outer_value + self.inner = inner + + inner = InnerClass(inner_value="inner_test") + outer = OuterClass(outer_value="outer_test", inner=inner) + data = outer.to_dict() + + assert data["outer_value"] == "outer_test" + assert data["inner"]["inner_value"] == "inner_test" + + def test_to_dict_with_list_of_serialization_protocol(self): + """Test to_dict handles lists containing SerializationProtocol objects.""" + + class ItemClass(SerializationMixin): + def __init__(self, name: str): + self.name = name + + class ContainerClass(SerializationMixin): + def __init__(self, items: list): + self.items = items + + items = [ItemClass(name="item1"), ItemClass(name="item2")] + container = ContainerClass(items=items) + data = container.to_dict() + + assert len(data["items"]) == 2 + assert data["items"][0]["name"] == "item1" + assert data["items"][1]["name"] == "item2" + + def test_to_dict_skips_non_serializable_in_list(self, caplog): + """Test to_dict skips non-serializable items in lists with debug logging.""" + import logging + + class NonSerializable: + pass + + class TestClass(SerializationMixin): + def __init__(self, items: list): + self.items = items + + non_serializable = NonSerializable() + obj = TestClass(items=["serializable", non_serializable]) + + with caplog.at_level(logging.DEBUG): + data = obj.to_dict() + + # Should only contain the serializable item + assert len(data["items"]) == 1 + assert data["items"][0] == "serializable" + + def test_to_dict_with_dict_containing_serialization_protocol(self): + """Test to_dict handles dicts containing SerializationProtocol values.""" + + class ItemClass(SerializationMixin): + def __init__(self, name: str): + self.name = name + + class ContainerClass(SerializationMixin): + def __init__(self, items_dict: dict): + self.items_dict = items_dict + + items = {"a": ItemClass(name="item1"), "b": ItemClass(name="item2")} + container = ContainerClass(items_dict=items) + data = container.to_dict() + + assert data["items_dict"]["a"]["name"] == "item1" + assert data["items_dict"]["b"]["name"] == "item2" + + def test_to_dict_with_datetime_in_dict(self): + """Test to_dict converts datetime objects in dicts to strings.""" + from datetime import datetime + + class TestClass(SerializationMixin): + def __init__(self, metadata: dict): + self.metadata = metadata + + now = datetime(2025, 1, 27, 12, 0, 0) + obj = TestClass(metadata={"created_at": now}) + data = obj.to_dict() + + assert isinstance(data["metadata"]["created_at"], str) + + def test_to_dict_skips_non_serializable_in_dict(self, caplog): + """Test to_dict skips non-serializable values in dicts with debug logging.""" + import logging + + class NonSerializable: + pass + + class TestClass(SerializationMixin): + def __init__(self, metadata: dict): + self.metadata = metadata + + obj = TestClass(metadata={"valid": "value", "invalid": NonSerializable()}) + + with caplog.at_level(logging.DEBUG): + data = obj.to_dict() + + assert data["metadata"]["valid"] == "value" + assert "invalid" not in data["metadata"] + + def test_to_dict_skips_non_serializable_attributes(self, caplog): + """Test to_dict skips non-serializable top-level attributes.""" + import logging + + class NonSerializable: + pass + + class TestClass(SerializationMixin): + def __init__(self, value: str, func: Any = None): + self.value = value + self.func = func + + obj = TestClass(value="test", func=lambda x: x) + + with caplog.at_level(logging.DEBUG): + data = obj.to_dict() + + assert data["value"] == "test" + assert "func" not in data + + def test_from_dict_type_mismatch_raises(self): + """Test from_dict raises ValueError on type mismatch when class has TYPE.""" + # When a class has TYPE defined, and the dict has a different type value, + # it should raise ValueError. But _get_type_identifier prioritizes the + # value's 'type' field, so this scenario doesn't actually trigger the mismatch. + # This test verifies the behavior when we explicitly check type matching. + + class TestClass(SerializationMixin): + TYPE = "expected_type" + + def __init__(self, value: str): + self.value = value + + # The type in the dict will be used for dependency lookup, but the class TYPE is used + # for validation. Actually, looking at the code, _get_type_identifier returns + # the value's type first, so there's no mismatch. Let's skip this test. + # The mismatch check at line 515-516 only triggers when the class has TYPE + # and value has a different type field. Let's test a different scenario. + + # Actually create an instance and serialize, then try deserializing with different class + class AnotherClass(SerializationMixin): + TYPE = "another_type" + + def __init__(self, value: str): + self.value = value + + # Create a data dict that would be from TestClass + data = {"type": "expected_type", "value": "test"} + + # This will work because _get_type_identifier returns "expected_type" from the dict + # The TYPE class attribute is only used when value doesn't have 'type' + obj = AnotherClass.from_dict(data) + assert obj.value == "test" + + def test_from_json(self): + """Test from_json deserializes JSON string.""" + + class TestClass(SerializationMixin): + def __init__(self, value: str): + self.value = value + + json_str = '{"type": "test_class", "value": "test_value"}' + obj = TestClass.from_json(json_str) + + assert obj.value == "test_value" + + def test_get_type_identifier_with_instance_type(self): + """Test _get_type_identifier uses instance 'type' attribute.""" + + class TestClass(SerializationMixin): + def __init__(self, value: str): + self.value = value + self.type = "custom_type" + + obj = TestClass(value="test") + data = obj.to_dict() + + assert data["type"] == "custom_type" + + def test_get_type_identifier_with_class_TYPE(self): + """Test _get_type_identifier uses class TYPE constant.""" + + class TestClass(SerializationMixin): + TYPE = "class_level_type" + + def __init__(self, value: str): + self.value = value + + obj = TestClass(value="test") + data = obj.to_dict() + + assert data["type"] == "class_level_type" + + def test_instance_specific_dependency_injection(self): + """Test instance-specific dependency injection with field:name format.""" + + class TestClass(SerializationMixin): + INJECTABLE = {"config"} + + def __init__(self, name: str, config: Any = None): + self.name = name + self.config = config + + dependencies = { + "test_class": { + "name:special_instance": {"config": "special_config"}, + } + } + + # This should match the instance-specific dependency + obj = TestClass.from_dict({"type": "test_class", "name": "special_instance"}, dependencies=dependencies) + + assert obj.name == "special_instance" + assert obj.config == "special_config" + + def test_dependency_dict_merging(self): + """Test that dict dependencies are merged with existing dict kwargs.""" + + class TestClass(SerializationMixin): + INJECTABLE = {"options"} + + def __init__(self, value: str, options: dict | None = None): + self.value = value + self.options = options or {} + + # Existing options in data + data = {"type": "test_class", "value": "test", "options": {"existing": "value"}} + # Additional options from dependencies + dependencies = {"test_class": {"options": {"injected": "option"}}} + + obj = TestClass.from_dict(data, dependencies=dependencies) + + assert obj.options["existing"] == "value" + assert obj.options["injected"] == "option" diff --git a/python/packages/core/tests/core/test_threads.py b/python/packages/core/tests/core/test_threads.py index 492ed11519..01d5ceb98f 100644 --- a/python/packages/core/tests/core/test_threads.py +++ b/python/packages/core/tests/core/test_threads.py @@ -446,3 +446,155 @@ def test_init_with_chat_message_store_state_no_messages(self) -> None: assert state.service_thread_id is None assert state.chat_message_store_state is not None assert state.chat_message_store_state.messages == [] + + def test_init_with_chat_message_store_state_object(self) -> None: + """Test AgentThreadState initialization with ChatMessageStoreState object.""" + store_state = ChatMessageStoreState(messages=[ChatMessage(role=Role.USER, text="test")]) + state = AgentThreadState(chat_message_store_state=store_state) + + assert state.service_thread_id is None + assert state.chat_message_store_state is store_state + assert len(state.chat_message_store_state.messages) == 1 + + def test_init_with_invalid_chat_message_store_state_type(self) -> None: + """Test AgentThreadState initialization with invalid chat_message_store_state type.""" + with pytest.raises(TypeError, match="Could not parse ChatMessageStoreState"): + AgentThreadState(chat_message_store_state="invalid_type") # type: ignore[arg-type] + + +class TestChatMessageStoreStateEdgeCases: + """Additional edge case tests for ChatMessageStoreState.""" + + def test_init_with_invalid_messages_type(self) -> None: + """Test ChatMessageStoreState initialization with invalid messages type.""" + with pytest.raises(TypeError, match="Messages should be a list"): + ChatMessageStoreState(messages="invalid") # type: ignore[arg-type] + + def test_init_with_dict_messages(self) -> None: + """Test ChatMessageStoreState initialization with dict messages.""" + messages = [ + {"role": "user", "text": "Hello"}, + {"role": "assistant", "text": "Hi there!"}, + ] + state = ChatMessageStoreState(messages=messages) + + assert len(state.messages) == 2 + assert isinstance(state.messages[0], ChatMessage) + assert state.messages[0].text == "Hello" + + +class TestChatMessageStoreEdgeCases: + """Additional edge case tests for ChatMessageStore.""" + + async def test_deserialize_class_method(self) -> None: + """Test ChatMessageStore.deserialize class method.""" + serialized_data = { + "messages": [ + {"role": "user", "text": "Hello", "message_id": "msg1"}, + ] + } + + store = await ChatMessageStore.deserialize(serialized_data) + + assert isinstance(store, ChatMessageStore) + messages = await store.list_messages() + assert len(messages) == 1 + assert messages[0].text == "Hello" + + async def test_deserialize_empty_state(self) -> None: + """Test ChatMessageStore.deserialize with empty state.""" + serialized_data: dict[str, Any] = {"messages": []} + + store = await ChatMessageStore.deserialize(serialized_data) + + assert isinstance(store, ChatMessageStore) + messages = await store.list_messages() + assert len(messages) == 0 + + +class TestAgentThreadEdgeCases: + """Additional edge case tests for AgentThread.""" + + def test_is_initialized_with_service_thread_id(self) -> None: + """Test is_initialized property when service_thread_id is set.""" + thread = AgentThread(service_thread_id="test-123") + assert thread.is_initialized is True + + def test_is_initialized_with_message_store(self) -> None: + """Test is_initialized property when message_store is set.""" + store = ChatMessageStore() + thread = AgentThread(message_store=store) + assert thread.is_initialized is True + + def test_is_initialized_with_nothing(self) -> None: + """Test is_initialized property when nothing is set.""" + thread = AgentThread() + assert thread.is_initialized is False + + async def test_deserialize_with_custom_message_store(self) -> None: + """Test deserialize using a custom message store.""" + serialized_data = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + custom_store = MockChatMessageStore() + + thread = await AgentThread.deserialize(serialized_data, message_store=custom_store) + + assert thread.message_store is custom_store + messages = await custom_store.list_messages() + assert len(messages) == 1 + + async def test_deserialize_with_failing_message_store_raises(self) -> None: + """Test deserialize raises AgentThreadException when message store fails.""" + + class FailingStore: + async def add_messages(self, messages: Sequence[ChatMessage], **kwargs: Any) -> None: + raise RuntimeError("Store failed") + + serialized_data = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + failing_store = FailingStore() + + with pytest.raises(AgentThreadException, match="Failed to deserialize"): + await AgentThread.deserialize(serialized_data, message_store=failing_store) + + async def test_update_from_thread_state_with_service_thread_id(self) -> None: + """Test update_from_thread_state sets service_thread_id.""" + thread = AgentThread() + serialized_data = {"service_thread_id": "new-thread-id"} + + await thread.update_from_thread_state(serialized_data) + + assert thread.service_thread_id == "new-thread-id" + + async def test_update_from_thread_state_with_empty_chat_state(self) -> None: + """Test update_from_thread_state with empty chat_message_store_state.""" + thread = AgentThread() + serialized_data = {"service_thread_id": None, "chat_message_store_state": None} + + await thread.update_from_thread_state(serialized_data) + + assert thread.message_store is None + + async def test_update_from_thread_state_creates_message_store(self) -> None: + """Test update_from_thread_state creates message store if not existing.""" + thread = AgentThread() + serialized_data = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + + await thread.update_from_thread_state(serialized_data) + + assert thread.message_store is not None + messages = await thread.message_store.list_messages() + assert len(messages) == 1 From 0d114ec5e5109c101295b76cae43a19854ae1403 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Wed, 28 Jan 2026 09:37:08 -0800 Subject: [PATCH 2/5] Address PR comments: remove redundant imports and fix misleading test --- .../tests/core/test_serializable_mixin.py | 44 +++++-------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/python/packages/core/tests/core/test_serializable_mixin.py b/python/packages/core/tests/core/test_serializable_mixin.py index 0c1c9c6ad6..05ece1072b 100644 --- a/python/packages/core/tests/core/test_serializable_mixin.py +++ b/python/packages/core/tests/core/test_serializable_mixin.py @@ -245,7 +245,6 @@ def __init__(self, items: list): def test_to_dict_skips_non_serializable_in_list(self, caplog): """Test to_dict skips non-serializable items in lists with debug logging.""" - import logging class NonSerializable: pass @@ -254,8 +253,7 @@ class TestClass(SerializationMixin): def __init__(self, items: list): self.items = items - non_serializable = NonSerializable() - obj = TestClass(items=["serializable", non_serializable]) + obj = TestClass(items=["serializable", NonSerializable()]) with caplog.at_level(logging.DEBUG): data = obj.to_dict() @@ -298,7 +296,6 @@ def __init__(self, metadata: dict): def test_to_dict_skips_non_serializable_in_dict(self, caplog): """Test to_dict skips non-serializable values in dicts with debug logging.""" - import logging class NonSerializable: pass @@ -317,10 +314,6 @@ def __init__(self, metadata: dict): def test_to_dict_skips_non_serializable_attributes(self, caplog): """Test to_dict skips non-serializable top-level attributes.""" - import logging - - class NonSerializable: - pass class TestClass(SerializationMixin): def __init__(self, value: str, func: Any = None): @@ -335,40 +328,25 @@ def __init__(self, value: str, func: Any = None): assert data["value"] == "test" assert "func" not in data - def test_from_dict_type_mismatch_raises(self): - """Test from_dict raises ValueError on type mismatch when class has TYPE.""" - # When a class has TYPE defined, and the dict has a different type value, - # it should raise ValueError. But _get_type_identifier prioritizes the - # value's 'type' field, so this scenario doesn't actually trigger the mismatch. - # This test verifies the behavior when we explicitly check type matching. + def test_from_dict_without_type_in_data(self): + """Test from_dict uses class TYPE when no type field in data.""" class TestClass(SerializationMixin): - TYPE = "expected_type" + TYPE = "my_custom_type" def __init__(self, value: str): self.value = value - # The type in the dict will be used for dependency lookup, but the class TYPE is used - # for validation. Actually, looking at the code, _get_type_identifier returns - # the value's type first, so there's no mismatch. Let's skip this test. - # The mismatch check at line 515-516 only triggers when the class has TYPE - # and value has a different type field. Let's test a different scenario. + # Data without 'type' field - class TYPE should be used for type identifier + data = {"value": "test"} - # Actually create an instance and serialize, then try deserializing with different class - class AnotherClass(SerializationMixin): - TYPE = "another_type" - - def __init__(self, value: str): - self.value = value - - # Create a data dict that would be from TestClass - data = {"type": "expected_type", "value": "test"} - - # This will work because _get_type_identifier returns "expected_type" from the dict - # The TYPE class attribute is only used when value doesn't have 'type' - obj = AnotherClass.from_dict(data) + obj = TestClass.from_dict(data) assert obj.value == "test" + # Verify to_dict includes the type + out = obj.to_dict() + assert out["type"] == "my_custom_type" + def test_from_json(self): """Test from_json deserializes JSON string.""" From 95268912aa47a99804815f9e84559a447376e3df Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Wed, 28 Jan 2026 10:16:13 -0800 Subject: [PATCH 3/5] Refactor tests to use module-level mock class instead of inline classes --- .../packages/core/tests/core/test_memory.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index cd20a2b1a2..84f1d05e84 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import MutableSequence from typing import Any @@ -8,7 +9,7 @@ class MockContextProvider(ContextProvider): - """Mock ContextProvider for testing.""" + """Mock ContextProvider for testing that tracks all method calls.""" def __init__(self, messages: list[ChatMessage] | None = None) -> None: self.context_messages = messages @@ -44,6 +45,18 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * return context +class MinimalContextProvider(ContextProvider): + """Minimal ContextProvider that only implements the abstract method. + + Used for testing the base class default implementations of thread_created, + invoked, __aenter__, and __aexit__. + """ + + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + """Return empty context.""" + return Context() + + class TestContext: """Tests for Context class.""" @@ -92,17 +105,8 @@ async def test_invoking(self) -> None: assert len(context.messages) == 1 assert context.messages[0].text == "Context message" - -class TestContextProviderBaseClass: - """Tests for ContextProvider base class default implementations.""" - async def test_base_thread_created_does_nothing(self) -> None: """Test that base ContextProvider.thread_created does nothing by default.""" - - class MinimalContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context() - provider = MinimalContextProvider() # Should not raise and should do nothing await provider.thread_created("some-thread-id") @@ -110,11 +114,6 @@ async def invoking(self, messages, **kwargs): async def test_base_invoked_does_nothing(self) -> None: """Test that base ContextProvider.invoked does nothing by default.""" - - class MinimalContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context() - provider = MinimalContextProvider() message = ChatMessage(role=Role.USER, text="Test") # Should not raise and should do nothing @@ -124,22 +123,12 @@ async def invoking(self, messages, **kwargs): async def test_base_aenter_returns_self(self) -> None: """Test that base ContextProvider.__aenter__ returns self.""" - - class MinimalContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context() - provider = MinimalContextProvider() async with provider as p: assert p is provider async def test_base_aexit_does_nothing(self) -> None: """Test that base ContextProvider.__aexit__ handles exceptions gracefully.""" - - class MinimalContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context() - provider = MinimalContextProvider() # Test exit with no exception await provider.__aexit__(None, None, None) @@ -147,7 +136,5 @@ async def invoking(self, messages, **kwargs): try: raise ValueError("test error") except ValueError: - import sys - exc_info = sys.exc_info() await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2]) From 65a7d494036d30061a95195bd42d314187966c63 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Wed, 28 Jan 2026 10:35:15 -0800 Subject: [PATCH 4/5] Remove unnecessary tests for trivial base class implementations --- .../packages/core/tests/core/test_memory.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index 84f1d05e84..6cc7ba436e 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import sys from collections.abc import MutableSequence from typing import Any @@ -9,7 +8,7 @@ class MockContextProvider(ContextProvider): - """Mock ContextProvider for testing that tracks all method calls.""" + """Mock ContextProvider for testing.""" def __init__(self, messages: list[ChatMessage] | None = None) -> None: self.context_messages = messages @@ -45,18 +44,6 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * return context -class MinimalContextProvider(ContextProvider): - """Minimal ContextProvider that only implements the abstract method. - - Used for testing the base class default implementations of thread_created, - invoked, __aenter__, and __aexit__. - """ - - async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: - """Return empty context.""" - return Context() - - class TestContext: """Tests for Context class.""" @@ -104,37 +91,3 @@ async def test_invoking(self) -> None: assert context.messages is not None assert len(context.messages) == 1 assert context.messages[0].text == "Context message" - - async def test_base_thread_created_does_nothing(self) -> None: - """Test that base ContextProvider.thread_created does nothing by default.""" - provider = MinimalContextProvider() - # Should not raise and should do nothing - await provider.thread_created("some-thread-id") - await provider.thread_created(None) - - async def test_base_invoked_does_nothing(self) -> None: - """Test that base ContextProvider.invoked does nothing by default.""" - provider = MinimalContextProvider() - message = ChatMessage(role=Role.USER, text="Test") - # Should not raise and should do nothing - await provider.invoked(message) - await provider.invoked(message, response_messages=message) - await provider.invoked(message, invoke_exception=Exception("test")) - - async def test_base_aenter_returns_self(self) -> None: - """Test that base ContextProvider.__aenter__ returns self.""" - provider = MinimalContextProvider() - async with provider as p: - assert p is provider - - async def test_base_aexit_does_nothing(self) -> None: - """Test that base ContextProvider.__aexit__ handles exceptions gracefully.""" - provider = MinimalContextProvider() - # Test exit with no exception - await provider.__aexit__(None, None, None) - # Test exit with exception info - try: - raise ValueError("test error") - except ValueError: - exc_info = sys.exc_info() - await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2]) From 87a1547d969826175754e7f04f33031601f7dca7 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Wed, 28 Jan 2026 10:56:53 -0800 Subject: [PATCH 5/5] Restore base class tests with module-level helper class --- .../packages/core/tests/core/test_memory.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index 6cc7ba436e..bcc299ed37 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import MutableSequence from typing import Any @@ -44,6 +45,18 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * return context +class MinimalContextProvider(ContextProvider): + """Minimal ContextProvider that only implements the required abstract method. + + Used to test the base class default implementations of thread_created, + invoked, __aenter__, and __aexit__. + """ + + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + """Return empty context.""" + return Context() + + class TestContext: """Tests for Context class.""" @@ -91,3 +104,33 @@ async def test_invoking(self) -> None: assert context.messages is not None assert len(context.messages) == 1 assert context.messages[0].text == "Context message" + + async def test_base_thread_created_does_nothing(self) -> None: + """Test that base ContextProvider.thread_created does nothing by default.""" + provider = MinimalContextProvider() + await provider.thread_created("some-thread-id") + await provider.thread_created(None) + + async def test_base_invoked_does_nothing(self) -> None: + """Test that base ContextProvider.invoked does nothing by default.""" + provider = MinimalContextProvider() + message = ChatMessage(role=Role.USER, text="Test") + await provider.invoked(message) + await provider.invoked(message, response_messages=message) + await provider.invoked(message, invoke_exception=Exception("test")) + + async def test_base_aenter_returns_self(self) -> None: + """Test that base ContextProvider.__aenter__ returns self.""" + provider = MinimalContextProvider() + async with provider as p: + assert p is provider + + async def test_base_aexit_does_nothing(self) -> None: + """Test that base ContextProvider.__aexit__ handles exceptions gracefully.""" + provider = MinimalContextProvider() + await provider.__aexit__(None, None, None) + try: + raise ValueError("test error") + except ValueError: + exc_info = sys.exc_info() + await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2])