diff --git a/evaluation/scripts/long_bench-v2/__init__.py b/evaluation/scripts/long_bench-v2/__init__.py new file mode 100644 index 000000000..786c0ce03 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/__init__.py @@ -0,0 +1 @@ +# LongBench v2 evaluation scripts diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py new file mode 100644 index 000000000..d84a63d93 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py @@ -0,0 +1,199 @@ +import argparse +import json +import os +import sys +import threading + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +def ingest_sample( + client, sample, sample_idx, frame, version, success_records, record_file, file_lock +): + """Ingest a single LongBench v2 sample as memories.""" + # Skip if already processed + if str(sample_idx) in success_records: + return True + + user_id = f"longbench_v2_{sample_idx}_{version}" + conv_id = f"longbench_v2_{sample_idx}_{version}" + + # Get context and convert to messages + context = sample.get("context", "") + + # For memos, we ingest the context as document content + messages = [ + { + "type": "file", + "file": { + "file_data": context, + "file_id": str(sample_idx), + }, + } + ] + + if "memos-api" in frame: + try: + client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx}") + # Record successful ingestion (thread-safe) + with file_lock, open(record_file, "a") as f: + f.write(f"{sample_idx}\n") + f.flush() + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}") + return False + + return False + + +def load_dataset_from_local(): + """Load LongBench v2 dataset from local JSON file.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + filepath = os.path.join(data_dir, "data.json") + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSON file + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + + return samples + + +def main(frame, version="default", num_workers=10, max_samples=None): + """Main ingestion function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load dataset from local file + try: + dataset = load_dataset_from_local() + print(f"Loaded {len(dataset)} samples from LongBench v2") + except FileNotFoundError as e: + print(f"❌ Error loading dataset: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize checkpoint file for resume functionality + checkpoint_dir = os.path.join( + ROOT_DIR, "evaluation", "results", "longbench_v2", f"{frame}-{version}" + ) + os.makedirs(checkpoint_dir, exist_ok=True) + record_file = os.path.join(checkpoint_dir, "success_records.txt") + + # Load existing success records for resume + success_records = set() + if os.path.exists(record_file): + with open(record_file) as f: + for line in f: + line = line.strip() + if line: + success_records.add(line) + print(f"📋 Found {len(success_records)} already processed samples (resume mode)") + else: + print("📋 Starting fresh ingestion (no checkpoint found)") + + # Initialize client + client = None + if frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Ingest samples + success_count = len(success_records) # Start with already processed count + file_lock = threading.Lock() # Lock for thread-safe file writing + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit( + ingest_sample, + client, + sample, + idx, + frame, + version, + success_records, + record_file, + file_lock, + ) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Ingesting LongBench v2", + ): + try: + if future.result(): + success_count += 1 + except Exception as e: + print(f"Error processing sample: {e}") + + print(f"\n{'=' * 80}") + print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="long-bench-v2-1208-1556", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=20, + help="Number of parallel workers", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples to process (default: all)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py new file mode 100644 index 000000000..c23d7885f --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion_async.py @@ -0,0 +1,158 @@ +import argparse +import json +import os +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +def ingest_sample(client, sample, sample_idx, frame, version): + """Ingest a single LongBench v2 sample as memories.""" + user_id = f"longbench_v2_{sample_idx}_{version}" + conv_id = f"longbench_v2_{sample_idx}_{version}" + + # Get context and convert to messages + context = sample.get("context", "") + + # For memos, we ingest the context as document content + messages = [ + { + "type": "file", + "file": { + "file_data": context, + "file_id": str(sample_idx), + }, + } + ] + + if "memos-api" in frame: + try: + client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) + print(f"✅ [{frame}] Ingested sample {sample_idx}") + return True + except Exception as e: + print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}") + return False + + return False + + +def load_dataset_from_local(): + """Load LongBench v2 dataset from local JSON file.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + filepath = os.path.join(data_dir, "data.json") + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSON file + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + + return samples + + +def main(frame, version="default", num_workers=10, max_samples=None): + """Main ingestion function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load dataset from local file + try: + dataset = load_dataset_from_local() + print(f"Loaded {len(dataset)} samples from LongBench v2") + except FileNotFoundError as e: + print(f"❌ Error loading dataset: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize client + client = None + if frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Ingest samples + success_count = 0 + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit(ingest_sample, client, sample, idx, frame, version) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Ingesting LongBench v2", + ): + try: + if future.result(): + success_count += 1 + except Exception as e: + print(f"Error processing sample: {e}") + + print(f"\n{'=' * 80}") + print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="long-bench-v2-1208-1556-async", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=20, + help="Number of parallel workers", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples to process (default: all)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py new file mode 100644 index 000000000..5fee9a3de --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py @@ -0,0 +1,142 @@ +import argparse +import json +import os + + +def calculate_accuracy(responses): + """Calculate accuracy metrics for LongBench v2.""" + total = len(responses) + if total == 0: + return {} + + # Overall accuracy + correct = sum(1 for r in responses if r.get("judge", False)) + overall_acc = round(100 * correct / total, 1) + + # By difficulty + easy_items = [r for r in responses if r.get("difficulty") == "easy"] + hard_items = [r for r in responses if r.get("difficulty") == "hard"] + easy_acc = ( + round(100 * sum(1 for r in easy_items if r.get("judge", False)) / len(easy_items), 1) + if easy_items + else 0.0 + ) + hard_acc = ( + round(100 * sum(1 for r in hard_items if r.get("judge", False)) / len(hard_items), 1) + if hard_items + else 0.0 + ) + + # By length + short_items = [r for r in responses if r.get("length") == "short"] + medium_items = [r for r in responses if r.get("length") == "medium"] + long_items = [r for r in responses if r.get("length") == "long"] + + short_acc = ( + round(100 * sum(1 for r in short_items if r.get("judge", False)) / len(short_items), 1) + if short_items + else 0.0 + ) + medium_acc = ( + round(100 * sum(1 for r in medium_items if r.get("judge", False)) / len(medium_items), 1) + if medium_items + else 0.0 + ) + long_acc = ( + round(100 * sum(1 for r in long_items if r.get("judge", False)) / len(long_items), 1) + if long_items + else 0.0 + ) + + # By domain + domain_stats = {} + for response in responses: + domain = response.get("domain", "Unknown") + if domain not in domain_stats: + domain_stats[domain] = {"total": 0, "correct": 0} + domain_stats[domain]["total"] += 1 + if response.get("judge", False): + domain_stats[domain]["correct"] += 1 + + domain_acc = { + domain: round(100 * stats["correct"] / stats["total"], 1) + for domain, stats in domain_stats.items() + } + + return { + "overall": overall_acc, + "easy": easy_acc, + "hard": hard_acc, + "short": short_acc, + "medium": medium_acc, + "long": long_acc, + "by_domain": domain_acc, + "total_samples": total, + "correct_samples": correct, + } + + +def main(frame, version="default"): + """Main metric calculation function.""" + print("\n" + "=" * 80) + print(f"📊 LONGBENCH V2 METRICS CALCULATION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load responses + responses_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_responses.json" + if not os.path.exists(responses_path): + print(f"❌ Responses not found: {responses_path}") + print("Please run longbench_v2_responses.py first") + return + + with open(responses_path, encoding="utf-8") as f: + responses = json.load(f) + + # Calculate metrics + metrics = calculate_accuracy(responses) + + # Save metrics + output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, ensure_ascii=False, indent=4) + + print(f"\n{'=' * 80}") + print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + # Print summary table + print("\n📊 Summary of Results:") + print("-" * 80) + print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.1f}%") + print(f"{'Easy':<30s}: {metrics['easy']:.1f}%") + print(f"{'Hard':<30s}: {metrics['hard']:.1f}%") + print(f"{'Short':<30s}: {metrics['short']:.1f}%") + print(f"{'Medium':<30s}: {metrics['medium']:.1f}%") + print(f"{'Long':<30s}: {metrics['long']:.1f}%") + print("\nBy Domain:") + for domain, acc in metrics["by_domain"].items(): + print(f" {domain:<28s}: {acc:.1f}%") + print(f"\nTotal Samples: {metrics['total_samples']}") + print(f"Correct: {metrics['correct_samples']}") + print("-" * 80) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for loading results", + ) + args = parser.parse_args() + + main(args.lib, args.version) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py new file mode 100644 index 000000000..3e19dc95f --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py @@ -0,0 +1,206 @@ +import argparse +import json +import os +import re +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +# Prompt template from LongBench v2 +LONGBENCH_V2_PROMPT = """Please read the following text and answer the question below. + + +{context} + + +What is the correct answer to this question: {question} +Choices: +(A) {choice_A} +(B) {choice_B} +(C) {choice_C} +(D) {choice_D} + +Format your response as follows: "The correct answer is (insert answer here)".""" + + +def extract_answer(response): + """Extract answer from response (A, B, C, or D).""" + response = response.replace("*", "") + # Try to find "The correct answer is (X)" pattern + match = re.search(r"The correct answer is \(([A-D])\)", response, re.IGNORECASE) + if match: + return match.group(1).upper() + else: + match = re.search(r"The correct answer is ([A-D])", response, re.IGNORECASE) + if match: + return match.group(1).upper() + else: + # Try to find standalone A, B, C, or D + match = re.search(r"\b([A-D])\b", response) + if match: + return match.group(1).upper() + return None + + +def generate_response(llm_client, context, question, choice_a, choice_b, choice_c, choice_d): + """Generate response using LLM.""" + prompt = LONGBENCH_V2_PROMPT.format( + context=context, + question=question, + choice_A=choice_a, + choice_B=choice_b, + choice_C=choice_c, + choice_D=choice_d, + ) + + try: + response = llm_client.chat.completions.create( + model=os.getenv("CHAT_MODEL"), + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=128, + ) + result = response.choices[0].message.content or "" + return result + except Exception as e: + print(f"Error generating response: {e}") + return "" + + +def process_sample(search_result, llm_client): + """Process a single sample: generate answer.""" + start = time() + + context = search_result.get("context", "") + question = search_result.get("question", "") + choice_a = search_result.get("choice_A", "") + choice_b = search_result.get("choice_B", "") + choice_c = search_result.get("choice_C", "") + choice_d = search_result.get("choice_D", "") + + # Generate answer + response = generate_response( + llm_client, context, question, choice_a, choice_b, choice_c, choice_d + ) + + # Extract answer (A, B, C, or D) + pred = extract_answer(response) + + response_duration_ms = (time() - start) * 1000 + + return { + "sample_idx": search_result.get("sample_idx"), + "_id": search_result.get("_id"), + "domain": search_result.get("domain"), + "sub_domain": search_result.get("sub_domain"), + "difficulty": search_result.get("difficulty"), + "length": search_result.get("length"), + "question": question, + "choice_A": choice_a, + "choice_B": choice_b, + "choice_C": choice_c, + "choice_D": choice_d, + "answer": search_result.get("answer"), + "pred": pred, + "response": response, + "judge": pred == search_result.get("answer") if pred else False, + "search_context": context, + "response_duration_ms": response_duration_ms, + "search_duration_ms": search_result.get("search_duration_ms", 0), + } + + +def main(frame, version="default", num_workers=10): + """Main response generation function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load search results + search_path = ( + f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_search_results.json" + ) + if not os.path.exists(search_path): + print(f"❌ Search results not found: {search_path}") + print("Please run longbench_v2_search.py first") + return + + with open(search_path, encoding="utf-8") as f: + search_results = json.load(f) + + # Initialize LLM client + llm_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), + base_url=os.getenv("CHAT_MODEL_BASE_URL"), + ) + print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") + + # Process all samples + all_responses = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(process_sample, sample, llm_client) for sample in search_results] + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Generating responses", + ): + result = future.result() + if result: + all_responses.append(result) + + # Save responses + output_path = f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_responses.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(all_responses, f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 80}") + print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for loading results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py new file mode 100644 index 000000000..f46928498 --- /dev/null +++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py @@ -0,0 +1,192 @@ +import argparse +import json +import os +import sys + +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time + +from dotenv import load_dotenv +from tqdm import tqdm + + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + +def memos_api_search(client, query, user_id, top_k, frame): + """Search using memos API.""" + start = time() + search_results = client.search(query=query, user_id=user_id, top_k=top_k) + + # Format context from search results based on frame type + context = "" + if ( + (frame == "memos-api" or frame == "memos-api-online") + and isinstance(search_results, dict) + and "text_mem" in search_results + ): + context = "\n".join([i["memory"] for i in search_results["text_mem"][0]["memories"]]) + if "pref_string" in search_results: + context += f"\n{search_results.get('pref_string', '')}" + + duration_ms = (time() - start) * 1000 + return context, duration_ms + + +def process_sample(client, sample, sample_idx, frame, version, top_k): + """Process a single sample: search for relevant memories.""" + user_id = f"longbench_v2_{sample_idx}_{version}" + query = sample.get("question", "") + + if not query: + return None + + context, duration_ms = memos_api_search(client, query, user_id, top_k, frame) + + return { + "sample_idx": sample_idx, + "_id": sample.get("_id"), + "domain": sample.get("domain"), + "sub_domain": sample.get("sub_domain"), + "difficulty": sample.get("difficulty"), + "length": sample.get("length"), + "question": query, + "choice_A": sample.get("choice_A"), + "choice_B": sample.get("choice_B"), + "choice_C": sample.get("choice_C"), + "choice_D": sample.get("choice_D"), + "answer": sample.get("answer"), + "context": context, + "search_duration_ms": duration_ms, + } + + +def load_dataset_from_local(): + """Load LongBench v2 dataset from local JSON file.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "data", + "long_bench_v2", + ) + + filepath = os.path.join(data_dir, "data.json") + + if not os.path.exists(filepath): + raise FileNotFoundError(f"Dataset file not found: {filepath}") + + # Load JSON file + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + + return samples + + +def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): + """Main search function.""" + load_dotenv() + + print("\n" + "=" * 80) + print(f"🚀 LONGBENCH V2 SEARCH - {frame.upper()} v{version}".center(80)) + print("=" * 80 + "\n") + + # Load dataset from local file + try: + dataset = load_dataset_from_local() + print(f"Loaded {len(dataset)} samples from LongBench v2") + except FileNotFoundError as e: + print(f"❌ Error loading dataset: {e}") + return + except Exception as e: + print(f"❌ Error loading dataset: {e}") + return + + # Limit samples if specified + if max_samples: + dataset = dataset[:max_samples] + print(f"Limited to {len(dataset)} samples") + + # Initialize client + client = None + if frame == "memos-api": + from utils.client import MemosApiClient + + client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() + else: + print(f"❌ Unsupported frame: {frame}") + return + + # Process samples + search_results = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for idx, sample in enumerate(dataset): + future = executor.submit(process_sample, client, sample, idx, frame, version, top_k) + futures.append(future) + + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Searching LongBench v2", + ): + result = future.result() + if result: + search_results.append(result) + + # Save results + os.makedirs(f"results/long_bench-v2/{frame}-{version}/", exist_ok=True) + output_path = ( + f"results/long_bench-v2/{frame}-{version}/{frame}_longbench_v2_search_results.json" + ) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(search_results, f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 80}") + print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80)) + print(f"{'=' * 80}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + type=str, + choices=["memos-api", "memos-api-online"], + default="memos-api", + ) + parser.add_argument( + "--version", + type=str, + default="default", + help="Version identifier for saving results", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Number of parallel workers", + ) + parser.add_argument( + "--top_k", + type=int, + default=20, + help="Number of results to retrieve in search queries", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Maximum number of samples to process (default: all)", + ) + args = parser.parse_args() + + main(args.lib, args.version, args.workers, args.top_k, args.max_samples) diff --git a/evaluation/scripts/longbench/__init__.py b/evaluation/scripts/longbench/__init__.py deleted file mode 100644 index 38cc006e3..000000000 --- a/evaluation/scripts/longbench/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# LongBench evaluation scripts diff --git a/evaluation/scripts/longbench/longbench_ingestion.py b/evaluation/scripts/longbench/longbench_ingestion.py deleted file mode 100644 index e2d2a8e7e..000000000 --- a/evaluation/scripts/longbench/longbench_ingestion.py +++ /dev/null @@ -1,306 +0,0 @@ -import argparse -import json -import os -import sys - -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime, timezone - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# All LongBench datasets -LONGBENCH_DATASETS = [ - "narrativeqa", - "qasper", - "multifieldqa_en", - "multifieldqa_zh", - "hotpotqa", - "2wikimqa", - "musique", - "dureader", - "gov_report", - "qmsum", - "multi_news", - "vcsum", - "trec", - "triviaqa", - "samsum", - "lsht", - "passage_count", - "passage_retrieval_en", - "passage_retrieval_zh", - "lcc", - "repobench-p", -] - - -def ingest_sample(client, sample, dataset_name, sample_idx, frame, version): - """Ingest a single LongBench sample as memories.""" - user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" - conv_id = f"longbench_{dataset_name}_{sample_idx}_{version}" - - # Get context and convert to messages - context = sample.get("context", "") - # not used now: input_text = sample.get("input", "") - - # For memos, we ingest the context as document content - # Split context into chunks if it's too long (optional, memos handles this internally) - # For now, we'll ingest the full context as a single message - messages = [ - { - "role": "assistant", - "content": context, - "chat_time": datetime.now(timezone.utc).isoformat(), - } - ] - - if "memos-api" in frame: - try: - client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif "mem0" in frame: - timestamp = int(datetime.now(timezone.utc).timestamp()) - try: - client.add(messages=messages, user_id=user_id, timestamp=timestamp, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif frame == "memobase": - for m in messages: - m["created_at"] = messages[0]["chat_time"] - try: - client.add(messages=messages, user_id=user_id, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif frame == "memu": - try: - client.add(messages=messages, user_id=user_id, iso_date=messages[0]["chat_time"]) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - elif frame == "supermemory": - try: - client.add(messages=messages, user_id=user_id) - print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") - return False - - return False - - -def load_dataset_from_local(dataset_name, use_e=False): - """Load LongBench dataset from local JSONL file.""" - # Determine data directory - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - # Determine filename - filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" - - filepath = os.path.join(data_dir, filename) - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSONL file - samples = [] - with open(filepath, encoding="utf-8") as f: - for line in f: - if line.strip(): - samples.append(json.loads(line)) - - return samples - - -def ingest_dataset(dataset_name, frame, version, num_workers=10, max_samples=None, use_e=False): - """Ingest a single LongBench dataset.""" - print(f"\n{'=' * 80}") - print(f"🔄 [INGESTING DATASET: {dataset_name.upper()}]".center(80)) - print(f"{'=' * 80}\n") - - # Load dataset from local files - try: - dataset = load_dataset_from_local(dataset_name, use_e) - print(f"Loaded {len(dataset)} samples from {dataset_name}") - except FileNotFoundError as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return - except Exception as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize client - client = None - if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client - - client = Mem0Client(enable_graph="graph" in frame) - elif frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - elif frame == "memos-api-online": - from utils.client import MemosApiOnlineClient - - client = MemosApiOnlineClient() - elif frame == "memobase": - from utils.client import MemobaseClient - - client = MemobaseClient() - elif frame == "memu": - from utils.client import MemuClient - - client = MemuClient() - elif frame == "supermemory": - from utils.client import SupermemoryClient - - client = SupermemoryClient() - else: - print(f"❌ Unsupported frame: {frame}") - return - - # Ingest samples - success_count = 0 - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit( - ingest_sample, client, sample, dataset_name, idx, frame, version - ) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Ingesting {dataset_name}", - ): - try: - if future.result(): - success_count += 1 - except Exception as e: - print(f"Error processing sample: {e}") - - print(f"\n✅ Completed ingesting {dataset_name}: {success_count}/{len(dataset)} samples") - return success_count - - -def main(frame, version="default", num_workers=10, datasets=None, max_samples=None, use_e=False): - """Main ingestion function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH INGESTION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Determine which datasets to process - dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS - - # Filter valid datasets - valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] - if not valid_datasets: - print("❌ No valid datasets specified") - return - - print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") - - # Ingest each dataset - total_success = 0 - total_samples = 0 - for dataset_name in valid_datasets: - success = ingest_dataset(dataset_name, frame, version, num_workers, max_samples, use_e) - if success is not None: - total_success += success - total_samples += max_samples if max_samples else 200 # Approximate - - print(f"\n{'=' * 80}") - print(f"✅ INGESTION COMPLETE: {total_success} samples ingested".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - parser.add_argument( - "--datasets", - type=str, - default=None, - help="Comma-separated list of datasets to process (default: all)", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples per dataset (default: all)", - ) - parser.add_argument( - "--e", - action="store_true", - help="Use LongBench-E variant (uniform length distribution)", - ) - args = parser.parse_args() - - main( - args.lib, - args.version, - args.workers, - args.datasets, - args.max_samples, - args.e, - ) diff --git a/evaluation/scripts/longbench/longbench_metric.py b/evaluation/scripts/longbench/longbench_metric.py deleted file mode 100644 index 495a793ab..000000000 --- a/evaluation/scripts/longbench/longbench_metric.py +++ /dev/null @@ -1,235 +0,0 @@ -import argparse -import json -import os -import sys - -import numpy as np - - -# Import LongBench metrics -# Try to import from the LongBench directory -LONGBENCH_METRICS_DIR = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - "longbench_v2", - "LongBench-main", - "LongBench", -) - -if os.path.exists(LONGBENCH_METRICS_DIR): - sys.path.insert(0, LONGBENCH_METRICS_DIR) - try: - from metrics import ( - classification_score, - code_sim_score, - count_score, - qa_f1_score, - qa_f1_zh_score, - retrieval_score, - retrieval_zh_score, - rouge_score, - rouge_zh_score, - ) - except ImportError: - print(f"Warning: Could not import metrics from {LONGBENCH_METRICS_DIR}") - print("Please ensure LongBench metrics.py is available") - raise -else: - print(f"Error: LongBench metrics directory not found at {LONGBENCH_METRICS_DIR}") - raise FileNotFoundError("LongBench metrics directory not found") - -# Dataset to metric mapping (from LongBench eval.py) -dataset2metric = { - "narrativeqa": qa_f1_score, - "qasper": qa_f1_score, - "multifieldqa_en": qa_f1_score, - "multifieldqa_zh": qa_f1_zh_score, - "hotpotqa": qa_f1_score, - "2wikimqa": qa_f1_score, - "musique": qa_f1_score, - "dureader": rouge_zh_score, - "gov_report": rouge_score, - "qmsum": rouge_score, - "multi_news": rouge_score, - "vcsum": rouge_zh_score, - "trec": classification_score, - "triviaqa": qa_f1_score, - "samsum": rouge_score, - "lsht": classification_score, - "passage_retrieval_en": retrieval_score, - "passage_count": count_score, - "passage_retrieval_zh": retrieval_zh_score, - "lcc": code_sim_score, - "repobench-p": code_sim_score, -} - - -def scorer(dataset, predictions, answers, all_classes): - """Calculate score for a dataset.""" - total_score = 0.0 - for prediction, ground_truths in zip(predictions, answers, strict=False): - score = 0.0 - # For some tasks, only take the first line - if dataset in ["trec", "triviaqa", "samsum", "lsht"]: - prediction = prediction.lstrip("\n").split("\n")[0] - - # Calculate max score across all ground truth answers - for ground_truth in ground_truths: - metric_func = dataset2metric.get(dataset) - if metric_func: - if dataset in ["trec", "lsht"]: - # Classification tasks need all_classes - score = max( - score, - metric_func(prediction, ground_truth, all_classes=all_classes), - ) - else: - score = max(score, metric_func(prediction, ground_truth)) - else: - print(f"Warning: No metric function for dataset {dataset}") - - total_score += score - - return round(100 * total_score / len(predictions), 2) if len(predictions) > 0 else 0.0 - - -def scorer_e(dataset, predictions, answers, lengths, all_classes): - """Calculate score for LongBench-E (with length-based analysis).""" - scores = {"0-4k": [], "4-8k": [], "8k+": []} - - for prediction, ground_truths, length in zip(predictions, answers, lengths, strict=False): - score = 0.0 - # For some tasks, only take the first line - if dataset in ["trec", "triviaqa", "samsum", "lsht"]: - prediction = prediction.lstrip("\n").split("\n")[0] - - # Calculate max score across all ground truth answers - metric_func = dataset2metric.get(dataset) - if metric_func: - for ground_truth in ground_truths: - if dataset in ["trec", "lsht"]: - score = max( - score, - metric_func(prediction, ground_truth, all_classes=all_classes), - ) - else: - score = max(score, metric_func(prediction, ground_truth)) - - # Categorize by length - if length < 4000: - scores["0-4k"].append(score) - elif length < 8000: - scores["4-8k"].append(score) - else: - scores["8k+"].append(score) - - # Calculate average scores per length category - for key in scores: - if len(scores[key]) > 0: - scores[key] = round(100 * np.mean(scores[key]), 2) - else: - scores[key] = 0.0 - - return scores - - -def main(frame, version="default", use_e=False): - """Main metric calculation function.""" - print("\n" + "=" * 80) - print(f"📊 LONGBENCH METRICS CALCULATION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load responses - responses_path = f"results/longbench/{frame}-{version}/{frame}_longbench_responses.json" - if not os.path.exists(responses_path): - print(f"❌ Responses not found: {responses_path}") - print("Please run longbench_responses.py first") - return - - with open(responses_path, encoding="utf-8") as f: - responses = json.load(f) - - # Calculate metrics for each dataset - all_scores = {} - overall_scores = [] - - for dataset_name, samples in responses.items(): - print(f"Calculating metrics for {dataset_name}...") - - predictions = [s.get("answer", "") for s in samples] - answers = [s.get("golden_answer", []) for s in samples] - all_classes = samples[0].get("all_classes") if samples else None - - if use_e: - lengths = [s.get("length", 0) for s in samples] - score = scorer_e(dataset_name, predictions, answers, lengths, all_classes) - else: - score = scorer(dataset_name, predictions, answers, all_classes) - - all_scores[dataset_name] = score - print(f" {dataset_name}: {score}") - - # For overall average, use single score (not length-based) - if use_e: - # Average across length categories - if isinstance(score, dict): - overall_scores.append(np.mean(list(score.values()))) - else: - overall_scores.append(score) - - # Calculate overall average - if overall_scores: - all_scores["average"] = round(np.mean(overall_scores), 2) - print(f"\nOverall Average: {all_scores['average']}") - - # Save metrics - output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_metrics.json" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(all_scores, f, ensure_ascii=False, indent=4) - - print(f"\n{'=' * 80}") - print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - # Print summary table - print("\n📊 Summary of Results:") - print("-" * 80) - for dataset, score in sorted(all_scores.items()): - if isinstance(score, dict): - print(f"{dataset:30s}: {score}") - else: - print(f"{dataset:30s}: {score:.2f}%") - print("-" * 80) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for loading results", - ) - parser.add_argument( - "--e", - action="store_true", - help="Use LongBench-E variant (uniform length distribution)", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.e) diff --git a/evaluation/scripts/longbench/longbench_responses.py b/evaluation/scripts/longbench/longbench_responses.py deleted file mode 100644 index 2d160160a..000000000 --- a/evaluation/scripts/longbench/longbench_responses.py +++ /dev/null @@ -1,196 +0,0 @@ -import argparse -import json -import os -import sys - -from concurrent.futures import ThreadPoolExecutor, as_completed -from time import time - -from dotenv import load_dotenv -from openai import OpenAI -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# Dataset to prompt mapping (from LongBench config) -DATASET_PROMPTS = { - "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", - "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:', - "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", - "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", - "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", - "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", - "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", - "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", - "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", - "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", - "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", - "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", - "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", - "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ', - "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:', - "lcc": "Please complete the code given below. \n{context}Next line of code:\n", - "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n", -} - - -def generate_response(llm_client, dataset_name, context, input_text): - """Generate response using LLM.""" - # Get prompt template for dataset - prompt_template = DATASET_PROMPTS.get(dataset_name, "{context}\n\nQuestion: {input}\nAnswer:") - - # Format prompt - if "{input}" in prompt_template: - prompt = prompt_template.format(context=context, input=input_text) - else: - # Some prompts don't have {input} placeholder (like gov_report, vcsum) - prompt = prompt_template.format(context=context) - - try: - response = llm_client.chat.completions.create( - model=os.getenv("CHAT_MODEL"), - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - temperature=0, - ) - result = response.choices[0].message.content or "" - return result - except Exception as e: - print(f"Error generating response: {e}") - return "" - - -def process_sample(search_result, llm_client): - """Process a single sample: generate answer.""" - start = time() - - dataset_name = search_result.get("dataset") - context = search_result.get("context", "") - input_text = search_result.get("input", "") - - # Generate answer - answer = generate_response(llm_client, dataset_name, context, input_text) - - response_duration_ms = (time() - start) * 1000 - - return { - "dataset": dataset_name, - "sample_idx": search_result.get("sample_idx"), - "input": input_text, - "answer": answer, - "golden_answer": search_result.get("answers", []), - "all_classes": search_result.get("all_classes"), - "length": search_result.get("length", 0), - "search_context": context, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_result.get("search_duration_ms", 0), - } - - -def main(frame, version="default", num_workers=10): - """Main response generation function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load search results - search_path = f"results/longbench/{frame}-{version}/{frame}_longbench_search_results.json" - if not os.path.exists(search_path): - print(f"❌ Search results not found: {search_path}") - print("Please run longbench_search.py first") - return - - with open(search_path, encoding="utf-8") as f: - search_results = json.load(f) - - # Initialize LLM client - llm_client = OpenAI( - api_key=os.getenv("CHAT_MODEL_API_KEY"), - base_url=os.getenv("CHAT_MODEL_BASE_URL"), - ) - print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") - - # Process all samples - all_responses = [] - for dataset_name, samples in search_results.items(): - print(f"\nProcessing {len(samples)} samples from {dataset_name}...") - - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(process_sample, sample, llm_client) for sample in samples] - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Generating responses for {dataset_name}", - ): - result = future.result() - if result: - all_responses.append(result) - - # Save responses - output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_responses.json" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - # Group by dataset - responses_by_dataset = {} - for response in all_responses: - dataset = response["dataset"] - if dataset not in responses_by_dataset: - responses_by_dataset[dataset] = [] - responses_by_dataset[dataset].append(response) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(responses_by_dataset, f, ensure_ascii=False, indent=2) - - print(f"\n{'=' * 80}") - print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for loading results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.workers) diff --git a/evaluation/scripts/longbench/longbench_search.py b/evaluation/scripts/longbench/longbench_search.py deleted file mode 100644 index aaf7300e4..000000000 --- a/evaluation/scripts/longbench/longbench_search.py +++ /dev/null @@ -1,309 +0,0 @@ -import argparse -import json -import os -import sys - -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from time import time - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# All LongBench datasets -LONGBENCH_DATASETS = [ - "narrativeqa", - "qasper", - "multifieldqa_en", - "multifieldqa_zh", - "hotpotqa", - "2wikimqa", - "musique", - "dureader", - "gov_report", - "qmsum", - "multi_news", - "vcsum", - "trec", - "triviaqa", - "samsum", - "lsht", - "passage_count", - "passage_retrieval_en", - "passage_retrieval_zh", - "lcc", - "repobench-p", -] - - -def memos_api_search(client, query, user_id, top_k, frame): - """Search using memos API.""" - start = time() - search_results = client.search(query=query, user_id=user_id, top_k=top_k) - - # Format context from search results based on frame type - context = "" - if frame == "memos-api" or frame == "memos-api-online": - if isinstance(search_results, dict) and "text_mem" in search_results: - context = "\n".join([i["memory"] for i in search_results["text_mem"][0]["memories"]]) - if "pref_string" in search_results: - context += f"\n{search_results.get('pref_string', '')}" - elif frame == "mem0" or frame == "mem0_graph": - if isinstance(search_results, dict) and "results" in search_results: - context = "\n".join( - [ - f"{m.get('created_at', '')}: {m.get('memory', '')}" - for m in search_results["results"] - ] - ) - elif frame == "memobase": - context = search_results if isinstance(search_results, str) else "" - elif frame == "memu": - context = "\n".join(search_results) if isinstance(search_results, list) else "" - elif frame == "supermemory": - context = search_results if isinstance(search_results, str) else "" - - duration_ms = (time() - start) * 1000 - return context, duration_ms - - -def process_sample(client, sample, dataset_name, sample_idx, frame, version, top_k): - """Process a single sample: search for relevant memories.""" - user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" - query = sample.get("input", "") - - if not query: - return None - - context, duration_ms = memos_api_search(client, query, user_id, top_k, frame) - - return { - "dataset": dataset_name, - "sample_idx": sample_idx, - "input": query, - "context": context, - "search_duration_ms": duration_ms, - "answers": sample.get("answers", []), - "all_classes": sample.get("all_classes"), - "length": sample.get("length", 0), - } - - -def load_dataset_from_local(dataset_name, use_e=False): - """Load LongBench dataset from local JSONL file.""" - # Determine data directory - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - # Determine filename - filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" - - filepath = os.path.join(data_dir, filename) - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSONL file - samples = [] - with open(filepath, encoding="utf-8") as f: - for line in f: - if line.strip(): - samples.append(json.loads(line)) - - return samples - - -def process_dataset( - dataset_name, frame, version, top_k=20, num_workers=10, max_samples=None, use_e=False -): - """Process a single dataset: search for all samples.""" - print(f"\n{'=' * 80}") - print(f"🔍 [SEARCHING DATASET: {dataset_name.upper()}]".center(80)) - print(f"{'=' * 80}\n") - - # Load dataset from local files - try: - dataset = load_dataset_from_local(dataset_name, use_e) - print(f"Loaded {len(dataset)} samples from {dataset_name}") - except FileNotFoundError as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return [] - except Exception as e: - print(f"❌ Error loading dataset {dataset_name}: {e}") - return [] - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize client - client = None - if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client - - client = Mem0Client(enable_graph="graph" in frame) - elif frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - elif frame == "memos-api-online": - from utils.client import MemosApiOnlineClient - - client = MemosApiOnlineClient() - elif frame == "memobase": - from utils.client import MemobaseClient - - client = MemobaseClient() - elif frame == "memu": - from utils.client import MemuClient - - client = MemuClient() - elif frame == "supermemory": - from utils.client import SupermemoryClient - - client = SupermemoryClient() - else: - print(f"❌ Unsupported frame: {frame}") - return [] - - # Process samples - search_results = [] - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit( - process_sample, client, sample, dataset_name, idx, frame, version, top_k - ) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Searching {dataset_name}", - ): - result = future.result() - if result: - search_results.append(result) - - print(f"\n✅ Completed searching {dataset_name}: {len(search_results)} samples") - return search_results - - -def main( - frame, version="default", num_workers=10, top_k=20, datasets=None, max_samples=None, use_e=False -): - """Main search function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH SEARCH - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Determine which datasets to process - dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS - - # Filter valid datasets - valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] - if not valid_datasets: - print("❌ No valid datasets specified") - return - - print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") - - # Create output directory - os.makedirs(f"results/longbench/{frame}-{version}/", exist_ok=True) - - # Process each dataset - all_results = defaultdict(list) - for dataset_name in valid_datasets: - results = process_dataset( - dataset_name, frame, version, top_k, num_workers, max_samples, use_e - ) - all_results[dataset_name] = results - - # Save results - output_path = f"results/longbench/{frame}-{version}/{frame}_longbench_search_results.json" - with open(output_path, "w", encoding="utf-8") as f: - json.dump(dict(all_results), f, ensure_ascii=False, indent=2) - - print(f"\n{'=' * 80}") - print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=[ - "mem0", - "mem0_graph", - "memos-api", - "memos-api-online", - "memobase", - "memu", - "supermemory", - ], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - parser.add_argument( - "--top_k", - type=int, - default=20, - help="Number of results to retrieve in search queries", - ) - parser.add_argument( - "--datasets", - type=str, - default=None, - help="Comma-separated list of datasets to process (default: all)", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples per dataset (default: all)", - ) - parser.add_argument( - "--e", - action="store_true", - help="Use LongBench-E variant (uniform length distribution)", - ) - args = parser.parse_args() - - main( - args.lib, - args.version, - args.workers, - args.top_k, - args.datasets, - args.max_samples, - args.e, - ) diff --git a/evaluation/scripts/longbench_v2/prepare_data.py b/evaluation/scripts/longbench_v2/prepare_data.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/memos/embedders/base.py b/src/memos/embedders/base.py index 22ef0d302..e46611d1a 100644 --- a/src/memos/embedders/base.py +++ b/src/memos/embedders/base.py @@ -23,7 +23,7 @@ def _count_tokens_for_embedding(text: str) -> int: enc = tiktoken.encoding_for_model("gpt-4o-mini") except Exception: enc = tiktoken.get_encoding("cl100k_base") - return len(enc.encode(text or "")) + return len(enc.encode(text or "", disallowed_special=())) except Exception: # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars if not text: diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 1d8a25b67..603adbd7d 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -152,7 +152,7 @@ def __init__(self, config: PolarDBGraphDBConfig): # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( minconn=5, - maxconn=500, + maxconn=2000, host=host, port=port, user=user, diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index f43ad01ba..2dcf75846 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -89,7 +89,7 @@ def from_config(_config): _ENC = tiktoken.get_encoding("cl100k_base") def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "")) + return len(_ENC.encode(s or "", disallowed_special=())) except Exception: # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars def _count_tokens_text(s: str) -> int: diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 3226f7ca0..2a3bae944 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -92,9 +92,9 @@ def add( """ added_ids: list[str] = [] - with ContextThreadPoolExecutor(max_workers=200) as executor: + with ContextThreadPoolExecutor(max_workers=50) as executor: futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} - for future in as_completed(futures, timeout=60): + for future in as_completed(futures, timeout=500): try: ids = future.result() added_ids.extend(ids) @@ -102,7 +102,7 @@ def add( logger.exception("Memory processing error: ", exc_info=e) if mode == "sync": - for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + for mem_type in ["WorkingMemory"]: try: self.graph_store.remove_oldest_memory( memory_type="WorkingMemory",