diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py index a4aeca498..a500438b6 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py @@ -154,26 +154,122 @@ def __init__( reader: MemReader instance for processing internet content max_results: Maximum number of search results to retrieve """ - import nltk - - try: - nltk.download("averaged_perceptron_tagger_eng") - except Exception as err: - raise Exception("Failed to download nltk averaged_perceptron_tagger_eng") from err - try: - nltk.download("stopwords") - except Exception as err: - raise Exception("Failed to download nltk stopwords") from err from jieba.analyse import TextRank - from rake_nltk import Rake self.bocha_api = BochaAISearchAPI(access_key, max_results=max_results) self.embedder = embedder self.reader = reader - self.en_fast_keywords_extractor = Rake() self.zh_fast_keywords_extractor = TextRank() + def _extract_tags(self, title: str, content: str, summary: str, parsed_goal=None) -> list[str]: + """ + Extract tags from title, content and summary + + Args: + title: Article title + content: Article content + summary: Article summary + parsed_goal: Parsed task goal (optional) + + Returns: + List of extracted tags + """ + tags = [] + + # Add source-based tags + tags.append("bocha_search") + tags.append("news") + + # Add content-based tags + text = f"{title} {content} {summary}".lower() + + # Simple keyword-based tagging + keywords = { + "economy": [ + "economy", + "GDP", + "growth", + "production", + "industry", + "investment", + "consumption", + "market", + "trade", + "finance", + ], + "politics": [ + "politics", + "government", + "policy", + "meeting", + "leader", + "election", + "parliament", + "ministry", + ], + "technology": [ + "technology", + "tech", + "innovation", + "digital", + "internet", + "AI", + "artificial intelligence", + "software", + "hardware", + ], + "sports": [ + "sports", + "game", + "athlete", + "olympic", + "championship", + "tournament", + "team", + "player", + ], + "culture": [ + "culture", + "education", + "art", + "history", + "literature", + "music", + "film", + "museum", + ], + "health": [ + "health", + "medical", + "pandemic", + "hospital", + "doctor", + "medicine", + "disease", + "treatment", + ], + "environment": [ + "environment", + "ecology", + "pollution", + "green", + "climate", + "sustainability", + "renewable", + ], + } + + for category, words in keywords.items(): + if any(word in text for word in words): + tags.append(category) + + # Add goal-based tags if available + if parsed_goal and hasattr(parsed_goal, "tags"): + tags.extend(parsed_goal.tags) + + return list(set(tags))[:15] # Limit to 15 tags + def retrieve_from_internet( self, query: str, top_k: int = 10, parsed_goal=None, info=None, mode="fast" ) -> list[TextualMemoryItem]: @@ -259,9 +355,9 @@ def _process_result( session_id = info_.pop("session_id", "") lang = detect_lang(summary) tags = ( - self.zh_fast_keywords_extractor.textrank(summary)[:3] + self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3] if lang == "zh" - else self.en_fast_keywords_extractor.extract_keywords_from_text(summary)[:3] + else self._extract_tags(title, content, summary)[:3] ) return [