diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index c60e84253..a4052d313 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -90,6 +90,7 @@ status_tracker = TaskStatusTracker(redis_client=redis_client) embedder = components["embedder"] graph_db = components["graph_db"] +vector_db = components["vector_db"] # ============================================================================= @@ -359,6 +360,13 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): ), ) result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids) + if vector_db: + prefs = [] + for collection_name in ["explicit_preference", "implicit_preference"]: + prefs.extend( + vector_db.get_by_ids(collection_name=collection_name, ids=request.memory_ids) + ) + result.update({pref.id: pref.payload.get("mem_cube_id", None) for pref in prefs}) return GetUserNamesByMemoryIdsResponse( code=200, message="Successfully", diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index a34315918..78f4d6e28 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -248,7 +248,7 @@ def get_all(self) -> list[TextualMemoryItem]: Returns: list[TextualMemoryItem]: List of all memories. """ - all_collections = self.vector_db.list_collections() + all_collections = ["explicit_preference", "implicit_preference"] all_memories = {} for collection_name in all_collections: items = self.vector_db.get_all(collection_name) diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index ee37d638c..cc1781f06 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -90,7 +90,7 @@ def get_with_collection_name( return None return TextualMemoryItem( id=res.id, - memory=res.payload.get("dialog_str", ""), + memory=res.memory, metadata=PreferenceTextualMemoryMetadata(**res.payload), ) except Exception as e: @@ -116,7 +116,7 @@ def get_by_ids_with_collection_name( return [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in res @@ -132,14 +132,14 @@ def get_all(self) -> list[TextualMemoryItem]: Returns: list[TextualMemoryItem]: List of all memories. """ - all_collections = self.vector_db.list_collections() + all_collections = ["explicit_preference", "implicit_preference"] all_memories = {} for collection_name in all_collections: items = self.vector_db.get_all(collection_name) all_memories[collection_name] = [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in items diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 5dacf0499..cc8909d34 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -457,14 +457,13 @@ def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None: return None entity = results[0] - payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} return MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), original_text=entity.get("original_text"), vector=entity.get("vector"), - payload=payload, + payload=entity.get("payload", {}), ) def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]: @@ -479,14 +478,13 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt items = [] for entity in results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} items.append( MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), original_text=entity.get("original_text"), vector=entity.get("vector"), - payload=payload, + payload=entity.get("payload", {}), ) )