diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index d05da19db..941b59106 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -192,14 +192,26 @@ def handle_get_memories( del memories["total_edges"] preferences: list[TextualMemoryItem] = [] + total_explicit_nodes, total_implicit_nodes = 0, 0 if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} if get_mem_req.user_id is not None: filter_params["user_id"] = get_mem_req.user_id if get_mem_req.mem_cube_id is not None: filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) - preferences = [format_memory_item(mem) for mem in preferences] + preferences = naive_mem_cube.pref_mem.get_memory_by_filter( + filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + ) + + for key, value_list in preferences.items(): + if key in ["explicit_preference", "implicit_preference"]: + formatted_list = [format_memory_item(item) for item in value_list] + preferences[key] = formatted_list + + total_explicit_nodes = preferences["total_explicit_nodes"] + total_implicit_nodes = preferences["total_implicit_nodes"] + del preferences["total_explicit_nodes"] + del preferences["total_implicit_nodes"] return GetMemoryResponse( message="Memories retrieved successfully", @@ -216,7 +228,8 @@ def handle_get_memories( { "cube_id": get_mem_req.mem_cube_id, "memories": preferences, - "total_nodes": len(preferences), + "total_explicit_nodes": total_explicit_nodes, + "total_implicit_nodes": total_implicit_nodes, } ], }, diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 9e521158d..75d7d2a4c 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -261,7 +261,9 @@ def get_all(self) -> list[TextualMemoryItem]: ] return all_memories - def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[TextualMemoryItem]: + def get_memory_by_filter( + self, filter: dict[str, Any] | None = None, **kwargs + ) -> list[TextualMemoryItem]: """Get memories by filter. Args: filter (dict[str, Any]): Filter criteria. @@ -269,18 +271,32 @@ def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[Tex list[TextualMemoryItem]: List of memories that match the filter. """ collection_list = self.vector_db.config.collection_name - all_db_items = [] + + memories = {} + total_explicit_nodes = 0 + total_implicit_nodes = 0 for collection_name in collection_list: - db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter) - all_db_items.extend(db_items) - memories = [ - TextualMemoryItem( - id=memo.id, - memory=memo.memory, - metadata=PreferenceTextualMemoryMetadata(**memo.payload), + memories[collection_name] = [] + db_items, total_count = self.vector_db.get_by_filter( + collection_name=collection_name, filter=filter, count_total=True, **kwargs ) - for memo in all_db_items - ] + db_items_memory = [ + TextualMemoryItem( + id=memo.id, + memory=memo.memory, + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in db_items + ] + memories[collection_name].extend(db_items_memory) + + if collection_name == "explicit_preference": + total_explicit_nodes = total_count + if collection_name == "implicit_preference": + total_implicit_nodes = total_count + memories["total_explicit_nodes"] = total_explicit_nodes + memories["total_implicit_nodes"] = total_implicit_nodes + return memories def delete(self, memory_ids: list[str]) -> None: diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index ecbca5815..b0753b31d 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -493,7 +493,14 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt return items def get_by_filter( - self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 + self, + collection_name: str, + filter: dict[str, Any], + scroll_limit: int = 100, + page: int | None = None, + page_size: int | None = None, + count_total=False, + **kwargs, ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -506,47 +513,74 @@ def get_by_filter( List of items including vectors and payload that match the filter """ expr = self._dict_to_expr(filter) if filter else "" - all_items = [] + if count_total: + total_count = 0 + count_iterator = self.client.query_iterator( + collection_name=collection_name, + filter=expr, + batch_size=scroll_limit, + output_fields=["id"], + ) + try: + while True: + batch = count_iterator.next() + if not batch: + break + total_count += len(batch) + finally: + count_iterator.close() + + result = [] + skipped = 0 + needed = page_size - # Use query_iterator for efficient pagination iterator = self.client.query_iterator( collection_name=collection_name, filter=expr, batch_size=scroll_limit, - output_fields=["*"], # Include all fields including payload + output_fields=["*"], ) - # Iterate through all batches try: - while True: - batch_results = iterator.next() - - if not batch_results: + while needed > 0: + batch = iterator.next() + if not batch: break - # Convert batch results to MilvusVecDBItem objects - for entity in batch_results: - # Extract the actual payload from Milvus entity + for entity in batch: + skipped += 1 + + if skipped <= (page - 1) * page_size: + continue + payload = entity.get("payload", {}) - all_items.append( - MilvusVecDBItem( - id=entity["id"], - memory=entity.get("memory"), - original_text=entity.get("original_text"), - vector=entity.get("vector"), - payload=payload, - ) + item = MilvusVecDBItem( + id=entity["id"], + memory=entity.get("memory"), + original_text=entity.get("original_text"), + vector=entity.get("vector"), + payload=payload, ) + result.append(item) + needed -= 1 + + if needed <= 0: + if count_total: + return result, total_count + return result + except Exception as e: - logger.warning( - f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." - ) + logger.warning(f"Error during iteration: {e}") finally: - # Close the iterator iterator.close() - logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") - return all_items + logger.info( + f"Milvus retrieve by filter completed - " + f"page {page}, page_size {page_size}, got {len(result)} items." + ) + if count_total: + return result, total_count + return result def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database."""