Skip to content

Commit da91ffc

Browse files
whipser030黑布林
andauthored
fix: The "total nodes" information returned by the "get memory" interface is inaccurate. (#1018)
* fix: add fileurl to memoryvalue * Extract the phrases from the key and input them into the tags. * The issue of inaccurate retrieval of "total nodes" through the "get memory" interface --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com>
1 parent eaf15ef commit da91ffc

2 files changed

Lines changed: 63 additions & 10 deletions

File tree

src/memos/api/handlers/memory_handler.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from memos.api.handlers.formatters_handler import (
1010
format_memory_item,
1111
post_process_pref_mem,
12-
post_process_textual_mem,
1312
)
1413
from memos.api.product_models import (
1514
DeleteMemoryRequest,
@@ -250,22 +249,68 @@ def handle_get_memories(
250249
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
251250
) -> GetMemoryResponse:
252251
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []}
253-
memories = naive_mem_cube.text_mem.get_all(
252+
text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
253+
text_memories_info = naive_mem_cube.text_mem.get_all(
254254
user_name=get_mem_req.mem_cube_id,
255255
user_id=get_mem_req.user_id,
256256
page=get_mem_req.page,
257257
page_size=get_mem_req.page_size,
258258
filter=get_mem_req.filter,
259-
)["nodes"]
259+
memory_type=text_memory_type,
260+
)
261+
text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"]
262+
results["text_mem"] = [
263+
{
264+
"cube_id": get_mem_req.mem_cube_id,
265+
"memories": text_memories,
266+
"total_nodes": total_text_nodes,
267+
}
268+
]
260269

261-
results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id)
270+
if get_mem_req.include_tool_memory:
271+
tool_memories_info = naive_mem_cube.text_mem.get_all(
272+
user_name=get_mem_req.mem_cube_id,
273+
user_id=get_mem_req.user_id,
274+
page=get_mem_req.page,
275+
page_size=get_mem_req.page_size,
276+
filter=get_mem_req.filter,
277+
memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"],
278+
)
279+
tool_memories, total_tool_nodes = (
280+
tool_memories_info["nodes"],
281+
tool_memories_info["total_nodes"],
282+
)
262283

263-
if not get_mem_req.include_tool_memory:
264-
results["tool_mem"] = []
265-
if not get_mem_req.include_skill_memory:
266-
results["skill_mem"] = []
284+
results["tool_mem"] = [
285+
{
286+
"cube_id": get_mem_req.mem_cube_id,
287+
"memories": tool_memories,
288+
"total_nodes": total_tool_nodes,
289+
}
290+
]
291+
if get_mem_req.include_skill_memory:
292+
skill_memories_info = naive_mem_cube.text_mem.get_all(
293+
user_name=get_mem_req.mem_cube_id,
294+
user_id=get_mem_req.user_id,
295+
page=get_mem_req.page,
296+
page_size=get_mem_req.page_size,
297+
filter=get_mem_req.filter,
298+
memory_type=["SkillMemory"],
299+
)
300+
skill_memories, total_skill_nodes = (
301+
skill_memories_info["nodes"],
302+
skill_memories_info["total_nodes"],
303+
)
267304

305+
results["skill_mem"] = [
306+
{
307+
"cube_id": get_mem_req.mem_cube_id,
308+
"memories": skill_memories,
309+
"total_nodes": total_skill_nodes,
310+
}
311+
]
268312
preferences: list[TextualMemoryItem] = []
313+
total_preference_nodes = 0
269314

270315
format_preferences = []
271316
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
@@ -294,14 +339,16 @@ def handle_get_memories(
294339

295340
filter_params.update(filter_copy)
296341

297-
preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter(
342+
preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter(
298343
filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
299344
)
300345
format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]
301346

302347
results = post_process_pref_mem(
303348
results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference
304349
)
350+
if total_preference_nodes > 0 and results.get("pref_mem", []):
351+
results["pref_mem"][0]["total_nodes"] = total_preference_nodes
305352

306353
# Filter to only keep text_mem, pref_mem, tool_mem
307354
filtered_results = {

src/memos/memories/textual/tree.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,19 @@ def get_all(
364364
page: int | None = None,
365365
page_size: int | None = None,
366366
filter: dict | None = None,
367+
memory_type: list[str] | None = None,
367368
) -> dict:
368369
"""Get all memories.
369370
Returns:
370371
list[TextualMemoryItem]: List of all memories.
371372
"""
372373
graph_output = self.graph_store.export_graph(
373-
user_name=user_name, user_id=user_id, page=page, page_size=page_size, filter=filter
374+
user_name=user_name,
375+
user_id=user_id,
376+
page=page,
377+
page_size=page_size,
378+
filter=filter,
379+
memory_type=memory_type,
374380
)
375381
return graph_output
376382

0 commit comments

Comments
 (0)