@@ -26,6 +26,17 @@ def mock_mem0_client() -> AsyncMock:
2626 return mock_client
2727
2828
29+ @pytest .fixture
30+ def mock_oss_mem0_client () -> AsyncMock :
31+ """Create a mock Mem0 OSS AsyncMemory client."""
32+ from mem0 import AsyncMemory
33+
34+ mock_client = AsyncMock (spec = AsyncMemory )
35+ mock_client .add = AsyncMock ()
36+ mock_client .search = AsyncMock ()
37+ return mock_client
38+
39+
2940# -- Initialization tests ------------------------------------------------------
3041
3142
@@ -157,6 +168,50 @@ async def test_search_query_combines_input_messages(self, mock_mem0_client: Asyn
157168 call_kwargs = mock_mem0_client .search .call_args .kwargs
158169 assert call_kwargs ["query" ] == "Hello\n World"
159170
171+ async def test_oss_client_passes_direct_kwargs (self , mock_oss_mem0_client : AsyncMock ) -> None :
172+ """OSS AsyncMemory client should receive user_id as direct kwarg, not in filters."""
173+ mock_oss_mem0_client .search .return_value = [{"memory" : "User likes Python" }]
174+ provider = Mem0ContextProvider (source_id = "mem0" , mem0_client = mock_oss_mem0_client , user_id = "u1" )
175+ session = AgentSession (session_id = "test-session" )
176+ ctx = SessionContext (input_messages = [Message (role = "user" , text = "Hello" )], session_id = "s1" )
177+
178+ await provider .before_run (agent = None , session = session , context = ctx , state = session .state ) # type: ignore[arg-type]
179+
180+ call_kwargs = mock_oss_mem0_client .search .call_args .kwargs
181+ assert call_kwargs ["query" ] == "Hello"
182+ assert call_kwargs ["user_id" ] == "u1"
183+ assert "filters" not in call_kwargs
184+
185+ async def test_oss_client_all_scoping_params (self , mock_oss_mem0_client : AsyncMock ) -> None :
186+ """OSS client with all scoping parameters passes them as direct kwargs."""
187+ mock_oss_mem0_client .search .return_value = []
188+ provider = Mem0ContextProvider (
189+ source_id = "mem0" , mem0_client = mock_oss_mem0_client , user_id = "u1" , agent_id = "a1" , application_id = "app1"
190+ )
191+ session = AgentSession (session_id = "test-session" )
192+ ctx = SessionContext (input_messages = [Message (role = "user" , text = "Hello" )], session_id = "s1" )
193+
194+ await provider .before_run (agent = None , session = session , context = ctx , state = session .state ) # type: ignore[arg-type]
195+
196+ call_kwargs = mock_oss_mem0_client .search .call_args .kwargs
197+ assert call_kwargs ["user_id" ] == "u1"
198+ assert call_kwargs ["agent_id" ] == "a1"
199+ assert "filters" not in call_kwargs
200+
201+ async def test_platform_client_passes_filters_dict (self , mock_mem0_client : AsyncMock ) -> None :
202+ """Platform AsyncMemoryClient should receive scoping params in a filters dict."""
203+ mock_mem0_client .search .return_value = []
204+ provider = Mem0ContextProvider (source_id = "mem0" , mem0_client = mock_mem0_client , user_id = "u1" )
205+ session = AgentSession (session_id = "test-session" )
206+ ctx = SessionContext (input_messages = [Message (role = "user" , text = "Hello" )], session_id = "s1" )
207+
208+ await provider .before_run (agent = None , session = session , context = ctx , state = session .state ) # type: ignore[arg-type]
209+
210+ call_kwargs = mock_mem0_client .search .call_args .kwargs
211+ assert call_kwargs ["query" ] == "Hello"
212+ assert "filters" in call_kwargs
213+ assert call_kwargs ["filters" ]["user_id" ] == "u1"
214+
160215
161216# -- after_run tests -----------------------------------------------------------
162217
0 commit comments