From 1c42f3692bd9c79645afb3128a8b8402d81067dc Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 9 Feb 2026 15:20:02 +0100 Subject: [PATCH 01/12] PR1: Add core context provider types and tests New types in _sessions.py (no changes to existing code): - SessionContext: per-invocation state with extend_messages/get_messages/ extend_instructions/extend_tools and read-only response property - _ContextProviderBase: base class with before_run/after_run hooks - _HistoryProviderBase: storage base with load/store flags, abstract get_messages/save_messages, default before_run/after_run - AgentSession: lightweight session with state dict, to_dict/from_dict - InMemoryHistoryProvider: built-in provider storing in session.state 35 unit tests covering all classes and configuration flags. --- .../core/agent_framework/_sessions.py | 447 ++++++++++++++++++ .../packages/core/tests/core/test_sessions.py | 376 +++++++++++++++ 2 files changed, 823 insertions(+) create mode 100644 python/packages/core/agent_framework/_sessions.py create mode 100644 python/packages/core/tests/core/test_sessions.py diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py new file mode 100644 index 0000000000..7c3bb9e691 --- /dev/null +++ b/python/packages/core/agent_framework/_sessions.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import uuid +from abc import abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from ._tools import ToolProtocol +from ._types import AgentResponse, ChatMessage + +if TYPE_CHECKING: + from ._agents import SupportsAgentRun + +"""Unified context management types for the agent framework. + +This module provides the core types for the context provider pipeline: +- SessionContext: Per-invocation state passed through providers +- BaseContextProvider: Base class for context providers (renamed to ContextProvider in PR2) +- BaseHistoryProvider: Base class for history storage providers (renamed to HistoryProvider in PR2) +- AgentSession: Lightweight session state container +- InMemoryHistoryProvider: Built-in in-memory history provider +""" + + +class SessionContext: + """Per-invocation state passed through the context provider pipeline. + + Created fresh for each agent.run() call. Providers read from and write to + the mutable fields to add context before invocation and process responses after. + + Attributes: + session_id: The ID of the current session. + service_session_id: Service-managed session ID (if present, service handles storage). + input_messages: The new messages being sent to the agent (set by caller). + context_messages: Dict mapping source_id -> messages added by that provider. + Maintains insertion order (provider execution order). + instructions: Additional instructions added by providers. + tools: Additional tools added by providers. + response: After invocation, contains the full AgentResponse (read-only property). + options: Options passed to agent.run() - read-only, for reflection only. + metadata: Shared metadata dictionary for cross-provider communication. + """ + + def __init__( + self, + *, + session_id: str | None = None, + service_session_id: str | None = None, + input_messages: list[ChatMessage], + context_messages: dict[str, list[ChatMessage]] | None = None, + instructions: list[str] | None = None, + tools: list[ToolProtocol] | None = None, + options: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ): + """Initialize the session context. + + Args: + session_id: The ID of the current session. + service_session_id: Service-managed session ID. + input_messages: The new messages being sent to the agent. + context_messages: Pre-populated context messages by source. + instructions: Pre-populated instructions. + tools: Pre-populated tools. + options: Options from agent.run() - read-only for providers. + metadata: Shared metadata for cross-provider communication. + """ + self.session_id = session_id + self.service_session_id = service_session_id + self.input_messages = input_messages + self.context_messages: dict[str, list[ChatMessage]] = context_messages or {} + self.instructions: list[str] = instructions or [] + self.tools: list[ToolProtocol] = tools or [] + self._response: AgentResponse | None = None + self.options: dict[str, Any] = options or {} + self.metadata: dict[str, Any] = metadata or {} + + @property + def response(self) -> AgentResponse | None: + """The agent's response. Set by the framework after invocation, read-only for providers.""" + return self._response + + def extend_messages(self, source_id: str, messages: Sequence[ChatMessage]) -> None: + """Add context messages from a specific source. + + Messages are stored keyed by source_id, maintaining insertion order + based on provider execution order. + + Args: + source_id: The provider source_id adding these messages. + messages: The messages to add. + """ + if source_id not in self.context_messages: + self.context_messages[source_id] = [] + self.context_messages[source_id].extend(messages) + + def extend_instructions(self, source_id: str, instructions: str | Sequence[str]) -> None: + """Add instructions to be prepended to the conversation. + + Args: + source_id: The provider source_id adding these instructions. + instructions: A single instruction string or sequence of strings. + """ + if isinstance(instructions, str): + instructions = [instructions] + self.instructions.extend(instructions) + + def extend_tools(self, source_id: str, tools: Sequence[ToolProtocol]) -> None: + """Add tools to be available for this invocation. + + Tools are added with source attribution in their metadata. + + Args: + source_id: The provider source_id adding these tools. + tools: The tools to add. + """ + for tool in tools: + if hasattr(tool, "additional_properties") and isinstance(tool.additional_properties, dict): + tool.additional_properties["context_source"] = source_id + self.tools.extend(tools) + + def get_messages( + self, + *, + sources: Sequence[str] | None = None, + exclude_sources: Sequence[str] | None = None, + include_input: bool = False, + include_response: bool = False, + ) -> list[ChatMessage]: + """Get context messages, optionally filtered and including input/response. + + Returns messages in provider execution order (dict insertion order), + with input and response appended if requested. + + Args: + sources: If provided, only include context messages from these sources. + exclude_sources: If provided, exclude context messages from these sources. + include_input: If True, append input_messages after context. + include_response: If True, append response.messages at the end. + + Returns: + Flattened list of messages in conversation order. + """ + result: list[ChatMessage] = [] + for source_id, messages in self.context_messages.items(): + if sources is not None and source_id not in sources: + continue + if exclude_sources is not None and source_id in exclude_sources: + continue + result.extend(messages) + if include_input and self.input_messages: + result.extend(self.input_messages) + if include_response and self.response and self.response.messages: + result.extend(self.response.messages) + return result + + +class BaseContextProvider: + """Base class for context providers (hooks pattern). + + Context providers participate in the context engineering pipeline, + adding context before model invocation and processing responses after. + + Note: + This class uses a temporary name prefixed with ``_`` to avoid collision + with the existing ``ContextProvider`` in ``_memory.py``. It will be + renamed to ``ContextProvider`` in PR2 when the old class is removed. + + Attributes: + source_id: Unique identifier for this provider instance (required). + Used for message/tool attribution so other providers can filter. + """ + + def __init__(self, source_id: str): + """Initialize the provider. + + Args: + source_id: Unique identifier for this provider instance. + """ + self.source_id = source_id + + async def before_run( + self, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Called before model invocation. + + Override to add context (messages, instructions, tools) to the + SessionContext before the model is invoked. + + Args: + agent: The agent running this invocation. + session: The current session. + context: The invocation context - add messages/instructions/tools here. + state: The session's mutable state dict. + """ + + async def after_run( + self, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Called after model invocation. + + Override to process the response (store messages, extract info, etc.). + The context.response will be populated at this point. + + Args: + agent: The agent that ran this invocation. + session: The current session. + context: The invocation context with response populated. + state: The session's mutable state dict. + """ + + +class BaseHistoryProvider(BaseContextProvider): + """Base class for conversation history storage providers. + + A single class configurable for different use cases: + - Primary memory storage (loads + stores messages) + - Audit/logging storage (stores only, doesn't load) + - Evaluation storage (stores only for later analysis) + + Note: + This class uses a temporary name prefixed with ``_`` to avoid collision + with existing types. It will be renamed to ``HistoryProvider`` in PR2. + + Subclasses only need to implement ``get_messages()`` and ``save_messages()``. + The default ``before_run``/``after_run`` handle loading and storing based on + configuration flags. Override them for custom behavior. + + Attributes: + load_messages: Whether to load messages before invocation (default True). + When False, the agent skips calling ``before_run`` entirely. + store_responses: Whether to store response messages (default True). + store_inputs: Whether to store input messages (default True). + store_context_messages: Whether to store context from other providers (default False). + store_context_from: If set, only store context from these source_ids. + """ + + def __init__( + self, + source_id: str, + *, + load_messages: bool = True, + store_responses: bool = True, + store_inputs: bool = True, + store_context_messages: bool = False, + store_context_from: Sequence[str] | None = None, + ): + """Initialize the history provider. + + Args: + source_id: Unique identifier for this provider instance. + load_messages: Whether to load messages before invocation. + store_responses: Whether to store response messages. + store_inputs: Whether to store input messages. + store_context_messages: Whether to store context from other providers. + store_context_from: If set, only store context from these source_ids. + """ + super().__init__(source_id) + self.load_messages = load_messages + self.store_responses = store_responses + self.store_inputs = store_inputs + self.store_context_messages = store_context_messages + self.store_context_from = list(store_context_from) if store_context_from else None + + @abstractmethod + async def get_messages(self, session_id: str | None) -> list[ChatMessage]: + """Retrieve stored messages for this session. + + Args: + session_id: The session ID to retrieve messages for. + + Returns: + List of stored messages. + """ + ... + + @abstractmethod + async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage]) -> None: + """Persist messages for this session. + + Args: + session_id: The session ID to store messages for. + messages: The messages to persist. + """ + ... + + def _get_context_messages_to_store(self, context: SessionContext) -> list[ChatMessage]: + """Get context messages that should be stored based on configuration.""" + if not self.store_context_messages: + return [] + if self.store_context_from is not None: + return context.get_messages(sources=self.store_context_from) + return context.get_messages(exclude_sources=[self.source_id]) + + async def before_run( + self, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Load history into context. Skipped by the agent when load_messages=False.""" + history = await self.get_messages(context.session_id) + context.extend_messages(self.source_id, history) + + async def after_run( + self, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Store messages based on configuration.""" + messages_to_store: list[ChatMessage] = [] + messages_to_store.extend(self._get_context_messages_to_store(context)) + if self.store_inputs: + messages_to_store.extend(context.input_messages) + if self.store_responses and context.response and context.response.messages: + messages_to_store.extend(context.response.messages) + if messages_to_store: + await self.save_messages(context.session_id, messages_to_store) + + +class AgentSession: + """A conversation session with an agent. + + Lightweight state container. Provider instances are owned by the agent, + not the session. The session only holds session IDs and a mutable state dict. + + Attributes: + session_id: Unique identifier for this session. + service_session_id: Service-managed session ID (if using service-side storage). + state: Mutable state dict shared with all providers. + """ + + def __init__( + self, + *, + session_id: str | None = None, + service_session_id: str | None = None, + ): + """Initialize the session. + + Args: + session_id: Optional session ID (generated if not provided). + service_session_id: Optional service-managed session ID. + """ + self._session_id = session_id or str(uuid.uuid4()) + self.service_session_id = service_session_id + self.state: dict[str, Any] = {} + + @property + def session_id(self) -> str: + """The unique identifier for this session.""" + return self._session_id + + def to_dict(self) -> dict[str, Any]: + """Serialize session to a plain dict for storage/transfer.""" + return { + "type": "session", + "session_id": self._session_id, + "service_session_id": self.service_session_id, + "state": self.state, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> AgentSession: + """Restore session from a previously serialized dict. + + Args: + data: Dict from a previous ``to_dict()`` call. + + Returns: + Restored AgentSession instance. + """ + session = cls( + session_id=data["session_id"], + service_session_id=data.get("service_session_id"), + ) + session.state = data.get("state", {}) + return session + + +class InMemoryHistoryProvider(BaseHistoryProvider): + """Built-in history provider that stores messages in session.state. + + Messages are stored in ``state[source_id]["messages"]`` as a list + of serialized ChatMessage dicts, making the session natively serializable. + + This is the default provider auto-added by the agent when no providers + are configured and ``conversation_id`` or ``store=True`` is set. + """ + + async def get_messages(self, session_id: str | None) -> list[ChatMessage]: + """Retrieve messages from session state. Requires state to be set via before_run.""" + return self._current_messages + + async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage]) -> None: + """Persist messages to session state.""" + state = self._current_state + my_state = state.setdefault(self.source_id, {}) + existing = my_state.get("messages", []) + my_state["messages"] = [*existing, *[m.to_dict() for m in messages]] + + async def before_run( + self, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Load history from session state into context.""" + self._current_state = state + my_state = state.get(self.source_id, {}) + raw_messages = my_state.get("messages", []) + self._current_messages = [ChatMessage.from_dict(m) for m in raw_messages] + context.extend_messages(self.source_id, self._current_messages) + + async def after_run( + self, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Store messages to session state.""" + self._current_state = state + await super().after_run(agent, session, context, state) + + +__all__ = [ + "AgentSession", + "BaseContextProvider", + "BaseHistoryProvider", + "InMemoryHistoryProvider", + "SessionContext", +] diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py new file mode 100644 index 0000000000..407b56a6bd --- /dev/null +++ b/python/packages/core/tests/core/test_sessions.py @@ -0,0 +1,376 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +from collections.abc import Sequence + +from agent_framework import ChatMessage +from agent_framework._sessions import ( + AgentSession, + BaseContextProvider, + BaseHistoryProvider, + InMemoryHistoryProvider, + SessionContext, +) + +# --------------------------------------------------------------------------- +# SessionContext tests +# --------------------------------------------------------------------------- + + +class TestSessionContext: + def test_init_defaults(self) -> None: + ctx = SessionContext(input_messages=[]) + assert ctx.session_id is None + assert ctx.service_session_id is None + assert ctx.input_messages == [] + assert ctx.context_messages == {} + assert ctx.instructions == [] + assert ctx.tools == [] + assert ctx.response is None + assert ctx.options == {} + assert ctx.metadata == {} + + def test_extend_messages_creates_key(self) -> None: + ctx = SessionContext(input_messages=[]) + msg = ChatMessage(role="user", contents=["hello"]) + ctx.extend_messages("rag", [msg]) + assert "rag" in ctx.context_messages + assert len(ctx.context_messages["rag"]) == 1 + assert ctx.context_messages["rag"][0].text == "hello" + + def test_extend_messages_appends_to_existing(self) -> None: + ctx = SessionContext(input_messages=[]) + msg1 = ChatMessage(role="user", contents=["first"]) + msg2 = ChatMessage(role="user", contents=["second"]) + ctx.extend_messages("src", [msg1]) + ctx.extend_messages("src", [msg2]) + assert len(ctx.context_messages["src"]) == 2 + + def test_extend_messages_preserves_source_order(self) -> None: + ctx = SessionContext(input_messages=[]) + ctx.extend_messages("a", [ChatMessage(role="user", contents=["a"])]) + ctx.extend_messages("b", [ChatMessage(role="user", contents=["b"])]) + ctx.extend_messages("c", [ChatMessage(role="user", contents=["c"])]) + assert list(ctx.context_messages.keys()) == ["a", "b", "c"] + + def test_extend_instructions_string(self) -> None: + ctx = SessionContext(input_messages=[]) + ctx.extend_instructions("sys", "Be helpful") + assert ctx.instructions == ["Be helpful"] + + def test_extend_instructions_sequence(self) -> None: + ctx = SessionContext(input_messages=[]) + ctx.extend_instructions("sys", ["Be helpful", "Be concise"]) + assert ctx.instructions == ["Be helpful", "Be concise"] + + def test_get_messages_all(self) -> None: + ctx = SessionContext(input_messages=[]) + ctx.extend_messages("a", [ChatMessage(role="user", contents=["a"])]) + ctx.extend_messages("b", [ChatMessage(role="user", contents=["b"])]) + result = ctx.get_messages() + assert len(result) == 2 + assert result[0].text == "a" + assert result[1].text == "b" + + def test_get_messages_filter_sources(self) -> None: + ctx = SessionContext(input_messages=[]) + ctx.extend_messages("a", [ChatMessage(role="user", contents=["a"])]) + ctx.extend_messages("b", [ChatMessage(role="user", contents=["b"])]) + result = ctx.get_messages(sources=["a"]) + assert len(result) == 1 + assert result[0].text == "a" + + def test_get_messages_exclude_sources(self) -> None: + ctx = SessionContext(input_messages=[]) + ctx.extend_messages("a", [ChatMessage(role="user", contents=["a"])]) + ctx.extend_messages("b", [ChatMessage(role="user", contents=["b"])]) + result = ctx.get_messages(exclude_sources=["a"]) + assert len(result) == 1 + assert result[0].text == "b" + + def test_get_messages_include_input(self) -> None: + input_msg = ChatMessage(role="user", contents=["input"]) + ctx = SessionContext(input_messages=[input_msg]) + ctx.extend_messages("a", [ChatMessage(role="user", contents=["context"])]) + result = ctx.get_messages(include_input=True) + assert len(result) == 2 + assert result[1].text == "input" + + def test_get_messages_include_response(self) -> None: + from agent_framework import AgentResponse + + ctx = SessionContext(input_messages=[]) + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["reply"])]) + result = ctx.get_messages(include_response=True) + assert len(result) == 1 + assert result[0].text == "reply" + + def test_response_readonly(self) -> None: + ctx = SessionContext(input_messages=[]) + assert ctx.response is None + # Can set via _response internally + from agent_framework import AgentResponse + + resp = AgentResponse(messages=[]) + ctx._response = resp + assert ctx.response is resp + + +# --------------------------------------------------------------------------- +# BaseContextProvider tests +# --------------------------------------------------------------------------- + + +class TestContextProviderBase: + def test_source_id_required(self) -> None: + provider = BaseContextProvider(source_id="test") + assert provider.source_id == "test" + + async def test_before_run_is_noop(self) -> None: + provider = BaseContextProvider(source_id="test") + session = AgentSession() + ctx = SessionContext(input_messages=[]) + # Should not raise + await provider.before_run(None, session, ctx, {}) # type: ignore[arg-type] + + async def test_after_run_is_noop(self) -> None: + provider = BaseContextProvider(source_id="test") + session = AgentSession() + ctx = SessionContext(input_messages=[]) + await provider.after_run(None, session, ctx, {}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# BaseHistoryProvider tests +# --------------------------------------------------------------------------- + + +class ConcreteHistoryProvider(BaseHistoryProvider): + """Concrete test implementation.""" + + def __init__(self, source_id: str, stored_messages: list[ChatMessage] | None = None, **kwargs) -> None: + super().__init__(source_id, **kwargs) + self.stored: list[ChatMessage] = [] + self._stored_messages = stored_messages or [] + + async def get_messages(self, session_id: str | None) -> list[ChatMessage]: + return list(self._stored_messages) + + async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage]) -> None: + self.stored.extend(messages) + + +class TestHistoryProviderBase: + def test_default_flags(self) -> None: + provider = ConcreteHistoryProvider("mem") + assert provider.load_messages is True + assert provider.store_responses is True + assert provider.store_inputs is True + assert provider.store_context_messages is False + assert provider.store_context_from is None + + def test_custom_flags(self) -> None: + provider = ConcreteHistoryProvider( + "audit", + load_messages=False, + store_inputs=False, + store_context_messages=True, + store_context_from=["rag"], + ) + assert provider.load_messages is False + assert provider.store_inputs is False + assert provider.store_context_messages is True + assert provider.store_context_from == ["rag"] + + async def test_before_run_loads_messages(self) -> None: + msgs = [ChatMessage(role="user", contents=["history"])] + provider = ConcreteHistoryProvider("mem", stored_messages=msgs) + session = AgentSession() + ctx = SessionContext(session_id="s1", input_messages=[]) + await provider.before_run(None, session, ctx, {}) # type: ignore[arg-type] + assert len(ctx.context_messages["mem"]) == 1 + assert ctx.context_messages["mem"][0].text == "history" + + async def test_after_run_stores_inputs_and_responses(self) -> None: + from agent_framework import AgentResponse + + provider = ConcreteHistoryProvider("mem") + session = AgentSession() + input_msg = ChatMessage(role="user", contents=["hello"]) + resp_msg = ChatMessage(role="assistant", contents=["hi"]) + ctx = SessionContext(session_id="s1", input_messages=[input_msg]) + ctx._response = AgentResponse(messages=[resp_msg]) + await provider.after_run(None, session, ctx, {}) # type: ignore[arg-type] + assert len(provider.stored) == 2 + assert provider.stored[0].text == "hello" + assert provider.stored[1].text == "hi" + + async def test_after_run_skips_inputs_when_disabled(self) -> None: + from agent_framework import AgentResponse + + provider = ConcreteHistoryProvider("mem", store_inputs=False) + ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) + await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + assert len(provider.stored) == 1 + assert provider.stored[0].text == "hi" + + async def test_after_run_skips_responses_when_disabled(self) -> None: + from agent_framework import AgentResponse + + provider = ConcreteHistoryProvider("mem", store_responses=False) + ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) + await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + assert len(provider.stored) == 1 + assert provider.stored[0].text == "hello" + + async def test_after_run_stores_context_messages(self) -> None: + from agent_framework import AgentResponse + + provider = ConcreteHistoryProvider("audit", load_messages=False, store_context_messages=True) + ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) + ctx.extend_messages("rag", [ChatMessage(role="system", contents=["context"])]) + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) + await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + # Should store: context from rag + input + response + texts = [m.text for m in provider.stored] + assert "context" in texts + assert "hello" in texts + assert "hi" in texts + + async def test_after_run_stores_context_from_specific_sources(self) -> None: + from agent_framework import AgentResponse + + provider = ConcreteHistoryProvider( + "audit", load_messages=False, store_context_messages=True, store_context_from=["rag"] + ) + ctx = SessionContext(session_id="s1", input_messages=[]) + ctx.extend_messages("rag", [ChatMessage(role="system", contents=["rag-context"])]) + ctx.extend_messages("other", [ChatMessage(role="system", contents=["other-context"])]) + ctx._response = AgentResponse(messages=[]) + await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + texts = [m.text for m in provider.stored] + assert "rag-context" in texts + assert "other-context" not in texts + + +# --------------------------------------------------------------------------- +# AgentSession tests +# --------------------------------------------------------------------------- + + +class TestAgentSession: + def test_auto_generates_session_id(self) -> None: + session = AgentSession() + assert session.session_id is not None + assert len(session.session_id) > 0 + + def test_custom_session_id(self) -> None: + session = AgentSession(session_id="custom-123") + assert session.session_id == "custom-123" + + def test_state_starts_empty(self) -> None: + session = AgentSession() + assert session.state == {} + + def test_service_session_id(self) -> None: + session = AgentSession(service_session_id="svc-456") + assert session.service_session_id == "svc-456" + + def test_to_dict(self) -> None: + session = AgentSession(session_id="s1", service_session_id="svc1") + session.state = {"key": "value"} + d = session.to_dict() + assert d["type"] == "session" + assert d["session_id"] == "s1" + assert d["service_session_id"] == "svc1" + assert d["state"] == {"key": "value"} + + def test_from_dict(self) -> None: + data = { + "type": "session", + "session_id": "s1", + "service_session_id": "svc1", + "state": {"key": "value"}, + } + session = AgentSession.from_dict(data) + assert session.session_id == "s1" + assert session.service_session_id == "svc1" + assert session.state == {"key": "value"} + + def test_roundtrip(self) -> None: + session = AgentSession(session_id="rt-1") + session.state = {"messages": ["a", "b"], "count": 42} + json_str = json.dumps(session.to_dict()) + restored = AgentSession.from_dict(json.loads(json_str)) + assert restored.session_id == "rt-1" + assert restored.state == {"messages": ["a", "b"], "count": 42} + + def test_from_dict_missing_state(self) -> None: + data = {"session_id": "s1"} + session = AgentSession.from_dict(data) + assert session.state == {} + + +# --------------------------------------------------------------------------- +# InMemoryHistoryProvider tests +# --------------------------------------------------------------------------- + + +class TestInMemoryHistoryProvider: + async def test_empty_state_returns_no_messages(self) -> None: + provider = InMemoryHistoryProvider("memory") + session = AgentSession() + ctx = SessionContext(session_id="s1", input_messages=[]) + await provider.before_run(None, session, ctx, session.state) # type: ignore[arg-type] + assert ctx.context_messages.get("memory", []) == [] + + async def test_stores_and_loads_messages(self) -> None: + from agent_framework import AgentResponse + + provider = InMemoryHistoryProvider("memory") + session = AgentSession() + + # First run: send input, get response + input_msg = ChatMessage(role="user", contents=["hello"]) + resp_msg = ChatMessage(role="assistant", contents=["hi there"]) + ctx1 = SessionContext(session_id="s1", input_messages=[input_msg]) + await provider.before_run(None, session, ctx1, session.state) # type: ignore[arg-type] + ctx1._response = AgentResponse(messages=[resp_msg]) + await provider.after_run(None, session, ctx1, session.state) # type: ignore[arg-type] + + # Second run: should load previous messages + ctx2 = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["again"])]) + await provider.before_run(None, session, ctx2, session.state) # type: ignore[arg-type] + loaded = ctx2.context_messages.get("memory", []) + assert len(loaded) == 2 + assert loaded[0].text == "hello" + assert loaded[1].text == "hi there" + + async def test_state_is_serializable(self) -> None: + from agent_framework import AgentResponse + + provider = InMemoryHistoryProvider("memory") + session = AgentSession() + + input_msg = ChatMessage(role="user", contents=["test"]) + ctx = SessionContext(session_id="s1", input_messages=[input_msg]) + await provider.before_run(None, session, ctx, session.state) # type: ignore[arg-type] + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["reply"])]) + await provider.after_run(None, session, ctx, session.state) # type: ignore[arg-type] + + # State should be JSON-serializable + json_str = json.dumps(session.state) + assert json_str # no error + + # And round-trip through session serialization + restored = AgentSession.from_dict(json.loads(json.dumps(session.to_dict()))) + assert restored.state == session.state + + async def test_source_id_attribution(self) -> None: + provider = InMemoryHistoryProvider("custom-source") + assert provider.source_id == "custom-source" + ctx = SessionContext(session_id="s1", input_messages=[]) + ctx.extend_messages("custom-source", [ChatMessage(role="user", contents=["test"])]) + assert "custom-source" in ctx.context_messages From b78cfe9fa5a9aa5e713cb0f7ebe2fd6424244d42 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 9 Feb 2026 15:38:49 +0100 Subject: [PATCH 02/12] feat: keyword-only params, stateless InMemoryHistoryProvider, deep serialization - Make before_run/after_run parameters keyword-only - InMemoryHistoryProvider stores ChatMessage objects directly (no per-cycle serialization) - Deep serialization via to_dict/from_dict only at session boundary - State type registry for automatic deserialization of registered types - Updated tests for new serialization approach --- .../core/agent_framework/_sessions.py | 134 ++++++++++++------ .../packages/core/tests/core/test_sessions.py | 48 ++++--- 2 files changed, 121 insertions(+), 61 deletions(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 7c3bb9e691..972880878c 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -23,6 +23,54 @@ - InMemoryHistoryProvider: Built-in in-memory history provider """ +# Registry of known types for state deserialization +_STATE_TYPE_REGISTRY: dict[str, type] = {} + + +def _register_state_type(cls: type) -> None: + """Register a type for automatic deserialization in session state.""" + type_id: str = getattr(cls, "_get_type_identifier", lambda: cls.__name__.lower())() + _STATE_TYPE_REGISTRY[type_id] = cls + + +def _serialize_value(value: Any) -> Any: + """Serialize a single value, handling objects with to_dict().""" + if hasattr(value, "to_dict") and callable(value.to_dict): + return value.to_dict() # pyright: ignore[reportUnknownMemberType] + if isinstance(value, list): + return [_serialize_value(item) for item in value] # pyright: ignore[reportUnknownVariableType] + if isinstance(value, dict): + return {str(k): _serialize_value(v) for k, v in value.items()} # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + return value + + +def _deserialize_value(value: Any) -> Any: + """Deserialize a single value, restoring registered types.""" + if isinstance(value, dict) and "type" in value: + type_id = str(value["type"]) # pyright: ignore[reportUnknownArgumentType] + cls = _STATE_TYPE_REGISTRY.get(type_id) + if cls is not None and hasattr(cls, "from_dict"): + return cls.from_dict(value) # type: ignore[union-attr] + if isinstance(value, list): + return [_deserialize_value(item) for item in value] # pyright: ignore[reportUnknownVariableType] + if isinstance(value, dict): + return {str(k): _deserialize_value(v) for k, v in value.items()} # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + return value + + +def _serialize_state(state: dict[str, Any]) -> dict[str, Any]: + """Deep-serialize a state dict, converting SerializationProtocol objects to dicts.""" + return {k: _serialize_value(v) for k, v in state.items()} + + +def _deserialize_state(state: dict[str, Any]) -> dict[str, Any]: + """Deep-deserialize a state dict, restoring SerializationProtocol objects.""" + return {k: _deserialize_value(v) for k, v in state.items()} + + +# Register known types +_register_state_type(ChatMessage) + class SessionContext: """Per-invocation state passed through the context provider pipeline. @@ -183,6 +231,7 @@ def __init__(self, source_id: str): async def before_run( self, + *, agent: SupportsAgentRun, session: AgentSession, context: SessionContext, @@ -202,6 +251,7 @@ async def before_run( async def after_run( self, + *, agent: SupportsAgentRun, session: AgentSession, context: SessionContext, @@ -273,11 +323,12 @@ def __init__( self.store_context_from = list(store_context_from) if store_context_from else None @abstractmethod - async def get_messages(self, session_id: str | None) -> list[ChatMessage]: + async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[ChatMessage]: """Retrieve stored messages for this session. Args: session_id: The session ID to retrieve messages for. + **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). Returns: List of stored messages. @@ -285,12 +336,13 @@ async def get_messages(self, session_id: str | None) -> list[ChatMessage]: ... @abstractmethod - async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage]) -> None: + async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage], **kwargs: Any) -> None: """Persist messages for this session. Args: session_id: The session ID to store messages for. messages: The messages to persist. + **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). """ ... @@ -304,17 +356,19 @@ def _get_context_messages_to_store(self, context: SessionContext) -> list[ChatMe async def before_run( self, + *, agent: SupportsAgentRun, session: AgentSession, context: SessionContext, state: dict[str, Any], ) -> None: """Load history into context. Skipped by the agent when load_messages=False.""" - history = await self.get_messages(context.session_id) + history = await self.get_messages(context.session_id, state=state) context.extend_messages(self.source_id, history) async def after_run( self, + *, agent: SupportsAgentRun, session: AgentSession, context: SessionContext, @@ -328,7 +382,7 @@ async def after_run( if self.store_responses and context.response and context.response.messages: messages_to_store.extend(context.response.messages) if messages_to_store: - await self.save_messages(context.session_id, messages_to_store) + await self.save_messages(context.session_id, messages_to_store, state=state) class AgentSession: @@ -365,18 +419,26 @@ def session_id(self) -> str: return self._session_id def to_dict(self) -> dict[str, Any]: - """Serialize session to a plain dict for storage/transfer.""" + """Serialize session to a plain dict for storage/transfer. + + Values in ``state`` that implement ``SerializationProtocol`` (i.e. have + ``to_dict``/``from_dict``) are serialized automatically. Built-in types + (str, int, float, bool, None, list, dict) are kept as-is. + """ return { "type": "session", "session_id": self._session_id, "service_session_id": self.service_session_id, - "state": self.state, + "state": _serialize_state(self.state), } @classmethod def from_dict(cls, data: dict[str, Any]) -> AgentSession: """Restore session from a previously serialized dict. + Values in ``state`` that were serialized via ``SerializationProtocol`` + (containing a ``type`` key) are restored to their original types. + Args: data: Dict from a previous ``to_dict()`` call. @@ -387,55 +449,47 @@ def from_dict(cls, data: dict[str, Any]) -> AgentSession: session_id=data["session_id"], service_session_id=data.get("service_session_id"), ) - session.state = data.get("state", {}) + session.state = _deserialize_state(data.get("state", {})) return session class InMemoryHistoryProvider(BaseHistoryProvider): """Built-in history provider that stores messages in session.state. - Messages are stored in ``state[source_id]["messages"]`` as a list - of serialized ChatMessage dicts, making the session natively serializable. + Messages are stored in ``state[source_id]["messages"]`` as a list of + ``ChatMessage`` objects. Serialization to/from dicts is handled by + ``AgentSession.to_dict()``/``from_dict()`` using ``SerializationProtocol``. + + This provider holds no instance state — all data lives in the session's + state dict, passed as a named ``state`` parameter to ``get_messages``/``save_messages``. This is the default provider auto-added by the agent when no providers are configured and ``conversation_id`` or ``store=True`` is set. """ - async def get_messages(self, session_id: str | None) -> list[ChatMessage]: - """Retrieve messages from session state. Requires state to be set via before_run.""" - return self._current_messages - - async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage]) -> None: - """Persist messages to session state.""" - state = self._current_state - my_state = state.setdefault(self.source_id, {}) - existing = my_state.get("messages", []) - my_state["messages"] = [*existing, *[m.to_dict() for m in messages]] - - async def before_run( - self, - agent: SupportsAgentRun, - session: AgentSession, - context: SessionContext, - state: dict[str, Any], - ) -> None: - """Load history from session state into context.""" - self._current_state = state + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[ChatMessage]: + """Retrieve messages from session state.""" + if state is None: + return [] my_state = state.get(self.source_id, {}) - raw_messages = my_state.get("messages", []) - self._current_messages = [ChatMessage.from_dict(m) for m in raw_messages] - context.extend_messages(self.source_id, self._current_messages) + return list(my_state.get("messages", [])) - async def after_run( + async def save_messages( self, - agent: SupportsAgentRun, - session: AgentSession, - context: SessionContext, - state: dict[str, Any], + session_id: str | None, + messages: Sequence[ChatMessage], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, ) -> None: - """Store messages to session state.""" - self._current_state = state - await super().after_run(agent, session, context, state) + """Persist messages to session state.""" + if state is None: + return + my_state = state.setdefault(self.source_id, {}) + existing = my_state.get("messages", []) + my_state["messages"] = [*existing, *messages] __all__ = [ diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 407b56a6bd..4673032f36 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -131,13 +131,13 @@ async def test_before_run_is_noop(self) -> None: session = AgentSession() ctx = SessionContext(input_messages=[]) # Should not raise - await provider.before_run(None, session, ctx, {}) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] async def test_after_run_is_noop(self) -> None: provider = BaseContextProvider(source_id="test") session = AgentSession() ctx = SessionContext(input_messages=[]) - await provider.after_run(None, session, ctx, {}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] # --------------------------------------------------------------------------- @@ -153,10 +153,10 @@ def __init__(self, source_id: str, stored_messages: list[ChatMessage] | None = N self.stored: list[ChatMessage] = [] self._stored_messages = stored_messages or [] - async def get_messages(self, session_id: str | None) -> list[ChatMessage]: + async def get_messages(self, session_id: str | None, **kwargs) -> list[ChatMessage]: return list(self._stored_messages) - async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage]) -> None: + async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage], **kwargs) -> None: self.stored.extend(messages) @@ -187,7 +187,7 @@ async def test_before_run_loads_messages(self) -> None: provider = ConcreteHistoryProvider("mem", stored_messages=msgs) session = AgentSession() ctx = SessionContext(session_id="s1", input_messages=[]) - await provider.before_run(None, session, ctx, {}) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] assert len(ctx.context_messages["mem"]) == 1 assert ctx.context_messages["mem"][0].text == "history" @@ -200,7 +200,7 @@ async def test_after_run_stores_inputs_and_responses(self) -> None: resp_msg = ChatMessage(role="assistant", contents=["hi"]) ctx = SessionContext(session_id="s1", input_messages=[input_msg]) ctx._response = AgentResponse(messages=[resp_msg]) - await provider.after_run(None, session, ctx, {}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] assert len(provider.stored) == 2 assert provider.stored[0].text == "hello" assert provider.stored[1].text == "hi" @@ -211,7 +211,7 @@ async def test_after_run_skips_inputs_when_disabled(self) -> None: provider = ConcreteHistoryProvider("mem", store_inputs=False) ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) - await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] assert len(provider.stored) == 1 assert provider.stored[0].text == "hi" @@ -221,7 +221,7 @@ async def test_after_run_skips_responses_when_disabled(self) -> None: provider = ConcreteHistoryProvider("mem", store_responses=False) ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) - await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] assert len(provider.stored) == 1 assert provider.stored[0].text == "hello" @@ -232,7 +232,7 @@ async def test_after_run_stores_context_messages(self) -> None: ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) ctx.extend_messages("rag", [ChatMessage(role="system", contents=["context"])]) ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) - await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] # Should store: context from rag + input + response texts = [m.text for m in provider.stored] assert "context" in texts @@ -249,7 +249,7 @@ async def test_after_run_stores_context_from_specific_sources(self) -> None: ctx.extend_messages("rag", [ChatMessage(role="system", contents=["rag-context"])]) ctx.extend_messages("other", [ChatMessage(role="system", contents=["other-context"])]) ctx._response = AgentResponse(messages=[]) - await provider.after_run(None, AgentSession(), ctx, {}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] texts = [m.text for m in provider.stored] assert "rag-context" in texts assert "other-context" not in texts @@ -323,7 +323,7 @@ async def test_empty_state_returns_no_messages(self) -> None: provider = InMemoryHistoryProvider("memory") session = AgentSession() ctx = SessionContext(session_id="s1", input_messages=[]) - await provider.before_run(None, session, ctx, session.state) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] assert ctx.context_messages.get("memory", []) == [] async def test_stores_and_loads_messages(self) -> None: @@ -336,13 +336,13 @@ async def test_stores_and_loads_messages(self) -> None: input_msg = ChatMessage(role="user", contents=["hello"]) resp_msg = ChatMessage(role="assistant", contents=["hi there"]) ctx1 = SessionContext(session_id="s1", input_messages=[input_msg]) - await provider.before_run(None, session, ctx1, session.state) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx1, state=session.state) # type: ignore[arg-type] ctx1._response = AgentResponse(messages=[resp_msg]) - await provider.after_run(None, session, ctx1, session.state) # type: ignore[arg-type] + await provider.after_run(agent=None, session=session, context=ctx1, state=session.state) # type: ignore[arg-type] # Second run: should load previous messages ctx2 = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["again"])]) - await provider.before_run(None, session, ctx2, session.state) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx2, state=session.state) # type: ignore[arg-type] loaded = ctx2.context_messages.get("memory", []) assert len(loaded) == 2 assert loaded[0].text == "hello" @@ -356,17 +356,23 @@ async def test_state_is_serializable(self) -> None: input_msg = ChatMessage(role="user", contents=["test"]) ctx = SessionContext(session_id="s1", input_messages=[input_msg]) - await provider.before_run(None, session, ctx, session.state) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["reply"])]) - await provider.after_run(None, session, ctx, session.state) # type: ignore[arg-type] + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - # State should be JSON-serializable - json_str = json.dumps(session.state) + # State contains ChatMessage objects (not dicts) + assert isinstance(session.state["memory"]["messages"][0], ChatMessage) + + # to_dict() serializes them via SerializationProtocol + session_dict = session.to_dict() + json_str = json.dumps(session_dict) assert json_str # no error - # And round-trip through session serialization - restored = AgentSession.from_dict(json.loads(json.dumps(session.to_dict()))) - assert restored.state == session.state + # Round-trip through session serialization restores ChatMessage objects + restored = AgentSession.from_dict(json.loads(json_str)) + assert isinstance(restored.state["memory"]["messages"][0], ChatMessage) + assert restored.state["memory"]["messages"][0].text == "test" + assert restored.state["memory"]["messages"][1].text == "reply" async def test_source_id_attribution(self) -> None: provider = InMemoryHistoryProvider("custom-source") From 1a99e7594fb10137ce0a6272a575e4006e293234 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 9 Feb 2026 15:58:53 +0100 Subject: [PATCH 03/12] feat: add new-pattern provider implementations for external packages - _RedisContextProvider(BaseContextProvider) - Redis search/vector context - _RedisHistoryProvider(BaseHistoryProvider) - Redis-backed message storage - _Mem0ContextProvider(BaseContextProvider) - Mem0 semantic memory - _AzureAISearchContextProvider(BaseContextProvider) - Azure AI Search (semantic + agentic) All use temporary _ prefix names for side-by-side coexistence with existing providers. Will be renamed in PR2 when old ContextProvider/ChatMessageStore are removed. --- .../__init__.py | 2 + .../_context_provider.py | 625 ++++++++++++++++++ .../mem0/agent_framework_mem0/__init__.py | 2 + .../agent_framework_mem0/_context_provider.py | 193 ++++++ .../redis/agent_framework_redis/__init__.py | 4 + .../_context_provider.py | 432 ++++++++++++ .../_history_provider.py | 187 ++++++ 7 files changed, 1445 insertions(+) create mode 100644 python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py create mode 100644 python/packages/mem0/agent_framework_mem0/_context_provider.py create mode 100644 python/packages/redis/agent_framework_redis/_context_provider.py create mode 100644 python/packages/redis/agent_framework_redis/_history_provider.py diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py index fedfb05bcd..4509c46d3e 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py @@ -2,6 +2,7 @@ import importlib.metadata +from ._context_provider import _AzureAISearchContextProvider from ._search_provider import AzureAISearchContextProvider, AzureAISearchSettings try: @@ -12,5 +13,6 @@ __all__ = [ "AzureAISearchContextProvider", "AzureAISearchSettings", + "_AzureAISearchContextProvider", "__version__", ] diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py new file mode 100644 index 0000000000..de0a1b2adc --- /dev/null +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -0,0 +1,625 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""New-pattern Azure AI Search context provider using BaseContextProvider. + +This module provides ``_AzureAISearchContextProvider``, a side-by-side implementation of +:class:`AzureAISearchContextProvider` built on the new :class:`BaseContextProvider` hooks +pattern. It will replace the existing class in PR2. +""" + +from __future__ import annotations + +import sys +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +from agent_framework import AGENT_FRAMEWORK_USER_AGENT, ChatMessage +from agent_framework._logging import get_logger +from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext +from agent_framework.exceptions import ServiceInitializationError +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.aio import SearchClient +from azure.search.documents.indexes.aio import SearchIndexClient +from azure.search.documents.indexes.models import ( + AzureOpenAIVectorizerParameters, + KnowledgeBase, + KnowledgeBaseAzureOpenAIModel, + KnowledgeRetrievalLowReasoningEffort, + KnowledgeRetrievalMediumReasoningEffort, + KnowledgeRetrievalMinimalReasoningEffort, + KnowledgeRetrievalOutputMode, + KnowledgeRetrievalReasoningEffort, + KnowledgeSourceReference, + SearchIndexKnowledgeSource, + SearchIndexKnowledgeSourceParameters, +) +from azure.search.documents.models import ( + QueryCaptionType, + QueryType, + VectorizableTextQuery, + VectorizedQuery, +) +from pydantic import ValidationError + +from ._search_provider import AzureAISearchSettings + +if TYPE_CHECKING: + from agent_framework._agents import SupportsAgentRun + from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageTextContent, + KnowledgeBaseRetrievalRequest, + KnowledgeRetrievalIntent, + KnowledgeRetrievalSemanticIntent, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalLowReasoningEffort as KBRetrievalLowReasoningEffort, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalMediumReasoningEffort as KBRetrievalMediumReasoningEffort, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalMinimalReasoningEffort as KBRetrievalMinimalReasoningEffort, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalOutputMode as KBRetrievalOutputMode, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalReasoningEffort as KBRetrievalReasoningEffort, + ) + +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + +# Runtime imports for agentic mode (optional dependency) +try: + from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient + from azure.search.documents.knowledgebases.models import ( + KnowledgeBaseMessage, + KnowledgeBaseMessageTextContent, + KnowledgeBaseRetrievalRequest, + KnowledgeRetrievalIntent, + KnowledgeRetrievalSemanticIntent, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalLowReasoningEffort as KBRetrievalLowReasoningEffort, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalMediumReasoningEffort as KBRetrievalMediumReasoningEffort, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalMinimalReasoningEffort as KBRetrievalMinimalReasoningEffort, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalOutputMode as KBRetrievalOutputMode, + ) + from azure.search.documents.knowledgebases.models import ( + KnowledgeRetrievalReasoningEffort as KBRetrievalReasoningEffort, + ) + + _agentic_retrieval_available = True +except ImportError: + _agentic_retrieval_available = False + +logger = get_logger(__name__) + +_DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10 + + +class _AzureAISearchContextProvider(BaseContextProvider): + """Azure AI Search context provider using the new BaseContextProvider hooks pattern. + + Retrieves relevant context from Azure AI Search using semantic or agentic search + modes. This is the new-pattern equivalent of :class:`AzureAISearchContextProvider`. + + Note: + This class uses a temporary ``_`` prefix to coexist with the existing + :class:`AzureAISearchContextProvider`. It will replace the existing class + in PR2. + """ + + _DEFAULT_SEARCH_CONTEXT_PROMPT: ClassVar[str] = "Use the following context to answer the question:" + + def __init__( + self, + source_id: str, + endpoint: str | None = None, + index_name: str | None = None, + api_key: str | AzureKeyCredential | None = None, + credential: AsyncTokenCredential | None = None, + *, + mode: Literal["semantic", "agentic"] = "semantic", + top_k: int = 5, + semantic_configuration_name: str | None = None, + vector_field_name: str | None = None, + embedding_function: Callable[[str], Awaitable[list[float]]] | None = None, + context_prompt: str | None = None, + azure_openai_resource_url: str | None = None, + model_deployment_name: str | None = None, + model_name: str | None = None, + knowledge_base_name: str | None = None, + retrieval_instructions: str | None = None, + azure_openai_api_key: str | None = None, + knowledge_base_output_mode: Literal["extractive_data", "answer_synthesis"] = "extractive_data", + retrieval_reasoning_effort: Literal["minimal", "medium", "low"] = "minimal", + agentic_message_history_count: int = _DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize Azure AI Search Context Provider. + + Args: + source_id: Unique identifier for this provider instance. + endpoint: Azure AI Search endpoint URL. + index_name: Name of the search index to query. + api_key: API key for authentication. + credential: AsyncTokenCredential for managed identity authentication. + mode: Search mode - "semantic" or "agentic". Default: "semantic". + top_k: Maximum number of documents to retrieve. Default: 5. + semantic_configuration_name: Name of semantic configuration in the index. + vector_field_name: Name of the vector field in the index. + embedding_function: Async function to generate embeddings. + context_prompt: Custom prompt to prepend to retrieved context. + azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base. + model_deployment_name: Model deployment name in Azure OpenAI. + model_name: The underlying model name. + knowledge_base_name: Name of an existing Knowledge Base to use. + retrieval_instructions: Custom instructions for Knowledge Base retrieval. + azure_openai_api_key: Azure OpenAI API key. + knowledge_base_output_mode: Output mode for Knowledge Base retrieval. + retrieval_reasoning_effort: Reasoning effort for Knowledge Base query planning. + agentic_message_history_count: Number of recent messages for agentic mode. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + """ + super().__init__(source_id) + + # Load settings from environment/file + try: + settings = AzureAISearchSettings( + endpoint=endpoint, + index_name=index_name, + knowledge_base_name=knowledge_base_name, + api_key=api_key if isinstance(api_key, str) else None, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as ex: + raise ServiceInitializationError("Failed to create Azure AI Search settings.", ex) from ex + + if not settings.endpoint: + raise ServiceInitializationError( + "Azure AI Search endpoint is required. Set via 'endpoint' parameter " + "or 'AZURE_SEARCH_ENDPOINT' environment variable." + ) + + if mode == "semantic": + if not settings.index_name: + raise ServiceInitializationError( + "Azure AI Search index name is required for semantic mode. " + "Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable." + ) + elif mode == "agentic": + if settings.index_name and settings.knowledge_base_name: + raise ServiceInitializationError( + "For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both." + ) + if not settings.index_name and not settings.knowledge_base_name: + raise ServiceInitializationError( + "For agentic mode, provide either 'index_name' or 'knowledge_base_name'." + ) + if settings.index_name and not model_deployment_name: + raise ServiceInitializationError( + "model_deployment_name is required for agentic mode when creating Knowledge Base from index." + ) + + resolved_credential: AzureKeyCredential | AsyncTokenCredential + if credential: + resolved_credential = credential + elif isinstance(api_key, AzureKeyCredential): + resolved_credential = api_key + elif settings.api_key: + resolved_credential = AzureKeyCredential(settings.api_key.get_secret_value()) + else: + raise ServiceInitializationError( + "Azure credential is required. Provide 'api_key' or 'credential' parameter " + "or set 'AZURE_SEARCH_API_KEY' environment variable." + ) + + self.endpoint = settings.endpoint + self.index_name = settings.index_name + self.credential = resolved_credential + self.mode = mode + self.top_k = top_k + self.semantic_configuration_name = semantic_configuration_name + self.vector_field_name = vector_field_name + self.embedding_function = embedding_function + self.context_prompt = context_prompt or self._DEFAULT_SEARCH_CONTEXT_PROMPT + + self.azure_openai_resource_url = azure_openai_resource_url + self.azure_openai_deployment_name = model_deployment_name + self.model_name = model_name or model_deployment_name + self.knowledge_base_name = settings.knowledge_base_name + self.retrieval_instructions = retrieval_instructions + self.azure_openai_api_key = azure_openai_api_key + self.knowledge_base_output_mode = knowledge_base_output_mode + self.retrieval_reasoning_effort = retrieval_reasoning_effort + self.agentic_message_history_count = agentic_message_history_count + + self._use_existing_knowledge_base = False + if mode == "agentic": + if settings.knowledge_base_name: + self._use_existing_knowledge_base = True + else: + self.knowledge_base_name = f"{settings.index_name}-kb" + + self._auto_discovered_vector_field = False + self._use_vectorizable_query = False + + if vector_field_name and not embedding_function: + raise ValueError("embedding_function is required when vector_field_name is specified") + + if mode == "agentic": + if not _agentic_retrieval_available: + raise ImportError( + "Agentic retrieval requires azure-search-documents >= 11.7.0b1 with Knowledge Base support." + ) + if not self._use_existing_knowledge_base and not self.azure_openai_resource_url: + raise ValueError( + "azure_openai_resource_url is required for agentic mode when creating Knowledge Base from index." + ) + + self._search_client: SearchClient | None = None + if self.index_name: + self._search_client = SearchClient( + endpoint=self.endpoint, + index_name=self.index_name, + credential=self.credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + + self._index_client: SearchIndexClient | None = None + self._retrieval_client: KnowledgeBaseRetrievalClient | None = None + if mode == "agentic": + self._index_client = SearchIndexClient( + endpoint=self.endpoint, + credential=self.credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + + self._knowledge_base_initialized = False + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit - cleanup clients.""" + if self._retrieval_client is not None: + await self._retrieval_client.close() + self._retrieval_client = None + + # -- Hooks pattern --------------------------------------------------------- + + async def before_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Retrieve relevant context from Azure AI Search and add to session context.""" + messages_list = list(context.input_messages) + + def get_role_value(role: str | Any) -> str: + return role.value if hasattr(role, "value") else str(role) + + filtered_messages = [ + msg + for msg in messages_list + if msg and msg.text and msg.text.strip() and get_role_value(msg.role) in ["user", "assistant"] + ] + if not filtered_messages: + return + + if self.mode == "semantic": + query = "\n".join(msg.text for msg in filtered_messages) + search_result_parts = await self._semantic_search(query) + else: + recent_messages = filtered_messages[-self.agentic_message_history_count :] + search_result_parts = await self._agentic_search(recent_messages) + + if not search_result_parts: + return + + context_messages = [ChatMessage(role="user", text=self.context_prompt)] + context_messages.extend([ChatMessage(role="user", text=part) for part in search_result_parts]) + context.extend_messages(self.source_id, context_messages) + + # -- Internal methods (ported from AzureAISearchContextProvider) ----------- + + def _find_vector_fields(self, index: Any) -> list[str]: + """Find all fields that can store vectors.""" + return [ + field.name + for field in index.fields + if field.vector_search_dimensions is not None and field.vector_search_dimensions > 0 + ] + + def _find_vectorizable_fields(self, index: Any, vector_fields: list[str]) -> list[str]: + """Find vector fields that have auto-vectorization configured.""" + vectorizable_fields: list[str] = [] + if not index.vector_search or not index.vector_search.profiles: + return vectorizable_fields + for field in index.fields: + if field.name in vector_fields and field.vector_search_profile_name: + profile = next( + (p for p in index.vector_search.profiles if p.name == field.vector_search_profile_name), None + ) + if profile and hasattr(profile, "vectorizer_name") and profile.vectorizer_name: + vectorizable_fields.append(field.name) + return vectorizable_fields + + async def _auto_discover_vector_field(self) -> None: + """Auto-discover vector field from index schema.""" + if self._auto_discovered_vector_field or self.vector_field_name: + return + + try: + if not self._index_client: + self._index_client = SearchIndexClient( + endpoint=self.endpoint, + credential=self.credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + if not self.index_name: + logger.warning("Cannot auto-discover vector field: index_name is not set.") + self._auto_discovered_vector_field = True + return + + index = await self._index_client.get_index(self.index_name) + vector_fields = self._find_vector_fields(index) + if not vector_fields: + logger.info(f"No vector fields found in index '{self.index_name}'. Using keyword-only search.") + self._auto_discovered_vector_field = True + return + + vectorizable_fields = self._find_vectorizable_fields(index, vector_fields) + if vectorizable_fields: + if len(vectorizable_fields) == 1: + self.vector_field_name = vectorizable_fields[0] + self._auto_discovered_vector_field = True + self._use_vectorizable_query = True + logger.info( + f"Auto-discovered vectorizable field '{self.vector_field_name}' with server-side vectorization." + ) + else: + logger.warning( + f"Multiple vectorizable fields found: {vectorizable_fields}. " + f"Please specify vector_field_name explicitly." + ) + elif len(vector_fields) == 1: + self.vector_field_name = vector_fields[0] + self._auto_discovered_vector_field = True + self._use_vectorizable_query = False + if not self.embedding_function: + logger.warning( + f"Auto-discovered vector field '{self.vector_field_name}' without server-side vectorization. " + f"Provide embedding_function for vector search." + ) + self.vector_field_name = None + else: + logger.warning( + f"Multiple vector fields found: {vector_fields}. Please specify vector_field_name explicitly." + ) + except Exception as e: + logger.warning(f"Failed to auto-discover vector field: {e}. Using keyword-only search.") + + self._auto_discovered_vector_field = True + + async def _semantic_search(self, query: str) -> list[str]: + """Perform semantic hybrid search.""" + await self._auto_discover_vector_field() + + vector_queries: list[VectorizableTextQuery | VectorizedQuery] = [] + if self.vector_field_name: + vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k + if self._use_vectorizable_query: + vector_queries = [ + VectorizableTextQuery(text=query, k_nearest_neighbors=vector_k, fields=self.vector_field_name) + ] + elif self.embedding_function: + query_vector = await self.embedding_function(query) + vector_queries = [ + VectorizedQuery(vector=query_vector, k_nearest_neighbors=vector_k, fields=self.vector_field_name) + ] + + search_params: dict[str, Any] = {"search_text": query, "top": self.top_k} + if vector_queries: + search_params["vector_queries"] = vector_queries + if self.semantic_configuration_name: + search_params["query_type"] = QueryType.SEMANTIC + search_params["semantic_configuration_name"] = self.semantic_configuration_name + search_params["query_caption"] = QueryCaptionType.EXTRACTIVE + + if not self._search_client: + raise RuntimeError("Search client is not initialized.") + results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType] + + formatted_results: list[str] = [] + async for doc in results: # type: ignore[reportUnknownVariableType] + doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType] + doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType] + if doc_text: + formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType] + return formatted_results + + async def _ensure_knowledge_base(self) -> None: + """Ensure Knowledge Base and knowledge source are created or use existing KB.""" + if self._knowledge_base_initialized: + return + + if not self.knowledge_base_name: + raise ValueError("knowledge_base_name is required for agentic mode") + + knowledge_base_name = self.knowledge_base_name + + if self._use_existing_knowledge_base: + if _agentic_retrieval_available and self._retrieval_client is None: + self._retrieval_client = KnowledgeBaseRetrievalClient( + endpoint=self.endpoint, + knowledge_base_name=knowledge_base_name, + credential=self.credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + self._knowledge_base_initialized = True + return + + if not self._index_client: + raise ValueError("Index client is required when creating Knowledge Base from index") + if not self.azure_openai_resource_url: + raise ValueError("azure_openai_resource_url is required when creating Knowledge Base from index") + if not self.azure_openai_deployment_name: + raise ValueError("model_deployment_name is required when creating Knowledge Base from index") + if not self.index_name: + raise ValueError("index_name is required when creating Knowledge Base from index") + + knowledge_source_name = f"{self.index_name}-source" + try: + await self._index_client.get_knowledge_source(knowledge_source_name) + except ResourceNotFoundError: + knowledge_source = SearchIndexKnowledgeSource( + name=knowledge_source_name, + description=f"Knowledge source for {self.index_name} search index", + search_index_parameters=SearchIndexKnowledgeSourceParameters( + search_index_name=self.index_name, + ), + ) + await self._index_client.create_knowledge_source(knowledge_source) + + aoai_params = AzureOpenAIVectorizerParameters( + resource_url=self.azure_openai_resource_url, + deployment_name=self.azure_openai_deployment_name, + model_name=self.model_name, + api_key=self.azure_openai_api_key, + ) + + output_mode = ( + KnowledgeRetrievalOutputMode.EXTRACTIVE_DATA + if self.knowledge_base_output_mode == "extractive_data" + else KnowledgeRetrievalOutputMode.ANSWER_SYNTHESIS + ) + reasoning_effort_map: dict[str, KnowledgeRetrievalReasoningEffort] = { + "minimal": KnowledgeRetrievalMinimalReasoningEffort(), + "medium": KnowledgeRetrievalMediumReasoningEffort(), + "low": KnowledgeRetrievalLowReasoningEffort(), + } + reasoning_effort = reasoning_effort_map[self.retrieval_reasoning_effort] + + knowledge_base = KnowledgeBase( + name=knowledge_base_name, + description=f"Knowledge Base for multi-hop retrieval across {self.index_name}", + knowledge_sources=[KnowledgeSourceReference(name=knowledge_source_name)], + models=[KnowledgeBaseAzureOpenAIModel(azure_open_ai_parameters=aoai_params)], + output_mode=output_mode, + retrieval_reasoning_effort=reasoning_effort, + ) + await self._index_client.create_or_update_knowledge_base(knowledge_base) + self._knowledge_base_initialized = True + + if _agentic_retrieval_available and self._retrieval_client is None: + self._retrieval_client = KnowledgeBaseRetrievalClient( + endpoint=self.endpoint, + knowledge_base_name=knowledge_base_name, + credential=self.credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + + async def _agentic_search(self, messages: list[ChatMessage]) -> list[str]: + """Perform agentic retrieval with multi-hop reasoning.""" + await self._ensure_knowledge_base() + + reasoning_effort_map: dict[str, KBRetrievalReasoningEffort] = { + "minimal": KBRetrievalMinimalReasoningEffort(), + "medium": KBRetrievalMediumReasoningEffort(), + "low": KBRetrievalLowReasoningEffort(), + } + reasoning_effort = reasoning_effort_map[self.retrieval_reasoning_effort] + + output_mode = ( + KBRetrievalOutputMode.EXTRACTIVE_DATA + if self.knowledge_base_output_mode == "extractive_data" + else KBRetrievalOutputMode.ANSWER_SYNTHESIS + ) + + if self.retrieval_reasoning_effort == "minimal": + query = "\n".join(msg.text for msg in messages if msg.text) + intents: list[KnowledgeRetrievalIntent] = [KnowledgeRetrievalSemanticIntent(search=query)] + retrieval_request = KnowledgeBaseRetrievalRequest( + intents=intents, + retrieval_reasoning_effort=reasoning_effort, + output_mode=output_mode, + include_activity=True, + ) + else: + kb_messages = [ + KnowledgeBaseMessage( + role=msg.role if hasattr(msg.role, "value") else str(msg.role), + content=[KnowledgeBaseMessageTextContent(text=msg.text)], + ) + for msg in messages + if msg.text + ] + retrieval_request = KnowledgeBaseRetrievalRequest( + messages=kb_messages, + retrieval_reasoning_effort=reasoning_effort, + output_mode=output_mode, + include_activity=True, + ) + + if not self._retrieval_client: + raise RuntimeError("Retrieval client not initialized.") + retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request) + + if retrieval_result.response and len(retrieval_result.response) > 0: + assistant_message = retrieval_result.response[-1] + if assistant_message.content: + answer_parts: list[str] = [] + for content_item in assistant_message.content: + if isinstance(content_item, KnowledgeBaseMessageTextContent) and content_item.text: + answer_parts.append(content_item.text) + if answer_parts: + return answer_parts + + return ["No results found from Knowledge Base."] + + def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str: + """Extract readable text from a search document with optional citation.""" + text = "" + for field in ["content", "text", "description", "body", "chunk"]: + if doc.get(field): + text = str(doc[field]) + break + if not text: + text_parts: list[str] = [] + for key, value in doc.items(): + if isinstance(value, str) and not key.startswith("@") and key != "id": + text_parts.append(f"{key}: {value}") + text = " | ".join(text_parts) if text_parts else "" + if doc_id and text: + return f"[Source: {doc_id}] {text}" + return text + + +__all__ = ["_AzureAISearchContextProvider"] diff --git a/python/packages/mem0/agent_framework_mem0/__init__.py b/python/packages/mem0/agent_framework_mem0/__init__.py index 7ff88aaa42..b43f1dba2c 100644 --- a/python/packages/mem0/agent_framework_mem0/__init__.py +++ b/python/packages/mem0/agent_framework_mem0/__init__.py @@ -8,6 +8,7 @@ if os.environ.get("MEM0_TELEMETRY") is None: os.environ["MEM0_TELEMETRY"] = "false" +from ._context_provider import _Mem0ContextProvider from ._provider import Mem0Provider try: @@ -17,5 +18,6 @@ __all__ = [ "Mem0Provider", + "_Mem0ContextProvider", "__version__", ] diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py new file mode 100644 index 0000000000..04befe294b --- /dev/null +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -0,0 +1,193 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""New-pattern Mem0 context provider using BaseContextProvider. + +This module provides ``_Mem0ContextProvider``, a side-by-side implementation of +:class:`Mem0Provider` built on the new :class:`BaseContextProvider` hooks pattern. +It will be renamed to ``Mem0ContextProvider`` in PR2 when the old class is removed. +""" + +from __future__ import annotations + +import sys +from contextlib import AbstractAsyncContextManager +from typing import TYPE_CHECKING, Any + +from agent_framework import ChatMessage +from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext +from agent_framework.exceptions import ServiceInitializationError +from mem0 import AsyncMemory, AsyncMemoryClient + +if sys.version_info >= (3, 11): + from typing import NotRequired, Self, TypedDict # pragma: no cover +else: + from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover + +if TYPE_CHECKING: + from agent_framework._agents import SupportsAgentRun + + +class _MemorySearchResponse_v1_1(TypedDict): + results: list[dict[str, Any]] + relations: NotRequired[list[dict[str, Any]]] + + +_MemorySearchResponse_v2 = list[dict[str, Any]] + + +class _Mem0ContextProvider(BaseContextProvider): + """Mem0 context provider using the new BaseContextProvider hooks pattern. + + Integrates Mem0 for persistent semantic memory, searching and storing + memories via the Mem0 API. This is the new-pattern equivalent of + :class:`Mem0Provider`. + + Note: + This class uses a temporary ``_`` prefix to coexist with the existing + :class:`Mem0Provider`. It will be renamed to ``Mem0ContextProvider`` + in PR2. + """ + + DEFAULT_CONTEXT_PROMPT = "## Memories\nConsider the following memories when answering user questions:" + + def __init__( + self, + source_id: str, + mem0_client: AsyncMemory | AsyncMemoryClient | None = None, + api_key: str | None = None, + application_id: str | None = None, + agent_id: str | None = None, + user_id: str | None = None, + *, + context_prompt: str | None = None, + ) -> None: + """Initialize the Mem0 context provider. + + Args: + source_id: Unique identifier for this provider instance. + mem0_client: A pre-created Mem0 MemoryClient or None to create a default client. + api_key: The API key for authenticating with the Mem0 API. + application_id: The application ID for scoping memories. + agent_id: The agent ID for scoping memories. + user_id: The user ID for scoping memories. + context_prompt: The prompt to prepend to retrieved memories. + """ + super().__init__(source_id) + should_close_client = False + if mem0_client is None: + mem0_client = AsyncMemoryClient(api_key=api_key) + should_close_client = True + + self.api_key = api_key + self.application_id = application_id + self.agent_id = agent_id + self.user_id = user_id + self.context_prompt = context_prompt or self.DEFAULT_CONTEXT_PROMPT + self.mem0_client = mem0_client + self._should_close_client = should_close_client + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + if self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager): + await self.mem0_client.__aenter__() + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + """Async context manager exit.""" + if self._should_close_client and self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager): + await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb) + + # -- Hooks pattern --------------------------------------------------------- + + async def before_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Search Mem0 for relevant memories and add to the session context.""" + self._validate_filters() + input_text = "\n".join(msg.text for msg in context.input_messages if msg and msg.text and msg.text.strip()) + if not input_text.strip(): + return + + filters = self._build_filters(session_id=context.session_id) + + search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc] + query=input_text, + filters=filters, + ) + + if isinstance(search_response, list): + memories = search_response + elif isinstance(search_response, dict) and "results" in search_response: + memories = search_response["results"] + else: + memories = [search_response] + + line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories) + if line_separated_memories: + context.extend_messages( + self.source_id, + [ChatMessage(role="user", text=f"{self.context_prompt}\n{line_separated_memories}")], + ) + + async def after_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Store request/response messages to Mem0 for future retrieval.""" + self._validate_filters() + + messages_to_store: list[ChatMessage] = list(context.input_messages) + if context.response and context.response.messages: + messages_to_store.extend(context.response.messages) + + def get_role_value(role: Any) -> str: + return role.value if hasattr(role, "value") else str(role) + + messages: list[dict[str, str]] = [ + {"role": get_role_value(message.role), "content": message.text} + for message in messages_to_store + if get_role_value(message.role) in {"user", "assistant", "system"} and message.text and message.text.strip() + ] + + if messages: + await self.mem0_client.add( # type: ignore[misc] + messages=messages, + user_id=self.user_id, + agent_id=self.agent_id, + run_id=context.session_id, + metadata={"application_id": self.application_id}, + ) + + # -- Internal methods ------------------------------------------------------ + + def _validate_filters(self) -> None: + """Validates that at least one filter is provided.""" + if not self.agent_id and not self.user_id and not self.application_id: + raise ServiceInitializationError( + "At least one of the filters: agent_id, user_id, or application_id is required." + ) + + def _build_filters(self, *, session_id: str | None = None) -> dict[str, Any]: + """Build search filters from initialization parameters.""" + filters: dict[str, Any] = {} + if self.user_id: + filters["user_id"] = self.user_id + if self.agent_id: + filters["agent_id"] = self.agent_id + if session_id: + filters["run_id"] = session_id + if self.application_id: + filters["app_id"] = self.application_id + return filters + + +__all__ = ["_Mem0ContextProvider"] diff --git a/python/packages/redis/agent_framework_redis/__init__.py b/python/packages/redis/agent_framework_redis/__init__.py index fd63a69cae..9453401441 100644 --- a/python/packages/redis/agent_framework_redis/__init__.py +++ b/python/packages/redis/agent_framework_redis/__init__.py @@ -2,6 +2,8 @@ import importlib.metadata from ._chat_message_store import RedisChatMessageStore +from ._context_provider import _RedisContextProvider +from ._history_provider import _RedisHistoryProvider from ._provider import RedisProvider try: @@ -12,5 +14,7 @@ __all__ = [ "RedisChatMessageStore", "RedisProvider", + "_RedisContextProvider", + "_RedisHistoryProvider", "__version__", ] diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py new file mode 100644 index 0000000000..b5fe36384d --- /dev/null +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -0,0 +1,432 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""New-pattern Redis context provider using BaseContextProvider. + +This module provides ``_RedisContextProvider``, a side-by-side implementation of +:class:`RedisProvider` built on the new :class:`BaseContextProvider` hooks pattern. +It will be renamed to ``RedisContextProvider`` in PR2 when the old class is removed. +""" + +from __future__ import annotations + +import json +import sys +from functools import reduce +from operator import and_ +from typing import Any, Literal, cast + +import numpy as np +from agent_framework import ChatMessage +from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext +from agent_framework.exceptions import ( + AgentException, + ServiceInitializationError, + ServiceInvalidRequestError, +) +from redisvl.index import AsyncSearchIndex +from redisvl.query import HybridQuery, TextQuery +from redisvl.query.filter import FilterExpression, Tag +from redisvl.utils.token_escaper import TokenEscaper +from redisvl.utils.vectorize import BaseVectorizer + +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + +if sys.version_info >= (3, 12): + from typing import TYPE_CHECKING, override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + +if TYPE_CHECKING: + from agent_framework._agents import SupportsAgentRun + + +class _RedisContextProvider(BaseContextProvider): + """Redis context provider using the new BaseContextProvider hooks pattern. + + Stores context in Redis and retrieves scoped context via full-text or + optional hybrid vector search. This is the new-pattern equivalent of + :class:`RedisProvider`. + + Note: + This class uses a temporary ``_`` prefix to coexist with the existing + :class:`RedisProvider`. It will be renamed to ``RedisContextProvider`` + in PR2. + """ + + DEFAULT_CONTEXT_PROMPT = "## Memories\nConsider the following memories when answering user questions:" + + def __init__( + self, + source_id: str, + redis_url: str = "redis://localhost:6379", + index_name: str = "context", + prefix: str = "context", + *, + redis_vectorizer: BaseVectorizer | None = None, + vector_field_name: str | None = None, + vector_algorithm: Literal["flat", "hnsw"] | None = None, + vector_distance_metric: Literal["cosine", "ip", "l2"] | None = None, + application_id: str | None = None, + agent_id: str | None = None, + user_id: str | None = None, + context_prompt: str | None = None, + redis_index: Any = None, + overwrite_index: bool = False, + ): + """Create a Redis Context Provider. + + Args: + source_id: Unique identifier for this provider instance. + redis_url: The Redis server URL. + index_name: The name of the Redis index. + prefix: The prefix for all keys in the Redis database. + redis_vectorizer: The vectorizer to use for Redis. + vector_field_name: The name of the vector field in Redis. + vector_algorithm: The algorithm to use for vector search. + vector_distance_metric: The distance metric to use for vector search. + application_id: The application ID to scope the context. + agent_id: The agent ID to scope the context. + user_id: The user ID to scope the context. + context_prompt: The context prompt to use for the provider. + redis_index: The Redis index to use for the provider. + overwrite_index: Whether to overwrite the existing Redis index. + """ + super().__init__(source_id) + self.redis_url = redis_url + self.index_name = index_name + self.prefix = prefix + if redis_vectorizer is not None and not isinstance(redis_vectorizer, BaseVectorizer): + raise AgentException( + f"The redis vectorizer is not a valid type, got: {type(redis_vectorizer)}, expected: BaseVectorizer." + ) + self.redis_vectorizer = redis_vectorizer + self.vector_field_name = vector_field_name + self.vector_algorithm: Literal["flat", "hnsw"] | None = vector_algorithm + self.vector_distance_metric: Literal["cosine", "ip", "l2"] | None = vector_distance_metric + self.application_id = application_id + self.agent_id = agent_id + self.user_id = user_id + self.context_prompt = context_prompt or self.DEFAULT_CONTEXT_PROMPT + self.overwrite_index = overwrite_index + self._token_escaper: TokenEscaper = TokenEscaper() + self._index_initialized: bool = False + self._schema_dict: dict[str, Any] | None = None + self.redis_index = redis_index or AsyncSearchIndex.from_dict( + self.schema_dict, redis_url=self.redis_url, validate_on_load=True + ) + + # -- Hooks pattern --------------------------------------------------------- + + @override + async def before_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Retrieve scoped context from Redis and add to the session context.""" + self._validate_filters() + input_text = "\n".join(msg.text for msg in context.input_messages if msg and msg.text and msg.text.strip()) + if not input_text.strip(): + return + + memories = await self._redis_search(text=input_text, session_id=context.session_id) + line_separated_memories = "\n".join( + str(memory.get("content", "")) for memory in memories if memory.get("content") + ) + if line_separated_memories: + context.extend_messages( + self.source_id, + [ChatMessage(role="user", text=f"{self.context_prompt}\n{line_separated_memories}")], + ) + + @override + async def after_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Store request/response messages to Redis for future retrieval.""" + self._validate_filters() + + messages_to_store: list[ChatMessage] = list(context.input_messages) + if context.response and context.response.messages: + messages_to_store.extend(context.response.messages) + + messages: list[dict[str, Any]] = [] + for message in messages_to_store: + if message.role in {"user", "assistant", "system"} and message.text and message.text.strip(): + shaped: dict[str, Any] = { + "role": message.role, + "content": message.text, + "conversation_id": context.session_id, + "message_id": message.message_id, + "author_name": message.author_name, + } + messages.append(shaped) + if messages: + await self._add(data=messages, session_id=context.session_id) + + # -- Internal methods (ported from RedisProvider) -------------------------- + + @property + def schema_dict(self) -> dict[str, Any]: + """Get the Redis schema dictionary, computing and caching it on first access.""" + if self._schema_dict is None: + vector_dims = self.redis_vectorizer.dims if self.redis_vectorizer is not None else None + vector_datatype = self.redis_vectorizer.dtype if self.redis_vectorizer is not None else None + self._schema_dict = self._build_schema_dict( + index_name=self.index_name, + prefix=self.prefix, + vector_field_name=self.vector_field_name, + vector_dims=vector_dims, + vector_datatype=vector_datatype, + vector_algorithm=self.vector_algorithm, + vector_distance_metric=self.vector_distance_metric, + ) + return self._schema_dict + + def _build_filter_from_dict(self, filters: dict[str, str | None]) -> Any | None: + """Builds a combined filter expression from simple equality tags.""" + parts = [Tag(k) == v for k, v in filters.items() if v] + return reduce(and_, parts) if parts else None + + def _build_schema_dict( + self, + *, + index_name: str, + prefix: str, + vector_field_name: str | None, + vector_dims: int | None, + vector_datatype: str | None, + vector_algorithm: Literal["flat", "hnsw"] | None, + vector_distance_metric: Literal["cosine", "ip", "l2"] | None, + ) -> dict[str, Any]: + """Builds the RediSearch schema configuration dictionary.""" + fields: list[dict[str, Any]] = [ + {"name": "role", "type": "tag"}, + {"name": "mime_type", "type": "tag"}, + {"name": "content", "type": "text"}, + {"name": "conversation_id", "type": "tag"}, + {"name": "message_id", "type": "tag"}, + {"name": "author_name", "type": "tag"}, + {"name": "application_id", "type": "tag"}, + {"name": "agent_id", "type": "tag"}, + {"name": "user_id", "type": "tag"}, + {"name": "thread_id", "type": "tag"}, + ] + if vector_field_name is not None and vector_dims is not None: + fields.append({ + "name": vector_field_name, + "type": "vector", + "attrs": { + "algorithm": (vector_algorithm or "hnsw"), + "dims": int(vector_dims), + "distance_metric": (vector_distance_metric or "cosine"), + "datatype": (vector_datatype or "float32"), + }, + }) + return { + "index": {"name": index_name, "prefix": prefix, "key_separator": ":", "storage_type": "hash"}, + "fields": fields, + } + + async def _ensure_index(self) -> None: + """Initialize the search index.""" + if self._index_initialized: + return + index_exists = await self.redis_index.exists() + if not self.overwrite_index and index_exists: + await self._validate_schema_compatibility() + await self.redis_index.create(overwrite=self.overwrite_index, drop=False) + self._index_initialized = True + + async def _validate_schema_compatibility(self) -> None: + """Validate that existing index schema matches current configuration.""" + TAG_DEFAULTS = {"separator": ",", "case_sensitive": False, "withsuffixtrie": False} + TEXT_DEFAULTS = {"weight": 1.0, "no_stem": False} + + def _significant_index(i: dict[str, Any]) -> dict[str, Any]: + return {k: i.get(k) for k in ("name", "prefix", "key_separator", "storage_type")} + + def _sig_tag(attrs: dict[str, Any] | None) -> dict[str, Any]: + a = {**TAG_DEFAULTS, **(attrs or {})} + return {k: a[k] for k in ("separator", "case_sensitive", "withsuffixtrie")} + + def _sig_text(attrs: dict[str, Any] | None) -> dict[str, Any]: + a = {**TEXT_DEFAULTS, **(attrs or {})} + return {k: a[k] for k in ("weight", "no_stem")} + + def _sig_vector(attrs: dict[str, Any] | None) -> dict[str, Any]: + a = {**(attrs or {})} + return {k: a.get(k) for k in ("algorithm", "dims", "distance_metric", "datatype")} + + def _schema_signature(schema: dict[str, Any]) -> dict[str, Any]: + sig: dict[str, Any] = {"index": _significant_index(schema.get("index", {})), "fields": {}} + for f in schema.get("fields", []): + name, ftype = f.get("name"), f.get("type") + if not name: + continue + if ftype == "tag": + sig["fields"][name] = {"type": "tag", "attrs": _sig_tag(f.get("attrs"))} + elif ftype == "text": + sig["fields"][name] = {"type": "text", "attrs": _sig_text(f.get("attrs"))} + elif ftype == "vector": + sig["fields"][name] = {"type": "vector", "attrs": _sig_vector(f.get("attrs"))} + else: + sig["fields"][name] = {"type": ftype} + return sig + + existing_index = await AsyncSearchIndex.from_existing(self.index_name, redis_url=self.redis_url) + existing_schema = existing_index.schema.to_dict() + current_schema = self.schema_dict + existing_sig = _schema_signature(existing_schema) + current_sig = _schema_signature(current_schema) + if existing_sig != current_sig: + raise ServiceInitializationError( + "Existing Redis index schema is incompatible with the current configuration.\n" + f"Existing (significant): {json.dumps(existing_sig, indent=2, sort_keys=True)}\n" + f"Current (significant): {json.dumps(current_sig, indent=2, sort_keys=True)}\n" + "Set overwrite_index=True to rebuild if this change is intentional." + ) + + async def _add( + self, + *, + data: dict[str, Any] | list[dict[str, Any]], + session_id: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Inserts one or many documents with partition fields populated.""" + self._validate_filters() + await self._ensure_index() + docs = data if isinstance(data, list) else [data] + + prepared: list[dict[str, Any]] = [] + for doc in docs: + d = dict(doc) + d.setdefault("application_id", self.application_id) + d.setdefault("agent_id", self.agent_id) + d.setdefault("user_id", self.user_id) + d.setdefault("thread_id", session_id) + d.setdefault("conversation_id", session_id) + if "content" not in d: + raise ServiceInvalidRequestError("add() requires a 'content' field in data") + if self.vector_field_name: + d.setdefault(self.vector_field_name, None) + prepared.append(d) + + if self.redis_vectorizer and self.vector_field_name: + text_list = [d["content"] for d in prepared] + embeddings = await self.redis_vectorizer.aembed_many(text_list, batch_size=len(text_list)) + for i, d in enumerate(prepared): + vec = np.asarray(embeddings[i], dtype=np.float32).tobytes() + field_name: str = self.vector_field_name + d[field_name] = vec + + await self.redis_index.load(prepared) + + async def _redis_search( + self, + text: str, + *, + session_id: str | None = None, + text_scorer: str = "BM25STD", + filter_expression: Any | None = None, + return_fields: list[str] | None = None, + num_results: int = 10, + alpha: float = 0.7, + ) -> list[dict[str, Any]]: + """Runs a text or hybrid vector-text search with optional filters.""" + await self._ensure_index() + self._validate_filters() + + q = (text or "").strip() + if not q: + raise ServiceInvalidRequestError("text_search() requires non-empty text") + num_results = max(int(num_results or 10), 1) + + combined_filter = self._build_filter_from_dict({ + "application_id": self.application_id, + "agent_id": self.agent_id, + "user_id": self.user_id, + "thread_id": session_id, + "conversation_id": session_id, + }) + if filter_expression is not None: + combined_filter = (combined_filter & filter_expression) if combined_filter else filter_expression + + return_fields = ( + return_fields + if return_fields is not None + else ["content", "role", "application_id", "agent_id", "user_id", "thread_id"] + ) + + try: + if self.redis_vectorizer and self.vector_field_name: + vector = await self.redis_vectorizer.aembed(q) + query = HybridQuery( + text=q, + text_field_name="content", + vector=vector, + vector_field_name=self.vector_field_name, + text_scorer=text_scorer, + filter_expression=combined_filter, + alpha=alpha, + dtype=self.redis_vectorizer.dtype, + num_results=num_results, + return_fields=return_fields, + stopwords=None, + ) + hybrid_results = await self.redis_index.query(query) + return cast(list[dict[str, Any]], hybrid_results) + query = TextQuery( + text=q, + text_field_name="content", + text_scorer=text_scorer, + filter_expression=combined_filter, + num_results=num_results, + return_fields=return_fields, + stopwords=None, + ) + text_results = await self.redis_index.query(query) + return cast(list[dict[str, Any]], text_results) + except Exception as exc: # pragma: no cover + raise ServiceInvalidRequestError(f"Redis text search failed: {exc}") from exc + + def _validate_filters(self) -> None: + """Validates that at least one filter is provided.""" + if not self.agent_id and not self.user_id and not self.application_id: + raise ServiceInitializationError( + "At least one of the filters: agent_id, user_id, or application_id is required." + ) + + async def search_all(self, page_size: int = 200) -> list[dict[str, Any]]: + """Returns all documents in the index.""" + from redisvl.query import FilterQuery + + out: list[dict[str, Any]] = [] + async for batch in self.redis_index.paginate( + FilterQuery(FilterExpression("*"), return_fields=[], num_results=page_size), + page_size=page_size, + ): + out.extend(batch) + return out + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + """Async context manager exit.""" + + +__all__ = ["_RedisContextProvider"] diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py new file mode 100644 index 0000000000..0c93054526 --- /dev/null +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""New-pattern Redis history provider using BaseHistoryProvider. + +This module provides ``_RedisHistoryProvider``, a side-by-side implementation of +:class:`RedisChatMessageStore` built on the new :class:`BaseHistoryProvider` hooks pattern. +It will be renamed to ``RedisHistoryProvider`` in PR2 when the old class is removed. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import redis.asyncio as redis +from agent_framework import ChatMessage +from agent_framework._sessions import BaseHistoryProvider +from redis.credentials import CredentialProvider + +if TYPE_CHECKING: + pass + + +class _RedisHistoryProvider(BaseHistoryProvider): + """Redis-backed history provider using the new BaseHistoryProvider hooks pattern. + + Stores conversation history in Redis Lists, with each session isolated by a + unique Redis key. This is the new-pattern equivalent of + :class:`RedisChatMessageStore`. + + Note: + This class uses a temporary ``_`` prefix to coexist with the existing + :class:`RedisChatMessageStore`. It will be renamed to ``RedisHistoryProvider`` + in PR2. + """ + + def __init__( + self, + source_id: str, + redis_url: str | None = None, + credential_provider: CredentialProvider | None = None, + host: str | None = None, + port: int = 6380, + ssl: bool = True, + username: str | None = None, + *, + key_prefix: str = "chat_messages", + max_messages: int | None = None, + load_messages: bool = True, + store_responses: bool = True, + store_inputs: bool = True, + store_context_messages: bool = False, + store_context_from: Sequence[str] | None = None, + ) -> None: + """Initialize the Redis history provider. + + Args: + source_id: Unique identifier for this provider instance. + redis_url: Redis connection URL (e.g., "redis://localhost:6379"). + Mutually exclusive with credential_provider. + credential_provider: Redis credential provider for Azure AD authentication. + Requires host parameter. Mutually exclusive with redis_url. + host: Redis host name. Required when using credential_provider. + port: Redis port number. Defaults to 6380 (Azure Redis SSL port). + ssl: Enable SSL/TLS connection. Defaults to True. + username: Redis username. + key_prefix: Prefix for Redis keys. Defaults to 'chat_messages'. + max_messages: Maximum number of messages to retain per session. + When exceeded, oldest messages are automatically trimmed. + None means unlimited storage. + load_messages: Whether to load messages before invocation. + store_responses: Whether to store response messages. + store_inputs: Whether to store input messages. + store_context_messages: Whether to store context from other providers. + store_context_from: If set, only store context from these source_ids. + + Raises: + ValueError: If neither redis_url nor credential_provider is provided. + ValueError: If both redis_url and credential_provider are provided. + ValueError: If credential_provider is used without host parameter. + """ + super().__init__( + source_id, + load_messages=load_messages, + store_responses=store_responses, + store_inputs=store_inputs, + store_context_messages=store_context_messages, + store_context_from=store_context_from, + ) + + if redis_url is None and credential_provider is None: + raise ValueError("Either redis_url or credential_provider must be provided") + if redis_url is not None and credential_provider is not None: + raise ValueError("redis_url and credential_provider are mutually exclusive") + if credential_provider is not None and host is None: + raise ValueError("host is required when using credential_provider") + + self.key_prefix = key_prefix + self.max_messages = max_messages + self.redis_url = redis_url + + if credential_provider is not None and host is not None: + self._redis_client = redis.Redis( + host=host, + port=port, + ssl=ssl, + username=username, + credential_provider=credential_provider, + decode_responses=True, + ) + else: + self._redis_client = redis.from_url(redis_url, decode_responses=True) # type: ignore[no-untyped-call] + + def _redis_key(self, session_id: str | None) -> str: + """Get the Redis key for a given session's messages.""" + return f"{self.key_prefix}:{session_id or 'default'}" + + async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[ChatMessage]: + """Retrieve stored messages for this session from Redis. + + Args: + session_id: The session ID to retrieve messages for. + **kwargs: Additional arguments (unused). + + Returns: + List of stored ChatMessage objects in chronological order. + """ + key = self._redis_key(session_id) + redis_messages = await self._redis_client.lrange(key, 0, -1) # type: ignore[misc] + messages: list[ChatMessage] = [] + if redis_messages: + for serialized in redis_messages: + messages.append(ChatMessage.from_dict(self._deserialize_json(serialized))) + return messages + + async def save_messages(self, session_id: str | None, messages: Sequence[ChatMessage], **kwargs: Any) -> None: + """Persist messages for this session to Redis. + + Args: + session_id: The session ID to store messages for. + messages: The messages to persist. + **kwargs: Additional arguments (unused). + """ + if not messages: + return + + key = self._redis_key(session_id) + serialized_messages = [self._serialize_json(msg) for msg in messages] + + async with self._redis_client.pipeline(transaction=True) as pipe: + for serialized in serialized_messages: + await pipe.rpush(key, serialized) # type: ignore[misc] + await pipe.execute() + + if self.max_messages is not None: + current_count = await self._redis_client.llen(key) # type: ignore[misc] + if current_count > self.max_messages: + await self._redis_client.ltrim(key, -self.max_messages, -1) # type: ignore[misc] + + @staticmethod + def _serialize_json(message: ChatMessage) -> str: + """Serialize a ChatMessage to a JSON string for Redis storage.""" + import json + + return json.dumps(message.to_dict()) + + @staticmethod + def _deserialize_json(data: str) -> dict[str, Any]: + """Deserialize a JSON string from Redis to a dict.""" + import json + + return json.loads(data) # type: ignore[no-any-return] + + async def clear(self, session_id: str | None) -> None: + """Clear all messages for a session. + + Args: + session_id: The session ID to clear messages for. + """ + await self._redis_client.delete(self._redis_key(session_id)) + + async def aclose(self) -> None: + """Close the Redis connection.""" + await self._redis_client.aclose() # type: ignore[misc] + + +__all__ = ["_RedisHistoryProvider"] From 4b95b525fd00d1fb19632c313a46a82a28fc3297 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 9 Feb 2026 16:24:08 +0100 Subject: [PATCH 04/12] test: add tests for new-pattern provider implementations - 32 tests for _RedisContextProvider and _RedisHistoryProvider - 29 tests for _Mem0ContextProvider - 17 tests for _AzureAISearchContextProvider --- .../test_aisearch_new_context_provider.py | 293 +++++++++++ .../tests/test_mem0_new_context_provider.py | 350 ++++++++++++++ .../redis/tests/test_new_providers.py | 455 ++++++++++++++++++ 3 files changed, 1098 insertions(+) create mode 100644 python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py create mode 100644 python/packages/mem0/tests/test_mem0_new_context_provider.py create mode 100644 python/packages/redis/tests/test_new_providers.py diff --git a/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py new file mode 100644 index 0000000000..1bd4a8eb94 --- /dev/null +++ b/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportPrivateUsage=false + +import os +from unittest.mock import AsyncMock, patch + +import pytest +from agent_framework import ChatMessage +from agent_framework._sessions import AgentSession, SessionContext +from agent_framework.exceptions import ServiceInitializationError + +from agent_framework_azure_ai_search._context_provider import _AzureAISearchContextProvider + +# -- Helpers ------------------------------------------------------------------- + + +class MockSearchResults: + """Async-iterable mock for Azure SearchClient.search() results.""" + + def __init__(self, docs: list[dict]): + self._docs = docs + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._docs): + raise StopAsyncIteration + doc = self._docs[self._index] + self._index += 1 + return doc + + +@pytest.fixture +def mock_search_client() -> AsyncMock: + """Create a mock SearchClient that returns one document.""" + client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([{"id": "doc1", "content": "test document"}]) + + client.search = AsyncMock(side_effect=_search) + return client + + +@pytest.fixture +def mock_search_client_empty() -> AsyncMock: + """Create a mock SearchClient that returns no results.""" + client = AsyncMock() + + async def _search(**kwargs): + return MockSearchResults([]) + + client.search = AsyncMock(side_effect=_search) + return client + + +def _make_provider(**overrides) -> _AzureAISearchContextProvider: + """Create a semantic-mode provider with mocked internals (skips auto-discovery).""" + defaults = { + "source_id": "aisearch", + "endpoint": "https://test.search.windows.net", + "index_name": "test-index", + "api_key": "test-key", + } + defaults.update(overrides) + provider = _AzureAISearchContextProvider(**defaults) + provider._auto_discovered_vector_field = True # skip auto-discovery + return provider + + +# -- Initialization: semantic mode --------------------------------------------- + + +class TestInitSemantic: + """Initialization tests for semantic mode.""" + + def test_valid_init(self) -> None: + provider = _make_provider() + assert provider.source_id == "aisearch" + assert provider.endpoint == "https://test.search.windows.net" + assert provider.index_name == "test-index" + assert provider.mode == "semantic" + + def test_source_id_set(self) -> None: + provider = _make_provider(source_id="my-source") + assert provider.source_id == "my-source" + + def test_missing_endpoint_raises(self) -> None: + with patch.dict(os.environ, {}, clear=True), pytest.raises(ServiceInitializationError, match="endpoint"): + _AzureAISearchContextProvider( + source_id="s", + endpoint=None, + index_name="idx", + api_key="key", + ) + + def test_missing_index_name_semantic_raises(self) -> None: + with pytest.raises(ServiceInitializationError, match="index name"): + _AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + index_name=None, + api_key="key", + ) + + def test_env_variable_fallback(self) -> None: + env = { + "AZURE_SEARCH_ENDPOINT": "https://env.search.windows.net", + "AZURE_SEARCH_INDEX_NAME": "env-index", + "AZURE_SEARCH_API_KEY": "env-key", + } + with patch.dict(os.environ, env, clear=False): + provider = _AzureAISearchContextProvider(source_id="env-test") + assert provider.endpoint == "https://env.search.windows.net" + assert provider.index_name == "env-index" + + +# -- Initialization: agentic mode validation ----------------------------------- + + +class TestInitAgenticValidation: + """Initialization validation tests for agentic mode.""" + + def test_both_index_and_kb_raises(self) -> None: + with pytest.raises(ServiceInitializationError, match="not both"): + _AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + index_name="idx", + knowledge_base_name="kb", + api_key="key", + mode="agentic", + model_deployment_name="deploy", + azure_openai_resource_url="https://aoai.openai.azure.com", + ) + + def test_neither_index_nor_kb_raises(self) -> None: + with pytest.raises(ServiceInitializationError, match="provide either"): + _AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + api_key="key", + mode="agentic", + ) + + def test_missing_model_deployment_name_raises(self) -> None: + with pytest.raises(ServiceInitializationError, match="model_deployment_name"): + _AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + index_name="idx", + api_key="key", + mode="agentic", + azure_openai_resource_url="https://aoai.openai.azure.com", + ) + + def test_vector_field_without_embedding_raises(self) -> None: + with pytest.raises(ValueError, match="embedding_function"): + _AzureAISearchContextProvider( + source_id="s", + endpoint="https://test.search.windows.net", + index_name="idx", + api_key="key", + vector_field_name="embedding", + ) + + +# -- before_run: semantic mode ------------------------------------------------- + + +class TestBeforeRunSemantic: + """Tests for before_run in semantic mode.""" + + async def test_results_added_to_context(self, mock_search_client: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client + + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ChatMessage(role="user", contents=["test query"])], + session_id="s1", + ) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_search_client.search.assert_awaited_once() + msgs = ctx.context_messages.get("aisearch", []) + assert len(msgs) >= 2 # context_prompt + at least one result + assert msgs[0].text == provider.context_prompt + + async def test_empty_input_no_search(self, mock_search_client: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client + + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[], session_id="s1") + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_search_client.search.assert_not_awaited() + assert ctx.context_messages.get("aisearch") is None + + async def test_no_results_no_messages(self, mock_search_client_empty: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client_empty + + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ChatMessage(role="user", contents=["test query"])], + session_id="s1", + ) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_search_client_empty.search.assert_awaited_once() + assert ctx.context_messages.get("aisearch") is None + + async def test_context_prompt_prepended(self, mock_search_client: AsyncMock) -> None: + custom_prompt = "Custom search context:" + provider = _make_provider(context_prompt=custom_prompt) + provider._search_client = mock_search_client + + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ChatMessage(role="user", contents=["test query"])], + session_id="s1", + ) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + msgs = ctx.context_messages["aisearch"] + assert msgs[0].text == custom_prompt + + +# -- before_run: message filtering --------------------------------------------- + + +class TestBeforeRunFiltering: + """Tests that only user/assistant messages are used for search.""" + + async def test_filters_non_user_assistant(self, mock_search_client: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client + + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + ChatMessage(role="system", contents=["system prompt"]), + ChatMessage(role="user", contents=["actual question"]), + ], + session_id="s1", + ) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_search_client.search.assert_awaited_once() + call_kwargs = mock_search_client.search.call_args[1] + # The search text should contain only the user message, not the system message + assert "actual question" in call_kwargs["search_text"] + assert "system prompt" not in call_kwargs["search_text"] + + async def test_only_system_messages_no_search(self, mock_search_client: AsyncMock) -> None: + provider = _make_provider() + provider._search_client = mock_search_client + + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ChatMessage(role="system", contents=["system prompt"])], + session_id="s1", + ) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_search_client.search.assert_not_awaited() + + +# -- __aexit__ ----------------------------------------------------------------- + + +class TestAexit: + """Tests for async context manager cleanup.""" + + async def test_closes_retrieval_client(self) -> None: + provider = _make_provider() + mock_retrieval = AsyncMock() + provider._retrieval_client = mock_retrieval + + await provider.__aexit__(None, None, None) + + mock_retrieval.close.assert_awaited_once() + assert provider._retrieval_client is None + + async def test_no_retrieval_client_no_error(self) -> None: + provider = _make_provider() + assert provider._retrieval_client is None + + await provider.__aexit__(None, None, None) # should not raise diff --git a/python/packages/mem0/tests/test_mem0_new_context_provider.py b/python/packages/mem0/tests/test_mem0_new_context_provider.py new file mode 100644 index 0000000000..7fb8585ef4 --- /dev/null +++ b/python/packages/mem0/tests/test_mem0_new_context_provider.py @@ -0,0 +1,350 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from agent_framework import AgentResponse, ChatMessage +from agent_framework._sessions import AgentSession, SessionContext +from agent_framework.exceptions import ServiceInitializationError + +from agent_framework_mem0._context_provider import _Mem0ContextProvider + + +@pytest.fixture +def mock_mem0_client() -> AsyncMock: + """Create a mock Mem0 AsyncMemoryClient.""" + from mem0 import AsyncMemoryClient + + mock_client = AsyncMock(spec=AsyncMemoryClient) + mock_client.add = AsyncMock() + mock_client.search = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + return mock_client + + +# -- Initialization tests ------------------------------------------------------ + + +class TestInit: + """Test _Mem0ContextProvider initialization.""" + + def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider( + source_id="mem0", + mem0_client=mock_mem0_client, + api_key="key-123", + application_id="app1", + agent_id="agent1", + user_id="user1", + context_prompt="Custom prompt", + ) + assert provider.source_id == "mem0" + assert provider.api_key == "key-123" + assert provider.application_id == "app1" + assert provider.agent_id == "agent1" + assert provider.user_id == "user1" + assert provider.context_prompt == "Custom prompt" + assert provider.mem0_client is mock_mem0_client + assert provider._should_close_client is False + + def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider.context_prompt == _Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT + + def test_init_auto_creates_client_when_none(self) -> None: + """When no client is provided, a default AsyncMemoryClient is created and flagged for closing.""" + with patch("agent_framework_mem0._context_provider.AsyncMemoryClient") as mock_client_cls: + provider = _Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1") + mock_client_cls.assert_called_once_with(api_key="test-key") + assert provider.mem0_client is mock_client_cls.return_value + assert provider._should_close_client is True + + def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider._should_close_client is False + + +# -- before_run tests ---------------------------------------------------------- + + +class TestBeforeRun: + """Test before_run hook.""" + + async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> None: + """Mocked mem0 search returns memories → messages added to context with prompt.""" + mock_mem0_client.search.return_value = [ + {"memory": "User likes Python"}, + {"memory": "User prefers dark mode"}, + ] + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="Hello")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_mem0_client.search.assert_awaited_once() + assert "mem0" in ctx.context_messages + added = ctx.context_messages["mem0"] + assert len(added) == 1 + assert "User likes Python" in added[0].text # type: ignore[operator] + assert "User prefers dark mode" in added[0].text # type: ignore[operator] + assert provider.context_prompt in added[0].text # type: ignore[operator] + + async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None: + """Empty input messages → no search performed.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_mem0_client.search.assert_not_awaited() + assert "mem0" not in ctx.context_messages + + async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None: + """Empty search results → no messages added.""" + mock_mem0_client.search.return_value = [] + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="test")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + assert "mem0" not in ctx.context_messages + + async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None: + """Raises ServiceInitializationError when no filters.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="test")], session_id="s1") + + with pytest.raises(ServiceInitializationError, match="At least one of the filters"): + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: + """Search response in v1.1 dict format with 'results' key.""" + mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]} + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="test")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + added = ctx.context_messages["mem0"] + assert "remembered fact" in added[0].text # type: ignore[operator] + + async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None: + """Multiple input messages are joined for the search query.""" + mock_mem0_client.search.return_value = [] + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + ChatMessage(role="user", text="Hello"), + ChatMessage(role="user", text="World"), + ], + session_id="s1", + ) + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + call_kwargs = mock_mem0_client.search.call_args.kwargs + assert call_kwargs["query"] == "Hello\nWorld" + + +# -- after_run tests ----------------------------------------------------------- + + +class TestAfterRun: + """Test after_run hook.""" + + async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None: + """Stores input+response messages to mem0 via client.add.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="question")], session_id="s1") + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", text="answer")]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_mem0_client.add.assert_awaited_once() + call_kwargs = mock_mem0_client.add.call_args.kwargs + assert call_kwargs["messages"] == [ + {"role": "user", "content": "question"}, + {"role": "assistant", "content": "answer"}, + ] + assert call_kwargs["user_id"] == "u1" + assert call_kwargs["run_id"] == "s1" + + async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None: + """Only stores user/assistant/system messages with text.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + ChatMessage(role="user", text="hello"), + ChatMessage(role="tool", text="tool output"), + ], + session_id="s1", + ) + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", text="reply")]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + call_kwargs = mock_mem0_client.add.call_args.kwargs + roles = [m["role"] for m in call_kwargs["messages"]] + assert "tool" not in roles + assert roles == ["user", "assistant"] + + async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None: + """Skips messages with empty text.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + ChatMessage(role="user", text=""), + ChatMessage(role="user", text=" "), + ], + session_id="s1", + ) + ctx._response = AgentResponse(messages=[]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_mem0_client.add.assert_not_awaited() + + async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None: + """Uses session_id as run_id.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="hi")], session_id="my-session") + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", text="hey")]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + assert mock_mem0_client.add.call_args.kwargs["run_id"] == "my-session" + + async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None: + """Raises ServiceInitializationError when no filters.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="hi")], session_id="s1") + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", text="hey")]) + + with pytest.raises(ServiceInitializationError, match="At least one of the filters"): + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None: + """application_id is passed in metadata.""" + provider = _Mem0ContextProvider( + source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1" + ) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", text="hi")], session_id="s1") + ctx._response = AgentResponse(messages=[]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"} + + +# -- _validate_filters tests -------------------------------------------------- + + +class TestValidateFilters: + """Test _validate_filters method.""" + + def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + with pytest.raises(ServiceInitializationError, match="At least one of the filters"): + provider._validate_filters() + + def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider._validate_filters() # should not raise + + def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1") + provider._validate_filters() + + def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1") + provider._validate_filters() + + +# -- _build_filters tests ----------------------------------------------------- + + +class TestBuildFilters: + """Test _build_filters method.""" + + def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider._build_filters() == {"user_id": "u1"} + + def test_all_params(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider( + source_id="mem0", + mem0_client=mock_mem0_client, + user_id="u1", + agent_id="a1", + application_id="app1", + ) + assert provider._build_filters(session_id="sess1") == { + "user_id": "u1", + "agent_id": "a1", + "run_id": "sess1", + "app_id": "app1", + } + + def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + filters = provider._build_filters() + assert "agent_id" not in filters + assert "run_id" not in filters + assert "app_id" not in filters + + def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + filters = provider._build_filters(session_id="s99") + assert filters["run_id"] == "s99" + + def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + assert provider._build_filters() == {} + + +# -- Context manager tests ----------------------------------------------------- + + +class TestContextManager: + """Test __aenter__/__aexit__ delegation.""" + + async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + result = await provider.__aenter__() + assert result is provider + mock_mem0_client.__aenter__.assert_awaited_once() + + async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None: + """Auto-created clients (_should_close_client=True) are closed on exit.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider._should_close_client = True + await provider.__aexit__(None, None, None) + mock_mem0_client.__aexit__.assert_awaited_once() + + async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None: + """Provided clients (_should_close_client=False) are NOT closed on exit.""" + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider._should_close_client is False + await provider.__aexit__(None, None, None) + mock_mem0_client.__aexit__.assert_not_awaited() + + async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None: + provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + async with provider as p: + assert p is provider diff --git a/python/packages/redis/tests/test_new_providers.py b/python/packages/redis/tests/test_new_providers.py new file mode 100644 index 0000000000..2a37ef5592 --- /dev/null +++ b/python/packages/redis/tests/test_new_providers.py @@ -0,0 +1,455 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for _RedisContextProvider and _RedisHistoryProvider.""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import AgentResponse, ChatMessage +from agent_framework._sessions import AgentSession, SessionContext +from agent_framework.exceptions import ServiceInitializationError + +from agent_framework_redis._context_provider import _RedisContextProvider +from agent_framework_redis._history_provider import _RedisHistoryProvider + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_index() -> AsyncMock: + idx = AsyncMock() + idx.create = AsyncMock() + idx.load = AsyncMock() + idx.query = AsyncMock(return_value=[]) + idx.exists = AsyncMock(return_value=False) + return idx + + +@pytest.fixture +def patch_index_from_dict(mock_index: AsyncMock): + with patch("agent_framework_redis._context_provider.AsyncSearchIndex") as mock_cls: + mock_cls.from_dict = MagicMock(return_value=mock_index) + + async def mock_from_existing(index_name: str, redis_url: str): # noqa: ARG001 + mock_existing = AsyncMock() + mock_existing.schema.to_dict = MagicMock( + side_effect=lambda: mock_cls.from_dict.call_args[0][0] if mock_cls.from_dict.call_args else {} + ) + return mock_existing + + mock_cls.from_existing = AsyncMock(side_effect=mock_from_existing) + yield mock_cls + + +@pytest.fixture +def mock_redis_client(): + client = MagicMock() + client.lrange = AsyncMock(return_value=[]) + client.llen = AsyncMock(return_value=0) + client.ltrim = AsyncMock() + client.delete = AsyncMock() + + mock_pipeline = AsyncMock() + mock_pipeline.rpush = AsyncMock() + mock_pipeline.execute = AsyncMock() + client.pipeline.return_value.__aenter__.return_value = mock_pipeline + + return client + + +# =========================================================================== +# _RedisContextProvider tests +# =========================================================================== + + +class TestRedisContextProviderInit: + def test_basic_construction(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + assert provider.source_id == "ctx" + assert provider.user_id == "u1" + assert provider.redis_url == "redis://localhost:6379" + assert provider.index_name == "context" + assert provider.prefix == "context" + + def test_custom_params(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider( + source_id="ctx", + redis_url="redis://custom:6380", + index_name="my_idx", + prefix="my_prefix", + application_id="app1", + agent_id="agent1", + user_id="user1", + context_prompt="Custom prompt", + ) + assert provider.redis_url == "redis://custom:6380" + assert provider.index_name == "my_idx" + assert provider.prefix == "my_prefix" + assert provider.application_id == "app1" + assert provider.agent_id == "agent1" + assert provider.context_prompt == "Custom prompt" + + def test_default_context_prompt(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + assert "Memories" in provider.context_prompt + + def test_invalid_vectorizer_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + from agent_framework.exceptions import AgentException + + with pytest.raises(AgentException, match="not a valid type"): + _RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type] + + +class TestRedisContextProviderValidateFilters: + def test_no_filters_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider(source_id="ctx") + with pytest.raises(ServiceInitializationError, match="(?i)at least one"): + provider._validate_filters() + + def test_any_single_filter_ok(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + for kwargs in [{"user_id": "u"}, {"agent_id": "a"}, {"application_id": "app"}]: + provider = _RedisContextProvider(source_id="ctx", **kwargs) + provider._validate_filters() # should not raise + + +class TestRedisContextProviderSchema: + def test_schema_has_expected_fields(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + schema = provider.schema_dict + field_names = [f["name"] for f in schema["fields"]] + for expected in ("role", "content", "conversation_id", "message_id", "application_id", "agent_id", "user_id"): + assert expected in field_names + assert schema["index"]["name"] == "context" + assert schema["index"]["prefix"] == "context" + + def test_schema_no_vector_without_vectorizer(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + field_types = [f["type"] for f in provider.schema_dict["fields"]] + assert "vector" not in field_types + + +class TestRedisContextProviderBeforeRun: + async def test_search_results_added_to_context( + self, + mock_index: AsyncMock, + patch_index_from_dict: MagicMock, # noqa: ARG002 + ): + mock_index.query = AsyncMock(return_value=[{"content": "Memory A"}, {"content": "Memory B"}]) + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["test query"])], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + assert "ctx" in ctx.context_messages + msgs = ctx.context_messages["ctx"] + assert len(msgs) == 1 + assert "Memory A" in msgs[0].text + assert "Memory B" in msgs[0].text + + async def test_empty_input_no_search( + self, + mock_index: AsyncMock, + patch_index_from_dict: MagicMock, # noqa: ARG002 + ): + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=[" "])], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_index.query.assert_not_called() + assert "ctx" not in ctx.context_messages + + async def test_empty_results_no_messages( + self, + mock_index: AsyncMock, + patch_index_from_dict: MagicMock, # noqa: ARG002 + ): + mock_index.query = AsyncMock(return_value=[]) + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["hello"])], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + assert "ctx" not in ctx.context_messages + + +class TestRedisContextProviderAfterRun: + async def test_stores_messages( + self, + mock_index: AsyncMock, + patch_index_from_dict: MagicMock, # noqa: ARG002 + ): + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + session = AgentSession(session_id="test-session") + response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["response text"])]) + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["user input"])], session_id="s1") + ctx._response = response + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_index.load.assert_called_once() + loaded = mock_index.load.call_args[0][0] + assert len(loaded) == 2 + roles = {d["role"] for d in loaded} + assert roles == {"user", "assistant"} + + async def test_skips_empty_conversations( + self, + mock_index: AsyncMock, + patch_index_from_dict: MagicMock, # noqa: ARG002 + ): + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=[" "])], session_id="s1") + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_index.load.assert_not_called() + + async def test_stores_partition_fields( + self, + mock_index: AsyncMock, + patch_index_from_dict: MagicMock, # noqa: ARG002 + ): + provider = _RedisContextProvider(source_id="ctx", application_id="app", agent_id="ag", user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + loaded = mock_index.load.call_args[0][0] + doc = loaded[0] + assert doc["application_id"] == "app" + assert doc["agent_id"] == "ag" + assert doc["user_id"] == "u1" + assert doc["conversation_id"] == "s1" + + +class TestRedisContextProviderContextManager: + async def test_aenter_returns_self(self, patch_index_from_dict: MagicMock): # noqa: ARG002 + provider = _RedisContextProvider(source_id="ctx", user_id="u1") + async with provider as p: + assert p is provider + + +# =========================================================================== +# _RedisHistoryProvider tests +# =========================================================================== + + +class TestRedisHistoryProviderInit: + def test_basic_construction(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("memory", redis_url="redis://localhost:6379") + + assert provider.source_id == "memory" + assert provider.key_prefix == "chat_messages" + assert provider.max_messages is None + assert provider.load_messages is True + assert provider.store_responses is True + assert provider.store_inputs is True + + def test_custom_params(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider( + "mem", + redis_url="redis://localhost:6379", + key_prefix="custom", + max_messages=50, + load_messages=False, + store_responses=False, + store_inputs=False, + ) + + assert provider.key_prefix == "custom" + assert provider.max_messages == 50 + assert provider.load_messages is False + assert provider.store_responses is False + assert provider.store_inputs is False + + def test_no_redis_url_or_credential_raises(self): + with pytest.raises(ValueError, match="Either redis_url or credential_provider must be provided"): + _RedisHistoryProvider("mem") + + def test_both_url_and_credential_raises(self): + mock_cred = MagicMock() + with pytest.raises(ValueError, match="mutually exclusive"): + _RedisHistoryProvider( + "mem", + redis_url="redis://localhost:6379", + credential_provider=mock_cred, + host="myhost", + ) + + def test_credential_provider_without_host_raises(self): + mock_cred = MagicMock() + with pytest.raises(ValueError, match="host is required"): + _RedisHistoryProvider("mem", credential_provider=mock_cred) + + def test_credential_provider_with_host(self): + mock_cred = MagicMock() + with patch("agent_framework_redis._history_provider.redis.Redis") as mock_redis_cls: + mock_redis_cls.return_value = MagicMock() + provider = _RedisHistoryProvider("mem", credential_provider=mock_cred, host="myhost") + + mock_redis_cls.assert_called_once_with( + host="myhost", + port=6380, + ssl=True, + username=None, + credential_provider=mock_cred, + decode_responses=True, + ) + assert provider.redis_url is None + + +class TestRedisHistoryProviderRedisKey: + def test_key_format(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379", key_prefix="msgs") + + assert provider._redis_key("session-123") == "msgs:session-123" + assert provider._redis_key(None) == "msgs:default" + + +class TestRedisHistoryProviderGetMessages: + async def test_returns_deserialized_messages(self, mock_redis_client: MagicMock): + msg1 = ChatMessage(role="user", contents=["Hello"]) + msg2 = ChatMessage(role="assistant", contents=["Hi!"]) + mock_redis_client.lrange = AsyncMock(return_value=[json.dumps(msg1.to_dict()), json.dumps(msg2.to_dict())]) + + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + messages = await provider.get_messages("s1") + assert len(messages) == 2 + assert messages[0].role == "user" + assert messages[0].text == "Hello" + assert messages[1].role == "assistant" + assert messages[1].text == "Hi!" + + async def test_empty_returns_empty(self, mock_redis_client: MagicMock): + mock_redis_client.lrange = AsyncMock(return_value=[]) + + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + messages = await provider.get_messages("s1") + assert messages == [] + + +class TestRedisHistoryProviderSaveMessages: + async def test_saves_serialized_messages(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + msgs = [ChatMessage(role="user", contents=["Hello"]), ChatMessage(role="assistant", contents=["Hi"])] + await provider.save_messages("s1", msgs) + + pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value + assert pipeline.rpush.call_count == 2 + pipeline.execute.assert_called_once() + + async def test_empty_messages_noop(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + await provider.save_messages("s1", []) + mock_redis_client.pipeline.assert_not_called() + + async def test_max_messages_trimming(self, mock_redis_client: MagicMock): + mock_redis_client.llen = AsyncMock(return_value=15) + + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10) + + await provider.save_messages("s1", [ChatMessage(role="user", contents=["msg"])]) + + mock_redis_client.ltrim.assert_called_once_with("chat_messages:s1", -10, -1) + + async def test_no_trim_when_under_limit(self, mock_redis_client: MagicMock): + mock_redis_client.llen = AsyncMock(return_value=3) + + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10) + + await provider.save_messages("s1", [ChatMessage(role="user", contents=["msg"])]) + + mock_redis_client.ltrim.assert_not_called() + + +class TestRedisHistoryProviderClear: + async def test_clear_calls_delete(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + await provider.clear("session-1") + mock_redis_client.delete.assert_called_once_with("chat_messages:session-1") + + +class TestRedisHistoryProviderBeforeAfterRun: + """Test before_run/after_run integration via BaseHistoryProvider defaults.""" + + async def test_before_run_loads_history(self, mock_redis_client: MagicMock): + msg = ChatMessage(role="user", contents=["old msg"]) + mock_redis_client.lrange = AsyncMock(return_value=[json.dumps(msg.to_dict())]) + + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + session = AgentSession(session_id="test") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["new msg"])], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + assert "mem" in ctx.context_messages + assert len(ctx.context_messages["mem"]) == 1 + assert ctx.context_messages["mem"][0].text == "old msg" + + async def test_after_run_stores_input_and_response(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + + session = AgentSession(session_id="test") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["hi"])], session_id="s1") + ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hello"])]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value + assert pipeline.rpush.call_count == 2 + pipeline.execute.assert_called_once() + + async def test_after_run_skips_when_no_messages(self, mock_redis_client: MagicMock): + with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: + mock_from_url.return_value = mock_redis_client + provider = _RedisHistoryProvider( + "mem", redis_url="redis://localhost:6379", store_inputs=False, store_responses=False + ) + + session = AgentSession(session_id="test") + ctx = SessionContext(input_messages=[ChatMessage(role="user", contents=["hi"])], session_id="s1") + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + mock_redis_client.pipeline.assert_not_called() From 7cf583cb32a486522c0d2950825cf36b73280e8a Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 9 Feb 2026 17:00:05 +0100 Subject: [PATCH 05/12] fix: address PR review comments and CI failures - Move module docstring before imports in _sessions.py (review comment) - Import TYPE_CHECKING unconditionally in Redis _context_provider.py (NameError on Python <3.12) - Fix Mem0 test_init_auto_creates_client_when_none to patch at class level --- .../core/agent_framework/_sessions.py | 20 +++++++++---------- .../tests/test_mem0_new_context_provider.py | 8 +++++--- .../_context_provider.py | 4 ++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 972880878c..b916d61670 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -1,5 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. +"""Unified context management types for the agent framework. + +This module provides the core types for the context provider pipeline: +- SessionContext: Per-invocation state passed through providers +- BaseContextProvider: Base class for context providers (renamed to ContextProvider in PR2) +- BaseHistoryProvider: Base class for history storage providers (renamed to HistoryProvider in PR2) +- AgentSession: Lightweight session state container +- InMemoryHistoryProvider: Built-in in-memory history provider +""" + from __future__ import annotations import uuid @@ -13,16 +23,6 @@ if TYPE_CHECKING: from ._agents import SupportsAgentRun -"""Unified context management types for the agent framework. - -This module provides the core types for the context provider pipeline: -- SessionContext: Per-invocation state passed through providers -- BaseContextProvider: Base class for context providers (renamed to ContextProvider in PR2) -- BaseHistoryProvider: Base class for history storage providers (renamed to HistoryProvider in PR2) -- AgentSession: Lightweight session state container -- InMemoryHistoryProvider: Built-in in-memory history provider -""" - # Registry of known types for state deserialization _STATE_TYPE_REGISTRY: dict[str, type] = {} diff --git a/python/packages/mem0/tests/test_mem0_new_context_provider.py b/python/packages/mem0/tests/test_mem0_new_context_provider.py index 7fb8585ef4..51d814efaf 100644 --- a/python/packages/mem0/tests/test_mem0_new_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_new_context_provider.py @@ -57,10 +57,12 @@ def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None: def test_init_auto_creates_client_when_none(self) -> None: """When no client is provided, a default AsyncMemoryClient is created and flagged for closing.""" - with patch("agent_framework_mem0._context_provider.AsyncMemoryClient") as mock_client_cls: + with ( + patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init, + patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None), + ): provider = _Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1") - mock_client_cls.assert_called_once_with(api_key="test-key") - assert provider.mem0_client is mock_client_cls.return_value + mock_init.assert_called_once_with(api_key="test-key") assert provider._should_close_client is True def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None: diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py index b5fe36384d..c0532b8051 100644 --- a/python/packages/redis/agent_framework_redis/_context_provider.py +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -13,7 +13,7 @@ import sys from functools import reduce from operator import and_ -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np from agent_framework import ChatMessage @@ -35,7 +35,7 @@ from typing_extensions import Self # pragma: no cover if sys.version_info >= (3, 12): - from typing import TYPE_CHECKING, override # type: ignore # pragma: no cover + from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover From 269521df3699eed9649bf4c07e42c0d702146e99 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 08:42:16 +0100 Subject: [PATCH 06/12] feat: add source attribution to extend_messages Set attribution marker in additional_properties for each message added via extend_messages(), matching the tool attribution pattern. Uses setdefault to preserve any existing attribution. --- python/packages/core/agent_framework/_sessions.py | 5 ++++- python/packages/core/tests/core/test_sessions.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index b916d61670..877b11fb27 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -134,12 +134,15 @@ def extend_messages(self, source_id: str, messages: Sequence[ChatMessage]) -> No """Add context messages from a specific source. Messages are stored keyed by source_id, maintaining insertion order - based on provider execution order. + based on provider execution order. Each message gets an ``attribution`` + marker in ``additional_properties`` for downstream filtering. Args: source_id: The provider source_id adding these messages. messages: The messages to add. """ + for message in messages: + message.additional_properties.setdefault("attribution", source_id) if source_id not in self.context_messages: self.context_messages[source_id] = [] self.context_messages[source_id].extend(messages) diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 4673032f36..7596ea1a42 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -53,6 +53,18 @@ def test_extend_messages_preserves_source_order(self) -> None: ctx.extend_messages("c", [ChatMessage(role="user", contents=["c"])]) assert list(ctx.context_messages.keys()) == ["a", "b", "c"] + def test_extend_messages_sets_attribution(self) -> None: + ctx = SessionContext(input_messages=[]) + msg = ChatMessage(role="system", contents=["context"]) + ctx.extend_messages("rag", [msg]) + assert msg.additional_properties["attribution"] == "rag" + + def test_extend_messages_does_not_overwrite_existing_attribution(self) -> None: + ctx = SessionContext(input_messages=[]) + msg = ChatMessage(role="system", contents=["context"], additional_properties={"attribution": "custom"}) + ctx.extend_messages("rag", [msg]) + assert msg.additional_properties["attribution"] == "custom" + def test_extend_instructions_string(self) -> None: ctx = SessionContext(input_messages=[]) ctx.extend_instructions("sys", "Be helpful") From d292f7788d54e8dc784889530453f7331e7f11b2 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 09:01:50 +0100 Subject: [PATCH 07/12] refactor: make attribution value a dict with source_id key --- python/packages/core/agent_framework/_sessions.py | 2 +- python/packages/core/tests/core/test_sessions.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 877b11fb27..afa43757a3 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -142,7 +142,7 @@ def extend_messages(self, source_id: str, messages: Sequence[ChatMessage]) -> No messages: The messages to add. """ for message in messages: - message.additional_properties.setdefault("attribution", source_id) + message.additional_properties.setdefault("attribution", {"source_id": source_id}) if source_id not in self.context_messages: self.context_messages[source_id] = [] self.context_messages[source_id].extend(messages) diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 7596ea1a42..d53f1c17fa 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -57,13 +57,13 @@ def test_extend_messages_sets_attribution(self) -> None: ctx = SessionContext(input_messages=[]) msg = ChatMessage(role="system", contents=["context"]) ctx.extend_messages("rag", [msg]) - assert msg.additional_properties["attribution"] == "rag" + assert msg.additional_properties["attribution"] == {"source_id": "rag"} def test_extend_messages_does_not_overwrite_existing_attribution(self) -> None: ctx = SessionContext(input_messages=[]) - msg = ChatMessage(role="system", contents=["context"], additional_properties={"attribution": "custom"}) + msg = ChatMessage(role="system", contents=["context"], additional_properties={"attribution": {"source_id": "custom"}}) ctx.extend_messages("rag", [msg]) - assert msg.additional_properties["attribution"] == "custom" + assert msg.additional_properties["attribution"] == {"source_id": "custom"} def test_extend_instructions_string(self) -> None: ctx = SessionContext(input_messages=[]) From 02d12979db71ce3e0262b3d7d4a58e2bba6211e1 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 09:36:35 +0100 Subject: [PATCH 08/12] add attribution and use sets for filters --- python/packages/core/agent_framework/_sessions.py | 12 ++++++------ python/packages/core/tests/core/test_sessions.py | 10 ++++++---- .../redis/agent_framework_redis/_history_provider.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index afa43757a3..834a0fa198 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -86,7 +86,7 @@ class SessionContext: Maintains insertion order (provider execution order). instructions: Additional instructions added by providers. tools: Additional tools added by providers. - response: After invocation, contains the full AgentResponse (read-only property). + response: After invocation, contains the full AgentResponse, should not be changed. options: Options passed to agent.run() - read-only, for reflection only. metadata: Shared metadata dictionary for cross-provider communication. """ @@ -175,8 +175,8 @@ def extend_tools(self, source_id: str, tools: Sequence[ToolProtocol]) -> None: def get_messages( self, *, - sources: Sequence[str] | None = None, - exclude_sources: Sequence[str] | None = None, + sources: set[str] | None = None, + exclude_sources: set[str] | None = None, include_input: bool = False, include_response: bool = False, ) -> list[ChatMessage]: @@ -306,7 +306,7 @@ def __init__( store_responses: bool = True, store_inputs: bool = True, store_context_messages: bool = False, - store_context_from: Sequence[str] | None = None, + store_context_from: set[str] | None = None, ): """Initialize the history provider. @@ -323,7 +323,7 @@ def __init__( self.store_responses = store_responses self.store_inputs = store_inputs self.store_context_messages = store_context_messages - self.store_context_from = list(store_context_from) if store_context_from else None + self.store_context_from = store_context_from @abstractmethod async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[ChatMessage]: @@ -355,7 +355,7 @@ def _get_context_messages_to_store(self, context: SessionContext) -> list[ChatMe return [] if self.store_context_from is not None: return context.get_messages(sources=self.store_context_from) - return context.get_messages(exclude_sources=[self.source_id]) + return context.get_messages(exclude_sources={self.source_id}) async def before_run( self, diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index d53f1c17fa..105bf8e620 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -61,7 +61,9 @@ def test_extend_messages_sets_attribution(self) -> None: def test_extend_messages_does_not_overwrite_existing_attribution(self) -> None: ctx = SessionContext(input_messages=[]) - msg = ChatMessage(role="system", contents=["context"], additional_properties={"attribution": {"source_id": "custom"}}) + msg = ChatMessage( + role="system", contents=["context"], additional_properties={"attribution": {"source_id": "custom"}} + ) ctx.extend_messages("rag", [msg]) assert msg.additional_properties["attribution"] == {"source_id": "custom"} @@ -187,12 +189,12 @@ def test_custom_flags(self) -> None: load_messages=False, store_inputs=False, store_context_messages=True, - store_context_from=["rag"], + store_context_from={"rag"}, ) assert provider.load_messages is False assert provider.store_inputs is False assert provider.store_context_messages is True - assert provider.store_context_from == ["rag"] + assert provider.store_context_from == {"rag"} async def test_before_run_loads_messages(self) -> None: msgs = [ChatMessage(role="user", contents=["history"])] @@ -255,7 +257,7 @@ async def test_after_run_stores_context_from_specific_sources(self) -> None: from agent_framework import AgentResponse provider = ConcreteHistoryProvider( - "audit", load_messages=False, store_context_messages=True, store_context_from=["rag"] + "audit", load_messages=False, store_context_messages=True, store_context_from={"rag"} ) ctx = SessionContext(session_id="s1", input_messages=[]) ctx.extend_messages("rag", [ChatMessage(role="system", contents=["rag-context"])]) diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index 0c93054526..e2f1192815 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -50,7 +50,7 @@ def __init__( store_responses: bool = True, store_inputs: bool = True, store_context_messages: bool = False, - store_context_from: Sequence[str] | None = None, + store_context_from: set[str] | None = None, ) -> None: """Initialize the Redis history provider. From a3b2440f34c26a98fe216f863f533cf379e443ea Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 11:56:24 +0100 Subject: [PATCH 09/12] Add source_type to message attribution and copy messages in extend_messages - SessionContext.extend_messages now accepts source as str or object with source_id attribute; when an object is passed, its class name is recorded as source_type in the attribution dict - Messages are shallow-copied before attribution is added so callers' original objects are never mutated - Filter framework-internal keys (attribution) from A2A wire metadata to prevent leaking internal state over the wire --- .../a2a/agent_framework_a2a/_agent.py | 8 +++-- .../core/agent_framework/_sessions.py | 33 ++++++++++++++----- .../packages/core/tests/core/test_sessions.py | 29 ++++++++++++++-- 3 files changed, 58 insertions(+), 12 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 335e4c1f68..3b1b95af21 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -7,7 +7,7 @@ import re import uuid from collections.abc import AsyncIterable, Awaitable, Sequence -from typing import Any, Final, Literal, cast, overload +from typing import Any, Final, Literal, overload import httpx from a2a.client import Client, ClientConfig, ClientFactory, minimal_agent_card @@ -401,11 +401,15 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage: case _: raise ValueError(f"Unknown content type: {content.type}") + # Exclude framework-internal keys (e.g. attribution) from wire metadata + internal_keys = {"attribution"} + metadata = {k: v for k, v in message.additional_properties.items() if k not in internal_keys} or None + return A2AMessage( role=A2ARole("user"), parts=parts, message_id=message.message_id or uuid.uuid4().hex, - metadata=cast(dict[str, Any], message.additional_properties), + metadata=metadata, ) def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]: diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 834a0fa198..96cfde5303 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -12,6 +12,7 @@ from __future__ import annotations +import copy import uuid from abc import abstractmethod from collections.abc import Sequence @@ -130,22 +131,38 @@ def response(self) -> AgentResponse | None: """The agent's response. Set by the framework after invocation, read-only for providers.""" return self._response - def extend_messages(self, source_id: str, messages: Sequence[ChatMessage]) -> None: + def extend_messages(self, source: str | object, messages: Sequence[ChatMessage]) -> None: """Add context messages from a specific source. - Messages are stored keyed by source_id, maintaining insertion order - based on provider execution order. Each message gets an ``attribution`` - marker in ``additional_properties`` for downstream filtering. + Messages are copied before attribution is added, so the caller's + original message objects are never mutated. The copies are stored + keyed by source_id, maintaining insertion order based on provider + execution order. Each message gets an ``attribution`` marker in + ``additional_properties`` for downstream filtering. Args: - source_id: The provider source_id adding these messages. + source: Either a plain ``source_id`` string, or an object with a + ``source_id`` attribute (e.g. a context provider). When an + object is passed, its class name is recorded as + ``source_type`` in the attribution. messages: The messages to add. """ + if isinstance(source, str): + source_id = source + attribution: dict[str, str] = {"source_id": source_id} + else: + source_id = source.source_id # type: ignore[union-attr] + attribution = {"source_id": source_id, "source_type": type(source).__name__} + + copied: list[ChatMessage] = [] for message in messages: - message.additional_properties.setdefault("attribution", {"source_id": source_id}) + msg_copy = copy.copy(message) + msg_copy.additional_properties = dict(message.additional_properties) + msg_copy.additional_properties.setdefault("attribution", attribution) + copied.append(msg_copy) if source_id not in self.context_messages: self.context_messages[source_id] = [] - self.context_messages[source_id].extend(messages) + self.context_messages[source_id].extend(copied) def extend_instructions(self, source_id: str, instructions: str | Sequence[str]) -> None: """Add instructions to be prepended to the conversation. @@ -367,7 +384,7 @@ async def before_run( ) -> None: """Load history into context. Skipped by the agent when load_messages=False.""" history = await self.get_messages(context.session_id, state=state) - context.extend_messages(self.source_id, history) + context.extend_messages(self, history) async def after_run( self, diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 105bf8e620..fb3c9989e4 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -57,7 +57,10 @@ def test_extend_messages_sets_attribution(self) -> None: ctx = SessionContext(input_messages=[]) msg = ChatMessage(role="system", contents=["context"]) ctx.extend_messages("rag", [msg]) - assert msg.additional_properties["attribution"] == {"source_id": "rag"} + stored = ctx.context_messages["rag"][0] + assert stored.additional_properties["attribution"] == {"source_id": "rag"} + # Original message is not mutated + assert "attribution" not in msg.additional_properties def test_extend_messages_does_not_overwrite_existing_attribution(self) -> None: ctx = SessionContext(input_messages=[]) @@ -65,7 +68,29 @@ def test_extend_messages_does_not_overwrite_existing_attribution(self) -> None: role="system", contents=["context"], additional_properties={"attribution": {"source_id": "custom"}} ) ctx.extend_messages("rag", [msg]) - assert msg.additional_properties["attribution"] == {"source_id": "custom"} + stored = ctx.context_messages["rag"][0] + assert stored.additional_properties["attribution"] == {"source_id": "custom"} + + def test_extend_messages_copies_messages(self) -> None: + ctx = SessionContext(input_messages=[]) + msg = ChatMessage(role="user", contents=["hello"]) + ctx.extend_messages("src", [msg]) + stored = ctx.context_messages["src"][0] + assert stored is not msg + assert stored.text == "hello" + # Mutating stored copy does not affect original + stored.additional_properties["extra"] = True + assert "extra" not in msg.additional_properties + + def test_extend_messages_sender_sets_source_type(self) -> None: + class MyProvider: + source_id = "rag" + + ctx = SessionContext(input_messages=[]) + msg = ChatMessage(role="system", contents=["ctx"]) + ctx.extend_messages(MyProvider(), [msg]) + stored = ctx.context_messages["rag"][0] + assert stored.additional_properties["attribution"] == {"source_id": "rag", "source_type": "MyProvider"} def test_extend_instructions_string(self) -> None: ctx = SessionContext(input_messages=[]) From 7ccd7a451046d1990ee480d60a1d986e882a7403 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 12:33:14 +0100 Subject: [PATCH 10/12] fix: correct mypy type: ignore comment from union-attr to attr-defined --- python/packages/core/agent_framework/_sessions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 96cfde5303..ee21a879c9 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -151,7 +151,7 @@ def extend_messages(self, source: str | object, messages: Sequence[ChatMessage]) source_id = source attribution: dict[str, str] = {"source_id": source_id} else: - source_id = source.source_id # type: ignore[union-attr] + source_id = source.source_id # type: ignore[attr-defined] attribution = {"source_id": source_id, "source_type": type(source).__name__} copied: list[ChatMessage] = [] From fad81205134ce083f40adcfa81eabaa5d8fb271b Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 19:56:17 +0100 Subject: [PATCH 11/12] set attribution to _attribution --- python/packages/a2a/agent_framework_a2a/_agent.py | 2 +- python/packages/core/agent_framework/_sessions.py | 2 +- python/packages/core/tests/core/test_sessions.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 3b1b95af21..f34d8f08b9 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -402,7 +402,7 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage: raise ValueError(f"Unknown content type: {content.type}") # Exclude framework-internal keys (e.g. attribution) from wire metadata - internal_keys = {"attribution"} + internal_keys = {"_attribution"} metadata = {k: v for k, v in message.additional_properties.items() if k not in internal_keys} or None return A2AMessage( diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index ee21a879c9..2ef60581d2 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -158,7 +158,7 @@ def extend_messages(self, source: str | object, messages: Sequence[ChatMessage]) for message in messages: msg_copy = copy.copy(message) msg_copy.additional_properties = dict(message.additional_properties) - msg_copy.additional_properties.setdefault("attribution", attribution) + msg_copy.additional_properties.setdefault("_attribution", attribution) copied.append(msg_copy) if source_id not in self.context_messages: self.context_messages[source_id] = [] diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index fb3c9989e4..507ca37979 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -58,18 +58,18 @@ def test_extend_messages_sets_attribution(self) -> None: msg = ChatMessage(role="system", contents=["context"]) ctx.extend_messages("rag", [msg]) stored = ctx.context_messages["rag"][0] - assert stored.additional_properties["attribution"] == {"source_id": "rag"} + assert stored.additional_properties["_attribution"] == {"source_id": "rag"} # Original message is not mutated - assert "attribution" not in msg.additional_properties + assert "_attribution" not in msg.additional_properties def test_extend_messages_does_not_overwrite_existing_attribution(self) -> None: ctx = SessionContext(input_messages=[]) msg = ChatMessage( - role="system", contents=["context"], additional_properties={"attribution": {"source_id": "custom"}} + role="system", contents=["context"], additional_properties={"_attribution": {"source_id": "custom"}} ) ctx.extend_messages("rag", [msg]) stored = ctx.context_messages["rag"][0] - assert stored.additional_properties["attribution"] == {"source_id": "custom"} + assert stored.additional_properties["_attribution"] == {"source_id": "custom"} def test_extend_messages_copies_messages(self) -> None: ctx = SessionContext(input_messages=[]) @@ -90,7 +90,7 @@ class MyProvider: msg = ChatMessage(role="system", contents=["ctx"]) ctx.extend_messages(MyProvider(), [msg]) stored = ctx.context_messages["rag"][0] - assert stored.additional_properties["attribution"] == {"source_id": "rag", "source_type": "MyProvider"} + assert stored.additional_properties["_attribution"] == {"source_id": "rag", "source_type": "MyProvider"} def test_extend_instructions_string(self) -> None: ctx = SessionContext(input_messages=[]) From 388b7b3fcc39188894a89c2b44da0b60f4b27c0a Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Feb 2026 20:43:08 +0100 Subject: [PATCH 12/12] adjusted naming of bools --- .../core/agent_framework/_sessions.py | 29 ++++++++++--------- .../packages/core/tests/core/test_sessions.py | 4 +-- .../_history_provider.py | 6 ++-- .../redis/tests/test_new_providers.py | 8 ++--- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 2ef60581d2..770c35a9a9 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -24,6 +24,16 @@ if TYPE_CHECKING: from ._agents import SupportsAgentRun + +__all__ = [ + "AgentSession", + "BaseContextProvider", + "BaseHistoryProvider", + "InMemoryHistoryProvider", + "SessionContext", +] + + # Registry of known types for state deserialization _STATE_TYPE_REGISTRY: dict[str, type] = {} @@ -309,10 +319,10 @@ class BaseHistoryProvider(BaseContextProvider): Attributes: load_messages: Whether to load messages before invocation (default True). When False, the agent skips calling ``before_run`` entirely. - store_responses: Whether to store response messages (default True). store_inputs: Whether to store input messages (default True). store_context_messages: Whether to store context from other providers (default False). store_context_from: If set, only store context from these source_ids. + store_outputs: Whether to store response messages (default True). """ def __init__( @@ -320,27 +330,27 @@ def __init__( source_id: str, *, load_messages: bool = True, - store_responses: bool = True, store_inputs: bool = True, store_context_messages: bool = False, store_context_from: set[str] | None = None, + store_outputs: bool = True, ): """Initialize the history provider. Args: source_id: Unique identifier for this provider instance. load_messages: Whether to load messages before invocation. - store_responses: Whether to store response messages. store_inputs: Whether to store input messages. store_context_messages: Whether to store context from other providers. store_context_from: If set, only store context from these source_ids. + store_outputs: Whether to store response messages. """ super().__init__(source_id) self.load_messages = load_messages - self.store_responses = store_responses self.store_inputs = store_inputs self.store_context_messages = store_context_messages self.store_context_from = store_context_from + self.store_outputs = store_outputs @abstractmethod async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[ChatMessage]: @@ -399,7 +409,7 @@ async def after_run( messages_to_store.extend(self._get_context_messages_to_store(context)) if self.store_inputs: messages_to_store.extend(context.input_messages) - if self.store_responses and context.response and context.response.messages: + if self.store_outputs and context.response and context.response.messages: messages_to_store.extend(context.response.messages) if messages_to_store: await self.save_messages(context.session_id, messages_to_store, state=state) @@ -510,12 +520,3 @@ async def save_messages( my_state = state.setdefault(self.source_id, {}) existing = my_state.get("messages", []) my_state["messages"] = [*existing, *messages] - - -__all__ = [ - "AgentSession", - "BaseContextProvider", - "BaseHistoryProvider", - "InMemoryHistoryProvider", - "SessionContext", -] diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 507ca37979..8b64a8ca72 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -203,7 +203,7 @@ class TestHistoryProviderBase: def test_default_flags(self) -> None: provider = ConcreteHistoryProvider("mem") assert provider.load_messages is True - assert provider.store_responses is True + assert provider.store_outputs is True assert provider.store_inputs is True assert provider.store_context_messages is False assert provider.store_context_from is None @@ -257,7 +257,7 @@ async def test_after_run_skips_inputs_when_disabled(self) -> None: async def test_after_run_skips_responses_when_disabled(self) -> None: from agent_framework import AgentResponse - provider = ConcreteHistoryProvider("mem", store_responses=False) + provider = ConcreteHistoryProvider("mem", store_outputs=False) ctx = SessionContext(session_id="s1", input_messages=[ChatMessage(role="user", contents=["hello"])]) ctx._response = AgentResponse(messages=[ChatMessage(role="assistant", contents=["hi"])]) await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index e2f1192815..a4954058d9 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -47,7 +47,7 @@ def __init__( key_prefix: str = "chat_messages", max_messages: int | None = None, load_messages: bool = True, - store_responses: bool = True, + store_outputs: bool = True, store_inputs: bool = True, store_context_messages: bool = False, store_context_from: set[str] | None = None, @@ -69,7 +69,7 @@ def __init__( When exceeded, oldest messages are automatically trimmed. None means unlimited storage. load_messages: Whether to load messages before invocation. - store_responses: Whether to store response messages. + store_outputs: Whether to store response messages. store_inputs: Whether to store input messages. store_context_messages: Whether to store context from other providers. store_context_from: If set, only store context from these source_ids. @@ -82,7 +82,7 @@ def __init__( super().__init__( source_id, load_messages=load_messages, - store_responses=store_responses, + store_outputs=store_outputs, store_inputs=store_inputs, store_context_messages=store_context_messages, store_context_from=store_context_from, diff --git a/python/packages/redis/tests/test_new_providers.py b/python/packages/redis/tests/test_new_providers.py index 2a37ef5592..0a350c0bea 100644 --- a/python/packages/redis/tests/test_new_providers.py +++ b/python/packages/redis/tests/test_new_providers.py @@ -255,7 +255,7 @@ def test_basic_construction(self, mock_redis_client: MagicMock): assert provider.key_prefix == "chat_messages" assert provider.max_messages is None assert provider.load_messages is True - assert provider.store_responses is True + assert provider.store_outputs is True assert provider.store_inputs is True def test_custom_params(self, mock_redis_client: MagicMock): @@ -267,14 +267,14 @@ def test_custom_params(self, mock_redis_client: MagicMock): key_prefix="custom", max_messages=50, load_messages=False, - store_responses=False, + store_outputs=False, store_inputs=False, ) assert provider.key_prefix == "custom" assert provider.max_messages == 50 assert provider.load_messages is False - assert provider.store_responses is False + assert provider.store_outputs is False assert provider.store_inputs is False def test_no_redis_url_or_credential_raises(self): @@ -444,7 +444,7 @@ async def test_after_run_skips_when_no_messages(self, mock_redis_client: MagicMo with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client provider = _RedisHistoryProvider( - "mem", redis_url="redis://localhost:6379", store_inputs=False, store_responses=False + "mem", redis_url="redis://localhost:6379", store_inputs=False, store_outputs=False ) session = AgentSession(session_id="test")