From 439ed49e6d6fff5c63ed3576bbb7eaa3c1c915b9 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 17:50:35 +0800 Subject: [PATCH 01/31] hotfix:hotfix --- src/memos/api/product_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 30df150ea..f7f0304c7 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -201,8 +201,8 @@ class APIADDRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) - async_mode: Literal["async", "sync"] = Field( - "async", description="Whether to add memory in async mode" + async_mode: Literal["async", "sync"] | None = Field( + None, description="Whether to add memory in async mode" ) From 39a7b34018c675f3cd5835d99a7cb058bcf60c97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 17:00:33 +0800 Subject: [PATCH 02/31] test: add routers api --- tests/api/test_product_router.py | 450 +++++++++++++++++++++++++++++++ tests/api/test_server_router.py | 445 ++++++++++++++++++++++++++++++ 2 files changed, 895 insertions(+) create mode 100644 tests/api/test_product_router.py create mode 100644 tests/api/test_server_router.py diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py new file mode 100644 index 000000000..9ed67d037 --- /dev/null +++ b/tests/api/test_product_router.py @@ -0,0 +1,450 @@ +""" +Unit tests for product_router input/output format validation. + +This module tests that the product_router endpoints correctly validate +input request formats and return properly formatted responses. +""" + +# Mock sklearn before importing any memos modules to avoid import errors +import importlib.util +import sys + +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from fastapi.testclient import TestClient + +# Patch the MOS_PRODUCT_INSTANCE directly after import +import memos.api.routers.product_router as pr_module + + +# Create a proper mock module with __spec__ +sklearn_mock = MagicMock() +sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) +sys.modules["sklearn"] = sklearn_mock + +sklearn_fe_mock = MagicMock() +sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) +sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock + +sklearn_fet_mock = MagicMock() +sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) +sklearn_fet_mock.TfidfVectorizer = MagicMock() +sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock + +# Mock sklearn.metrics as well +sklearn_metrics_mock = MagicMock() +sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) +sklearn_metrics_mock.roc_curve = MagicMock() +sys.modules["sklearn.metrics"] = sklearn_metrics_mock + + +# Create mock instance +_mock_mos_instance = Mock() + +pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance +pr_module.get_mos_product_instance = lambda: _mock_mos_instance + +# Mock MOSProduct class before importing to prevent initialization +with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): + # Import after patching + from memos.api import product_api + + +@pytest.fixture(scope="module") +def mock_mos_product_instance(): + """Mock get_mos_product_instance for all tests.""" + # Ensure the mock is set + pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance + pr_module.get_mos_product_instance = lambda: _mock_mos_instance + yield product_api.app, _mock_mos_instance + + +@pytest.fixture +def client(mock_mos_product_instance): + """Create test client for product_api.""" + app, _ = mock_mos_product_instance + return TestClient(app) + + +@pytest.fixture +def mock_mos_product(mock_mos_product_instance): + """Get the mocked MOSProduct instance.""" + _, mock_instance = mock_mos_product_instance + # Ensure get_mos_product_instance returns this mock + import memos.api.routers.product_router as pr_module + + pr_module.get_mos_product_instance = lambda: mock_instance + pr_module.MOS_PRODUCT_INSTANCE = mock_instance + return mock_instance + + +@pytest.fixture(autouse=True) +def setup_mock_mos_product(mock_mos_product): + """Set up default return values for MOSProduct methods.""" + # Set up default return values for methods + mock_mos_product.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} + mock_mos_product.add.return_value = None + mock_mos_product.chat.return_value = ("test response", []) + mock_mos_product.chat_with_references.return_value = iter( + ['data: {"type": "content", "data": "test"}\n\n'] + ) + # Ensure get_all and get_subgraph return proper list format (MemoryResponse expects list) + default_memory_result = [{"cube_id": "test_cube", "memories": []}] + mock_mos_product.get_all.return_value = default_memory_result + mock_mos_product.get_subgraph.return_value = default_memory_result + mock_mos_product.get_suggestion_query.return_value = ["suggestion1", "suggestion2"] + # Ensure get_mos_product_instance returns the mock + import memos.api.routers.product_router as pr_module + + pr_module.get_mos_product_instance = lambda: mock_mos_product + + +class TestProductRouterSearch: + """Test /search endpoint input/output format.""" + + def test_search_valid_input_output(self, mock_mos_product, client): + """Test search endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + "top_k": 10, + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], dict) + + # Verify MOSProduct.search was called with correct parameters + mock_mos_product.search.assert_called_once() + call_kwargs = mock_mos_product.search.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["query"] == "test query" + + def test_search_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test search endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/search", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_search_response_format(self, mock_mos_product, client): + """Test search endpoint returns SearchResponse format.""" + mock_mos_product.search.return_value = { + "text_mem": [{"cube_id": "test_cube", "memories": []}], + "act_mem": [], + "para_mem": [], + } + + request_data = { + "user_id": "test_user", + "query": "test query", + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Search completed successfully" + assert isinstance(data["data"], dict) + assert "text_mem" in data["data"] + + +class TestProductRouterAdd: + """Test /add endpoint input/output format.""" + + def test_add_valid_input_output(self, mock_mos_product, client): + """Test add endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "memory_content": "test memory content", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert data["data"] is None # SimpleResponse has None data + + # Verify MOSProduct.add was called with correct parameters + mock_mos_product.add.assert_called_once() + call_kwargs = mock_mos_product.add.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["memory_content"] == "test memory content" + + def test_add_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test add endpoint with missing required field.""" + request_data = { + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_add_response_format(self, mock_mos_product, client): + """Test add endpoint returns SimpleResponse format.""" + request_data = { + "user_id": "test_user", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memory created successfully" + assert data["data"] is None + + +class TestProductRouterChatComplete: + """Test /chat/complete endpoint input/output format.""" + + def test_chat_complete_valid_input_output(self, mock_mos_product, client): + """Test chat/complete endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "message" in data + assert "data" in data + assert isinstance(data["data"], dict) + assert "response" in data["data"] + assert "references" in data["data"] + + # Verify MOSProduct.chat was called with correct parameters + mock_mos_product.chat.assert_called_once() + call_kwargs = mock_mos_product.chat.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["query"] == "test query" + + def test_chat_complete_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test chat/complete endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_chat_complete_response_format(self, mock_mos_product, client): + """Test chat/complete endpoint returns correct format.""" + mock_mos_product.chat.return_value = ("test response", [{"id": "ref1"}]) + + request_data = { + "user_id": "test_user", + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Chat completed successfully" + assert isinstance(data["data"]["response"], str) + assert isinstance(data["data"]["references"], list) + + +class TestProductRouterChat: + """Test /chat endpoint input/output format (SSE stream).""" + + def test_chat_valid_input_output(self, mock_mos_product, client): + """Test chat endpoint with valid input returns SSE stream.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/chat", json=request_data) + + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + # Verify MOSProduct.chat_with_references was called + mock_mos_product.chat_with_references.assert_called_once() + call_kwargs = mock_mos_product.chat_with_references.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["query"] == "test query" + + def test_chat_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test chat endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/chat", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + +class TestProductRouterSuggestions: + """Test /suggestions endpoint input/output format.""" + + def test_suggestions_valid_input_output(self, mock_mos_product, client): + """Test suggestions endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "zh", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], dict) + assert "query" in data["data"] + + # Verify MOSProduct.get_suggestion_query was called + mock_mos_product.get_suggestion_query.assert_called_once() + call_kwargs = mock_mos_product.get_suggestion_query.call_args[1] + assert call_kwargs["user_id"] == "test_user" + + def test_suggestions_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test suggestions endpoint with missing required field.""" + request_data = { + "mem_cube_id": "test_cube", + } + + response = client.post("/product/suggestions", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_suggestions_response_format(self, mock_mos_product, client): + """Test suggestions endpoint returns SuggestionResponse format.""" + mock_mos_product.get_suggestion_query.return_value = [ + "suggestion1", + "suggestion2", + "suggestion3", + ] + + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "en", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Suggestions retrieved successfully" + assert isinstance(data["data"], dict) + assert isinstance(data["data"]["query"], list) + + +class TestProductRouterGetAll: + """Test /get_all endpoint input/output format.""" + + def test_get_all_valid_input_output(self, mock_mos_product, client): + """Test get_all endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], list) + + # Verify MOSProduct.get_all was called + mock_mos_product.get_all.assert_called_once() + call_kwargs = mock_mos_product.get_all.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["memory_type"] == "text_mem" + + def test_get_all_with_search_query(self, mock_mos_product, client): + """Test get_all endpoint with search_query uses get_subgraph.""" + # Reset mock call counts + mock_mos_product.get_all.reset_mock() + mock_mos_product.get_subgraph.reset_mock() + + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + "search_query": "test query", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + # Verify get_subgraph was called instead of get_all + mock_mos_product.get_subgraph.assert_called_once() + mock_mos_product.get_all.assert_not_called() + + def test_get_all_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test get_all endpoint with missing required field.""" + request_data = { + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_get_all_response_format(self, mock_mos_product, client): + """Test get_all endpoint returns MemoryResponse format.""" + mock_mos_product.get_all.return_value = [{"cube_id": "test_cube", "memories": []}] + + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memories retrieved successfully" + assert isinstance(data["data"], list) + assert len(data["data"]) > 0 diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py new file mode 100644 index 000000000..a4bb198e0 --- /dev/null +++ b/tests/api/test_server_router.py @@ -0,0 +1,445 @@ +""" +Unit tests for server_router input/output format validation. + +This module tests that the server_router endpoints correctly validate +input request formats and return properly formatted responses. +""" + +# Mock sklearn before importing any memos modules to avoid import errors +import importlib.util +import sys + +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from fastapi.testclient import TestClient + +from memos.api.product_models import ( + APIADDRequest, + APIChatCompleteRequest, + APISearchRequest, + MemoryResponse, + SearchResponse, + SuggestionResponse, +) + + +# Create a proper mock module with __spec__ +sklearn_mock = MagicMock() +sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) +sys.modules["sklearn"] = sklearn_mock + +sklearn_fe_mock = MagicMock() +sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) +sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock + +sklearn_metrics_mock = MagicMock() +sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) +sys.modules["sklearn.metrics"] = sklearn_metrics_mock + +sklearn_fet_mock = MagicMock() +sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) +sklearn_fet_mock.TfidfVectorizer = MagicMock() +sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock + + +@pytest.fixture(scope="module") +def mock_init_server(): + """Mock init_server before importing server_api.""" + # Create mock components + mock_components = { + "graph_db": Mock(), + "mem_reader": Mock(), + "llm": Mock(), + "embedder": Mock(), + "reranker": Mock(), + "internet_retriever": Mock(), + "memory_manager": Mock(), + "default_cube_config": Mock(), + "mos_server": Mock(), + "mem_scheduler": Mock(), + "naive_mem_cube": Mock(), + "searcher": Mock(), + "api_module": Mock(), + "vector_db": None, + "pref_extractor": None, + "pref_adder": None, + "pref_retriever": None, + "pref_mem": None, + "online_bot": None, + } + + with patch("memos.api.handlers.init_server", return_value=mock_components): + # Import after patching + from memos.api import server_api + + yield server_api.app + + +@pytest.fixture +def client(mock_init_server): + """Create test client for server_api.""" + return TestClient(mock_init_server) + + +@pytest.fixture +def mock_handlers(): + """Mock all handlers used by server_router.""" + with ( + patch("memos.api.routers.server_router.search_handler") as mock_search, + patch("memos.api.routers.server_router.add_handler") as mock_add, + patch("memos.api.routers.server_router.chat_handler") as mock_chat, + patch("memos.api.routers.server_router.handlers.suggestion_handler") as mock_suggestion, + patch("memos.api.routers.server_router.handlers.memory_handler") as mock_memory, + ): + # Set up default return values + mock_search.handle_search_memories.return_value = SearchResponse( + message="Search completed successfully", + data={"text_mem": [], "act_mem": [], "para_mem": []}, + ) + + mock_add.handle_add_memories.return_value = MemoryResponse( + message="Memory added successfully", data=[] + ) + + mock_chat.handle_chat_complete.return_value = { + "message": "Chat completed successfully", + "data": {"response": "test response", "references": []}, + } + + mock_suggestion.handle_get_suggestion_queries.return_value = SuggestionResponse( + message="Suggestions retrieved successfully", data={"query": ["suggestion1"]} + ) + + mock_memory.handle_get_all_memories.return_value = MemoryResponse( + message="Memories retrieved successfully", data=[] + ) + + mock_memory.handle_get_subgraph.return_value = MemoryResponse( + message="Memories retrieved successfully", data=[] + ) + + yield { + "search": mock_search, + "add": mock_add, + "chat": mock_chat, + "suggestion": mock_suggestion, + "memory": mock_memory, + } + + +class TestServerRouterSearch: + """Test /search endpoint input/output format.""" + + def test_search_valid_input_output(self, mock_handlers, client): + """Test search endpoint with valid input returns correct output format.""" + request_data = { + "query": "test query", + "user_id": "test_user", + "mem_cube_id": "test_cube", + "top_k": 10, + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], dict) + + # Verify handler was called with correct request type + mock_handlers["search"].handle_search_memories.assert_called_once() + call_args = mock_handlers["search"].handle_search_memories.call_args[0][0] + assert isinstance(call_args, APISearchRequest) + assert call_args.query == "test query" + assert call_args.user_id == "test_user" + + def test_search_invalid_input_missing_query(self, mock_handlers, client): + """Test search endpoint with missing required field.""" + request_data = { + "user_id": "test_user", + } + + response = client.post("/product/search", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_search_response_format(self, mock_handlers, client): + """Test search endpoint returns SearchResponse format.""" + mock_handlers["search"].handle_search_memories.return_value = SearchResponse( + message="Search completed successfully", + data={ + "text_mem": [{"cube_id": "test_cube", "memories": []}], + "act_mem": [], + "para_mem": [], + }, + ) + + request_data = { + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Search completed successfully" + assert isinstance(data["data"], dict) + assert "text_mem" in data["data"] + + +class TestServerRouterAdd: + """Test /add endpoint input/output format.""" + + def test_add_valid_input_output(self, mock_handlers, client): + """Test add endpoint with valid input returns correct output format.""" + request_data = { + "mem_cube_id": "test_cube", + "user_id": "test_user", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], list) + + # Verify handler was called with correct request type + mock_handlers["add"].handle_add_memories.assert_called_once() + call_args = mock_handlers["add"].handle_add_memories.call_args[0][0] + assert isinstance(call_args, APIADDRequest) + assert call_args.mem_cube_id == "test_cube" + assert call_args.user_id == "test_user" + + def test_add_invalid_input_missing_cube_id(self, mock_handlers, client): + """Test add endpoint with missing required field.""" + request_data = { + "user_id": "test_user", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_add_response_format(self, mock_handlers, client): + """Test add endpoint returns MemoryResponse format.""" + mock_handlers["add"].handle_add_memories.return_value = MemoryResponse( + message="Memory added successfully", + data=[{"cube_id": "test_cube", "memories": []}], + ) + + request_data = { + "mem_cube_id": "test_cube", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memory added successfully" + assert isinstance(data["data"], list) + + +class TestServerRouterChatComplete: + """Test /chat/complete endpoint input/output format.""" + + def test_chat_complete_valid_input_output(self, mock_handlers, client): + """Test chat/complete endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "message" in data + assert "data" in data + assert isinstance(data["data"], dict) + assert "response" in data["data"] + assert "references" in data["data"] + + # Verify handler was called with correct request type + mock_handlers["chat"].handle_chat_complete.assert_called_once() + call_args = mock_handlers["chat"].handle_chat_complete.call_args[0][0] + assert isinstance(call_args, APIChatCompleteRequest) + assert call_args.user_id == "test_user" + assert call_args.query == "test query" + + def test_chat_complete_invalid_input_missing_user_id(self, mock_handlers, client): + """Test chat/complete endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_chat_complete_response_format(self, mock_handlers, client): + """Test chat/complete endpoint returns correct format.""" + mock_handlers["chat"].handle_chat_complete.return_value = { + "message": "Chat completed successfully", + "data": {"response": "test response", "references": [{"id": "ref1"}]}, + } + + request_data = { + "user_id": "test_user", + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Chat completed successfully" + assert isinstance(data["data"]["response"], str) + assert isinstance(data["data"]["references"], list) + + +class TestServerRouterSuggestions: + """Test /suggestions endpoint input/output format.""" + + def test_suggestions_valid_input_output(self, mock_handlers, client): + """Test suggestions endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "zh", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + + # Verify handler was called + mock_handlers["suggestion"].handle_get_suggestion_queries.assert_called_once() + + def test_suggestions_invalid_input_missing_user_id(self, mock_handlers, client): + """Test suggestions endpoint with missing required field.""" + request_data = { + "mem_cube_id": "test_cube", + } + + response = client.post("/product/suggestions", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_suggestions_response_format(self, mock_handlers, client): + """Test suggestions endpoint returns SuggestionResponse format.""" + mock_handlers["suggestion"].handle_get_suggestion_queries.return_value = SuggestionResponse( + message="Suggestions retrieved successfully", + data={"query": ["suggestion1", "suggestion2"]}, + ) + + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "en", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Suggestions retrieved successfully" + assert isinstance(data["data"], dict) + assert "query" in data["data"] + + +class TestServerRouterGetAll: + """Test /get_all endpoint input/output format.""" + + def test_get_all_valid_input_output(self, mock_handlers, client): + """Test get_all endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], list) + + def test_get_all_with_search_query(self, mock_handlers, client): + """Test get_all endpoint with search_query uses subgraph handler.""" + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + "search_query": "test query", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + # Verify subgraph handler was called + mock_handlers["memory"].handle_get_subgraph.assert_called_once() + + def test_get_all_invalid_input_missing_user_id(self, mock_handlers, client): + """Test get_all endpoint with missing required field.""" + request_data = { + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_get_all_response_format(self, mock_handlers, client): + """Test get_all endpoint returns MemoryResponse format.""" + mock_handlers["memory"].handle_get_all_memories.return_value = MemoryResponse( + message="Memories retrieved successfully", + data=[{"cube_id": "test_cube", "memories": []}], + ) + + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memories retrieved successfully" + assert isinstance(data["data"], list) From cbed950c4fcad6ce0d0faa2e6c5ec2b5779f80c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sun, 7 Dec 2025 15:17:58 +0800 Subject: [PATCH 03/31] fix: doc fine mode bug --- src/memos/mem_reader/multi_modal_struct.py | 38 +++++++++++++++------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 0cb4e1542..9a7a3054d 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -377,21 +377,37 @@ def _process_string_fine( except Exception as e: logger.error(f"[MultiModalFine] Error calling LLM: {e}") continue - for m in resp.get("memory list", []): + if resp.get("memory list", []): + for m in resp.get("memory list", []): + try: + # Normalize memory_type (same as simple_struct) + memory_type = ( + m.get("memory_type", "LongTermMemory") + .replace("长期记忆", "LongTermMemory") + .replace("用户记忆", "UserMemory") + ) + # Create fine mode memory item (same as simple_struct) + node = self._make_memory_item( + value=m.get("value", ""), + info=info, + memory_type=memory_type, + tags=m.get("tags", []), + key=m.get("key", ""), + sources=sources, # Preserve sources from fast item + background=resp.get("summary", ""), + ) + fine_memory_items.append(node) + except Exception as e: + logger.error(f"[MultiModalFine] parse error: {e}") + elif isinstance(resp, dict): try: - # Normalize memory_type (same as simple_struct) - memory_type = ( - m.get("memory_type", "LongTermMemory") - .replace("长期记忆", "LongTermMemory") - .replace("用户记忆", "UserMemory") - ) # Create fine mode memory item (same as simple_struct) node = self._make_memory_item( - value=m.get("value", ""), + value=resp.get("value", "").strip(), info=info, - memory_type=memory_type, - tags=m.get("tags", []), - key=m.get("key", ""), + memory_type="LongTermMemory", + tags=resp.get("tags", []), + key=resp.get("key", None), sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), ) From 20e08396b113f169b7e34f405098f1059a93b7f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sun, 7 Dec 2025 15:21:06 +0800 Subject: [PATCH 04/31] fix: doc fine mode bug --- src/memos/mem_reader/multi_modal_struct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 9a7a3054d..3a9aa014b 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -399,7 +399,7 @@ def _process_string_fine( fine_memory_items.append(node) except Exception as e: logger.error(f"[MultiModalFine] parse error: {e}") - elif isinstance(resp, dict): + elif resp.get("value") and resp.get("key"): try: # Create fine mode memory item (same as simple_struct) node = self._make_memory_item( From fff0fb290627eb9cc6fbeb242222ec398858bd71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sun, 7 Dec 2025 16:30:19 +0800 Subject: [PATCH 05/31] feat: init longbench_v2 --- evaluation/scripts/longbench/__init__.py | 1 + .../scripts/longbench/longbench_ingestion.py | 306 +++++++++++++++++ .../scripts/longbench/longbench_metric.py | 235 +++++++++++++ .../scripts/longbench/longbench_responses.py | 196 +++++++++++ .../scripts/longbench/longbench_search.py | 309 ++++++++++++++++++ .../scripts/longbench_v2/prepare_data.py | 0 6 files changed, 1047 insertions(+) create mode 100644 evaluation/scripts/longbench/__init__.py create mode 100644 evaluation/scripts/longbench/longbench_ingestion.py create mode 100644 evaluation/scripts/longbench/longbench_metric.py create mode 100644 evaluation/scripts/longbench/longbench_responses.py create mode 100644 evaluation/scripts/longbench/longbench_search.py create mode 100644 evaluation/scripts/longbench_v2/prepare_data.py diff --git a/evaluation/scripts/longbench/__init__.py b/evaluation/scripts/longbench/__init__.py new file mode 100644 index 000000000..38cc006e3 --- /dev/null +++ b/evaluation/scripts/longbench/__init__.py @@ -0,0 +1 @@ +# LongBench evaluation scripts diff --git a/evaluation/scripts/longbench/longbench_ingestion.py b/evaluation/scripts/longbench/longbench_ingestion.py new file mode 100644 index 000000000..e2d2a8e7e --- /dev/null +++ b/evaluation/scripts/longbench/longbench_ingestion.py @@ -0,0 +1,306 @@ +import argparse +import json +import os +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timezone + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +# All LongBench datasets +LONGBENCH_DATASETS = [ + "narrativeqa", + "qasper", + "multifieldqa_en", + "multifieldqa_zh", + "hotpotqa", + "2wikimqa", + "musique", + "dureader", + "gov_report", + "qmsum", + "multi_news", + "vcsum", + "trec", + "triviaqa", + "samsum", + "lsht", + "passage_count", + "passage_retrieval_en", + "passage_retrieval_zh", + "lcc", + "repobench-p", +] + + +def ingest_sample(client, sample, dataset_name, sample_idx, frame, version): + """Ingest a single LongBench sample as memories.""" + user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" + conv_id = f"longbench_{dataset_name}_{sample_idx}_{version}" + + # Get context and convert to messages + context = sample.get("context", "") + # not used now: input_text = sample.get("input", "") + + # For memos, we ingest the context as document content + # Split context into chunks if it's too long (optional, memos handles this internally) + # For now, we'll ingest the full context as a single message + messages = [ + { + "role": "assistant", + "content": context, + "chat_time": datetime.now(timezone.utc).isoformat(), + } + ] + + if "memos-api" in frame: + try: + client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") + return False + elif "mem0" in frame: + timestamp = int(datetime.now(timezone.utc).timestamp()) + try: + client.add(messages=messages, user_id=user_id, timestamp=timestamp, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") + return False + elif frame == "memobase": + for m in messages: + m["created_at"] = messages[0]["chat_time"] + try: + client.add(messages=messages, user_id=user_id, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") + return False + elif frame == "memu": + try: + client.add(messages=messages, user_id=user_id, iso_date=messages[0]["chat_time"]) + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") + return False + elif frame == "supermemory": + try: + client.add(messages=messages, user_id=user_id) + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") + return False + + return False + + +def load_dataset_from_local(dataset_name, use_e=False): + """Load LongBench dataset from local JSONL file.""" + # Determine data directory + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + # Determine filename + filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" + + filepath = os.path.join(data_dir, filename) + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSONL file + samples = [] + with open(filepath, encoding="utf-8") as f: + for line in f: + if line.strip(): + samples.append(json.loads(line)) + + return samples + + +def ingest_dataset(dataset_name, frame, version, num_workers=10, max_samples=None, use_e=False): + """Ingest a single LongBench dataset.""" + print(f"\n{'=' * 80}") + print(f"🔄 [INGESTING DATASET: {dataset_name.upper()}]".center(80)) + print(f"{'=' * 80}\n") + + # Load dataset from local files + try: + dataset = load_dataset_from_local(dataset_name, use_e) + print(f"Loaded {len(dataset)} samples from {dataset_name}") + except FileNotFoundError as e: + print(f"❌ Error loading dataset {dataset_name}: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset {dataset_name}: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize client + client = None + if frame == "mem0" or frame == "mem0_graph": + from utils.client import Mem0Client + + client = Mem0Client(enable_graph="graph" in frame) + elif frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() + elif frame == "memobase": + from utils.client import MemobaseClient + + client = MemobaseClient() + elif frame == "memu": + from utils.client import MemuClient + + client = MemuClient() + elif frame == "supermemory": + from utils.client import SupermemoryClient + + client = SupermemoryClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Ingest samples + success_count = 0 + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit( + ingest_sample, client, sample, dataset_name, idx, frame, version + ) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=f"Ingesting {dataset_name}", + ): + try: + if future.result(): + success_count += 1 + except Exception as e: + print(f"Error processing sample: {e}") + + print(f"\n✅ Completed ingesting {dataset_name}: {success_count}/{len(dataset)} samples") + return success_count + + +def main(frame, version="default", num_workers=10, datasets=None, max_samples=None, use_e=False): + """Main ingestion function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH INGESTION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Determine which datasets to process + dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS + + # Filter valid datasets + valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] + if not valid_datasets: + print("❌ No valid datasets specified") + return + + print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") + + # Ingest each dataset + total_success = 0 + total_samples = 0 + for dataset_name in valid_datasets: + success = ingest_dataset(dataset_name, frame, version, num_workers, max_samples, use_e) + if success is not None: + total_success += success + total_samples += max_samples if max_samples else 200 # Approximate + + print(f"\n{'=' * 80}") + print(f"✅ INGESTION COMPLETE: {total_success} samples ingested".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + parser.add_argument( + "--datasets", + type=str, + default=None, + help="Comma-separated list of datasets to process (default: all)", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples per dataset (default: all)", + ) + parser.add_argument( + "--e", + action="store_true", + help="Use LongBench-E variant (uniform length distribution)", + ) + args = parser.parse_args() + + main( + args.lib, + args.version, + args.workers, + args.datasets, + args.max_samples, + args.e, + ) diff --git a/evaluation/scripts/longbench/longbench_metric.py b/evaluation/scripts/longbench/longbench_metric.py new file mode 100644 index 000000000..495a793ab --- /dev/null +++ b/evaluation/scripts/longbench/longbench_metric.py @@ -0,0 +1,235 @@ +import argparse +import json +import os +import sys + +import numpy as np + + +# Import LongBench metrics +# Try to import from the LongBench directory +LONGBENCH_METRICS_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "longbench_v2", + "LongBench-main", + "LongBench", +) + +if os.path.exists(LONGBENCH_METRICS_DIR): + sys.path.insert(0, LONGBENCH_METRICS_DIR) + try: + from metrics import ( + classification_score, + code_sim_score, + count_score, + qa_f1_score, + qa_f1_zh_score, + retrieval_score, + retrieval_zh_score, + rouge_score, + rouge_zh_score, + ) + except ImportError: + print(f"Warning: Could not import metrics from {LONGBENCH_METRICS_DIR}") + print("Please ensure LongBench metrics.py is available") + raise +else: + print(f"Error: LongBench metrics directory not found at {LONGBENCH_METRICS_DIR}") + raise FileNotFoundError("LongBench metrics directory not found") + +# Dataset to metric mapping (from LongBench eval.py) +dataset2metric = { + "narrativeqa": qa_f1_score, + "qasper": qa_f1_score, + "multifieldqa_en": qa_f1_score, + "multifieldqa_zh": qa_f1_zh_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "dureader": rouge_zh_score, + "gov_report": rouge_score, + "qmsum": rouge_score, + "multi_news": rouge_score, + "vcsum": rouge_zh_score, + "trec": classification_score, + "triviaqa": qa_f1_score, + "samsum": rouge_score, + "lsht": classification_score, + "passage_retrieval_en": retrieval_score, + "passage_count": count_score, + "passage_retrieval_zh": retrieval_zh_score, + "lcc": code_sim_score, + "repobench-p": code_sim_score, +} + + +def scorer(dataset, predictions, answers, all_classes): + """Calculate score for a dataset.""" + total_score = 0.0 + for prediction, ground_truths in zip(predictions, answers, strict=False): + score = 0.0 + # For some tasks, only take the first line + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip("\n").split("\n")[0] + + # Calculate max score across all ground truth answers + for ground_truth in ground_truths: + metric_func = dataset2metric.get(dataset) + if metric_func: + if dataset in ["trec", "lsht"]: + # Classification tasks need all_classes + score = max( + score, + metric_func(prediction, ground_truth, all_classes=all_classes), + ) + else: + score = max(score, metric_func(prediction, ground_truth)) + else: + print(f"Warning: No metric function for dataset {dataset}") + + total_score += score + + return round(100 * total_score / len(predictions), 2) if len(predictions) > 0 else 0.0 + + +def scorer_e(dataset, predictions, answers, lengths, all_classes): + """Calculate score for LongBench-E (with length-based analysis).""" + scores = {"0-4k": [], "4-8k": [], "8k+": []} + + for prediction, ground_truths, length in zip(predictions, answers, lengths, strict=False): + score = 0.0 + # For some tasks, only take the first line + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip("\n").split("\n")[0] + + # Calculate max score across all ground truth answers + metric_func = dataset2metric.get(dataset) + if metric_func: + for ground_truth in ground_truths: + if dataset in ["trec", "lsht"]: + score = max( + score, + metric_func(prediction, ground_truth, all_classes=all_classes), + ) + else: + score = max(score, metric_func(prediction, ground_truth)) + + # Categorize by length + if length < 4000: + scores["0-4k"].append(score) + elif length < 8000: + scores["4-8k"].append(score) + else: + scores["8k+"].append(score) + + # Calculate average scores per length category + for key in scores: + if len(scores[key]) > 0: + scores[key] = round(100 * np.mean(scores[key]), 2) + else: + scores[key] = 0.0 + + return scores + + +def main(frame, version="default", use_e=False): + """Main metric calculation function.""" + print("\n" + "=" * 80) + print(f"📊 LONGBENCH METRICS CALCULATION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load responses + responses_path = f"results/longbench/{frame}-{version}/{frame}_longbench_responses.json" + if not os.path.exists(responses_path): + print(f"❌ Responses not found: {responses_path}") + print("Please run longbench_responses.py first") + return + + with open(responses_path, encoding="utf-8") as f: + responses = json.load(f) + + # Calculate metrics for each dataset + all_scores = {} + overall_scores = [] + + for dataset_name, samples in responses.items(): + print(f"Calculating metrics for {dataset_name}...") + + predictions = [s.get("answer", "") for s in samples] + answers = [s.get("golden_answer", []) for s in samples] + all_classes = samples[0].get("all_classes") if samples else None + + if use_e: + lengths = [s.get("length", 0) for s in samples] + score = scorer_e(dataset_name, predictions, answers, lengths, all_classes) + else: + score = scorer(dataset_name, predictions, answers, all_classes) + + all_scores[dataset_name] = score + print(f" {dataset_name}: {score}") + + # For overall average, use single score (not length-based) + if use_e: + # Average across length categories + if isinstance(score, dict): + overall_scores.append(np.mean(list(score.values()))) + else: + overall_scores.append(score) + + # Calculate overall average + if overall_scores: + all_scores["average"] = round(np.mean(overall_scores), 2) + print(f"\nOverall Average: {all_scores['average']}") + + # Save metrics + output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_metrics.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(all_scores, f, ensure_ascii=False, indent=4) + + print(f"\n{'=' * 80}") + print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + # Print summary table + print("\n📊 Summary of Results:") + print("-" * 80) + for dataset, score in sorted(all_scores.items()): + if isinstance(score, dict): + print(f"{dataset:30s}: {score}") + else: + print(f"{dataset:30s}: {score:.2f}%") + print("-" * 80) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for loading results", + ) + parser.add_argument( + "--e", + action="store_true", + help="Use LongBench-E variant (uniform length distribution)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.e) diff --git a/evaluation/scripts/longbench/longbench_responses.py b/evaluation/scripts/longbench/longbench_responses.py new file mode 100644 index 000000000..2d160160a --- /dev/null +++ b/evaluation/scripts/longbench/longbench_responses.py @@ -0,0 +1,196 @@ +import argparse +import json +import os +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +# Dataset to prompt mapping (from LongBench config) +DATASET_PROMPTS = { + "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:', + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", + "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", + "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ', + "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:', + "lcc": "Please complete the code given below. \n{context}Next line of code:\n", + "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n", +} + + +def generate_response(llm_client, dataset_name, context, input_text): + """Generate response using LLM.""" + # Get prompt template for dataset + prompt_template = DATASET_PROMPTS.get(dataset_name, "{context}\n\nQuestion: {input}\nAnswer:") + + # Format prompt + if "{input}" in prompt_template: + prompt = prompt_template.format(context=context, input=input_text) + else: + # Some prompts don't have {input} placeholder (like gov_report, vcsum) + prompt = prompt_template.format(context=context) + + try: + response = llm_client.chat.completions.create( + model=os.getenv("CHAT_MODEL"), + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=0, + ) + result = response.choices[0].message.content or "" + return result + except Exception as e: + print(f"Error generating response: {e}") + return "" + + +def process_sample(search_result, llm_client): + """Process a single sample: generate answer.""" + start = time() + + dataset_name = search_result.get("dataset") + context = search_result.get("context", "") + input_text = search_result.get("input", "") + + # Generate answer + answer = generate_response(llm_client, dataset_name, context, input_text) + + response_duration_ms = (time() - start) * 1000 + + return { + "dataset": dataset_name, + "sample_idx": search_result.get("sample_idx"), + "input": input_text, + "answer": answer, + "golden_answer": search_result.get("answers", []), + "all_classes": search_result.get("all_classes"), + "length": search_result.get("length", 0), + "search_context": context, + "response_duration_ms": response_duration_ms, + "search_duration_ms": search_result.get("search_duration_ms", 0), + } + + +def main(frame, version="default", num_workers=10): + """Main response generation function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load search results + search_path = f"results/longbench/{frame}-{version}/{frame}_longbench_search_results.json" + if not os.path.exists(search_path): + print(f"❌ Search results not found: {search_path}") + print("Please run longbench_search.py first") + return + + with open(search_path, encoding="utf-8") as f: + search_results = json.load(f) + + # Initialize LLM client + llm_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), + base_url=os.getenv("CHAT_MODEL_BASE_URL"), + ) + print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") + + # Process all samples + all_responses = [] + for dataset_name, samples in search_results.items(): + print(f"\nProcessing {len(samples)} samples from {dataset_name}...") + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(process_sample, sample, llm_client) for sample in samples] + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=f"Generating responses for {dataset_name}", + ): + result = future.result() + if result: + all_responses.append(result) + + # Save responses + output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_responses.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Group by dataset + responses_by_dataset = {} + for response in all_responses: + dataset = response["dataset"] + if dataset not in responses_by_dataset: + responses_by_dataset[dataset] = [] + responses_by_dataset[dataset].append(response) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(responses_by_dataset, f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 80}") + print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for loading results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers) diff --git a/evaluation/scripts/longbench/longbench_search.py b/evaluation/scripts/longbench/longbench_search.py new file mode 100644 index 000000000..aaf7300e4 --- /dev/null +++ b/evaluation/scripts/longbench/longbench_search.py @@ -0,0 +1,309 @@ +import argparse +import json +import os +import sys + +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +# All LongBench datasets +LONGBENCH_DATASETS = [ + "narrativeqa", + "qasper", + "multifieldqa_en", + "multifieldqa_zh", + "hotpotqa", + "2wikimqa", + "musique", + "dureader", + "gov_report", + "qmsum", + "multi_news", + "vcsum", + "trec", + "triviaqa", + "samsum", + "lsht", + "passage_count", + "passage_retrieval_en", + "passage_retrieval_zh", + "lcc", + "repobench-p", +] + + +def memos_api_search(client, query, user_id, top_k, frame): + """Search using memos API.""" + start = time() + search_results = client.search(query=query, user_id=user_id, top_k=top_k) + + # Format context from search results based on frame type + context = "" + if frame == "memos-api" or frame == "memos-api-online": + if isinstance(search_results, dict) and "text_mem" in search_results: + context = "\n".join([i["memory"] for i in search_results["text_mem"][0]["memories"]]) + if "pref_string" in search_results: + context += f"\n{search_results.get('pref_string', '')}" + elif frame == "mem0" or frame == "mem0_graph": + if isinstance(search_results, dict) and "results" in search_results: + context = "\n".join( + [ + f"{m.get('created_at', '')}: {m.get('memory', '')}" + for m in search_results["results"] + ] + ) + elif frame == "memobase": + context = search_results if isinstance(search_results, str) else "" + elif frame == "memu": + context = "\n".join(search_results) if isinstance(search_results, list) else "" + elif frame == "supermemory": + context = search_results if isinstance(search_results, str) else "" + + duration_ms = (time() - start) * 1000 + return context, duration_ms + + +def process_sample(client, sample, dataset_name, sample_idx, frame, version, top_k): + """Process a single sample: search for relevant memories.""" + user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" + query = sample.get("input", "") + + if not query: + return None + + context, duration_ms = memos_api_search(client, query, user_id, top_k, frame) + + return { + "dataset": dataset_name, + "sample_idx": sample_idx, + "input": query, + "context": context, + "search_duration_ms": duration_ms, + "answers": sample.get("answers", []), + "all_classes": sample.get("all_classes"), + "length": sample.get("length", 0), + } + + +def load_dataset_from_local(dataset_name, use_e=False): + """Load LongBench dataset from local JSONL file.""" + # Determine data directory + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + # Determine filename + filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" + + filepath = os.path.join(data_dir, filename) + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSONL file + samples = [] + with open(filepath, encoding="utf-8") as f: + for line in f: + if line.strip(): + samples.append(json.loads(line)) + + return samples + + +def process_dataset( + dataset_name, frame, version, top_k=20, num_workers=10, max_samples=None, use_e=False +): + """Process a single dataset: search for all samples.""" + print(f"\n{'=' * 80}") + print(f"🔍 [SEARCHING DATASET: {dataset_name.upper()}]".center(80)) + print(f"{'=' * 80}\n") + + # Load dataset from local files + try: + dataset = load_dataset_from_local(dataset_name, use_e) + print(f"Loaded {len(dataset)} samples from {dataset_name}") + except FileNotFoundError as e: + print(f"❌ Error loading dataset {dataset_name}: {e}") + return [] + except Exception as e: + print(f"❌ Error loading dataset {dataset_name}: {e}") + return [] + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize client + client = None + if frame == "mem0" or frame == "mem0_graph": + from utils.client import Mem0Client + + client = Mem0Client(enable_graph="graph" in frame) + elif frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() + elif frame == "memobase": + from utils.client import MemobaseClient + + client = MemobaseClient() + elif frame == "memu": + from utils.client import MemuClient + + client = MemuClient() + elif frame == "supermemory": + from utils.client import SupermemoryClient + + client = SupermemoryClient() + else: + print(f"❌ Unsupported frame: {frame}") + return [] + + # Process samples + search_results = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit( + process_sample, client, sample, dataset_name, idx, frame, version, top_k + ) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=f"Searching {dataset_name}", + ): + result = future.result() + if result: + search_results.append(result) + + print(f"\n✅ Completed searching {dataset_name}: {len(search_results)} samples") + return search_results + + +def main( + frame, version="default", num_workers=10, top_k=20, datasets=None, max_samples=None, use_e=False +): + """Main search function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH SEARCH - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Determine which datasets to process + dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS + + # Filter valid datasets + valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] + if not valid_datasets: + print("❌ No valid datasets specified") + return + + print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") + + # Create output directory + os.makedirs(f"results/longbench/{frame}-{version}/", exist_ok=True) + + # Process each dataset + all_results = defaultdict(list) + for dataset_name in valid_datasets: + results = process_dataset( + dataset_name, frame, version, top_k, num_workers, max_samples, use_e + ) + all_results[dataset_name] = results + + # Save results + output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_search_results.json" + with open(output_path, "w", encoding="utf-8") as f: + json.dump(dict(all_results), f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 80}") + print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + parser.add_argument( + "--top_k", + type=int, + default=20, + help="Number of results to retrieve in search queries", + ) + parser.add_argument( + "--datasets", + type=str, + default=None, + help="Comma-separated list of datasets to process (default: all)", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples per dataset (default: all)", + ) + parser.add_argument( + "--e", + action="store_true", + help="Use LongBench-E variant (uniform length distribution)", + ) + args = parser.parse_args() + + main( + args.lib, + args.version, + args.workers, + args.top_k, + args.datasets, + args.max_samples, + args.e, + ) diff --git a/evaluation/scripts/longbench_v2/prepare_data.py b/evaluation/scripts/longbench_v2/prepare_data.py new file mode 100644 index 000000000..e69de29bb From 9beabbac3fae9ff0bc8c4335aa663e424840e101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sun, 7 Dec 2025 18:03:45 +0800 Subject: [PATCH 06/31] feat: more strict embedder trucation --- src/memos/embedders/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/embedders/base.py b/src/memos/embedders/base.py index d573521f6..22ef0d302 100644 --- a/src/memos/embedders/base.py +++ b/src/memos/embedders/base.py @@ -79,7 +79,7 @@ def __init__(self, config: BaseEmbedderConfig): """Initialize the embedding model with the given configuration.""" self.config = config - def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list)[str]: + def _truncate_texts(self, texts: list[str], approx_char_per_token=1.0) -> (list)[str]: """ Truncate texts to fit within max_tokens limit if configured. @@ -98,7 +98,7 @@ def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list) if len(t) < max_tokens * approx_char_per_token: truncated.append(t) else: - truncated.append(_truncate_text_to_tokens(t, max_tokens)) + truncated.append(t[:max_tokens]) return truncated @abstractmethod From 8f368bb7b347132d7f93f4365f5180628310106c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sun, 7 Dec 2025 18:41:51 +0800 Subject: [PATCH 07/31] feat: parallel processing fine mode in multi-modal-fine --- src/memos/mem_reader/multi_modal_struct.py | 30 +++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 3a9aa014b..4d4faff30 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -358,13 +358,15 @@ def _process_string_fine( if not fast_memory_items: return [] - fine_memory_items = [] + def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: + """Process a single fast memory item and return a list of fine items.""" + fine_items: list[TextualMemoryItem] = [] - for fast_item in fast_memory_items: # Extract memory text (string content) mem_str = fast_item.memory or "" if not mem_str.strip(): - continue + return fine_items + sources = fast_item.metadata.sources or [] if not isinstance(sources, list): sources = [sources] @@ -376,7 +378,8 @@ def _process_string_fine( resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) except Exception as e: logger.error(f"[MultiModalFine] Error calling LLM: {e}") - continue + return fine_items + if resp.get("memory list", []): for m in resp.get("memory list", []): try: @@ -396,7 +399,7 @@ def _process_string_fine( sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), ) - fine_memory_items.append(node) + fine_items.append(node) except Exception as e: logger.error(f"[MultiModalFine] parse error: {e}") elif resp.get("value") and resp.get("key"): @@ -411,10 +414,25 @@ def _process_string_fine( sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), ) - fine_memory_items.append(node) + fine_items.append(node) except Exception as e: logger.error(f"[MultiModalFine] parse error: {e}") + return fine_items + + fine_memory_items: list[TextualMemoryItem] = [] + + with ContextThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(_process_one_item, item) for item in fast_memory_items] + + for future in concurrent.futures.as_completed(futures): + try: + result = future.result() + if result: + fine_memory_items.extend(result) + except Exception as e: + logger.error(f"[MultiModalFine] worker error: {e}") + return fine_memory_items def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict: From be293bcd059be12cf4226870b4b7c470d3503de2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 11:54:28 +0800 Subject: [PATCH 08/31] feat: update parsers; add chunk info into source; remove origin_part --- src/memos/mem_reader/multi_modal_struct.py | 2 +- .../read_multi_modal/file_content_parser.py | 96 ++++++++++++++----- .../read_multi_modal/image_parser.py | 5 - .../read_multi_modal/text_content_parser.py | 1 - .../read_multi_modal/tool_parser.py | 3 - .../read_multi_modal/user_parser.py | 5 - 6 files changed, 74 insertions(+), 38 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 4d4faff30..ed139f958 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -422,7 +422,7 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: fine_memory_items: list[TextualMemoryItem] = [] - with ContextThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=30) as executor: futures = [executor.submit(_process_one_item, item) for item in fast_memory_items] for future in concurrent.futures.as_completed(futures): diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 67de3020d..fe1b44270 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -167,27 +167,40 @@ def create_source( self, message: File, info: dict[str, Any], + chunk_index: int | None = None, + chunk_total: int | None = None, + chunk_content: str | None = None, ) -> SourceMessage: """Create SourceMessage from file content part.""" if isinstance(message, dict): file_info = message.get("file", {}) - return SourceMessage( - type="file", - doc_path=file_info.get("filename") or file_info.get("file_id", ""), - content=file_info.get("file_data", ""), - original_part=message, - ) - return SourceMessage(type="file", doc_path=str(message)) + source_dict = { + "type": "file", + "doc_path": file_info.get("filename") or file_info.get("file_id", ""), + "content": chunk_content + if chunk_content is not None + else file_info.get("file_data", ""), + } + # Add chunk ordering information if provided + if chunk_index is not None: + source_dict["chunk_index"] = chunk_index + if chunk_total is not None: + source_dict["chunk_total"] = chunk_total + return SourceMessage(**source_dict) + source_dict = {"type": "file", "doc_path": str(message)} + if chunk_index is not None: + source_dict["chunk_index"] = chunk_index + if chunk_total is not None: + source_dict["chunk_total"] = chunk_total + if chunk_content is not None: + source_dict["content"] = chunk_content + return SourceMessage(**source_dict) def rebuild_from_source( self, source: SourceMessage, ) -> File: """Rebuild file content part from SourceMessage.""" - # Use original_part if available - if hasattr(source, "original_part") and source.original_part: - return source.original_part - # Rebuild from source fields return { "type": "file", @@ -311,9 +324,6 @@ def parse_fast( # Split content into chunks content_chunks = self._split_text(content) - # Create source - source = self.create_source(message, info) - # Extract info fields info_ = info.copy() if file_id: @@ -325,12 +335,23 @@ def parse_fast( # (since we don't have role information at this level) memory_type = "LongTermMemory" file_ids = [file_id] if file_id else [] + total_chunks = len(content_chunks) + # Create memory items for each chunk memory_items = [] for chunk_idx, chunk_text in enumerate(content_chunks): if not chunk_text.strip(): continue + # Create source for this specific chunk with its index and content + source = self.create_source( + message, + info, + chunk_index=chunk_idx, + chunk_total=total_chunks, + chunk_content=chunk_text, + ) + memory_item = TextualMemoryItem( memory=chunk_text, metadata=TreeNodeTextualMemoryMetadata( @@ -341,7 +362,7 @@ def parse_fast( tags=[ "mode:fast", "multimodal:file", - f"chunk:{chunk_idx + 1}/{len(content_chunks)}", + f"chunk:{chunk_idx + 1}/{total_chunks}", ], key=_derive_key(chunk_text), embedding=self.embedder.embed([chunk_text])[0], @@ -358,6 +379,14 @@ def parse_fast( # If no chunks were created, create a placeholder if not memory_items: + # Create source for placeholder (no chunk index since there are no chunks) + placeholder_source = self.create_source( + message, + info, + chunk_index=None, + chunk_total=0, + chunk_content=content, + ) memory_item = TextualMemoryItem( memory=content, metadata=TreeNodeTextualMemoryMetadata( @@ -369,7 +398,7 @@ def parse_fast( key=_derive_key(content), embedding=self.embedder.embed([content])[0], usage=[], - sources=[source], + sources=[placeholder_source], background="", confidence=0.99, type="fact", @@ -462,7 +491,9 @@ def parse_fine( parsed_text = self._handle_base64(file_data) else: - parsed_text = file_data + # TODO: discuss the proper place for processing + # string file-data + return [] # Priority 2: If file_id is provided but no file_data, try to use file_id as path elif file_id: logger.warning(f"[FileContentParser] File data not provided for file_id: {file_id}") @@ -490,9 +521,6 @@ def parse_fine( f"[FileContentParser] Failed to delete temp file {temp_file_path}: {e}" ) - # Create source - source = self.create_source(message, info) - # Extract info fields if not info: info = {} @@ -520,8 +548,25 @@ def _make_memory_item( mem_type: str = memory_type, tags: list[str] | None = None, key: str | None = None, + chunk_idx: int | None = None, ) -> TextualMemoryItem: - """Construct memory item with common fields.""" + """Construct memory item with common fields. + + Args: + value: Memory content (chunk text) + mem_type: Memory type + tags: Tags for the memory item + key: Key for the memory item + chunk_idx: Index of the chunk in the document (0-based) + """ + # Create source for this specific chunk with its index and content + chunk_source = self.create_source( + message, + info, + chunk_index=chunk_idx, + chunk_total=total_chunks, + chunk_content=value, + ) return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( @@ -533,7 +578,7 @@ def _make_memory_item( key=key if key is not None else _derive_key(value), embedding=self.embedder.embed([value])[0], usage=[], - sources=[source], + sources=[chunk_source], background="", confidence=0.99, type="fact", @@ -555,6 +600,7 @@ def _make_fallback( f"fallback:{reason}", f"chunk:{chunk_idx + 1}/{total_chunks}", ], + chunk_idx=chunk_idx, ) # Handle empty chunks case @@ -563,6 +609,7 @@ def _make_fallback( _make_memory_item( value=parsed_text or "[File: empty content]", tags=["mode:fine", "multimodal:file"], + chunk_idx=None, ) ] @@ -591,6 +638,7 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: mem_type=llm_mem_type, tags=tags, key=response_json.get("key"), + chunk_idx=chunk_idx, ) except Exception as e: logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}") @@ -637,6 +685,8 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: return memory_items or [ _make_memory_item( - value=parsed_text or "[File: empty content]", tags=["mode:fine", "multimodal:file"] + value=parsed_text or "[File: empty content]", + tags=["mode:fine", "multimodal:file"], + chunk_idx=None, ) ] diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 88991fbe7..5a19393a9 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -53,7 +53,6 @@ def create_source( return SourceMessage( type="image", content=url, - original_part=message, url=url, detail=detail, ) @@ -64,10 +63,6 @@ def rebuild_from_source( source: SourceMessage, ) -> ChatCompletionContentPartImageParam: """Rebuild image_url content part from SourceMessage.""" - # Use original_part if available - if hasattr(source, "original_part") and source.original_part: - return source.original_part - # Rebuild from source fields url = getattr(source, "url", "") or (source.content or "").replace("[image_url]: ", "") detail = getattr(source, "detail", "auto") diff --git a/src/memos/mem_reader/read_multi_modal/text_content_parser.py b/src/memos/mem_reader/read_multi_modal/text_content_parser.py index 5ff0a76fd..febc166ec 100644 --- a/src/memos/mem_reader/read_multi_modal/text_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/text_content_parser.py @@ -51,7 +51,6 @@ def create_source( return SourceMessage( type="text", content=text, - original_part=message, ) return SourceMessage(type="text", content=str(message)) diff --git a/src/memos/mem_reader/read_multi_modal/tool_parser.py b/src/memos/mem_reader/read_multi_modal/tool_parser.py index 09bd9e9d0..e13b684a7 100644 --- a/src/memos/mem_reader/read_multi_modal/tool_parser.py +++ b/src/memos/mem_reader/read_multi_modal/tool_parser.py @@ -79,7 +79,6 @@ def create_source( filename=file_info.get("filename", ""), file_id=file_info.get("file_id", ""), tool_call_id=tool_call_id, - original_part=part, ) ) elif part_type == "image_url": @@ -93,7 +92,6 @@ def create_source( content=file_info.get("url", ""), detail=file_info.get("detail", "auto"), tool_call_id=tool_call_id, - original_part=part, ) ) elif part_type == "input_audio": @@ -107,7 +105,6 @@ def create_source( content=file_info.get("data", ""), format=file_info.get("format", "wav"), tool_call_id=tool_call_id, - original_part=part, ) ) else: diff --git a/src/memos/mem_reader/read_multi_modal/user_parser.py b/src/memos/mem_reader/read_multi_modal/user_parser.py index c7b8ad4e9..359506e13 100644 --- a/src/memos/mem_reader/read_multi_modal/user_parser.py +++ b/src/memos/mem_reader/read_multi_modal/user_parser.py @@ -68,8 +68,6 @@ def create_source( chat_time=chat_time, message_id=message_id, content=part.get("text", ""), - # Save original part for reconstruction - original_part=part, ) ) elif part_type == "file": @@ -82,7 +80,6 @@ def create_source( message_id=message_id, doc_path=file_info.get("filename") or file_info.get("file_id", ""), content=file_info.get("file_data", ""), - original_part=part, ) ) elif part_type == "image_url": @@ -94,7 +91,6 @@ def create_source( chat_time=chat_time, message_id=message_id, image_path=image_info.get("url"), - original_part=part, ) ) else: @@ -106,7 +102,6 @@ def create_source( chat_time=chat_time, message_id=message_id, content=f"[{part_type}]", - original_part=part, ) ) else: From 2edd0a3082e56d172abcc010fa381cd342002950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 12:09:19 +0800 Subject: [PATCH 09/31] feat: modify chunk_content in file-fine-parser --- src/memos/mem_reader/read_multi_modal/file_content_parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 75b627af3..cce99e76a 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -547,6 +547,7 @@ def _make_memory_item( tags: list[str] | None = None, key: str | None = None, chunk_idx: int | None = None, + chunk_content: str | None = None, ) -> TextualMemoryItem: """Construct memory item with common fields. @@ -563,7 +564,7 @@ def _make_memory_item( info, chunk_index=chunk_idx, chunk_total=total_chunks, - chunk_content=value, + chunk_content=chunk_content, ) return TextualMemoryItem( memory=value, @@ -599,6 +600,7 @@ def _make_fallback( f"chunk:{chunk_idx + 1}/{total_chunks}", ], chunk_idx=chunk_idx, + chunk_content=chunk_text, ) # Handle empty chunks case @@ -637,6 +639,7 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: tags=tags, key=response_json.get("key"), chunk_idx=chunk_idx, + chunk_content=chunk_text, ) except Exception as e: logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}") From f80896e74801832135ab1039a7df25b4c8af6a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 16:08:32 +0800 Subject: [PATCH 10/31] fix: token counter bug --- evaluation/scripts/longbench/__init__.py | 1 - .../scripts/longbench/longbench_ingestion.py | 306 ----------------- .../scripts/longbench/longbench_metric.py | 235 ------------- .../scripts/longbench/longbench_responses.py | 196 ----------- .../scripts/longbench/longbench_search.py | 309 ------------------ .../scripts/longbench_v2/prepare_data.py | 0 src/memos/embedders/base.py | 2 +- src/memos/mem_reader/simple_struct.py | 2 +- .../tree_text_memory/organize/manager.py | 6 +- 9 files changed, 5 insertions(+), 1052 deletions(-) delete mode 100644 evaluation/scripts/longbench/__init__.py delete mode 100644 evaluation/scripts/longbench/longbench_ingestion.py delete mode 100644 evaluation/scripts/longbench/longbench_metric.py delete mode 100644 evaluation/scripts/longbench/longbench_responses.py delete mode 100644 evaluation/scripts/longbench/longbench_search.py delete mode 100644 evaluation/scripts/longbench_v2/prepare_data.py diff --git a/evaluation/scripts/longbench/__init__.py b/evaluation/scripts/longbench/__init__.py deleted file mode 100644 index 38cc006e3..000000000 --- a/evaluation/scripts/longbench/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# LongBench evaluation scripts diff --git a/evaluation/scripts/longbench/longbench_ingestion.py b/evaluation/scripts/longbench/longbench_ingestion.py deleted file mode 100644 index e2d2a8e7e..000000000 --- a/evaluation/scripts/longbench/longbench_ingestion.py +++ /dev/null @@ -1,306 +0,0 @@ -import argparse -import json -import os -import sys - -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime, timezone - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# All LongBench datasets -LONGBENCH_DATASETS = [ - "narrativeqa", - "qasper", - "multifieldqa_en", - "multifieldqa_zh", - "hotpotqa", - "2wikimqa", - "musique", - "dureader", - "gov_report", - "qmsum", - "multi_news", - "vcsum", - "trec", - "triviaqa", - "samsum", - "lsht", - "passage_count", - "passage_retrieval_en", - "passage_retrieval_zh", - "lcc", - "repobench-p", -] - - -def ingest_sample(client, sample, dataset_name, sample_idx, frame, version): - """Ingest a single LongBench sample as memories.""" - user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" - conv_id = f"longbench_{dataset_name}_{sample_idx}_{version}" - - # Get context and convert to messages - context = sample.get("context", "") - # not used now: input_text = sample.get("input", "") - - # For memos, we ingest the context as document content - # Split context into chunks if it's too long (optional, memos handles this internally) - # For now, we'll ingest the full context as a single message - messages = [ - { - "role": "assistant", - "content": context, - "chat_time": datetime.now(timezone.utc).isoformat(), - } - ] - - if "memos-api" in frame: - try: - client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif "mem0" in frame: - timestamp = int(datetime.now(timezone.utc).timestamp()) - try: - client.add(messages=messages, user_id=user_id, timestamp=timestamp, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif frame == "memobase": - for m in messages: - m["created_at"] = messages[0]["chat_time"] - try: - client.add(messages=messages, user_id=user_id, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif frame == "memu": - try: - client.add(messages=messages, user_id=user_id, iso_date=messages[0]["chat_time"]) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif frame == "supermemory": - try: - client.add(messages=messages, user_id=user_id) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - - return False - - -def load_dataset_from_local(dataset_name, use_e=False): - """Load LongBench dataset from local JSONL file.""" - # Determine data directory - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - # Determine filename - filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" - - filepath = os.path.join(data_dir, filename) - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSONL file - samples = [] - with open(filepath, encoding="utf-8") as f: - for line in f: - if line.strip(): - samples.append(json.loads(line)) - - return samples - - -def ingest_dataset(dataset_name, frame, version, num_workers=10, max_samples=None, use_e=False): - """Ingest a single LongBench dataset.""" - print(f"\n{'=' * 80}") - print(f"🔄 [INGESTING DATASET: {dataset_name.upper()}]".center(80)) - print(f"{'=' * 80}\n") - - # Load dataset from local files - try: - dataset = load_dataset_from_local(dataset_name, use_e) - print(f"Loaded {len(dataset)} samples from {dataset_name}") - except FileNotFoundError as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return - except Exception as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize client - client = None - if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client - - client = Mem0Client(enable_graph="graph" in frame) - elif frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - elif frame == "memos-api-online": - from utils.client import MemosApiOnlineClient - - client = MemosApiOnlineClient() - elif frame == "memobase": - from utils.client import MemobaseClient - - client = MemobaseClient() - elif frame == "memu": - from utils.client import MemuClient - - client = MemuClient() - elif frame == "supermemory": - from utils.client import SupermemoryClient - - client = SupermemoryClient() - else: - print(f"❌ Unsupported frame: {frame}") - return - - # Ingest samples - success_count = 0 - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit( - ingest_sample, client, sample, dataset_name, idx, frame, version - ) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Ingesting {dataset_name}", - ): - try: - if future.result(): - success_count += 1 - except Exception as e: - print(f"Error processing sample: {e}") - - print(f"\n✅ Completed ingesting {dataset_name}: {success_count}/{len(dataset)} samples") - return success_count - - -def main(frame, version="default", num_workers=10, datasets=None, max_samples=None, use_e=False): - """Main ingestion function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH INGESTION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Determine which datasets to process - dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS - - # Filter valid datasets - valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] - if not valid_datasets: - print("❌ No valid datasets specified") - return - - print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") - - # Ingest each dataset - total_success = 0 - total_samples = 0 - for dataset_name in valid_datasets: - success = ingest_dataset(dataset_name, frame, version, num_workers, max_samples, use_e) - if success is not None: - total_success += success - total_samples += max_samples if max_samples else 200 # Approximate - - print(f"\n{'=' * 80}") - print(f"✅ INGESTION COMPLETE: {total_success} samples ingested".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - parser.add_argument( - "--datasets", - type=str, - default=None, - help="Comma-separated list of datasets to process (default: all)", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples per dataset (default: all)", - ) - parser.add_argument( - "--e", - action="store_true", - help="Use LongBench-E variant (uniform length distribution)", - ) - args = parser.parse_args() - - main( - args.lib, - args.version, - args.workers, - args.datasets, - args.max_samples, - args.e, - ) diff --git a/evaluation/scripts/longbench/longbench_metric.py b/evaluation/scripts/longbench/longbench_metric.py deleted file mode 100644 index 495a793ab..000000000 --- a/evaluation/scripts/longbench/longbench_metric.py +++ /dev/null @@ -1,235 +0,0 @@ -import argparse -import json -import os -import sys - -import numpy as np - - -# Import LongBench metrics -# Try to import from the LongBench directory -LONGBENCH_METRICS_DIR = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - "longbench_v2", - "LongBench-main", - "LongBench", -) - -if os.path.exists(LONGBENCH_METRICS_DIR): - sys.path.insert(0, LONGBENCH_METRICS_DIR) - try: - from metrics import ( - classification_score, - code_sim_score, - count_score, - qa_f1_score, - qa_f1_zh_score, - retrieval_score, - retrieval_zh_score, - rouge_score, - rouge_zh_score, - ) - except ImportError: - print(f"Warning: Could not import metrics from {LONGBENCH_METRICS_DIR}") - print("Please ensure LongBench metrics.py is available") - raise -else: - print(f"Error: LongBench metrics directory not found at {LONGBENCH_METRICS_DIR}") - raise FileNotFoundError("LongBench metrics directory not found") - -# Dataset to metric mapping (from LongBench eval.py) -dataset2metric = { - "narrativeqa": qa_f1_score, - "qasper": qa_f1_score, - "multifieldqa_en": qa_f1_score, - "multifieldqa_zh": qa_f1_zh_score, - "hotpotqa": qa_f1_score, - "2wikimqa": qa_f1_score, - "musique": qa_f1_score, - "dureader": rouge_zh_score, - "gov_report": rouge_score, - "qmsum": rouge_score, - "multi_news": rouge_score, - "vcsum": rouge_zh_score, - "trec": classification_score, - "triviaqa": qa_f1_score, - "samsum": rouge_score, - "lsht": classification_score, - "passage_retrieval_en": retrieval_score, - "passage_count": count_score, - "passage_retrieval_zh": retrieval_zh_score, - "lcc": code_sim_score, - "repobench-p": code_sim_score, -} - - -def scorer(dataset, predictions, answers, all_classes): - """Calculate score for a dataset.""" - total_score = 0.0 - for prediction, ground_truths in zip(predictions, answers, strict=False): - score = 0.0 - # For some tasks, only take the first line - if dataset in ["trec", "triviaqa", "samsum", "lsht"]: - prediction = prediction.lstrip("\n").split("\n")[0] - - # Calculate max score across all ground truth answers - for ground_truth in ground_truths: - metric_func = dataset2metric.get(dataset) - if metric_func: - if dataset in ["trec", "lsht"]: - # Classification tasks need all_classes - score = max( - score, - metric_func(prediction, ground_truth, all_classes=all_classes), - ) - else: - score = max(score, metric_func(prediction, ground_truth)) - else: - print(f"Warning: No metric function for dataset {dataset}") - - total_score += score - - return round(100 * total_score / len(predictions), 2) if len(predictions) > 0 else 0.0 - - -def scorer_e(dataset, predictions, answers, lengths, all_classes): - """Calculate score for LongBench-E (with length-based analysis).""" - scores = {"0-4k": [], "4-8k": [], "8k+": []} - - for prediction, ground_truths, length in zip(predictions, answers, lengths, strict=False): - score = 0.0 - # For some tasks, only take the first line - if dataset in ["trec", "triviaqa", "samsum", "lsht"]: - prediction = prediction.lstrip("\n").split("\n")[0] - - # Calculate max score across all ground truth answers - metric_func = dataset2metric.get(dataset) - if metric_func: - for ground_truth in ground_truths: - if dataset in ["trec", "lsht"]: - score = max( - score, - metric_func(prediction, ground_truth, all_classes=all_classes), - ) - else: - score = max(score, metric_func(prediction, ground_truth)) - - # Categorize by length - if length < 4000: - scores["0-4k"].append(score) - elif length < 8000: - scores["4-8k"].append(score) - else: - scores["8k+"].append(score) - - # Calculate average scores per length category - for key in scores: - if len(scores[key]) > 0: - scores[key] = round(100 * np.mean(scores[key]), 2) - else: - scores[key] = 0.0 - - return scores - - -def main(frame, version="default", use_e=False): - """Main metric calculation function.""" - print("\n" + "=" * 80) - print(f"📊 LONGBENCH METRICS CALCULATION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load responses - responses_path = f"results/longbench/{frame}-{version}/{frame}_longbench_responses.json" - if not os.path.exists(responses_path): - print(f"❌ Responses not found: {responses_path}") - print("Please run longbench_responses.py first") - return - - with open(responses_path, encoding="utf-8") as f: - responses = json.load(f) - - # Calculate metrics for each dataset - all_scores = {} - overall_scores = [] - - for dataset_name, samples in responses.items(): - print(f"Calculating metrics for {dataset_name}...") - - predictions = [s.get("answer", "") for s in samples] - answers = [s.get("golden_answer", []) for s in samples] - all_classes = samples[0].get("all_classes") if samples else None - - if use_e: - lengths = [s.get("length", 0) for s in samples] - score = scorer_e(dataset_name, predictions, answers, lengths, all_classes) - else: - score = scorer(dataset_name, predictions, answers, all_classes) - - all_scores[dataset_name] = score - print(f" {dataset_name}: {score}") - - # For overall average, use single score (not length-based) - if use_e: - # Average across length categories - if isinstance(score, dict): - overall_scores.append(np.mean(list(score.values()))) - else: - overall_scores.append(score) - - # Calculate overall average - if overall_scores: - all_scores["average"] = round(np.mean(overall_scores), 2) - print(f"\nOverall Average: {all_scores['average']}") - - # Save metrics - output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_metrics.json" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(all_scores, f, ensure_ascii=False, indent=4) - - print(f"\n{'=' * 80}") - print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - # Print summary table - print("\n📊 Summary of Results:") - print("-" * 80) - for dataset, score in sorted(all_scores.items()): - if isinstance(score, dict): - print(f"{dataset:30s}: {score}") - else: - print(f"{dataset:30s}: {score:.2f}%") - print("-" * 80) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for loading results", - ) - parser.add_argument( - "--e", - action="store_true", - help="Use LongBench-E variant (uniform length distribution)", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.e) diff --git a/evaluation/scripts/longbench/longbench_responses.py b/evaluation/scripts/longbench/longbench_responses.py deleted file mode 100644 index 2d160160a..000000000 --- a/evaluation/scripts/longbench/longbench_responses.py +++ /dev/null @@ -1,196 +0,0 @@ -import argparse -import json -import os -import sys - -from concurrent.futures import ThreadPoolExecutor, as_completed -from time import time - -from dotenv import load_dotenv -from openai import OpenAI -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# Dataset to prompt mapping (from LongBench config) -DATASET_PROMPTS = { - "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", - "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:', - "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", - "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", - "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", - "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", - "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", - "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", - "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", - "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", - "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", - "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", - "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", - "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ', - "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:', - "lcc": "Please complete the code given below. \n{context}Next line of code:\n", - "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n", -} - - -def generate_response(llm_client, dataset_name, context, input_text): - """Generate response using LLM.""" - # Get prompt template for dataset - prompt_template = DATASET_PROMPTS.get(dataset_name, "{context}\n\nQuestion: {input}\nAnswer:") - - # Format prompt - if "{input}" in prompt_template: - prompt = prompt_template.format(context=context, input=input_text) - else: - # Some prompts don't have {input} placeholder (like gov_report, vcsum) - prompt = prompt_template.format(context=context) - - try: - response = llm_client.chat.completions.create( - model=os.getenv("CHAT_MODEL"), - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - temperature=0, - ) - result = response.choices[0].message.content or "" - return result - except Exception as e: - print(f"Error generating response: {e}") - return "" - - -def process_sample(search_result, llm_client): - """Process a single sample: generate answer.""" - start = time() - - dataset_name = search_result.get("dataset") - context = search_result.get("context", "") - input_text = search_result.get("input", "") - - # Generate answer - answer = generate_response(llm_client, dataset_name, context, input_text) - - response_duration_ms = (time() - start) * 1000 - - return { - "dataset": dataset_name, - "sample_idx": search_result.get("sample_idx"), - "input": input_text, - "answer": answer, - "golden_answer": search_result.get("answers", []), - "all_classes": search_result.get("all_classes"), - "length": search_result.get("length", 0), - "search_context": context, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_result.get("search_duration_ms", 0), - } - - -def main(frame, version="default", num_workers=10): - """Main response generation function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load search results - search_path = f"results/longbench/{frame}-{version}/{frame}_longbench_search_results.json" - if not os.path.exists(search_path): - print(f"❌ Search results not found: {search_path}") - print("Please run longbench_search.py first") - return - - with open(search_path, encoding="utf-8") as f: - search_results = json.load(f) - - # Initialize LLM client - llm_client = OpenAI( - api_key=os.getenv("CHAT_MODEL_API_KEY"), - base_url=os.getenv("CHAT_MODEL_BASE_URL"), - ) - print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") - - # Process all samples - all_responses = [] - for dataset_name, samples in search_results.items(): - print(f"\nProcessing {len(samples)} samples from {dataset_name}...") - - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(process_sample, sample, llm_client) for sample in samples] - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Generating responses for {dataset_name}", - ): - result = future.result() - if result: - all_responses.append(result) - - # Save responses - output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_responses.json" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - # Group by dataset - responses_by_dataset = {} - for response in all_responses: - dataset = response["dataset"] - if dataset not in responses_by_dataset: - responses_by_dataset[dataset] = [] - responses_by_dataset[dataset].append(response) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(responses_by_dataset, f, ensure_ascii=False, indent=2) - - print(f"\n{'=' * 80}") - print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for loading results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.workers) diff --git a/evaluation/scripts/longbench/longbench_search.py b/evaluation/scripts/longbench/longbench_search.py deleted file mode 100644 index aaf7300e4..000000000 --- a/evaluation/scripts/longbench/longbench_search.py +++ /dev/null @@ -1,309 +0,0 @@ -import argparse -import json -import os -import sys - -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from time import time - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# All LongBench datasets -LONGBENCH_DATASETS = [ - "narrativeqa", - "qasper", - "multifieldqa_en", - "multifieldqa_zh", - "hotpotqa", - "2wikimqa", - "musique", - "dureader", - "gov_report", - "qmsum", - "multi_news", - "vcsum", - "trec", - "triviaqa", - "samsum", - "lsht", - "passage_count", - "passage_retrieval_en", - "passage_retrieval_zh", - "lcc", - "repobench-p", -] - - -def memos_api_search(client, query, user_id, top_k, frame): - """Search using memos API.""" - start = time() - search_results = client.search(query=query, user_id=user_id, top_k=top_k) - - # Format context from search results based on frame type - context = "" - if frame == "memos-api" or frame == "memos-api-online": - if isinstance(search_results, dict) and "text_mem" in search_results: - context = "\n".join([i["memory"] for i in search_results["text_mem"][0]["memories"]]) - if "pref_string" in search_results: - context += f"\n{search_results.get('pref_string', '')}" - elif frame == "mem0" or frame == "mem0_graph": - if isinstance(search_results, dict) and "results" in search_results: - context = "\n".join( - [ - f"{m.get('created_at', '')}: {m.get('memory', '')}" - for m in search_results["results"] - ] - ) - elif frame == "memobase": - context = search_results if isinstance(search_results, str) else "" - elif frame == "memu": - context = "\n".join(search_results) if isinstance(search_results, list) else "" - elif frame == "supermemory": - context = search_results if isinstance(search_results, str) else "" - - duration_ms = (time() - start) * 1000 - return context, duration_ms - - -def process_sample(client, sample, dataset_name, sample_idx, frame, version, top_k): - """Process a single sample: search for relevant memories.""" - user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" - query = sample.get("input", "") - - if not query: - return None - - context, duration_ms = memos_api_search(client, query, user_id, top_k, frame) - - return { - "dataset": dataset_name, - "sample_idx": sample_idx, - "input": query, - "context": context, - "search_duration_ms": duration_ms, - "answers": sample.get("answers", []), - "all_classes": sample.get("all_classes"), - "length": sample.get("length", 0), - } - - -def load_dataset_from_local(dataset_name, use_e=False): - """Load LongBench dataset from local JSONL file.""" - # Determine data directory - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - # Determine filename - filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" - - filepath = os.path.join(data_dir, filename) - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSONL file - samples = [] - with open(filepath, encoding="utf-8") as f: - for line in f: - if line.strip(): - samples.append(json.loads(line)) - - return samples - - -def process_dataset( - dataset_name, frame, version, top_k=20, num_workers=10, max_samples=None, use_e=False -): - """Process a single dataset: search for all samples.""" - print(f"\n{'=' * 80}") - print(f"🔍 [SEARCHING DATASET: {dataset_name.upper()}]".center(80)) - print(f"{'=' * 80}\n") - - # Load dataset from local files - try: - dataset = load_dataset_from_local(dataset_name, use_e) - print(f"Loaded {len(dataset)} samples from {dataset_name}") - except FileNotFoundError as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return [] - except Exception as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return [] - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize client - client = None - if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client - - client = Mem0Client(enable_graph="graph" in frame) - elif frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - elif frame == "memos-api-online": - from utils.client import MemosApiOnlineClient - - client = MemosApiOnlineClient() - elif frame == "memobase": - from utils.client import MemobaseClient - - client = MemobaseClient() - elif frame == "memu": - from utils.client import MemuClient - - client = MemuClient() - elif frame == "supermemory": - from utils.client import SupermemoryClient - - client = SupermemoryClient() - else: - print(f"❌ Unsupported frame: {frame}") - return [] - - # Process samples - search_results = [] - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit( - process_sample, client, sample, dataset_name, idx, frame, version, top_k - ) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Searching {dataset_name}", - ): - result = future.result() - if result: - search_results.append(result) - - print(f"\n✅ Completed searching {dataset_name}: {len(search_results)} samples") - return search_results - - -def main( - frame, version="default", num_workers=10, top_k=20, datasets=None, max_samples=None, use_e=False -): - """Main search function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH SEARCH - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Determine which datasets to process - dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS - - # Filter valid datasets - valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] - if not valid_datasets: - print("❌ No valid datasets specified") - return - - print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") - - # Create output directory - os.makedirs(f"results/longbench/{frame}-{version}/", exist_ok=True) - - # Process each dataset - all_results = defaultdict(list) - for dataset_name in valid_datasets: - results = process_dataset( - dataset_name, frame, version, top_k, num_workers, max_samples, use_e - ) - all_results[dataset_name] = results - - # Save results - output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_search_results.json" - with open(output_path, "w", encoding="utf-8") as f: - json.dump(dict(all_results), f, ensure_ascii=False, indent=2) - - print(f"\n{'=' * 80}") - print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - parser.add_argument( - "--top_k", - type=int, - default=20, - help="Number of results to retrieve in search queries", - ) - parser.add_argument( - "--datasets", - type=str, - default=None, - help="Comma-separated list of datasets to process (default: all)", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples per dataset (default: all)", - ) - parser.add_argument( - "--e", - action="store_true", - help="Use LongBench-E variant (uniform length distribution)", - ) - args = parser.parse_args() - - main( - args.lib, - args.version, - args.workers, - args.top_k, - args.datasets, - args.max_samples, - args.e, - ) diff --git a/evaluation/scripts/longbench_v2/prepare_data.py b/evaluation/scripts/longbench_v2/prepare_data.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/memos/embedders/base.py b/src/memos/embedders/base.py index 22ef0d302..e46611d1a 100644 --- a/src/memos/embedders/base.py +++ b/src/memos/embedders/base.py @@ -23,7 +23,7 @@ def _count_tokens_for_embedding(text: str) -> int: enc = tiktoken.encoding_for_model("gpt-4o-mini") except Exception: enc = tiktoken.get_encoding("cl100k_base") - return len(enc.encode(text or "")) + return len(enc.encode(text or "", disallowed_special=())) except Exception: # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars if not text: diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index f43ad01ba..2dcf75846 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -89,7 +89,7 @@ def from_config(_config): _ENC = tiktoken.get_encoding("cl100k_base") def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "")) + return len(_ENC.encode(s or "", disallowed_special=())) except Exception: # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars def _count_tokens_text(s: str) -> int: diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 3226f7ca0..2a3bae944 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -92,9 +92,9 @@ def add( """ added_ids: list[str] = [] - with ContextThreadPoolExecutor(max_workers=200) as executor: + with ContextThreadPoolExecutor(max_workers=50) as executor: futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} - for future in as_completed(futures, timeout=60): + for future in as_completed(futures, timeout=500): try: ids = future.result() added_ids.extend(ids) @@ -102,7 +102,7 @@ def add( logger.exception("Memory processing error: ", exc_info=e) if mode == "sync": - for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + for mem_type in ["WorkingMemory"]: try: self.graph_store.remove_oldest_memory( memory_type="WorkingMemory", From b375d51ff09bbe92bc89de990c614246020b1837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 16:17:48 +0800 Subject: [PATCH 11/31] feat: enlarge polardb --- evaluation/scripts/long_bench-v2/__init__.py | 1 + .../long_bench-v2/longbench_v2_ingestion.py | 199 +++++++++++++++++ .../longbench_v2_ingestion_async.py | 158 ++++++++++++++ .../long_bench-v2/longbench_v2_metric.py | 142 ++++++++++++ .../long_bench-v2/longbench_v2_responses.py | 206 ++++++++++++++++++ .../long_bench-v2/longbench_v2_search.py | 192 ++++++++++++++++ src/memos/graph_dbs/polardb.py | 2 +- 7 files changed, 899 insertions(+), 1 deletion(-) create mode 100644 evaluation/scripts/long_bench-v2/__init__.py create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_metric.py create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_responses.py create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_search.py diff --git a/evaluation/scripts/long_bench-v2/__init__.py b/evaluation/scripts/long_bench-v2/__init__.py new file mode 100644 index 000000000..786c0ce03 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/__init__.py @@ -0,0 +1 @@ +# LongBench v2 evaluation scripts diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py new file mode 100644 index 000000000..d84a63d93 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -0,0 +1,199 @@ +import argparse +import json +import os +import sys +import threading + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +def ingest_sample( + client, sample, sample_idx, frame, version, success_records, record_file, file_lock +): + """Ingest a single LongBench v2 sample as memories.""" + # Skip if already processed + if str(sample_idx) in success_records: + return True + + user_id = f"longbench_v2_{sample_idx}_{version}" + conv_id = f"longbench_v2_{sample_idx}_{version}" + + # Get context and convert to messages + context = sample.get("context", "") + + # For memos, we ingest the context as document content + messages = [ + { + "type": "file", + "file": { + "file_data": context, + "file_id": str(sample_idx), + }, + } + ] + + if "memos-api" in frame: + try: + client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx}") + # Record successful ingestion (thread-safe) + with file_lock, open(record_file, "a") as f: + f.write(f"{sample_idx}\n") + f.flush() + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}") + return False + + return False + + +def load_dataset_from_local(): + """Load LongBench v2 dataset from local JSON file.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + filepath = os.path.join(data_dir, "data.json") + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSON file + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + + return samples + + +def main(frame, version="default", num_workers=10, max_samples=None): + """Main ingestion function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load dataset from local file + try: + dataset = load_dataset_from_local() + print(f"Loaded {len(dataset)} samples from LongBench v2") + except FileNotFoundError as e: + print(f"❌ Error loading dataset: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize checkpoint file for resume functionality + checkpoint_dir = os.path.join( + ROOT_DIR, "evaluation", "results", "longbench_v2", f"{frame}-{version}" + ) + os.makedirs(checkpoint_dir, exist_ok=True) + record_file = os.path.join(checkpoint_dir, "success_records.txt") + + # Load existing success records for resume + success_records = set() + if os.path.exists(record_file): + with open(record_file) as f: + for line in f: + line = line.strip() + if line: + success_records.add(line) + print(f"📋 Found {len(success_records)} already processed samples (resume mode)") + else: + print("📋 Starting fresh ingestion (no checkpoint found)") + + # Initialize client + client = None + if frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Ingest samples + success_count = len(success_records) # Start with already processed count + file_lock = threading.Lock() # Lock for thread-safe file writing + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit( + ingest_sample, + client, + sample, + idx, + frame, + version, + success_records, + record_file, + file_lock, + ) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Ingesting LongBench v2", + ): + try: + if future.result(): + success_count += 1 + except Exception as e: + print(f"Error processing sample: {e}") + + print(f"\n{'=' * 80}") + print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="long-bench-v2-1208-1556", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=20, + help="Number of parallel workers", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples to process (default: all)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py new file mode 100644 index 000000000..c23d7885f --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py @@ -0,0 +1,158 @@ +import argparse +import json +import os +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +def ingest_sample(client, sample, sample_idx, frame, version): + """Ingest a single LongBench v2 sample as memories.""" + user_id = f"longbench_v2_{sample_idx}_{version}" + conv_id = f"longbench_v2_{sample_idx}_{version}" + + # Get context and convert to messages + context = sample.get("context", "") + + # For memos, we ingest the context as document content + messages = [ + { + "type": "file", + "file": { + "file_data": context, + "file_id": str(sample_idx), + }, + } + ] + + if "memos-api" in frame: + try: + client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}") + return False + + return False + + +def load_dataset_from_local(): + """Load LongBench v2 dataset from local JSON file.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + filepath = os.path.join(data_dir, "data.json") + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSON file + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + + return samples + + +def main(frame, version="default", num_workers=10, max_samples=None): + """Main ingestion function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load dataset from local file + try: + dataset = load_dataset_from_local() + print(f"Loaded {len(dataset)} samples from LongBench v2") + except FileNotFoundError as e: + print(f"❌ Error loading dataset: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize client + client = None + if frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Ingest samples + success_count = 0 + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit(ingest_sample, client, sample, idx, frame, version) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Ingesting LongBench v2", + ): + try: + if future.result(): + success_count += 1 + except Exception as e: + print(f"Error processing sample: {e}") + + print(f"\n{'=' * 80}") + print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="long-bench-v2-1208-1556-async", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=20, + help="Number of parallel workers", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples to process (default: all)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py new file mode 100644 index 000000000..5fee9a3de --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -0,0 +1,142 @@ +import argparse +import json +import os + + +def calculate_accuracy(responses): + """Calculate accuracy metrics for LongBench v2.""" + total = len(responses) + if total == 0: + return {} + + # Overall accuracy + correct = sum(1 for r in responses if r.get("judge", False)) + overall_acc = round(100 * correct / total, 1) + + # By difficulty + easy_items = [r for r in responses if r.get("difficulty") == "easy"] + hard_items = [r for r in responses if r.get("difficulty") == "hard"] + easy_acc = ( + round(100 * sum(1 for r in easy_items if r.get("judge", False)) / len(easy_items), 1) + if easy_items + else 0.0 + ) + hard_acc = ( + round(100 * sum(1 for r in hard_items if r.get("judge", False)) / len(hard_items), 1) + if hard_items + else 0.0 + ) + + # By length + short_items = [r for r in responses if r.get("length") == "short"] + medium_items = [r for r in responses if r.get("length") == "medium"] + long_items = [r for r in responses if r.get("length") == "long"] + + short_acc = ( + round(100 * sum(1 for r in short_items if r.get("judge", False)) / len(short_items), 1) + if short_items + else 0.0 + ) + medium_acc = ( + round(100 * sum(1 for r in medium_items if r.get("judge", False)) / len(medium_items), 1) + if medium_items + else 0.0 + ) + long_acc = ( + round(100 * sum(1 for r in long_items if r.get("judge", False)) / len(long_items), 1) + if long_items + else 0.0 + ) + + # By domain + domain_stats = {} + for response in responses: + domain = response.get("domain", "Unknown") + if domain not in domain_stats: + domain_stats[domain] = {"total": 0, "correct": 0} + domain_stats[domain]["total"] += 1 + if response.get("judge", False): + domain_stats[domain]["correct"] += 1 + + domain_acc = { + domain: round(100 * stats["correct"] / stats["total"], 1) + for domain, stats in domain_stats.items() + } + + return { + "overall": overall_acc, + "easy": easy_acc, + "hard": hard_acc, + "short": short_acc, + "medium": medium_acc, + "long": long_acc, + "by_domain": domain_acc, + "total_samples": total, + "correct_samples": correct, + } + + +def main(frame, version="default"): + """Main metric calculation function.""" + print("\n" + "=" * 80) + print(f"📊 LONGBENCH V2 METRICS CALCULATION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load responses + responses_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_responses.json" + if not os.path.exists(responses_path): + print(f"❌ Responses not found: {responses_path}") + print("Please run longbench_v2_responses.py first") + return + + with open(responses_path, encoding="utf-8") as f: + responses = json.load(f) + + # Calculate metrics + metrics = calculate_accuracy(responses) + + # Save metrics + output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, ensure_ascii=False, indent=4) + + print(f"\n{'=' * 80}") + print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + # Print summary table + print("\n📊 Summary of Results:") + print("-" * 80) + print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.1f}%") + print(f"{'Easy':<30s}: {metrics['easy']:.1f}%") + print(f"{'Hard':<30s}: {metrics['hard']:.1f}%") + print(f"{'Short':<30s}: {metrics['short']:.1f}%") + print(f"{'Medium':<30s}: {metrics['medium']:.1f}%") + print(f"{'Long':<30s}: {metrics['long']:.1f}%") + print("\nBy Domain:") + for domain, acc in metrics["by_domain"].items(): + print(f" {domain:<28s}: {acc:.1f}%") + print(f"\nTotal Samples: {metrics['total_samples']}") + print(f"Correct: {metrics['correct_samples']}") + print("-" * 80) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for loading results", + ) + args = parser.parse_args() + + main(args.lib, args.version) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py new file mode 100644 index 000000000..3e19dc95f --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py @@ -0,0 +1,206 @@ +import argparse +import json +import os +import re +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +# Prompt template from LongBench v2 +LONGBENCH_V2_PROMPT = """Please read the following text and answer the question below. + + +{context} + + +What is the correct answer to this question: {question} +Choices: +(A) {choice_A} +(B) {choice_B} +(C) {choice_C} +(D) {choice_D} + +Format your response as follows: "The correct answer is (insert answer here)".""" + + +def extract_answer(response): + """Extract answer from response (A, B, C, or D).""" + response = response.replace("*", "") + # Try to find "The correct answer is (X)" pattern + match = re.search(r"The correct answer is \(([A-D])\)", response, re.IGNORECASE) + if match: + return match.group(1).upper() + else: + match = re.search(r"The correct answer is ([A-D])", response, re.IGNORECASE) + if match: + return match.group(1).upper() + else: + # Try to find standalone A, B, C, or D + match = re.search(r"\b([A-D])\b", response) + if match: + return match.group(1).upper() + return None + + +def generate_response(llm_client, context, question, choice_a, choice_b, choice_c, choice_d): + """Generate response using LLM.""" + prompt = LONGBENCH_V2_PROMPT.format( + context=context, + question=question, + choice_A=choice_a, + choice_B=choice_b, + choice_C=choice_c, + choice_D=choice_d, + ) + + try: + response = llm_client.chat.completions.create( + model=os.getenv("CHAT_MODEL"), + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=128, + ) + result = response.choices[0].message.content or "" + return result + except Exception as e: + print(f"Error generating response: {e}") + return "" + + +def process_sample(search_result, llm_client): + """Process a single sample: generate answer.""" + start = time() + + context = search_result.get("context", "") + question = search_result.get("question", "") + choice_a = search_result.get("choice_A", "") + choice_b = search_result.get("choice_B", "") + choice_c = search_result.get("choice_C", "") + choice_d = search_result.get("choice_D", "") + + # Generate answer + response = generate_response( + llm_client, context, question, choice_a, choice_b, choice_c, choice_d + ) + + # Extract answer (A, B, C, or D) + pred = extract_answer(response) + + response_duration_ms = (time() - start) * 1000 + + return { + "sample_idx": search_result.get("sample_idx"), + "_id": search_result.get("_id"), + "domain": search_result.get("domain"), + "sub_domain": search_result.get("sub_domain"), + "difficulty": search_result.get("difficulty"), + "length": search_result.get("length"), + "question": question, + "choice_A": choice_a, + "choice_B": choice_b, + "choice_C": choice_c, + "choice_D": choice_d, + "answer": search_result.get("answer"), + "pred": pred, + "response": response, + "judge": pred == search_result.get("answer") if pred else False, + "search_context": context, + "response_duration_ms": response_duration_ms, + "search_duration_ms": search_result.get("search_duration_ms", 0), + } + + +def main(frame, version="default", num_workers=10): + """Main response generation function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load search results + search_path = ( + f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_search_results.json" + ) + if not os.path.exists(search_path): + print(f"❌ Search results not found: {search_path}") + print("Please run longbench_v2_search.py first") + return + + with open(search_path, encoding="utf-8") as f: + search_results = json.load(f) + + # Initialize LLM client + llm_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), + base_url=os.getenv("CHAT_MODEL_BASE_URL"), + ) + print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") + + # Process all samples + all_responses = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(process_sample, sample, llm_client) for sample in search_results] + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Generating responses", + ): + result = future.result() + if result: + all_responses.append(result) + + # Save responses + output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_responses.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(all_responses, f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 80}") + print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for loading results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py new file mode 100644 index 000000000..f46928498 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -0,0 +1,192 @@ +import argparse +import json +import os +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +def memos_api_search(client, query, user_id, top_k, frame): + """Search using memos API.""" + start = time() + search_results = client.search(query=query, user_id=user_id, top_k=top_k) + + # Format context from search results based on frame type + context = "" + if ( + (frame == "memos-api" or frame == "memos-api-online") + and isinstance(search_results, dict) + and "text_mem" in search_results + ): + context = "\n".join([i["memory"] for i in search_results["text_mem"][0]["memories"]]) + if "pref_string" in search_results: + context += f"\n{search_results.get('pref_string', '')}" + + duration_ms = (time() - start) * 1000 + return context, duration_ms + + +def process_sample(client, sample, sample_idx, frame, version, top_k): + """Process a single sample: search for relevant memories.""" + user_id = f"longbench_v2_{sample_idx}_{version}" + query = sample.get("question", "") + + if not query: + return None + + context, duration_ms = memos_api_search(client, query, user_id, top_k, frame) + + return { + "sample_idx": sample_idx, + "_id": sample.get("_id"), + "domain": sample.get("domain"), + "sub_domain": sample.get("sub_domain"), + "difficulty": sample.get("difficulty"), + "length": sample.get("length"), + "question": query, + "choice_A": sample.get("choice_A"), + "choice_B": sample.get("choice_B"), + "choice_C": sample.get("choice_C"), + "choice_D": sample.get("choice_D"), + "answer": sample.get("answer"), + "context": context, + "search_duration_ms": duration_ms, + } + + +def load_dataset_from_local(): + """Load LongBench v2 dataset from local JSON file.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + filepath = os.path.join(data_dir, "data.json") + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSON file + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + + return samples + + +def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): + """Main search function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 SEARCH - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load dataset from local file + try: + dataset = load_dataset_from_local() + print(f"Loaded {len(dataset)} samples from LongBench v2") + except FileNotFoundError as e: + print(f"❌ Error loading dataset: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize client + client = None + if frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Process samples + search_results = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit(process_sample, client, sample, idx, frame, version, top_k) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Searching LongBench v2", + ): + result = future.result() + if result: + search_results.append(result) + + # Save results + os.makedirs(f"results/long_bench-v2/{frame}-{version}/", exist_ok=True) + output_path = ( + f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_search_results.json" + ) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(search_results, f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 80}") + print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + parser.add_argument( + "--top_k", + type=int, + default=20, + help="Number of results to retrieve in search queries", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples to process (default: all)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers, args.top_k, args.max_samples) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ddcbfe285..85e5d14f8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -151,7 +151,7 @@ def __init__(self, config: PolarDBGraphDBConfig): # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( minconn=5, - maxconn=100, + maxconn=2000, host=host, port=port, user=user, From 69dd3a8bc9695cf0d1deda329f100cdba1f1a718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 20:20:40 +0800 Subject: [PATCH 12/31] feat: derease parallrl --- src/memos/memories/textual/tree_text_memory/organize/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 2a3bae944..470d2c483 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -92,7 +92,7 @@ def add( """ added_ids: list[str] = [] - with ContextThreadPoolExecutor(max_workers=50) as executor: + with ContextThreadPoolExecutor(max_workers=10) as executor: futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=500): try: From ac38046ff22e96c6d84265236434fc9befdc4244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 20:28:04 +0800 Subject: [PATCH 13/31] feat: add image parser in file --- .../read_multi_modal/file_content_parser.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index cce99e76a..8edcbfe52 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -2,6 +2,7 @@ import concurrent.futures import os +import re import tempfile from typing import Any @@ -13,6 +14,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_reader.read_multi_modal.base import BaseMessageParser, _derive_key +from memos.mem_reader.read_multi_modal.image_parser import ImageParser from memos.mem_reader.read_multi_modal.utils import ( detect_lang, get_parser, @@ -129,6 +131,91 @@ def _handle_local(self, data: str) -> str: logger.info("[FileContentParser] Local file paths are not supported in fine mode.") return "" + def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) -> str: + """ + Extract all images from markdown text and process them using ImageParser. + Replaces image references with extracted text content. + + Args: + text: Markdown text containing image references + info: Dictionary containing user_id and session_id + **kwargs: Additional parameters for ImageParser + + Returns: + Text with image references replaced by extracted content + """ + if not text or not self.image_parser: + return text + + # Pattern to match markdown images: ![](url) or ![alt](url) + image_pattern = r"!\[([^\]]*)\]\(([^)]+)\)" + + # Find all image matches first + image_matches = list(re.finditer(image_pattern, text)) + if not image_matches: + return text + + logger.info(f"[FileContentParser] Found {len(image_matches)} images to process") + + # Process images and build replacement map + replacements = {} + for idx, match in enumerate(image_matches, 1): + image_url = match.group(2) + + try: + # Construct image message format for ImageParser + image_message = { + "type": "image_url", + "image_url": { + "url": image_url, + "detail": "auto", + }, + } + + # Process image using ImageParser + logger.info( + f"[FileContentParser] Processing image {idx}/{len(image_matches)}: {image_url}" + ) + memory_items = self.image_parser.parse_fine(image_message, info, **kwargs) + + # Extract text content from memory items (only strings as requested) + extracted_texts = [] + for item in memory_items: + if hasattr(item, "memory") and item.memory: + extracted_texts.append(str(item.memory)) + + if extracted_texts: + # Combine all extracted texts + extracted_content = "\n".join(extracted_texts) + # Replace image with extracted content + replacements[match.group(0)] = ( + f"\n[Image Content from {image_url}]:\n{extracted_content}\n" + ) + else: + # If no content extracted, keep original with a note + logger.warning( + f"[FileContentParser] No content extracted from image: {image_url}" + ) + replacements[match.group(0)] = ( + f"\n[Image: {image_url} - No content extracted]\n" + ) + + except Exception as e: + logger.error(f"[FileContentParser] Error processing image {image_url}: {e}") + # On error, keep original image reference + replacements[match.group(0)] = match.group(0) + + # Replace all images in the text + processed_text = text + for original, replacement in replacements.items(): + processed_text = processed_text.replace(original, replacement, 1) + + logger.info( + f"[FileContentParser] Processed {len(image_matches)} images, " + f"extracted content for {sum(1 for r in replacements.values() if 'Image Content' in r)} images" + ) + return processed_text + def __init__( self, embedder: BaseEmbedder, @@ -149,6 +236,8 @@ def __init__( """ super().__init__(embedder, llm) self.parser = parser + # Initialize ImageParser for processing images in markdown + self.image_parser = ImageParser(embedder, llm) if llm else None # Get inner markdown hostnames from config or environment if direct_markdown_hostnames is not None: @@ -519,6 +608,10 @@ def parse_fine( f"[FileContentParser] Failed to delete temp file {temp_file_path}: {e}" ) + # Extract and process images from parsed_text + if is_markdown and parsed_text and self.image_parser: + parsed_text = self._extract_and_process_images(parsed_text, info, **kwargs) + # Extract info fields if not info: info = {} From 37bcc904e4b79932f63bc60e5d665acc5e062ef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 8 Dec 2025 20:42:20 +0800 Subject: [PATCH 14/31] feat: update file_content_parser --- .../read_multi_modal/file_content_parser.py | 142 ++++++++++++------ 1 file changed, 94 insertions(+), 48 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 972758c42..408736d2f 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -131,9 +131,65 @@ def _handle_local(self, data: str) -> str: logger.info("[FileContentParser] Local file paths are not supported in fine mode.") return "" + def _process_single_image( + self, image_url: str, original_ref: str, info: dict[str, Any], **kwargs + ) -> tuple[str, str]: + """ + Process a single image and return (original_ref, replacement_text). + + Args: + image_url: URL of the image to process + original_ref: Original markdown image reference to replace + info: Dictionary containing user_id and session_id + **kwargs: Additional parameters for ImageParser + + Returns: + Tuple of (original_ref, replacement_text) + """ + try: + # Construct image message format for ImageParser + image_message = { + "type": "image_url", + "image_url": { + "url": image_url, + "detail": "auto", + }, + } + + # Process image using ImageParser + logger.debug(f"[FileContentParser] Processing image: {image_url}") + memory_items = self.image_parser.parse_fine(image_message, info, **kwargs) + + # Extract text content from memory items (only strings as requested) + extracted_texts = [] + for item in memory_items: + if hasattr(item, "memory") and item.memory: + extracted_texts.append(str(item.memory)) + + if extracted_texts: + # Combine all extracted texts + extracted_content = "\n".join(extracted_texts) + # Replace image with extracted content + return ( + original_ref, + f"\n[Image Content from {image_url}]:\n{extracted_content}\n", + ) + else: + # If no content extracted, keep original with a note + logger.warning(f"[FileContentParser] No content extracted from image: {image_url}") + return ( + original_ref, + f"\n[Image: {image_url} - No content extracted]\n", + ) + + except Exception as e: + logger.error(f"[FileContentParser] Error processing image {image_url}: {e}") + # On error, keep original image reference + return (original_ref, original_ref) + def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) -> str: """ - Extract all images from markdown text and process them using ImageParser. + Extract all images from markdown text and process them using ImageParser in parallel. Replaces image references with extracted text content. Args: @@ -155,64 +211,54 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) if not image_matches: return text - logger.info(f"[FileContentParser] Found {len(image_matches)} images to process") + logger.info(f"[FileContentParser] Found {len(image_matches)} images to process in parallel") - # Process images and build replacement map - replacements = {} - for idx, match in enumerate(image_matches, 1): + # Prepare tasks for parallel processing + tasks = [] + for match in image_matches: image_url = match.group(2) + original_ref = match.group(0) + tasks.append((image_url, original_ref)) - try: - # Construct image message format for ImageParser - image_message = { - "type": "image_url", - "image_url": { - "url": image_url, - "detail": "auto", - }, - } - - # Process image using ImageParser - logger.info( - f"[FileContentParser] Processing image {idx}/{len(image_matches)}: {image_url}" - ) - memory_items = self.image_parser.parse_fine(image_message, info, **kwargs) - - # Extract text content from memory items (only strings as requested) - extracted_texts = [] - for item in memory_items: - if hasattr(item, "memory") and item.memory: - extracted_texts.append(str(item.memory)) - - if extracted_texts: - # Combine all extracted texts - extracted_content = "\n".join(extracted_texts) - # Replace image with extracted content - replacements[match.group(0)] = ( - f"\n[Image Content from {image_url}]:\n{extracted_content}\n" - ) - else: - # If no content extracted, keep original with a note - logger.warning( - f"[FileContentParser] No content extracted from image: {image_url}" - ) - replacements[match.group(0)] = ( - f"\n[Image: {image_url} - No content extracted]\n" - ) + # Process images in parallel + replacements = {} + max_workers = min(len(tasks), 10) # Limit concurrent image processing - except Exception as e: - logger.error(f"[FileContentParser] Error processing image {image_url}: {e}") - # On error, keep original image reference - replacements[match.group(0)] = match.group(0) + with ContextThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit( + self._process_single_image, image_url, original_ref, info, **kwargs + ): (image_url, original_ref) + for image_url, original_ref in tasks + } + + # Collect results with progress tracking + for future in tqdm( + concurrent.futures.as_completed(futures), + total=len(futures), + desc="[FileContentParser] Processing images", + ): + try: + original_ref, replacement = future.result() + replacements[original_ref] = replacement + except Exception as e: + image_url, original_ref = futures[future] + logger.error(f"[FileContentParser] Future failed for image {image_url}: {e}") + # On error, keep original image reference + replacements[original_ref] = original_ref # Replace all images in the text processed_text = text for original, replacement in replacements.items(): processed_text = processed_text.replace(original, replacement, 1) + # Count successfully extracted images + success_count = sum( + 1 for replacement in replacements.values() if "Image Content from" in replacement + ) logger.info( - f"[FileContentParser] Processed {len(image_matches)} images, " - f"extracted content for {sum(1 for r in replacements.values() if 'Image Content' in r)} images" + f"[FileContentParser] Processed {len(image_matches)} images in parallel, " + f"extracted content for {success_count} images" ) return processed_text From 20af5d0204af0feb1dc88e2d814bcc7e75541f44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 9 Dec 2025 12:34:09 +0800 Subject: [PATCH 15/31] feat: modify long_bench_v2 --- .../long_bench-v2/longbench_v2_ingestion.py | 4 +- .../long_bench-v2/longbench_v2_metric.py | 5 +- .../long_bench-v2/longbench_v2_responses.py | 85 ++++++++++++-- .../long_bench-v2/longbench_v2_search.py | 92 ++++++++++++--- .../scripts/long_bench-v2/wait_scheduler.py | 67 +++++++++++ evaluation/scripts/run_longbench_v2_eval.sh | 110 ++++++++++++++++++ 6 files changed, 334 insertions(+), 29 deletions(-) create mode 100644 evaluation/scripts/long_bench-v2/wait_scheduler.py create mode 100755 evaluation/scripts/run_longbench_v2_eval.sh diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py index d84a63d93..72a02397d 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -179,13 +179,13 @@ def main(frame, version="default", num_workers=10, max_samples=None): parser.add_argument( "--version", type=str, - default="long-bench-v2-1208-1556", + default="default", help="Version identifier for saving results", ) parser.add_argument( "--workers", type=int, - default=20, + default=3, help="Number of parallel workers", ) parser.add_argument( diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py index 5fee9a3de..6489dc401 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -92,8 +92,11 @@ def main(frame, version="default"): with open(responses_path, encoding="utf-8") as f: responses = json.load(f) + # Only keep entries with non-empty context (search_context) to align with response generation + filtered = [r for r in responses if str(r.get("search_context", "")).strip() != ""] + # Calculate metrics - metrics = calculate_accuracy(responses) + metrics = calculate_accuracy(filtered) # Save metrics output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py index 3e19dc95f..8c4439436 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py @@ -3,6 +3,7 @@ import os import re import sys +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from time import time @@ -85,8 +86,13 @@ def generate_response(llm_client, context, question, choice_a, choice_b, choice_ return "" -def process_sample(search_result, llm_client): +def process_sample(search_result, llm_client, success_records, record_file, file_lock): """Process a single sample: generate answer.""" + sample_idx = search_result.get("sample_idx") + # Skip if already processed + if sample_idx is not None and str(sample_idx) in success_records: + return None + start = time() context = search_result.get("context", "") @@ -96,6 +102,10 @@ def process_sample(search_result, llm_client): choice_c = search_result.get("choice_C", "") choice_d = search_result.get("choice_D", "") + # Skip empty/placeholder contexts (e.g., "\n" or whitespace-only) + if not context or context.strip() == "": + return None + # Generate answer response = generate_response( llm_client, context, question, choice_a, choice_b, choice_c, choice_d @@ -106,7 +116,7 @@ def process_sample(search_result, llm_client): response_duration_ms = (time() - start) * 1000 - return { + result = { "sample_idx": search_result.get("sample_idx"), "_id": search_result.get("_id"), "domain": search_result.get("domain"), @@ -123,10 +133,20 @@ def process_sample(search_result, llm_client): "response": response, "judge": pred == search_result.get("answer") if pred else False, "search_context": context, + # Preserve full search results payload (e.g., list of memories) + "search_results": search_result.get("search_results"), "response_duration_ms": response_duration_ms, "search_duration_ms": search_result.get("search_duration_ms", 0), } + # Record successful processing (thread-safe) + if sample_idx is not None: + with file_lock, open(record_file, "a") as f: + f.write(f"{sample_idx}\n") + f.flush() + + return result + def main(frame, version="default", num_workers=10): """Main response generation function.""" @@ -136,10 +156,16 @@ def main(frame, version="default", num_workers=10): print(f"🚀 LONGBENCH V2 RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) print("=" * 80 + "\n") - # Load search results - search_path = ( - f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_search_results.json" + # Initialize checkpoint file for resume functionality + checkpoint_dir = os.path.join( + ROOT_DIR, "evaluation", "results", "long_bench-v2", f"{frame}-{version}" ) + os.makedirs(checkpoint_dir, exist_ok=True) + record_file = os.path.join(checkpoint_dir, "response_success_records.txt") + search_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_search_results.json") + output_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_responses.json") + + # Load search results if not os.path.exists(search_path): print(f"❌ Search results not found: {search_path}") print("Please run longbench_v2_search.py first") @@ -148,6 +174,30 @@ def main(frame, version="default", num_workers=10): with open(search_path, encoding="utf-8") as f: search_results = json.load(f) + # Load existing results and success records for resume + existing_results = {} + success_records = set() + if os.path.exists(output_path): + with open(output_path, encoding="utf-8") as f: + existing_results_list = json.load(f) + for result in existing_results_list: + sample_idx = result.get("sample_idx") + if sample_idx is not None: + existing_results[sample_idx] = result + success_records.add(str(sample_idx)) + print(f"📋 Found {len(existing_results)} existing responses (resume mode)") + else: + print("📋 Starting fresh response generation (no checkpoint found)") + + # Load additional success records from checkpoint file + if os.path.exists(record_file): + with open(record_file) as f: + for line in f: + line = line.strip() + if line and line not in success_records: + success_records.add(line) + print(f"📋 Total {len(success_records)} samples already processed") + # Initialize LLM client llm_client = OpenAI( api_key=os.getenv("CHAT_MODEL_API_KEY"), @@ -156,9 +206,15 @@ def main(frame, version="default", num_workers=10): print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") # Process all samples - all_responses = [] + new_results = [] + file_lock = threading.Lock() # Lock for thread-safe file writing with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(process_sample, sample, llm_client) for sample in search_results] + futures = [ + executor.submit( + process_sample, sample, llm_client, success_records, record_file, file_lock + ) + for sample in search_results + ] for future in tqdm( as_completed(futures), @@ -167,11 +223,16 @@ def main(frame, version="default", num_workers=10): ): result = future.result() if result: - all_responses.append(result) - - # Save responses - output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_responses.json" - os.makedirs(os.path.dirname(output_path), exist_ok=True) + new_results.append(result) + # Update existing results with new result + sample_idx = result.get("sample_idx") + if sample_idx is not None: + existing_results[sample_idx] = result + + # Merge and save all results + all_responses = list(existing_results.values()) + # Sort by sample_idx to maintain order + all_responses.sort(key=lambda x: x.get("sample_idx", 0)) with open(output_path, "w", encoding="utf-8") as f: json.dump(all_responses, f, ensure_ascii=False, indent=2) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py index f46928498..686ff0ba1 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_search.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -2,6 +2,7 @@ import json import os import sys +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from time import time @@ -24,7 +25,7 @@ def memos_api_search(client, query, user_id, top_k, frame): start = time() search_results = client.search(query=query, user_id=user_id, top_k=top_k) - # Format context from search results based on frame type + # Format context from search results based on frame type for backward compatibility context = "" if ( (frame == "memos-api" or frame == "memos-api-online") @@ -36,20 +37,26 @@ def memos_api_search(client, query, user_id, top_k, frame): context += f"\n{search_results.get('pref_string', '')}" duration_ms = (time() - start) * 1000 - return context, duration_ms + return context, duration_ms, search_results -def process_sample(client, sample, sample_idx, frame, version, top_k): +def process_sample( + client, sample, sample_idx, frame, version, top_k, success_records, record_file, file_lock +): """Process a single sample: search for relevant memories.""" + # Skip if already processed + if str(sample_idx) in success_records: + return None + user_id = f"longbench_v2_{sample_idx}_{version}" query = sample.get("question", "") if not query: return None - context, duration_ms = memos_api_search(client, query, user_id, top_k, frame) + context, duration_ms, search_results = memos_api_search(client, query, user_id, top_k, frame) - return { + result = { "sample_idx": sample_idx, "_id": sample.get("_id"), "domain": sample.get("domain"), @@ -63,9 +70,18 @@ def process_sample(client, sample, sample_idx, frame, version, top_k): "choice_D": sample.get("choice_D"), "answer": sample.get("answer"), "context": context, + # Preserve full search results instead of only the concatenated context + "search_results": search_results, "search_duration_ms": duration_ms, } + # Record successful processing (thread-safe) + with file_lock, open(record_file, "a") as f: + f.write(f"{sample_idx}\n") + f.flush() + + return result + def load_dataset_from_local(): """Load LongBench v2 dataset from local JSON file.""" @@ -111,6 +127,38 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): dataset = dataset[:max_samples] print(f"Limited to {len(dataset)} samples") + # Initialize checkpoint file for resume functionality + checkpoint_dir = os.path.join( + ROOT_DIR, "evaluation", "results", "long_bench-v2", f"{frame}-{version}" + ) + os.makedirs(checkpoint_dir, exist_ok=True) + record_file = os.path.join(checkpoint_dir, "search_success_records.txt") + output_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_search_results.json") + + # Load existing results and success records for resume + existing_results = {} + success_records = set() + if os.path.exists(output_path): + with open(output_path, encoding="utf-8") as f: + existing_results_list = json.load(f) + for result in existing_results_list: + sample_idx = result.get("sample_idx") + if sample_idx is not None: + existing_results[sample_idx] = result + success_records.add(str(sample_idx)) + print(f"📋 Found {len(existing_results)} existing search results (resume mode)") + else: + print("📋 Starting fresh search (no checkpoint found)") + + # Load additional success records from checkpoint file + if os.path.exists(record_file): + with open(record_file) as f: + for line in f: + line = line.strip() + if line and line not in success_records: + success_records.add(line) + print(f"📋 Total {len(success_records)} samples already processed") + # Initialize client client = None if frame == "memos-api": @@ -126,11 +174,23 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): return # Process samples - search_results = [] + new_results = [] + file_lock = threading.Lock() # Lock for thread-safe file writing with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [] for idx, sample in enumerate(dataset): - future = executor.submit(process_sample, client, sample, idx, frame, version, top_k) + future = executor.submit( + process_sample, + client, + sample, + idx, + frame, + version, + top_k, + success_records, + record_file, + file_lock, + ) futures.append(future) for future in tqdm( @@ -140,13 +200,17 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): ): result = future.result() if result: - search_results.append(result) + new_results.append(result) + # Update existing results with new result + sample_idx = result.get("sample_idx") + if sample_idx is not None: + existing_results[sample_idx] = result + + # Merge and save all results + search_results = list(existing_results.values()) + # Sort by sample_idx to maintain order + search_results.sort(key=lambda x: x.get("sample_idx", 0)) - # Save results - os.makedirs(f"results/long_bench-v2/{frame}-{version}/", exist_ok=True) - output_path = ( - f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_search_results.json" - ) with open(output_path, "w", encoding="utf-8") as f: json.dump(search_results, f, ensure_ascii=False, indent=2) @@ -172,7 +236,7 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): parser.add_argument( "--workers", type=int, - default=10, + default=1, help="Number of parallel workers", ) parser.add_argument( diff --git a/evaluation/scripts/long_bench-v2/wait_scheduler.py b/evaluation/scripts/long_bench-v2/wait_scheduler.py new file mode 100644 index 000000000..716869a11 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/wait_scheduler.py @@ -0,0 +1,67 @@ +import os +import time + +import requests + +from dotenv import load_dotenv + + +def wait_until_completed(params: dict, interval: float = 2.0, timeout: float = 600.0): + """ + Keep polling /product/scheduler/status until status == 'completed' (or terminal). + + params: dict passed as query params, e.g. {"user_id": "xxx"} or {"user_id": "xxx", "task_id": "..."} + interval: seconds between polls + timeout: max seconds to wait before raising TimeoutError + """ + load_dotenv() + base_url = os.getenv("MEMOS_URL") + if not base_url: + raise RuntimeError("MEMOS_URL not set in environment") + + url = f"{base_url}/product/scheduler/status" + start = time.time() + active_states = {"waiting", "pending", "in_progress"} + + while True: + resp = requests.get(url, params=params, timeout=10) + resp.raise_for_status() + data = resp.json() + + items = data.get("data", []) if isinstance(data, dict) else [] + statuses = [item.get("status") for item in items if isinstance(item, dict)] + status_set = set(statuses) + + # Print current status snapshot + print(f"Current status: {status_set or 'empty'}") + + # Completed if no active states remain + if not status_set or status_set.isdisjoint(active_states): + print("Task completed!") + return data + + if (time.time() - start) > timeout: + raise TimeoutError(f"Timeout after {timeout}s; last statuses={status_set or 'empty'}") + + time.sleep(interval) + + +if __name__ == "__main__": + import argparse + import json + + parser = argparse.ArgumentParser() + parser.add_argument( + "--user_id", default="longbench_v2_0_long-bench-v2-1208-2119-async", help="User ID to query" + ) + parser.add_argument("--task_id", help="Optional task_id to query") + parser.add_argument("--interval", type=float, default=2.0, help="Poll interval seconds") + parser.add_argument("--timeout", type=float, default=600.0, help="Timeout seconds") + args = parser.parse_args() + + params = {"user_id": args.user_id} + if args.task_id: + params["task_id"] = args.task_id + + result = wait_until_completed(params, interval=args.interval, timeout=args.timeout) + print(json.dumps(result, indent=2, ensure_ascii=False)) diff --git a/evaluation/scripts/run_longbench_v2_eval.sh b/evaluation/scripts/run_longbench_v2_eval.sh new file mode 100755 index 000000000..917c57bfb --- /dev/null +++ b/evaluation/scripts/run_longbench_v2_eval.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Common parameters for all scripts +LIB="memos-api" +VERSION="long-bench-v2-1208-1556-async" +WORKERS=10 +TOPK=20 +MAX_SAMPLES="" # Empty means all samples +WAIT_INTERVAL=2 # seconds between polls +WAIT_TIMEOUT=900 # seconds per user + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --lib) + LIB="$2" + shift 2 + ;; + --version) + VERSION="$2" + shift 2 + ;; + --workers) + WORKERS="$2" + shift 2 + ;; + --top_k) + TOPK="$2" + shift 2 + ;; + --max_samples) + MAX_SAMPLES="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Build max_samples argument +MAX_SAMPLES_ARG="" +if [ -n "$MAX_SAMPLES" ]; then + MAX_SAMPLES_ARG="--max_samples $MAX_SAMPLES" +fi + +echo "Running LongBench v2 evaluation with:" +echo " LIB: $LIB" +echo " VERSION: $VERSION" +echo " WORKERS: $WORKERS" +echo " TOPK: $TOPK" +echo " MAX_SAMPLES: ${MAX_SAMPLES:-all}" +echo "" + +# Step 2: Search +echo "" +echo "==========================================" +echo "Step 2: Running longbench_v2_search.py..." +echo "==========================================" +python scripts/long_bench-v2/longbench_v2_search.py \ + --lib $LIB \ + --version $VERSION \ + --top_k $TOPK \ + --workers $WORKERS \ + $MAX_SAMPLES_ARG + +if [ $? -ne 0 ]; then + echo "Error running longbench_v2_search.py" + exit 1 +fi + +# Step 3: Response Generation +echo "" +echo "==========================================" +echo "Step 3: Running longbench_v2_responses.py..." +echo "==========================================" +python scripts/long_bench-v2/longbench_v2_responses.py \ + --lib $LIB \ + --version $VERSION \ + --workers $WORKERS + +if [ $? -ne 0 ]; then + echo "Error running longbench_v2_responses.py" + exit 1 +fi + +# Step 4: Metrics Calculation +echo "" +echo "==========================================" +echo "Step 4: Running longbench_v2_metric.py..." +echo "==========================================" +python scripts/long_bench-v2/longbench_v2_metric.py \ + --lib $LIB \ + --version $VERSION + +if [ $? -ne 0 ]; then + echo "Error running longbench_v2_metric.py" + exit 1 +fi + +echo "" +echo "==========================================" +echo "All steps completed successfully!" +echo "==========================================" +echo "" +echo "Results are saved in: results/long_bench-v2/$LIB-$VERSION/" +echo " - Search results: ${LIB}_longbench_v2_search_results.json" +echo " - Responses: ${LIB}_longbench_v2_responses.json" +echo " - Metrics: ${LIB}_longbench_v2_metrics.json" From 0ef1bb54c5173632733852dbc9eef4ff0d348004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 9 Dec 2025 15:06:40 +0800 Subject: [PATCH 16/31] feat: modify long_bench_v2 --- .../long_bench-v2/longbench_v2_ingestion.py | 2 +- .../longbench_v2_ingestion_async.py | 158 ------------------ .../long_bench-v2/longbench_v2_metric.py | 4 +- .../long_bench-v2/longbench_v2_responses.py | 2 +- .../long_bench-v2/longbench_v2_search.py | 48 +++++- 5 files changed, 50 insertions(+), 164 deletions(-) delete mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py index 72a02397d..fc65e4975 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -106,7 +106,7 @@ def main(frame, version="default", num_workers=10, max_samples=None): # Initialize checkpoint file for resume functionality checkpoint_dir = os.path.join( - ROOT_DIR, "evaluation", "results", "longbench_v2", f"{frame}-{version}" + ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}" ) os.makedirs(checkpoint_dir, exist_ok=True) record_file = os.path.join(checkpoint_dir, "success_records.txt") diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py deleted file mode 100644 index c23d7885f..000000000 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py +++ /dev/null @@ -1,158 +0,0 @@ -import argparse -import json -import os -import sys - -from concurrent.futures import ThreadPoolExecutor, as_completed - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -def ingest_sample(client, sample, sample_idx, frame, version): - """Ingest a single LongBench v2 sample as memories.""" - user_id = f"longbench_v2_{sample_idx}_{version}" - conv_id = f"longbench_v2_{sample_idx}_{version}" - - # Get context and convert to messages - context = sample.get("context", "") - - # For memos, we ingest the context as document content - messages = [ - { - "type": "file", - "file": { - "file_data": context, - "file_id": str(sample_idx), - }, - } - ] - - if "memos-api" in frame: - try: - client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}") - return False - - return False - - -def load_dataset_from_local(): - """Load LongBench v2 dataset from local JSON file.""" - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - filepath = os.path.join(data_dir, "data.json") - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSON file - with open(filepath, encoding="utf-8") as f: - samples = json.load(f) - - return samples - - -def main(frame, version="default", num_workers=10, max_samples=None): - """Main ingestion function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load dataset from local file - try: - dataset = load_dataset_from_local() - print(f"Loaded {len(dataset)} samples from LongBench v2") - except FileNotFoundError as e: - print(f"❌ Error loading dataset: {e}") - return - except Exception as e: - print(f"❌ Error loading dataset: {e}") - return - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize client - client = None - if frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - else: - print(f"❌ Unsupported frame: {frame}") - return - - # Ingest samples - success_count = 0 - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit(ingest_sample, client, sample, idx, frame, version) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc="Ingesting LongBench v2", - ): - try: - if future.result(): - success_count += 1 - except Exception as e: - print(f"Error processing sample: {e}") - - print(f"\n{'=' * 80}") - print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=["memos-api", "memos-api-online"], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="long-bench-v2-1208-1556-async", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=20, - help="Number of parallel workers", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples to process (default: all)", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.workers, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py index 6489dc401..6a4fc2b7f 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -83,7 +83,7 @@ def main(frame, version="default"): print("=" * 80 + "\n") # Load responses - responses_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_responses.json" + responses_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_responses.json" if not os.path.exists(responses_path): print(f"❌ Responses not found: {responses_path}") print("Please run longbench_v2_responses.py first") @@ -99,7 +99,7 @@ def main(frame, version="default"): metrics = calculate_accuracy(filtered) # Save metrics - output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" + output_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py index 8c4439436..cc1586112 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py @@ -158,7 +158,7 @@ def main(frame, version="default", num_workers=10): # Initialize checkpoint file for resume functionality checkpoint_dir = os.path.join( - ROOT_DIR, "evaluation", "results", "long_bench-v2", f"{frame}-{version}" + ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}" ) os.makedirs(checkpoint_dir, exist_ok=True) record_file = os.path.join(checkpoint_dir, "response_success_records.txt") diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py index 686ff0ba1..9730e937e 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_search.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -25,6 +25,46 @@ def memos_api_search(client, query, user_id, top_k, frame): start = time() search_results = client.search(query=query, user_id=user_id, top_k=top_k) + def _reorder_memories_by_sources(sr: dict) -> list: + """ + Reorder text_mem[0].memories using sources' chunk_index (ascending). + Falls back to original order if no chunk_index is found. + """ + if not isinstance(sr, dict): + return [] + text_mem = sr.get("text_mem") or [] + if not text_mem or not text_mem[0].get("memories"): + return [] + memories = list(text_mem[0]["memories"]) + + def _first_source(mem: dict): + if not isinstance(mem, dict): + return None + # Prefer top-level sources, else metadata.sources + return (mem.get("sources") or mem.get("metadata", {}).get("sources") or []) or None + + def _chunk_index(mem: dict): + srcs = _first_source(mem) + if not srcs or not isinstance(srcs, list): + return None + for s in srcs: + if isinstance(s, dict) and s.get("chunk_index") is not None: + return s.get("chunk_index") + return None + + # Collect keys + keyed = [] + for i, mem in enumerate(memories): + ci = _chunk_index(mem) + keyed.append((ci, i, mem)) # keep original order as tie-breaker + + # If no chunk_index present at all, return original + if all(ci is None for ci, _, _ in keyed): + return memories + + keyed.sort(key=lambda x: (float("inf") if x[0] is None else x[0], x[1])) + return [k[2] for k in keyed] + # Format context from search results based on frame type for backward compatibility context = "" if ( @@ -32,7 +72,11 @@ def memos_api_search(client, query, user_id, top_k, frame): and isinstance(search_results, dict) and "text_mem" in search_results ): - context = "\n".join([i["memory"] for i in search_results["text_mem"][0]["memories"]]) + ordered_memories = _reorder_memories_by_sources(search_results) + if not ordered_memories and search_results["text_mem"][0].get("memories"): + ordered_memories = search_results["text_mem"][0]["memories"] + + context = "\n".join([i.get("memory", "") for i in ordered_memories]) if "pref_string" in search_results: context += f"\n{search_results.get('pref_string', '')}" @@ -129,7 +173,7 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): # Initialize checkpoint file for resume functionality checkpoint_dir = os.path.join( - ROOT_DIR, "evaluation", "results", "long_bench-v2", f"{frame}-{version}" + ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}" ) os.makedirs(checkpoint_dir, exist_ok=True) record_file = os.path.join(checkpoint_dir, "search_success_records.txt") From b58ee88db79efdb448a9783a69068070ce8d807c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 9 Dec 2025 15:46:34 +0800 Subject: [PATCH 17/31] fix: image bug --- .../mem_reader/read_multi_modal/file_content_parser.py | 3 ++- src/memos/mem_reader/read_multi_modal/image_parser.py | 6 +++++- src/memos/mem_reader/read_multi_modal/tool_parser.py | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 408736d2f..20fc03ec2 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -471,6 +471,7 @@ def parse_fast( total_chunks = len(content_chunks) # Create memory items for each chunk + content_chunk_embeddings = self.embedder.embed(content_chunks) memory_items = [] for chunk_idx, chunk_text in enumerate(content_chunks): if not chunk_text.strip(): @@ -499,7 +500,7 @@ def parse_fast( f"chunk:{chunk_idx + 1}/{total_chunks}", ], key=_derive_key(chunk_text), - embedding=self.embedder.embed([chunk_text])[0], + embedding=content_chunk_embeddings[chunk_idx], usage=[], sources=[source], background="", diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 5a19393a9..741295089 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -64,7 +64,11 @@ def rebuild_from_source( ) -> ChatCompletionContentPartImageParam: """Rebuild image_url content part from SourceMessage.""" # Rebuild from source fields - url = getattr(source, "url", "") or (source.content or "").replace("[image_url]: ", "") + url = ( + getattr(source, "url", "") + or getattr(source, "image_path", "") + or (source.content or "").replace("[image_url]: ", "") + ) detail = getattr(source, "detail", "auto") return { "type": "image_url", diff --git a/src/memos/mem_reader/read_multi_modal/tool_parser.py b/src/memos/mem_reader/read_multi_modal/tool_parser.py index e13b684a7..705896489 100644 --- a/src/memos/mem_reader/read_multi_modal/tool_parser.py +++ b/src/memos/mem_reader/read_multi_modal/tool_parser.py @@ -79,6 +79,7 @@ def create_source( filename=file_info.get("filename", ""), file_id=file_info.get("file_id", ""), tool_call_id=tool_call_id, + file_info=file_info, ) ) elif part_type == "image_url": From f94b0012f813ac23002da4600ec1853d7f0c9557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 9 Dec 2025 15:48:42 +0800 Subject: [PATCH 18/31] feat: increase playground depth --- src/memos/memories/textual/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 7f022b439..75eae30e8 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -210,7 +210,7 @@ def search( def get_relevant_subgraph( self, query: str, - top_k: int = 5, + top_k: int = 20, depth: int = 2, center_status: str = "activated", user_name: str | None = None, From eba9e96216975495591ccd62b730b88e09dd0449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 9 Dec 2025 17:48:01 +0800 Subject: [PATCH 19/31] feat: set parsed_text None in file parser --- .../read_multi_modal/file_content_parser.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 20fc03ec2..8fa0f2454 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -612,8 +612,6 @@ def parse_fine( # Use parser from utils if parser: parsed_text = parser.parse(temp_file_path) - else: - parsed_text = "[File parsing error: Parser not available]" except Exception as e: logger.error( f"[FileContentParser] Error parsing downloaded file: {e}" @@ -633,18 +631,9 @@ def parse_fine( # Priority 2: If file_id is provided but no file_data, try to use file_id as path elif file_id: logger.warning(f"[FileContentParser] File data not provided for file_id: {file_id}") - parsed_text = f"[File ID: {file_id}]: File data not provided" - - # If no content could be parsed, create a placeholder - if not parsed_text: - if filename: - parsed_text = f"[File: {filename}] File data not provided" - else: - parsed_text = "[File: unknown] File data not provided" except Exception as e: logger.error(f"[FileContentParser] Error in parse_fine: {e}") - parsed_text = f"[File parsing error: {e!s}]" finally: # Clean up temporary file @@ -656,7 +645,8 @@ def parse_fine( logger.warning( f"[FileContentParser] Failed to delete temp file {temp_file_path}: {e}" ) - + if not parsed_text: + return [] # Extract and process images from parsed_text if is_markdown and parsed_text and self.image_parser: parsed_text = self._extract_and_process_images(parsed_text, info, **kwargs) From 918bc6acd944453e38b9930375de8d8745be5ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 9 Dec 2025 20:38:39 +0800 Subject: [PATCH 20/31] fix: file_ids bug in file-mode --- src/memos/mem_reader/multi_modal_struct.py | 33 ++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index ed139f958..88ef56b7c 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -206,6 +206,7 @@ def _build_window_from_items( memory_texts = [] all_sources = [] roles = set() + aggregated_file_ids: list[str] = [] for item in items: if item.memory: @@ -226,6 +227,15 @@ def _build_window_from_items( elif isinstance(source, dict) and source.get("role"): roles.add(source.get("role")) + # Aggregate file_ids from metadata + metadata = getattr(item, "metadata", None) + if metadata is not None: + item_file_ids = getattr(metadata, "file_ids", None) + if isinstance(item_file_ids, list): + for fid in item_file_ids: + if fid and fid not in aggregated_file_ids: + aggregated_file_ids.append(fid) + # Determine memory_type based on roles (same logic as simple_struct) # UserMemory if only user role, else LongTermMemory memory_type = "UserMemory" if roles == {"user"} else "LongTermMemory" @@ -238,12 +248,16 @@ def _build_window_from_items( return None # Create aggregated memory item (similar to _build_fast_node in simple_struct) + extra_kwargs: dict[str, Any] = {} + if aggregated_file_ids: + extra_kwargs["file_ids"] = aggregated_file_ids aggregated_item = self._make_memory_item( value=merged_text, info=info, memory_type=memory_type, tags=["mode:fast"], sources=all_sources, + **extra_kwargs, ) return aggregated_item @@ -371,6 +385,19 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: if not isinstance(sources, list): sources = [sources] + # Extract file_ids from fast item metadata for propagation + metadata = getattr(fast_item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata is not None else None + file_ids = [fid for fid in file_ids if fid] if isinstance(file_ids, list) else [] + + # Build per-item info copy and kwargs for _make_memory_item + info_per_item = info.copy() + if file_ids and "file_id" not in info_per_item: + info_per_item["file_id"] = file_ids[0] + extra_kwargs: dict[str, Any] = {} + if file_ids: + extra_kwargs["file_ids"] = file_ids + # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) @@ -392,12 +419,13 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: # Create fine mode memory item (same as simple_struct) node = self._make_memory_item( value=m.get("value", ""), - info=info, + info=info_per_item, memory_type=memory_type, tags=m.get("tags", []), key=m.get("key", ""), sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), + **extra_kwargs, ) fine_items.append(node) except Exception as e: @@ -407,12 +435,13 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: # Create fine mode memory item (same as simple_struct) node = self._make_memory_item( value=resp.get("value", "").strip(), - info=info, + info=info_per_item, memory_type="LongTermMemory", tags=resp.get("tags", []), key=resp.get("key", None), sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), + **extra_kwargs, ) fine_items.append(node) except Exception as e: From 56e0d6d167f96ff138f4d318115ff8417e263af0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 15 Dec 2025 18:45:11 +0800 Subject: [PATCH 21/31] feat: update evaluation --- .../long_bench-v2/longbench_v2_ingestion.py | 11 +- .../long_bench-v2/longbench_v2_metric.py | 118 +++++++++-------- .../long_bench-v2/longbench_v2_responses.py | 123 ++++++++++-------- .../long_bench-v2/longbench_v2_search.py | 68 ++-------- 4 files changed, 146 insertions(+), 174 deletions(-) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py index fc65e4975..7dddd7ffc 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -36,11 +36,8 @@ def ingest_sample( # For memos, we ingest the context as document content messages = [ { - "type": "file", - "file": { - "file_data": context, - "file_id": str(sample_idx), - }, + "role": "user", + "content": context, } ] @@ -179,13 +176,13 @@ def main(frame, version="default", num_workers=10, max_samples=None): parser.add_argument( "--version", type=str, - default="default", + default="longbench_v2_20251215_1838", help="Version identifier for saving results", ) parser.add_argument( "--workers", type=int, - default=3, + default=2, help="Number of parallel workers", ) parser.add_argument( diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py index 6a4fc2b7f..66c11134e 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -4,75 +4,73 @@ def calculate_accuracy(responses): - """Calculate accuracy metrics for LongBench v2.""" + """Calculate accuracy metrics for LongBench v2. + + Logic is aligned with longbench_stx.print_metrics, but returns a dict + and additionally computes by_domain statistics. + """ total = len(responses) if total == 0: return {} - # Overall accuracy - correct = sum(1 for r in responses if r.get("judge", False)) - overall_acc = round(100 * correct / total, 1) - - # By difficulty - easy_items = [r for r in responses if r.get("difficulty") == "easy"] - hard_items = [r for r in responses if r.get("difficulty") == "hard"] - easy_acc = ( - round(100 * sum(1 for r in easy_items if r.get("judge", False)) / len(easy_items), 1) - if easy_items - else 0.0 - ) - hard_acc = ( - round(100 * sum(1 for r in hard_items if r.get("judge", False)) / len(hard_items), 1) - if hard_items - else 0.0 - ) - - # By length - short_items = [r for r in responses if r.get("length") == "short"] - medium_items = [r for r in responses if r.get("length") == "medium"] - long_items = [r for r in responses if r.get("length") == "long"] - - short_acc = ( - round(100 * sum(1 for r in short_items if r.get("judge", False)) / len(short_items), 1) - if short_items - else 0.0 - ) - medium_acc = ( - round(100 * sum(1 for r in medium_items if r.get("judge", False)) / len(medium_items), 1) - if medium_items - else 0.0 - ) - long_acc = ( - round(100 * sum(1 for r in long_items if r.get("judge", False)) / len(long_items), 1) - if long_items - else 0.0 - ) - - # By domain + # Counters (aligned with longbench_stx.print_metrics) + easy = hard = short = medium = long = 0 + easy_acc = hard_acc = short_acc = medium_acc = long_acc = 0 + + for pred in responses: + acc = int(pred.get("judge", False)) + diff = pred.get("difficulty", "easy") + length = pred.get("length", "short") + + if diff == "easy": + easy += 1 + easy_acc += acc + else: + hard += 1 + hard_acc += acc + + if length == "short": + short += 1 + short_acc += acc + elif length == "medium": + medium += 1 + medium_acc += acc + else: + long += 1 + long_acc += acc + + o_acc = round(100 * (easy_acc + hard_acc) / total, 2) + e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0.0 + h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0.0 + s_acc = round(100 * short_acc / short, 2) if short > 0 else 0.0 + m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0.0 + l_acc = round(100 * long_acc / long, 2) if long > 0 else 0.0 + + # Additional by-domain stats (extra vs. stx) domain_stats = {} - for response in responses: - domain = response.get("domain", "Unknown") + for r in responses: + domain = r.get("domain", "Unknown") if domain not in domain_stats: domain_stats[domain] = {"total": 0, "correct": 0} domain_stats[domain]["total"] += 1 - if response.get("judge", False): + if r.get("judge", False): domain_stats[domain]["correct"] += 1 domain_acc = { - domain: round(100 * stats["correct"] / stats["total"], 1) + domain: round(100 * stats["correct"] / stats["total"], 2) for domain, stats in domain_stats.items() } return { - "overall": overall_acc, - "easy": easy_acc, - "hard": hard_acc, - "short": short_acc, - "medium": medium_acc, - "long": long_acc, + "overall": o_acc, + "easy": e_acc, + "hard": h_acc, + "short": s_acc, + "medium": m_acc, + "long": l_acc, "by_domain": domain_acc, "total_samples": total, - "correct_samples": correct, + "correct_samples": easy_acc + hard_acc, } @@ -92,8 +90,8 @@ def main(frame, version="default"): with open(responses_path, encoding="utf-8") as f: responses = json.load(f) - # Only keep entries with non-empty context (search_context) to align with response generation - filtered = [r for r in responses if str(r.get("search_context", "")).strip() != ""] + # Use all responses (aligned with longbench_stx.print_metrics behavior) + filtered = responses # Calculate metrics metrics = calculate_accuracy(filtered) @@ -112,12 +110,12 @@ def main(frame, version="default"): # Print summary table print("\n📊 Summary of Results:") print("-" * 80) - print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.1f}%") - print(f"{'Easy':<30s}: {metrics['easy']:.1f}%") - print(f"{'Hard':<30s}: {metrics['hard']:.1f}%") - print(f"{'Short':<30s}: {metrics['short']:.1f}%") - print(f"{'Medium':<30s}: {metrics['medium']:.1f}%") - print(f"{'Long':<30s}: {metrics['long']:.1f}%") + print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.2f}%") + print(f"{'Easy':<30s}: {metrics['easy']:.2f}%") + print(f"{'Hard':<30s}: {metrics['hard']:.2f}%") + print(f"{'Short':<30s}: {metrics['short']:.2f}%") + print(f"{'Medium':<30s}: {metrics['medium']:.2f}%") + print(f"{'Long':<30s}: {metrics['long']:.2f}%") print("\nBy Domain:") for domain, acc in metrics["by_domain"].items(): print(f" {domain:<28s}: {acc:.1f}%") diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py index cc1586112..48a341933 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py @@ -22,72 +22,73 @@ sys.path.insert(0, EVAL_SCRIPTS_DIR) -# Prompt template from LongBench v2 -LONGBENCH_V2_PROMPT = """Please read the following text and answer the question below. +# RAG-style prompt template aligned with longbench_stx.TEMPLATE_RAG +TEMPLATE_RAG = """Please read the following retrieved text chunks and answer the question below. -{context} +$DOC$ -What is the correct answer to this question: {question} +What is the correct answer to this question: $Q$ Choices: -(A) {choice_A} -(B) {choice_B} -(C) {choice_C} -(D) {choice_D} +(A) $C_A$ +(B) $C_B$ +(C) $C_C$ +(D) $C_D$ Format your response as follows: "The correct answer is (insert answer here)".""" def extract_answer(response): - """Extract answer from response (A, B, C, or D).""" + """Extract answer from response (A, B, C, or D). + + Logic is kept consistent with longbench_stx.extract_answer. + """ response = response.replace("*", "") # Try to find "The correct answer is (X)" pattern - match = re.search(r"The correct answer is \(([A-D])\)", response, re.IGNORECASE) + match = re.search(r"The correct answer is \(([A-D])\)", response) if match: - return match.group(1).upper() + return match.group(1) else: - match = re.search(r"The correct answer is ([A-D])", response, re.IGNORECASE) + match = re.search(r"The correct answer is ([A-D])", response) if match: - return match.group(1).upper() - else: - # Try to find standalone A, B, C, or D - match = re.search(r"\b([A-D])\b", response) - if match: - return match.group(1).upper() - return None - - -def generate_response(llm_client, context, question, choice_a, choice_b, choice_c, choice_d): - """Generate response using LLM.""" - prompt = LONGBENCH_V2_PROMPT.format( - context=context, - question=question, - choice_A=choice_a, - choice_B=choice_b, - choice_C=choice_c, - choice_D=choice_d, + return match.group(1) + return None + + +def llm_answer(llm_client, memories, question, choices): + """Generate response using RAG-style prompt, aligned with longbench_stx.llm_answer.""" + # Join memories to form the retrieved context document + doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)]) + + prompt = ( + TEMPLATE_RAG.replace("$DOC$", doc_content) + .replace("$Q$", question) + .replace("$C_A$", choices.get("A", "")) + .replace("$C_B$", choices.get("B", "")) + .replace("$C_C$", choices.get("C", "")) + .replace("$C_D$", choices.get("D", "")) ) try: response = llm_client.chat.completions.create( model=os.getenv("CHAT_MODEL"), - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], + messages=[{"role": "user", "content": prompt}], temperature=0.1, - max_tokens=128, + max_tokens=12800, ) - result = response.choices[0].message.content or "" - return result + return response.choices[0].message.content or "" except Exception as e: print(f"Error generating response: {e}") return "" def process_sample(search_result, llm_client, success_records, record_file, file_lock): - """Process a single sample: generate answer.""" + """Process a single sample: generate answer. + + This mirrors longbench_stx.evaluate_sample but consumes precomputed search results + produced by longbench_v2_search.py. + """ sample_idx = search_result.get("sample_idx") # Skip if already processed if sample_idx is not None and str(sample_idx) in success_records: @@ -95,21 +96,36 @@ def process_sample(search_result, llm_client, success_records, record_file, file start = time() - context = search_result.get("context", "") question = search_result.get("question", "") - choice_a = search_result.get("choice_A", "") - choice_b = search_result.get("choice_B", "") - choice_c = search_result.get("choice_C", "") - choice_d = search_result.get("choice_D", "") + choices = { + "A": search_result.get("choice_A", "") or "", + "B": search_result.get("choice_B", "") or "", + "C": search_result.get("choice_C", "") or "", + "D": search_result.get("choice_D", "") or "", + } - # Skip empty/placeholder contexts (e.g., "\n" or whitespace-only) - if not context or context.strip() == "": + # Prefer memories saved by longbench_v2_search; fall back to reconstructing + # from raw search_results if needed (for old search jsons). + memories = search_result.get("memories_used") + if memories is None: + raw = search_result.get("search_results") or {} + memories = [] + if isinstance(raw, dict) and raw.get("text_mem"): + text_mem = raw["text_mem"] + if text_mem and text_mem[0].get("memories"): + memories = [ + m.get("memory", "") for m in text_mem[0]["memories"] if isinstance(m, dict) + ] + + # Ensure we have a list, even if empty + memories = memories or [] + + # Skip if no retrieved memories and no question + if not question: return None # Generate answer - response = generate_response( - llm_client, context, question, choice_a, choice_b, choice_c, choice_d - ) + response = llm_answer(llm_client, memories, str(question), choices) # Extract answer (A, B, C, or D) pred = extract_answer(response) @@ -124,15 +140,16 @@ def process_sample(search_result, llm_client, success_records, record_file, file "difficulty": search_result.get("difficulty"), "length": search_result.get("length"), "question": question, - "choice_A": choice_a, - "choice_B": choice_b, - "choice_C": choice_c, - "choice_D": choice_d, + "choice_A": choices["A"], + "choice_B": choices["B"], + "choice_C": choices["C"], + "choice_D": choices["D"], "answer": search_result.get("answer"), "pred": pred, "response": response, "judge": pred == search_result.get("answer") if pred else False, - "search_context": context, + # Keep full retrieved memories list for inspection / debugging + "memories_used": memories, # Preserve full search results payload (e.g., list of memories) "search_results": search_result.get("search_results"), "response_duration_ms": response_duration_ms, diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py index 9730e937e..dc61bf5c8 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_search.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -25,63 +25,20 @@ def memos_api_search(client, query, user_id, top_k, frame): start = time() search_results = client.search(query=query, user_id=user_id, top_k=top_k) - def _reorder_memories_by_sources(sr: dict) -> list: - """ - Reorder text_mem[0].memories using sources' chunk_index (ascending). - Falls back to original order if no chunk_index is found. - """ - if not isinstance(sr, dict): - return [] - text_mem = sr.get("text_mem") or [] - if not text_mem or not text_mem[0].get("memories"): - return [] - memories = list(text_mem[0]["memories"]) - - def _first_source(mem: dict): - if not isinstance(mem, dict): - return None - # Prefer top-level sources, else metadata.sources - return (mem.get("sources") or mem.get("metadata", {}).get("sources") or []) or None - - def _chunk_index(mem: dict): - srcs = _first_source(mem) - if not srcs or not isinstance(srcs, list): - return None - for s in srcs: - if isinstance(s, dict) and s.get("chunk_index") is not None: - return s.get("chunk_index") - return None - - # Collect keys - keyed = [] - for i, mem in enumerate(memories): - ci = _chunk_index(mem) - keyed.append((ci, i, mem)) # keep original order as tie-breaker - - # If no chunk_index present at all, return original - if all(ci is None for ci, _, _ in keyed): - return memories - - keyed.sort(key=lambda x: (float("inf") if x[0] is None else x[0], x[1])) - return [k[2] for k in keyed] - - # Format context from search results based on frame type for backward compatibility - context = "" + # Extract raw memory texts in the same way as longbench_stx.memos_search + memories_texts: list[str] = [] if ( (frame == "memos-api" or frame == "memos-api-online") and isinstance(search_results, dict) and "text_mem" in search_results ): - ordered_memories = _reorder_memories_by_sources(search_results) - if not ordered_memories and search_results["text_mem"][0].get("memories"): - ordered_memories = search_results["text_mem"][0]["memories"] - - context = "\n".join([i.get("memory", "") for i in ordered_memories]) - if "pref_string" in search_results: - context += f"\n{search_results.get('pref_string', '')}" + text_mem = search_results.get("text_mem") or [] + if text_mem and text_mem[0].get("memories"): + memories = text_mem[0]["memories"] + memories_texts = [m.get("memory", "") for m in memories if isinstance(m, dict)] duration_ms = (time() - start) * 1000 - return context, duration_ms, search_results + return memories_texts, duration_ms, search_results def process_sample( @@ -98,7 +55,9 @@ def process_sample( if not query: return None - context, duration_ms, search_results = memos_api_search(client, query, user_id, top_k, frame) + memories_used, duration_ms, search_results = memos_api_search( + client, query, user_id, top_k, frame + ) result = { "sample_idx": sample_idx, @@ -113,8 +72,9 @@ def process_sample( "choice_C": sample.get("choice_C"), "choice_D": sample.get("choice_D"), "answer": sample.get("answer"), - "context": context, - # Preserve full search results instead of only the concatenated context + # Raw memories used for RAG answering (aligned with longbench_stx) + "memories_used": memories_used, + # Preserve full search results payload for debugging / analysis "search_results": search_results, "search_duration_ms": duration_ms, } @@ -274,7 +234,7 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): parser.add_argument( "--version", type=str, - default="default", + default="long-bench-v2-1208-1639", help="Version identifier for saving results", ) parser.add_argument( From c64fd26895a1a49c805ef68efb02ff413df799ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 15 Dec 2025 18:46:34 +0800 Subject: [PATCH 22/31] feat: update evaluation --- .../scripts/long_bench-v2/longbench_v2_metric.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py index 66c11134e..fc212ef7b 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -90,8 +90,17 @@ def main(frame, version="default"): with open(responses_path, encoding="utf-8") as f: responses = json.load(f) - # Use all responses (aligned with longbench_stx.print_metrics behavior) - filtered = responses + # Only keep entries that actually have search results: + # - For new pipeline: non-empty memories_used list + # - For older runs: non-empty search_context string + def _has_search_results(r: dict) -> bool: + mems = r.get("memories_used") + if isinstance(mems, list) and any(str(m).strip() for m in mems): + return True + ctx = str(r.get("search_context", "")).strip() + return ctx != "" + + filtered = [r for r in responses if _has_search_results(r)] # Calculate metrics metrics = calculate_accuracy(filtered) From 41ac6c26cc83957208322506cd47405f3b3b54f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 16 Dec 2025 15:15:53 +0800 Subject: [PATCH 23/31] feat: add general string prompt --- src/memos/mem_reader/multi_modal_struct.py | 8 +- src/memos/mem_reader/simple_struct.py | 6 + src/memos/templates/mem_reader_prompts.py | 205 ++++++++++++++++++++- 3 files changed, 216 insertions(+), 3 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 88ef56b7c..10bac319e 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -304,6 +304,10 @@ def _get_llm_response( template = PROMPT_DICT["doc"][lang] examples = "" # doc prompts don't have examples prompt = template.replace("{chunk_text}", mem_str) + elif prompt_type == "general_string": + template = PROMPT_DICT["general_string"][lang] + examples = "" + prompt = template.replace("{chunk_text}", mem_str) else: template = PROMPT_DICT["chat"][lang] examples = PROMPT_DICT["chat"][f"{lang}_example"] @@ -316,7 +320,7 @@ def _get_llm_response( ) # Replace custom_tags_prompt placeholder (different for doc vs chat) - if prompt_type == "doc": + if prompt_type in ["doc", "general_string"]: prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) else: prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt) @@ -348,7 +352,7 @@ def _determine_prompt_type(self, sources: list) -> str: """ if not sources: return "chat" - prompt_type = "doc" + prompt_type = "general_string" for source in sources: source_role = None if hasattr(source, "role"): diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 555f1f110..0c3645b49 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -26,6 +26,8 @@ from memos.templates.mem_reader_prompts import ( CUSTOM_TAGS_INSTRUCTION, CUSTOM_TAGS_INSTRUCTION_ZH, + GENERAL_STRUCT_STRING_READER_PROMPT, + GENERAL_STRUCT_STRING_READER_PROMPT_ZH, PROMPT_MAPPING, SIMPLE_STRUCT_DOC_READER_PROMPT, SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, @@ -79,6 +81,10 @@ def from_config(_config): "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, }, "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, + "general_string": { + "en": GENERAL_STRUCT_STRING_READER_PROMPT, + "zh": GENERAL_STRUCT_STRING_READER_PROMPT_ZH, + }, "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, } diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index cf8456c80..4ac12eb70 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -223,7 +223,6 @@ Your Output:""" - SIMPLE_STRUCT_DOC_READER_PROMPT_ZH = """您是搜索与检索系统的文本分析专家。 您的任务是处理文档片段,并生成一个结构化的 JSON 对象。 @@ -258,11 +257,215 @@ {custom_tags_prompt} +示例: +输入的文本片段: +在Kalamang语中,亲属名词在所有格构式中的行为并不一致。名词 esa“父亲”和 ema“母亲”只能在技术称谓(teknonym)中与第三人称所有格后缀共现,而在非技术称谓用法中,带有所有格后缀是不合语法的。相比之下,大多数其他亲属名词并不允许所有格构式,只有极少数例外。 +语料中还发现一种“双重所有格标记”的现象,即名词同时带有所有格后缀和独立的所有格代词。这种构式在语料中极为罕见,其语用功能尚不明确,且多出现在马来语借词中,但也偶尔见于Kalamang本族词。 +此外,黏着词 =kin 可用于表达多种关联关系,包括目的性关联、空间关联以及泛指的群体所有关系。在此类构式中,被标记的通常是施事或关联方,而非被拥有物本身。这一用法显示出 =kin 可能处于近期语法化阶段。 + +输出: +{ + "memory list": [ + { + "key": "亲属名词在所有格构式中的不一致行为", + "memory_type": "LongTermMemory", + "value": "Kalamang语中的亲属名词在所有格构式中的行为存在显著差异,其中“父亲”(esa)和“母亲”(ema)仅能在技术称谓用法中与第三人称所有格后缀共现,而在非技术称谓中带所有格后缀是不合语法的。", + "tags": ["亲属名词", "所有格", "语法限制"] + }, + { + "key": "双重所有格标记现象", + "memory_type": "LongTermMemory", + "value": "语料中存在名词同时带有所有格后缀和独立所有格代词的双重所有格标记构式,但该现象出现频率极低,其具体语用功能尚不明确。", + "tags": ["双重所有格", "罕见构式", "语用功能"] + }, + { + "key": "双重所有格与借词的关系", + "memory_type": "LongTermMemory", + "value": "双重所有格标记多见于马来语借词中,但也偶尔出现在Kalamang本族词中,显示该构式并非完全由语言接触触发。", + "tags": ["语言接触", "借词", "构式分布"] + }, + { + "key": "=kin 的关联功能与语法地位", + "memory_type": "LongTermMemory", + "value": "黏着词 =kin 用于表达目的性、空间或群体性的关联关系,其标记对象通常为关联方而非被拥有物,这表明 =kin 可能处于近期语法化过程中。", + "tags": ["=kin", "关联关系", "语法化"] + } + ], + "summary": "该文本描述了Kalamang语中所有格构式的多样性与不对称性。亲属名词在所有格标记上的限制显示出语义类别内部的分化,而罕见的双重所有格构式则反映了构式层面的不稳定性。同时,=kin 的多功能关联用法及其分布特征为理解该语言的语法化路径提供了重要线索。" +} + +文档片段: +{chunk_text} + +您的输出:""" + +GENERAL_STRUCT_STRING_READER_PROMPT = """You are a text analysis expert for search and retrieval systems. +Your task is to parse a text chunk into multiple structured memories for long-term storage and precise future retrieval. The text chunk may contain information from various sources, including conversations, plain text, speech-to-text transcripts, tables, tool documentation, and more. + +Please perform the following steps: + +1. Decompose the text chunk into multiple memories that are mutually independent, minimally redundant, and each fully expresses a single information point. Together, these memories should cover different aspects of the document so that a reader can understand all core content without reading the original text. + +2. Memory splitting and deduplication rules (very important): +2.1 Each memory must express only one primary information point, such as: + - A fact + - A clear conclusion or judgment + - A decision or action + - An important background or condition + - A notable emotional tone or attitude + - A plan, risk, or downstream impact + +2.2 Do not force multiple information points into a single memory. + +2.3 Do not generate memories that are semantically repetitive or highly overlapping: + - If two memories describe the same fact or judgment, retain only the one with more complete information. + - Do not create “different” memories solely by rephrasing. + +2.4 There is no fixed upper or lower limit on the number of memories; the count should be determined naturally by the information density of the text. + +3. Information parsing requirements: +3.1 Identify and clearly specify all important: + - Times (distinguishing event time from document recording time) + - People (resolving pronouns and aliases to explicit identities) + - Organizations, locations, and events + +3.2 Explicitly resolve all references to time, people, locations, and events: + - When context allows, convert relative time expressions (e.g., “last year,” “next quarter”) into absolute dates. + - If uncertainty exists, explicitly state it (e.g., “around 2024,” “exact date unknown”). + - Include specific locations when mentioned. + - Resolve all pronouns, aliases, and ambiguous references to full names or clear identities. + - Disambiguate entities with the same name when necessary. + +4. Writing and perspective rules: + - Always write in the third person, clearly referring to subjects or content, and avoid first-person expressions (“I,” “we,” “my”). + - Use precise, neutral language and do not infer or introduce information not explicitly stated in the text. + +Return a valid JSON object with the following structure: + +{ + "memory list": [ + { + "key": , + "memory_type": "LongTermMemory", + "value": , + "tags": + }, + ... + ], + "summary": +} + +Language rules: +- The `key`, `value`, `tags`, and `summary` fields must use the same primary language as the input document. **If the input is Chinese, output must be in Chinese.** +- `memory_type` must remain in English. + +{custom_tags_prompt} + +Example: +Text chunk: + +In Kalamang, kinship terms show uneven behavior in possessive constructions. The nouns esa ‘father’ and ema ‘mother’ can only co-occur with a third-person possessive suffix when used as teknonyms; outside of such contexts, possessive marking is ungrammatical. Most other kinship terms do not allow possessive constructions, with only a few marginal exceptions. + +The corpus also contains rare cases of double possessive marking, in which a noun bears both a possessive suffix and a free possessive pronoun. This construction is infrequent and its discourse function remains unclear. While it appears more often with Malay loanwords, it is not restricted to borrowed vocabulary. + +In addition, the clitic =kin encodes a range of associative relations, including purposive, spatial, and collective ownership. In such constructions, the marked element typically corresponds to the possessor or associated entity rather than the possessed item, suggesting that =kin may be undergoing recent grammaticalization. + +Output: +{ + "memory list": [ + { + "key": "Asymmetric possessive behavior of kinship terms", + "memory_type": "LongTermMemory", + "value": "In Kalamang, kinship terms do not behave uniformly in possessive constructions: ‘father’ (esa) and ‘mother’ (ema) require a teknonymic context to appear with a third-person possessive suffix, whereas possessive marking is otherwise ungrammatical.", + "tags": ["kinship terms", "possessive constructions", "grammatical constraints"] + }, + { + "key": "Rare double possessive marking", + "memory_type": "LongTermMemory", + "value": "The language exhibits a rare construction in which a noun carries both a possessive suffix and a free possessive pronoun, though the pragmatic function of this double marking remains unclear.", + "tags": ["double possessive", "rare constructions", "pragmatics"] + }, + { + "key": "Distribution of double possessives across lexicon", + "memory_type": "LongTermMemory", + "value": "Double possessive constructions occur more frequently with Malay loanwords but are also attested with indigenous Kalamang vocabulary, indicating that the pattern is not solely contact-induced.", + "tags": ["loanwords", "language contact", "distribution"] + }, + { + "key": "Associative clitic =kin", + "memory_type": "LongTermMemory", + "value": "The clitic =kin marks various associative relations, including purposive, spatial, and collective ownership, typically targeting the possessor or associated entity, and appears to reflect an ongoing process of grammaticalization.", + "tags": ["=kin", "associative relations", "grammaticalization"] + } + ], + "summary": "The text outlines key properties of possessive and associative constructions in Kalamang. Kinship terms exhibit asymmetric grammatical behavior, rare double possessive patterns suggest constructional instability, and the multifunctional clitic =kin provides evidence for evolving associative marking within the language’s grammar." +} + +Text chunk: +{chunk_text} + +Your output: +""" + +GENERAL_STRUCT_STRING_READER_PROMPT_ZH = """您是搜索与检索系统的文本分析专家。 +您的任务是将一个文本片段解析为【多条结构化记忆】,用于长期存储和后续精准检索,这里的文本片段可能包含各种对话、纯文本、语音转录的文字、表格、工具说明等等的信息。 + +请执行以下操作: +1. 将文档片段拆解为若干条【相互独立、尽量不重复、各自完整表达单一信息点】的记忆。这些记忆应共同覆盖文档的不同方面,使读者无需阅读原文即可理解该文档的全部核心内容。 +2. 记忆拆分与去重规则(非常重要): +2.1 每一条记忆应只表达【一个主要信息点】: + - 一个事实 + - 一个明确结论或判断 + - 一个决定或行动 + - 一个重要背景或条件 + - 一个显著的情感基调或态度 + - 一个计划、风险或后续影响 +2.2 不要将多个信息点强行合并到同一条记忆中。 +2.3 不要生成语义重复或高度重叠的记忆: + - 如果两条记忆表达的是同一事实或同一判断,只保留信息更完整的一条。 + - 不允许仅通过措辞变化来制造“不同”的记忆。 +2.4 记忆条数不设固定上限或下限,应由文档信息密度自然决定。 +3. 信息解析要求 +3.1 识别并明确所有重要的: + - 时间(区分事件发生时间与文档记录时间) + - 人物(解析代词、别名为明确身份) + - 组织、地点、事件 +3.2 清晰解析所有时间、人物、地点和事件的指代: + - 如果上下文允许,将相对时间表达(如“去年”、“下一季度”)转换为绝对日期。 + - 如果存在不确定性,需明确说明(例如,“约2024年”,“具体日期不详”)。 + - 若提及具体地点,请包含在内。 + - 将所有代词、别名和模糊指代解析为全名或明确身份。 + - 如有同名实体,需加以区分。 +4. 写作与视角规则 + - 始终以第三人称视角撰写,清晰指代主题或内容,避免使用第一人称(“我”、“我们”、“我的”)。 + - 语言应准确、中性,不自行引申文档未明确表达的内容。 + +返回一个有效的 JSON 对象,结构如下: +{ + "memory list": [ + { + "key": <字符串,简洁且唯一的记忆标题>, + "memory_type": "LongTermMemory", + "value": <一段完整、清晰、可独立理解的记忆描述;若输入为中文则使用中文,若为英文则使用英文>, + "tags": <与该记忆高度相关的主题关键词列表> + }, + ... + ], + "summary": <一段整体性总结,概括这些记忆如何共同反映文档的核心内容与重点,语言与输入文档一致> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入文档摘要的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 + +{custom_tags_prompt} + 文档片段: {chunk_text} 您的输出:""" + SIMPLE_STRUCT_MEM_READER_EXAMPLE = """Example: Conversation: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. From eaedc9a044549f30e1dc05444530c25e12317665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 16 Dec 2025 18:00:55 +0800 Subject: [PATCH 24/31] fix: test server router --- tests/api/test_server_router.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 559c5cd35..4786ef08f 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -231,18 +231,6 @@ def test_add_valid_input_output(self, mock_handlers, client): assert call_args.mem_cube_id == "test_cube" assert call_args.user_id == "test_user" - def test_add_invalid_input_missing_cube_id(self, mock_handlers, client): - """Test add endpoint with missing required field.""" - request_data = { - "user_id": "test_user", - "memory_content": "test memory content", - } - - response = client.post("/product/add", json=request_data) - - # Should return validation error - assert response.status_code == 422 - def test_add_response_format(self, mock_handlers, client): """Test add endpoint returns MemoryResponse format.""" mock_handlers["add"].handle_add_memories.return_value = MemoryResponse( From 7674eccdaa8847d62d940330f35e7ad170acf4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 16 Dec 2025 19:25:16 +0800 Subject: [PATCH 25/31] feat: update evluation --- .../long_bench-v2/longbench_v2_ingestion.py | 9 ++- .../long_bench-v2/longbench_v2_metric.py | 28 +++++++- .../long_bench-v2/longbench_v2_responses.py | 71 ++++++++++++++----- .../long_bench-v2/longbench_v2_search.py | 15 +++- 4 files changed, 99 insertions(+), 24 deletions(-) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py index 7dddd7ffc..f70f24531 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -36,8 +36,11 @@ def ingest_sample( # For memos, we ingest the context as document content messages = [ { - "role": "user", - "content": context, + "type": "file", + "file": { + "file_data": context, + "file_id": str(sample_idx), + }, } ] @@ -176,7 +179,7 @@ def main(frame, version="default", num_workers=10, max_samples=None): parser.add_argument( "--version", type=str, - default="longbench_v2_20251215_1838", + default="default", help="Version identifier for saving results", ) parser.add_argument( diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py index fc212ef7b..af324c9c7 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -16,12 +16,17 @@ def calculate_accuracy(responses): # Counters (aligned with longbench_stx.print_metrics) easy = hard = short = medium = long = 0 easy_acc = hard_acc = short_acc = medium_acc = long_acc = 0 + total_prompt_tokens = 0 for pred in responses: acc = int(pred.get("judge", False)) diff = pred.get("difficulty", "easy") length = pred.get("length", "short") + pt = pred.get("prompt_tokens") + if isinstance(pt, int | float): + total_prompt_tokens += int(pt) + if diff == "easy": easy += 1 easy_acc += acc @@ -71,6 +76,8 @@ def calculate_accuracy(responses): "by_domain": domain_acc, "total_samples": total, "correct_samples": easy_acc + hard_acc, + "total_prompt_tokens": total_prompt_tokens, + "avg_prompt_tokens": round(total_prompt_tokens / total, 2) if total > 0 else 0.0, } @@ -102,8 +109,24 @@ def _has_search_results(r: dict) -> bool: filtered = [r for r in responses if _has_search_results(r)] - # Calculate metrics - metrics = calculate_accuracy(filtered) + # Calculate metrics (handle case where no samples have search results) + if not filtered: + print("⚠️ No responses with valid search results were found. Metrics will be zeroed.") + metrics = { + "overall": 0.0, + "easy": 0.0, + "hard": 0.0, + "short": 0.0, + "medium": 0.0, + "long": 0.0, + "by_domain": {}, + "total_samples": 0, + "correct_samples": 0, + "total_prompt_tokens": 0, + "avg_prompt_tokens": 0.0, + } + else: + metrics = calculate_accuracy(filtered) # Save metrics output_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" @@ -125,6 +148,7 @@ def _has_search_results(r: dict) -> bool: print(f"{'Short':<30s}: {metrics['short']:.2f}%") print(f"{'Medium':<30s}: {metrics['medium']:.2f}%") print(f"{'Long':<30s}: {metrics['long']:.2f}%") + print(f"{'Avg Prompt Tokens':<30s}: {metrics.get('avg_prompt_tokens', 0.0):.2f}") print("\nBy Domain:") for domain, acc in metrics["by_domain"].items(): print(f" {domain:<28s}: {acc:.1f}%") diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py index 48a341933..686062c5f 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py @@ -57,7 +57,11 @@ def extract_answer(response): def llm_answer(llm_client, memories, question, choices): - """Generate response using RAG-style prompt, aligned with longbench_stx.llm_answer.""" + """Generate response using RAG-style prompt, aligned with longbench_stx.llm_answer. + + Returns: + tuple[str, int | None]: (response_text, prompt_tokens) + """ # Join memories to form the retrieved context document doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)]) @@ -77,10 +81,24 @@ def llm_answer(llm_client, memories, question, choices): temperature=0.1, max_tokens=12800, ) - return response.choices[0].message.content or "" + text = response.choices[0].message.content or "" + prompt_tokens = None + usage = getattr(response, "usage", None) + if usage is not None: + # openai>=1.x style: usage.prompt_tokens + pt = getattr(usage, "prompt_tokens", None) + if isinstance(pt, int): + prompt_tokens = pt + else: + # fallback for dict-like usage + try: + prompt_tokens = int(usage.get("prompt_tokens")) # type: ignore[call-arg] + except Exception: + prompt_tokens = None + return text, prompt_tokens except Exception as e: print(f"Error generating response: {e}") - return "" + return "", None def process_sample(search_result, llm_client, success_records, record_file, file_lock): @@ -89,9 +107,13 @@ def process_sample(search_result, llm_client, success_records, record_file, file This mirrors longbench_stx.evaluate_sample but consumes precomputed search results produced by longbench_v2_search.py. """ + # Use sample_idx when available, otherwise fall back to _id so that + # we can work with stx-style search results that only have _id. sample_idx = search_result.get("sample_idx") + sample_key = str(sample_idx) if sample_idx is not None else str(search_result.get("_id", "")) + # Skip if already processed - if sample_idx is not None and str(sample_idx) in success_records: + if sample_key and sample_key in success_records: return None start = time() @@ -123,9 +145,11 @@ def process_sample(search_result, llm_client, success_records, record_file, file # Skip if no retrieved memories and no question if not question: return None + if not memories: + return None # Generate answer - response = llm_answer(llm_client, memories, str(question), choices) + response, prompt_tokens = llm_answer(llm_client, memories, str(question), choices) # Extract answer (A, B, C, or D) pred = extract_answer(response) @@ -133,6 +157,7 @@ def process_sample(search_result, llm_client, success_records, record_file, file response_duration_ms = (time() - start) * 1000 result = { + # Preserve sample_idx if present for backward compatibility "sample_idx": search_result.get("sample_idx"), "_id": search_result.get("_id"), "domain": search_result.get("domain"), @@ -148,6 +173,7 @@ def process_sample(search_result, llm_client, success_records, record_file, file "pred": pred, "response": response, "judge": pred == search_result.get("answer") if pred else False, + "prompt_tokens": prompt_tokens, # Keep full retrieved memories list for inspection / debugging "memories_used": memories, # Preserve full search results payload (e.g., list of memories) @@ -157,9 +183,9 @@ def process_sample(search_result, llm_client, success_records, record_file, file } # Record successful processing (thread-safe) - if sample_idx is not None: + if sample_key: with file_lock, open(record_file, "a") as f: - f.write(f"{sample_idx}\n") + f.write(f"{sample_key}\n") f.flush() return result @@ -192,16 +218,18 @@ def main(frame, version="default", num_workers=10): search_results = json.load(f) # Load existing results and success records for resume - existing_results = {} - success_records = set() + existing_results: dict[str, dict] = {} + success_records: set[str] = set() if os.path.exists(output_path): with open(output_path, encoding="utf-8") as f: existing_results_list = json.load(f) for result in existing_results_list: + # Use sample_idx if present, otherwise _id as the unique key sample_idx = result.get("sample_idx") - if sample_idx is not None: - existing_results[sample_idx] = result - success_records.add(str(sample_idx)) + key = str(sample_idx) if sample_idx is not None else str(result.get("_id", "")) + if key: + existing_results[key] = result + success_records.add(key) print(f"📋 Found {len(existing_results)} existing responses (resume mode)") else: print("📋 Starting fresh response generation (no checkpoint found)") @@ -222,7 +250,7 @@ def main(frame, version="default", num_workers=10): ) print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") - # Process all samples + # Process all samples concurrently using ThreadPoolExecutor new_results = [] file_lock = threading.Lock() # Lock for thread-safe file writing with ThreadPoolExecutor(max_workers=num_workers) as executor: @@ -241,15 +269,22 @@ def main(frame, version="default", num_workers=10): result = future.result() if result: new_results.append(result) - # Update existing results with new result + # Update existing results with new result (keyed by sample_idx or _id) sample_idx = result.get("sample_idx") - if sample_idx is not None: - existing_results[sample_idx] = result + key = str(sample_idx) if sample_idx is not None else str(result.get("_id", "")) + if key: + existing_results[key] = result # Merge and save all results all_responses = list(existing_results.values()) - # Sort by sample_idx to maintain order - all_responses.sort(key=lambda x: x.get("sample_idx", 0)) + + # Sort by sample_idx when available, otherwise by _id for stability + def _sort_key(x: dict): + if x.get("sample_idx") is not None: + return ("0", int(x.get("sample_idx"))) + return ("1", str(x.get("_id", ""))) + + all_responses.sort(key=_sort_key) with open(output_path, "w", encoding="utf-8") as f: json.dump(all_responses, f, ensure_ascii=False, indent=2) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py index dc61bf5c8..4d74d7083 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_search.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -35,7 +35,17 @@ def memos_api_search(client, query, user_id, top_k, frame): text_mem = search_results.get("text_mem") or [] if text_mem and text_mem[0].get("memories"): memories = text_mem[0]["memories"] - memories_texts = [m.get("memory", "") for m in memories if isinstance(m, dict)] + for m in memories: + if not isinstance(m, dict): + continue + # tags may be at top-level or inside metadata + tags = m.get("tags") or m.get("metadata", {}).get("tags") or [] + # Skip fast-mode memories + if any(isinstance(t, str) and "mode:fast" in t for t in tags): + continue + mem_text = m.get("memory", "") + if str(mem_text).strip(): + memories_texts.append(mem_text) duration_ms = (time() - start) * 1000 return memories_texts, duration_ms, search_results @@ -59,6 +69,9 @@ def process_sample( client, query, user_id, top_k, frame ) + if not (isinstance(memories_used, list) and any(str(m).strip() for m in memories_used)): + return None + result = { "sample_idx": sample_idx, "_id": sample.get("_id"), From 66e9325ca057c8dc74ff66a351696e46ceda4ef7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Dec 2025 14:23:33 +0800 Subject: [PATCH 26/31] feat: decrease graph-db batch size to 5 --- .../tree_text_memory/organize/manager.py | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 95f4e780d..c96d5a12a 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -135,7 +135,7 @@ def _add_memories_parallel( return added_ids def _add_memories_batch( - self, memories: list[TextualMemoryItem], user_name: str | None = None, batch_size: int = 50 + self, memories: list[TextualMemoryItem], user_name: str | None = None, batch_size: int = 5 ) -> list[str]: """ Add memories using batch database operations (more efficient for large batches). @@ -200,25 +200,31 @@ def _add_memories_batch( graph_node_ids.append(graph_node_id) added_ids.append(graph_node_id) - for i in range(0, len(working_nodes), batch_size): - batch = working_nodes[i : i + batch_size] - try: - self.graph_store.add_nodes_batch(batch, user_name=user_name) - except Exception as e: - logger.exception( - f"Batch add WorkingMemory nodes error (batch {i // batch_size + 1}): ", - exc_info=e, - ) - - for i in range(0, len(graph_nodes), batch_size): - batch = graph_nodes[i : i + batch_size] - try: - self.graph_store.add_nodes_batch(batch, user_name=user_name) - except Exception as e: - logger.exception( - f"Batch add graph memory nodes error (batch {i // batch_size + 1}): ", - exc_info=e, - ) + def _submit_batches(nodes: list[dict], node_kind: str) -> None: + if not nodes: + return + + max_workers = min(8, max(1, len(nodes) // max(1, batch_size))) + with ContextThreadPoolExecutor(max_workers=max_workers) as executor: + futures: list[tuple[int, int, object]] = [] + for batch_index, i in enumerate(range(0, len(nodes), batch_size), start=1): + batch = nodes[i : i + batch_size] + fut = executor.submit( + self.graph_store.add_nodes_batch, batch, user_name=user_name + ) + futures.append((batch_index, len(batch), fut)) + + for idx, size, fut in futures: + try: + fut.result() + except Exception as e: + logger.exception( + f"Batch add {node_kind} nodes error (batch {idx}, size {size}): ", + exc_info=e, + ) + + _submit_batches(working_nodes, "WorkingMemory") + _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: self.reorganizer.add_message(QueueMessage(op="add", after_node=graph_node_ids)) From e10365cb57e42f85d94e747b0bfbb14264738b52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Dec 2025 14:27:21 +0800 Subject: [PATCH 27/31] fix: default name in long_bench-v2/longbench_v2_search --- evaluation/scripts/long_bench-v2/longbench_v2_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py index 4d74d7083..2347e5d66 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_search.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -247,7 +247,7 @@ def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): parser.add_argument( "--version", type=str, - default="long-bench-v2-1208-1639", + default="default", help="Version identifier for saving results", ) parser.add_argument( From 316e1476306d62806733c5899cff46c7129e0a75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Dec 2025 14:40:39 +0800 Subject: [PATCH 28/31] fix: test bug --- tests/api/test_product_router.py | 34 +------------------------------- tests/api/test_server_router.py | 25 +++-------------------- 2 files changed, 4 insertions(+), 55 deletions(-) diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py index 76bda41ab..27a4fbe03 100644 --- a/tests/api/test_product_router.py +++ b/tests/api/test_product_router.py @@ -6,10 +6,7 @@ """ # Mock sklearn before importing any memos modules to avoid import errors -import importlib.util -import sys - -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest @@ -26,36 +23,7 @@ pr_module.get_mos_product_instance = lambda: _mock_mos_instance -# Create a proper mock module with __spec__ -sklearn_mock = MagicMock() -sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) -sys.modules["sklearn"] = sklearn_mock - -sklearn_fe_mock = MagicMock() -sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) -sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock - -sklearn_fet_mock = MagicMock() -sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) -sklearn_fet_mock.TfidfVectorizer = MagicMock() -sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock - -# Mock sklearn.metrics as well -sklearn_metrics_mock = MagicMock() -sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) -sklearn_metrics_mock.roc_curve = MagicMock() -sys.modules["sklearn.metrics"] = sklearn_metrics_mock - - -# Create mock instance -_mock_mos_instance = Mock() - -pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance -pr_module.get_mos_product_instance = lambda: _mock_mos_instance - -# Mock MOSProduct class before importing to prevent initialization with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): - # Import after patching from memos.api import product_api diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 4786ef08f..b6dd5078a 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -5,11 +5,7 @@ input request formats and return properly formatted responses. """ -# Mock sklearn before importing any memos modules to avoid import errors -import importlib.util -import sys - -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest @@ -25,23 +21,8 @@ ) -# Create a proper mock module with __spec__ -sklearn_mock = MagicMock() -sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) -sys.modules["sklearn"] = sklearn_mock - -sklearn_fe_mock = MagicMock() -sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) -sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock - -sklearn_metrics_mock = MagicMock() -sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) -sys.modules["sklearn.metrics"] = sklearn_metrics_mock - -sklearn_fet_mock = MagicMock() -sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) -sklearn_fet_mock.TfidfVectorizer = MagicMock() -sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock +# Patch init_server so we can import server_api without starting the full MemOS stack, +# and keep sklearn and other core dependencies untouched for other tests. @pytest.fixture(scope="module") From e8e29f8624d5e8a3be9d341268549cfb1ce439c3 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 17 Dec 2025 14:41:34 +0800 Subject: [PATCH 29/31] Update test_server_router.py --- tests/api/test_server_router.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index b6dd5078a..5906697d9 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -23,8 +23,6 @@ # Patch init_server so we can import server_api without starting the full MemOS stack, # and keep sklearn and other core dependencies untouched for other tests. - - @pytest.fixture(scope="module") def mock_init_server(): """Mock init_server before importing server_api.""" From ce70121504db55fcbbb2cf2f583678824b197eda Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 17 Dec 2025 14:42:01 +0800 Subject: [PATCH 30/31] Update test_product_router.py --- tests/api/test_product_router.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py index 27a4fbe03..857b290c5 100644 --- a/tests/api/test_product_router.py +++ b/tests/api/test_product_router.py @@ -5,7 +5,6 @@ input request formats and return properly formatted responses. """ -# Mock sklearn before importing any memos modules to avoid import errors from unittest.mock import Mock, patch import pytest @@ -21,8 +20,6 @@ _mock_mos_instance = Mock() pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance pr_module.get_mos_product_instance = lambda: _mock_mos_instance - - with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): from memos.api import product_api From 9e7ca00c1dc62507176fff7bbfe80a3497a99f3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Dec 2025 14:49:35 +0800 Subject: [PATCH 31/31] feat: comment --- evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py index f70f24531..5a5c11968 100644 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -33,7 +33,7 @@ def ingest_sample( # Get context and convert to messages context = sample.get("context", "") - # For memos, we ingest the context as document content + # For memos, we ingest the context as a raw document content messages = [ { "type": "file",