Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,12 +625,26 @@ def _extract_fields(mem_item):
or mem_item.get("old_memory")
or mem_item.get("original_content")
)
return mem_id, mem_memory, original_content
source_doc_id = None
if isinstance(mem_item, dict):
source_doc_id = (
mem_item.get("source_doc_id")
or mem_item.get("doc_id")
or (mem_item.get("metadata") or {}).get("source_doc_id")
)
else:
metadata = getattr(mem_item, "metadata", None)
if metadata:
source_doc_id = getattr(metadata, "source_doc_id", None) or getattr(
metadata, "doc_id", None
)

return mem_id, mem_memory, original_content, source_doc_id

kb_log_content: list[dict] = []

for mem_item in add_records or []:
mem_id, mem_memory, _ = _extract_fields(mem_item)
mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item)
if mem_id and mem_memory:
kb_log_content.append(
{
Expand All @@ -640,7 +654,7 @@ def _extract_fields(mem_item):
"memory_id": mem_id,
"content": mem_memory,
"original_content": None,
"source_doc_id": None,
"source_doc_id": source_doc_id,
}
)
else:
Expand All @@ -654,7 +668,7 @@ def _extract_fields(mem_item):
)

for mem_item in update_records or []:
mem_id, mem_memory, original_content = _extract_fields(mem_item)
mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item)
if mem_id and mem_memory:
kb_log_content.append(
{
Expand All @@ -664,7 +678,7 @@ def _extract_fields(mem_item):
"memory_id": mem_id,
"content": mem_memory,
"original_content": original_content,
"source_doc_id": None,
"source_doc_id": source_doc_id,
}
)
else:
Expand Down
Loading