|
9 | 9 | from memos.api.handlers.formatters_handler import ( |
10 | 10 | format_memory_item, |
11 | 11 | post_process_pref_mem, |
12 | | - post_process_textual_mem, |
13 | 12 | ) |
14 | 13 | from memos.api.product_models import ( |
15 | 14 | DeleteMemoryRequest, |
@@ -250,22 +249,68 @@ def handle_get_memories( |
250 | 249 | get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube |
251 | 250 | ) -> GetMemoryResponse: |
252 | 251 | results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []} |
253 | | - memories = naive_mem_cube.text_mem.get_all( |
| 252 | + text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] |
| 253 | + text_memories_info = naive_mem_cube.text_mem.get_all( |
254 | 254 | user_name=get_mem_req.mem_cube_id, |
255 | 255 | user_id=get_mem_req.user_id, |
256 | 256 | page=get_mem_req.page, |
257 | 257 | page_size=get_mem_req.page_size, |
258 | 258 | filter=get_mem_req.filter, |
259 | | - )["nodes"] |
| 259 | + memory_type=text_memory_type, |
| 260 | + ) |
| 261 | + text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"] |
| 262 | + results["text_mem"] = [ |
| 263 | + { |
| 264 | + "cube_id": get_mem_req.mem_cube_id, |
| 265 | + "memories": text_memories, |
| 266 | + "total_nodes": total_text_nodes, |
| 267 | + } |
| 268 | + ] |
260 | 269 |
|
261 | | - results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id) |
| 270 | + if get_mem_req.include_tool_memory: |
| 271 | + tool_memories_info = naive_mem_cube.text_mem.get_all( |
| 272 | + user_name=get_mem_req.mem_cube_id, |
| 273 | + user_id=get_mem_req.user_id, |
| 274 | + page=get_mem_req.page, |
| 275 | + page_size=get_mem_req.page_size, |
| 276 | + filter=get_mem_req.filter, |
| 277 | + memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"], |
| 278 | + ) |
| 279 | + tool_memories, total_tool_nodes = ( |
| 280 | + tool_memories_info["nodes"], |
| 281 | + tool_memories_info["total_nodes"], |
| 282 | + ) |
262 | 283 |
|
263 | | - if not get_mem_req.include_tool_memory: |
264 | | - results["tool_mem"] = [] |
265 | | - if not get_mem_req.include_skill_memory: |
266 | | - results["skill_mem"] = [] |
| 284 | + results["tool_mem"] = [ |
| 285 | + { |
| 286 | + "cube_id": get_mem_req.mem_cube_id, |
| 287 | + "memories": tool_memories, |
| 288 | + "total_nodes": total_tool_nodes, |
| 289 | + } |
| 290 | + ] |
| 291 | + if get_mem_req.include_skill_memory: |
| 292 | + skill_memories_info = naive_mem_cube.text_mem.get_all( |
| 293 | + user_name=get_mem_req.mem_cube_id, |
| 294 | + user_id=get_mem_req.user_id, |
| 295 | + page=get_mem_req.page, |
| 296 | + page_size=get_mem_req.page_size, |
| 297 | + filter=get_mem_req.filter, |
| 298 | + memory_type=["SkillMemory"], |
| 299 | + ) |
| 300 | + skill_memories, total_skill_nodes = ( |
| 301 | + skill_memories_info["nodes"], |
| 302 | + skill_memories_info["total_nodes"], |
| 303 | + ) |
267 | 304 |
|
| 305 | + results["skill_mem"] = [ |
| 306 | + { |
| 307 | + "cube_id": get_mem_req.mem_cube_id, |
| 308 | + "memories": skill_memories, |
| 309 | + "total_nodes": total_skill_nodes, |
| 310 | + } |
| 311 | + ] |
268 | 312 | preferences: list[TextualMemoryItem] = [] |
| 313 | + total_preference_nodes = 0 |
269 | 314 |
|
270 | 315 | format_preferences = [] |
271 | 316 | if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: |
@@ -294,14 +339,16 @@ def handle_get_memories( |
294 | 339 |
|
295 | 340 | filter_params.update(filter_copy) |
296 | 341 |
|
297 | | - preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter( |
| 342 | + preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( |
298 | 343 | filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size |
299 | 344 | ) |
300 | 345 | format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] |
301 | 346 |
|
302 | 347 | results = post_process_pref_mem( |
303 | 348 | results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference |
304 | 349 | ) |
| 350 | + if total_preference_nodes > 0 and results.get("pref_mem", []): |
| 351 | + results["pref_mem"][0]["total_nodes"] = total_preference_nodes |
305 | 352 |
|
306 | 353 | # Filter to only keep text_mem, pref_mem, tool_mem |
307 | 354 | filtered_results = { |
|
0 commit comments