Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def before_run(
if not input_text.strip():
return

filters = self._build_filters(session_id=context.session_id)
filters = self._build_filters()

# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
# AsyncMemoryClient (Platform) expects them in a filters dict
Expand Down Expand Up @@ -164,7 +164,6 @@ def get_role_value(role: Any) -> str:
messages=messages,
user_id=self.user_id,
agent_id=self.agent_id,
run_id=context.session_id,
metadata={"application_id": self.application_id},
)

Expand All @@ -177,15 +176,13 @@ def _validate_filters(self) -> None:
"At least one of the filters: agent_id, user_id, or application_id is required."
)

def _build_filters(self, *, session_id: str | None = None) -> dict[str, Any]:
def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters."""
filters: dict[str, Any] = {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if session_id:
filters["run_id"] = session_id
if self.application_id:
filters["app_id"] = self.application_id
return filters
Expand Down
18 changes: 9 additions & 9 deletions python/packages/mem0/tests/test_mem0_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> N
{"role": "assistant", "content": "answer"},
]
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["run_id"] == "s1"
assert "run_id" not in call_kwargs

async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None:
"""Only stores user/assistant/system messages with text."""
Expand Down Expand Up @@ -298,8 +298,8 @@ async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None:

mock_mem0_client.add.assert_not_awaited()

async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None:
"""Uses session_id as run_id."""
async def test_no_run_id_in_storage(self, mock_mem0_client: AsyncMock) -> None:
"""run_id is not passed to mem0 add, so memories are not scoped to sessions."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="my-session")
Expand All @@ -309,7 +309,7 @@ async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> N
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]

assert mock_mem0_client.add.call_args.kwargs["run_id"] == "my-session"
assert "run_id" not in mock_mem0_client.add.call_args.kwargs

