Skip to content

Commit e231231

Browse files
Fix Mem0 OSS client search: pass scoping params as direct kwargs
AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs, while AsyncMemoryClient (Platform) expects them in a filters dict. Adds tests for both client types. Port of fix from #3844 to new Mem0ContextProvider.
1 parent b0a8997 commit e231231

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

python/packages/mem0/agent_framework_mem0/_context_provider.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,16 @@ async def before_run(
108108

109109
filters = self._build_filters(session_id=context.session_id)
110110

111+
# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
112+
# AsyncMemoryClient (Platform) expects them in a filters dict
113+
search_kwargs: dict[str, Any] = {"query": input_text}
114+
if isinstance(self.mem0_client, AsyncMemory):
115+
search_kwargs.update(filters)
116+
else:
117+
search_kwargs["filters"] = filters
118+
111119
search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
112-
query=input_text,
113-
filters=filters,
120+
**search_kwargs,
114121
)
115122

116123
if isinstance(search_response, list):

python/packages/mem0/tests/test_mem0_context_provider.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ def mock_mem0_client() -> AsyncMock:
2626
return mock_client
2727

2828

29+
@pytest.fixture
30+
def mock_oss_mem0_client() -> AsyncMock:
31+
"""Create a mock Mem0 OSS AsyncMemory client."""
32+
from mem0 import AsyncMemory
33+
34+
mock_client = AsyncMock(spec=AsyncMemory)
35+
mock_client.add = AsyncMock()
36+
mock_client.search = AsyncMock()
37+
return mock_client
38+
39+
2940
# -- Initialization tests ------------------------------------------------------
3041

3142

@@ -157,6 +168,50 @@ async def test_search_query_combines_input_messages(self, mock_mem0_client: Asyn
157168
call_kwargs = mock_mem0_client.search.call_args.kwargs
158169
assert call_kwargs["query"] == "Hello\nWorld"
159170

171+
async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: AsyncMock) -> None:
172+
"""OSS AsyncMemory client should receive user_id as direct kwarg, not in filters."""
173+
mock_oss_mem0_client.search.return_value = [{"memory": "User likes Python"}]
174+
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1")
175+
session = AgentSession(session_id="test-session")
176+
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
177+
178+
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
179+
180+
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
181+
assert call_kwargs["query"] == "Hello"
182+
assert call_kwargs["user_id"] == "u1"
183+
assert "filters" not in call_kwargs
184+
185+
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
186+
"""OSS client with all scoping parameters passes them as direct kwargs."""
187+
mock_oss_mem0_client.search.return_value = []
188+
provider = Mem0ContextProvider(
189+
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
190+
)
191+
session = AgentSession(session_id="test-session")
192+
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
193+
194+
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
195+
196+
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
197+
assert call_kwargs["user_id"] == "u1"
198+
assert call_kwargs["agent_id"] == "a1"
199+
assert "filters" not in call_kwargs
200+
201+
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
202+
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
203+
mock_mem0_client.search.return_value = []
204+
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
205+
session = AgentSession(session_id="test-session")
206+
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
207+
208+
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
209+
210+
call_kwargs = mock_mem0_client.search.call_args.kwargs
211+
assert call_kwargs["query"] == "Hello"
212+
assert "filters" in call_kwargs
213+
assert call_kwargs["filters"]["user_id"] == "u1"
214+
160215

161216
# -- after_run tests -----------------------------------------------------------
162217

0 commit comments

Comments
 (0)