async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None:
"""Raises ServiceInitializationError when no filters."""
Expand Down Expand Up @@ -381,10 +381,9 @@ def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
agent_id="a1",
application_id="app1",
)
assert provider._build_filters(session_id="sess1") == {
assert provider._build_filters() == {
"user_id": "u1",
"agent_id": "a1",
"run_id": "sess1",
"app_id": "app1",
}

Expand All @@ -395,10 +394,11 @@ def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
assert "run_id" not in filters
assert "app_id" not in filters

def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None:
def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
"""run_id is excluded from search filters so memories work across sessions."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters(session_id="s99")
assert filters["run_id"] == "s99"
filters = provider._build_filters()
assert "run_id" not in filters

def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def before_run(
if not input_text.strip():
return

memories = await self._redis_search(text=input_text, session_id=context.session_id)
memories = await self._redis_search(text=input_text)
line_separated_memories = "\n".join(
str(memory.get("content", "")) for memory in memories if memory.get("content")
)
Expand Down Expand Up @@ -337,7 +337,7 @@ async def _redis_search(
filter_expression: Any | None = None,
return_fields: list[str] | None = None,
num_results: int = 10,
alpha: float = 0.7,
linear_alpha: float = 0.7,
) -> list[dict[str, Any]]:
"""Runs a text or hybrid vector-text search with optional filters."""
await self._ensure_index()
Expand Down Expand Up @@ -374,7 +374,7 @@ async def _redis_search(
vector_field_name=self.vector_field_name,
text_scorer=text_scorer,
filter_expression=combined_filter,
alpha=alpha,
linear_alpha=linear_alpha,
dtype=self.redis_vectorizer.dtype,
num_results=num_results,
return_fields=return_fields,
Expand Down
20 changes: 20 additions & 0 deletions python/packages/redis/tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ async def test_empty_input_no_search(
mock_index.query.assert_not_called()
assert "ctx" not in ctx.context_messages

async def test_before_run_searches_without_session_id(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
"""Verify that before_run performs cross-session retrieval (no session_id filter)."""
mock_index.query = AsyncMock(return_value=[{"content": "Memory"}])
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1")

with patch.object(provider, "_redis_search", wraps=provider._redis_search) as spy:
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]

spy.assert_called_once()
# session_id should not be passed to _redis_search (cross-session retrieval)
assert "session_id" not in spy.call_args.kwargs

async def test_empty_results_no_messages(
self,
mock_index: AsyncMock,
Expand Down
10 changes: 5 additions & 5 deletions python/samples/getting_started/sessions/mem0/mem0_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ async def main() -> None:
print(f"Agent: {result}\n")

# Mem0 processes and indexes memories asynchronously.
# Wait for memories to be indexed before querying in a new thread.
# Wait for memories to be indexed before querying in a new session.
# In production, consider implementing retry logic or using Mem0's
# eventual consistency handling instead of a fixed delay.
print("Waiting for memories to be processed...")
await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing

print("\nRequest within a new session:")
# Create a new session for the agent.
# The new session has no context of the previous conversation.
# The new session has no conversation history from the previous session.
session = agent.create_session()

# Since we have the mem0 component in the session, the agent should be able to
# retrieve the company report without asking for clarification, as it will
# be able to remember the user preferences from Mem0 component.
# Since we have the Mem0 context provider, the agent should be able to
# retrieve the company report without asking for clarification, as Mem0
# remembers user preferences across sessions.
query = "Please retrieve my company report"
print(f"User: {query}")
result = await agent.run(query, session=session)
Expand Down
116 changes: 30 additions & 86 deletions python/samples/getting_started/sessions/mem0/mem0_sessions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import uuid

from agent_framework import tool
from agent_framework.azure import AzureAIAgentClient
Expand All @@ -20,115 +19,57 @@ def get_user_preferences(user_id: str) -> str:
return preferences.get(user_id, "No specific preferences found")


async def example_global_thread_scope() -> None:
"""Example 1: Global thread_id scope (memories shared across all operations)."""
print("1. Global Thread Scope Example:")
async def example_cross_session_memory() -> None:
"""Example 1: Cross-session memory (memories shared across all sessions for a user)."""
print("1. Cross-Session Memory Example:")
print("-" * 40)

global_thread_id = str(uuid.uuid4())
user_id = "user123"

async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="GlobalMemoryAssistant",
name="MemoryAssistant",
instructions="You are an assistant that remembers user preferences across conversations.",
tools=get_user_preferences,
context_providers=[Mem0ContextProvider(
user_id=user_id,
thread_id=global_thread_id,
scope_to_per_operation_thread_id=False, # Share memories across all sessions
)],
) as global_agent,
context_providers=[Mem0ContextProvider(user_id=user_id)],
) as agent,
):
# Store some preferences in the global scope
# Store some preferences
query = "Remember that I prefer technical responses with code examples when discussing programming."
print(f"User: {query}")
result = await global_agent.run(query)
result = await agent.run(query)
print(f"Agent: {result}\n")

# Create a new session - but memories should still be accessible due to global scope
new_session = global_agent.create_session()
# Mem0 processes and indexes memories asynchronously.
print("Waiting for memories to be processed...")
await asyncio.sleep(12)

# Create a new session - memories should still be accessible
# because Mem0 scopes by user_id, not session
new_session = agent.create_session()
query = "What do you know about my preferences?"
print(f"User (new session): {query}")
result = await global_agent.run(query, session=new_session)
result = await agent.run(query, session=new_session)
print(f"Agent: {result}\n")


async def example_per_operation_thread_scope() -> None:
"""Example 2: Per-operation thread scope (memories isolated per session).

Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single session
throughout its lifetime. Use the same session object for all operations with that provider.
"""
print("2. Per-Operation Thread Scope Example:")
async def example_agent_scoped_memory() -> None:
"""Example 2: Agent-scoped memory (memories isolated per agent)."""
print("2. Agent-Scoped Memory Example:")
print("-" * 40)

user_id = "user123"

async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="ScopedMemoryAssistant",
instructions="You are an assistant with thread-scoped memory.",
tools=get_user_preferences,
context_providers=[Mem0ContextProvider(
user_id=user_id,
scope_to_per_operation_thread_id=True, # Isolate memories per session
)],
) as scoped_agent,
):
# Create a specific session for this scoped provider
dedicated_session = scoped_agent.create_session()

# Store some information in the dedicated session
query = "Remember that for this conversation, I'm working on a Python project about data analysis."
print(f"User (dedicated session): {query}")
result = await scoped_agent.run(query, session=dedicated_session)
print(f"Agent: {result}\n")

# Test memory retrieval in the same dedicated session
query = "What project am I working on?"
print(f"User (same dedicated session): {query}")
result = await scoped_agent.run(query, session=dedicated_session)
print(f"Agent: {result}\n")

# Store more information in the same session
query = "Also remember that I prefer using pandas and matplotlib for this project."
print(f"User (same dedicated session): {query}")
result = await scoped_agent.run(query, session=dedicated_session)
print(f"Agent: {result}\n")

# Test comprehensive memory retrieval
query = "What do you know about my current project and preferences?"
print(f"User (same dedicated session): {query}")
result = await scoped_agent.run(query, session=dedicated_session)
print(f"Agent: {result}\n")


async def example_multiple_agents() -> None:
"""Example 3: Multiple agents with different thread configurations."""
print("3. Multiple Agents with Different Thread Configurations:")
print("-" * 40)

agent_id_1 = "agent_personal"
agent_id_2 = "agent_work"

async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="PersonalAssistant",
instructions="You are a personal assistant that helps with personal tasks.",
context_providers=[Mem0ContextProvider(
agent_id=agent_id_1,
)],
context_providers=[Mem0ContextProvider(agent_id="agent_personal")],
) as personal_agent,
AzureAIAgentClient(credential=credential).as_agent(
name="WorkAssistant",
instructions="You are a work assistant that helps with professional tasks.",
context_providers=[Mem0ContextProvider(
agent_id=agent_id_2,
)],
context_providers=[Mem0ContextProvider(agent_id="agent_work")],
) as work_agent,
):
# Store personal information
Expand All @@ -143,7 +84,11 @@ async def example_multiple_agents() -> None:
result = await work_agent.run(query)
print(f"Work Agent: {result}\n")

# Test memory isolation
# Mem0 processes and indexes memories asynchronously.
print("Waiting for memories to be processed...")
await asyncio.sleep(12)

# Test memory isolation - each agent should only recall its own memories
query = "What do you know about my schedule?"
print(f"User to Personal Agent: {query}")
result = await personal_agent.run(query)
Expand All @@ -155,12 +100,11 @@ async def example_multiple_agents() -> None:


async def main() -> None:
"""Run all Mem0 thread management examples."""
print("=== Mem0 Thread Management Example ===\n")
"""Run all Mem0 session management examples."""
print("=== Mem0 Session Management Example ===\n")

await example_global_thread_scope()
await example_per_operation_thread_scope()
await example_multiple_agents()
await example_cross_session_memory()
await example_agent_scoped_memory()


if __name__ == "__main__":
Expand Down
9 changes: 3 additions & 6 deletions python/samples/getting_started/sessions/redis/redis_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ async def main() -> None:
cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"),
)
# The provider manages persistence and retrieval. application_id/agent_id/user_id
# scope data for multi-tenant separation; thread_id (set later) narrows to a
# specific conversation.
# scope data for multi-tenant separation.
provider = RedisContextProvider(
redis_url="redis://localhost:6379",
index_name="redis_basics",
Expand All @@ -138,16 +137,14 @@ async def main() -> None:
from agent_framework import AgentSession, SessionContext

session = AgentSession(session_id="runA")
context = SessionContext()
context.extend_messages("input", messages)
context = SessionContext(input_messages=messages)
state = session.state

# Store messages via after_run
await provider.after_run(agent=None, session=session, context=context, state=state)

# Retrieve relevant memories via before_run
query_context = SessionContext()
query_context.extend_messages("input", [Message("system", ["B: Assistant Message"])])
query_context = SessionContext(input_messages=[Message("system", ["B: Assistant Message"])])
await provider.before_run(agent=None, session=session, context=query_context, state=state)

# Inspect retrieved memories that would be injected into instructions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ async def main() -> None:
cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"),
)

session_id = "test_session"

provider = RedisContextProvider(
redis_url="redis://localhost:6379",
index_name="redis_conversation",
Expand All @@ -49,7 +47,6 @@ async def main() -> None:
vector_field_name="vector",
vector_algorithm="hnsw",
vector_distance_metric="cosine",
thread_id=session_id,
)

# Create chat client for the agent
Expand Down
Loading