diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index f1253bd75..65b384c12 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -102,6 +102,24 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # Local codebase / document search (sirchmunk), exposed as the `localsearch` tool + localsearch: + mcp: false + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # Optional: llm_api_key, llm_base_url, llm_model_name (else inherited from `llm`) + # When true, a shallow sirchmunk DirectoryScanner run at tool connect injects file titles/previews + # into the `localsearch` tool description (default: false) + # description_catalog: false + # description_catalog_max_files: 120 + # description_catalog_max_depth: 5 + # description_catalog_max_chars: 10000 + # description_catalog_max_preview_chars: 400 + # description_catalog_cache_ttl_seconds: 300 + # description_catalog_exclude: [] # extra globs / dir names merged with sirchmunk defaults ``` For the complete list of supported tools and custom tools, please refer to [here](./Tools.md) @@ -167,19 +185,19 @@ In addition to yaml configuration, MS-Agent also supports several additional com > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. -- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. +- knowledge_search_paths: Comma-separated local search paths. Merges into `tools.localsearch.paths` and registers the **`localsearch`** tool (sirchmunk) for on-demand use by the model—not automatic per-turn injection. LLM settings are inherited from the `llm` module unless you set `tools.localsearch.llm_*` fields. ### Quick Start for Knowledge Search -Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: +Use `--knowledge_search_paths` or define `tools.localsearch` in yaml so the model can call `localsearch` when needed: ```bash # Using default agent.yaml configuration, automatically reuses LLM settings -ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "/path/to/docs" # Specify configuration file ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" ``` LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. -If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. +For a dedicated sirchmunk LLM, set `tools.localsearch.llm_api_key`, `llm_base_url`, and `llm_model_name` in yaml. Legacy top-level `knowledge_search` with the same keys is still read for backward compatibility. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 12849f2a7..a3031d918 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -102,6 +102,24 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # 本地代码库/文档搜索(sirchmunk),对应模型可调用的 `localsearch` 工具 + localsearch: + mcp: false + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # 可选:llm_api_key、llm_base_url、llm_model_name(不填则从 `llm` 继承) + # 为 true 时,在工具连接阶段用 sirchmunk DirectoryScanner 做浅层扫描,把文件标题/预览写入 + # `localsearch` 工具 description,便于模型知道本地知识库里大致有哪些内容(默认 false) + # description_catalog: false + # description_catalog_max_files: 120 + # description_catalog_max_depth: 5 + # description_catalog_max_chars: 10000 + # description_catalog_max_preview_chars: 400 + # description_catalog_cache_ttl_seconds: 300 + # description_catalog_exclude: [] # 额外 glob / 目录名,与 sirchmunk 默认排除合并 ``` 支持的完整工具列表,以及自定义工具请参考 [这里](./tools) @@ -165,13 +183,13 @@ handler: custom_handler } } ``` -- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 +- knowledge_search_paths: 知识搜索路径,逗号分隔。会合并到 `tools.localsearch.paths` 并注册 **`localsearch`** 工具(sirchmunk),由模型按需调用;如未配置 `tools.localsearch.llm_*`, LLM 从 `llm` 模块复用 > agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 ### 知识搜索快速使用 -通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: +通过 `--knowledge_search_paths` 或在 yaml 中配置 `tools.localsearch`,启用本地知识搜索(模型按需调用 `localsearch`): ```bash # 使用默认 agent.yaml 配置,自动复用 LLM 设置 @@ -182,4 +200,4 @@ ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_sea ``` LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 -如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 +若 sirchmunk 需独立 LLM,可在 yaml 的 `tools.localsearch` 下设置 `llm_api_key`、`llm_base_url`、`llm_model_name`。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example deleted file mode 100644 index cc11a8a3d..000000000 --- a/examples/knowledge_search/agent.yaml.example +++ /dev/null @@ -1,86 +0,0 @@ -# Sirchmunk Knowledge Search 配置示例 -# Sirchmunk Knowledge Search Configuration Example - -# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: - -llm: - service: modelscope - model: Qwen/Qwen3-235B-A22B-Instruct-2507 - modelscope_api_key: - modelscope_base_url: https://api-inference.modelscope.cn/v1 - -generation_config: - temperature: 0.3 - top_k: 20 - stream: true - -# Knowledge Search 配置(可选) -# 用于在本地代码库中搜索相关信息 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录,用于缓存 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则使用上面 llm 的配置) - llm_api_key: - llm_base_url: https://api.openai.com/v1 - llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:聚类相似度阈值 - cluster_sim_threshold: 0.85 - - # 可选:聚类 TopK - cluster_sim_top_k: 3 - - # 可选:是否重用之前的知识 - reuse_knowledge: true - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:最大循环次数 - max_loops: 10 - - # 可选:最大 token 预算 - max_token_budget: 128000 - -prompt: - system: | - You are an assistant that helps me complete tasks. - -max_chat_round: 9999 - -# 使用说明: -# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 -# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 -# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM -# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 -# -# Python 使用示例: -# ```python -# from ms_agent import LLMAgent -# from ms_agent.config import Config -# -# config = Config.from_task('path/to/agent.yaml') -# agent = LLMAgent(config=config) -# result = await agent.run('如何实现用户认证功能?') -# -# # 获取搜索详情(用于前端展示) -# for msg in result: -# if msg.role == 'user': -# print(f"Search logs: {msg.searching_detail}") -# print(f"Search results: {msg.search_result}") -# ``` -# -# CLI 测试命令: -# export LLM_API_KEY="your-api-key" -# export LLM_BASE_URL="https://api.openai.com/v1" -# export LLM_MODEL_NAME="gpt-4o-mini" -# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/agent.yaml b/ms_agent/agent/agent.yaml index a8a65d440..21b01c22f 100644 --- a/ms_agent/agent/agent.yaml +++ b/ms_agent/agent/agent.yaml @@ -13,42 +13,7 @@ generation_config: prompt: system: | - You are an assistant that helps me complete tasks. You need to follow these instructions: - - 1. Analyze whether my requirements need tool-calling. If no tools are needed, you can think directly and provide an answer. - - 2. I will give you many tools, some of which are similar. Please carefully analyze which tool you currently need to invoke. - * If tools need to be invoked, you must call at least one tool in each round until the requirement is completed. - * If you get any useful links or images from the tool calling, output them with your answer as well. - * Check carefully the tool result, what it contains, whether it has information you need. - - 3. You DO NOT have built-in geocode/coordinates/links. Do not output any fake geocode/coordinates/links. Always query geocode/coordinates/links from tools first! - - 4. If you need to complete coding tasks, you need to carefully analyze the original requirements, provide detailed requirement analysis, and then complete the code writing. - - 5. This conversation is NOT for demonstration or testing purposes. Answer it as accurately as you can. - - 6. Do not call tools carelessly. Show your thoughts **as detailed as possible**. - - 7. Respond in the same language the user uses. If the user switches, switch accordingly. - - For requests that require performing a specific task or retrieving information, using the following format: - ``` - The user needs to ... - I have analyzed this request in detail and broken it down into the following steps: - ... - ``` - If you have tools which may help you to solve problems, follow this format to answer: - ``` - The user needs to ... - I have analyzed this request in detail and broken it down into the following steps: - ... - First, I should use the [Tool Name] because [explain relevance]. The required input parameters are: ... - ... - I have carefully reviewed the tool's output. The result does/does not fully meet my expectations. Next, I need to ... - ``` - - **Important: Always respond in the same language the user is using.** + you are a helpful assistant. max_chat_round: 9999 diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 5f2ddf2e7..0e32cfc3a 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -13,7 +13,6 @@ import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping @@ -107,7 +106,6 @@ def __init__( self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None - self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -528,6 +526,7 @@ async def parallel_tool_call(self, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], resources=tool_call_result_format.resources, + tool_detail=tool_call_result_format.tool_detail, ) if _new_message.tool_call_id is None: @@ -538,6 +537,63 @@ async def parallel_tool_call(self, self.log_output(_new_message.content) return messages + async def parallel_tool_call_streaming( + self, messages: List[Message]) -> AsyncGenerator: + """Streaming variant of parallel_tool_call. + + Yields messages list snapshots during tool execution: + - While tools are running: yields messages with the latest incremental + ``tool_detail`` on a temporary placeholder Message (content='') so the + caller can stream logs to the frontend. + - After all tools finish: yields the final messages list (with proper + tool result Messages appended), same as parallel_tool_call. + """ + tool_calls = messages[-1].tool_calls + + # Map call_id -> tool_call_query for final message construction. + call_id_to_query = {tc['id']: tc for tc in tool_calls} + + # Accumulate final results keyed by call_id. + final_results: dict = {} + + async for call_id, item, is_final in self.tool_manager.parallel_call_tool_streaming( + tool_calls): + if is_final: + # Final result for this call_id (any type; not inferred from content). + final_results[call_id] = item + else: + # Intermediate log line: one incremental chunk in tool_detail. + log_message = Message( + role='tool', + content='', + tool_call_id=call_id, + name=call_id_to_query.get(call_id, + {}).get('tool_name', ''), + tool_detail=item, + ) + yield messages + [log_message] + + # All tools done — build final tool messages and yield. + for tool_call_query in tool_calls: + cid = tool_call_query['id'] + raw_result = final_results.get( + cid, f'Tool call missing result for id {cid}') + tool_call_result_format = ToolResult.from_raw(raw_result) + _new_message = Message( + role='tool', + content=tool_call_result_format.text, + tool_call_id=cid, + name=tool_call_query['tool_name'], + resources=tool_call_result_format.resources, + ) + if _new_message.tool_call_id is None: + _new_message.tool_call_id = str(uuid.uuid4())[:8] + tool_call_query['id'] = _new_message.tool_call_id + messages.append(_new_message) + self.log_output(_new_message.content) + + yield messages + async def prepare_tools(self): """Initialize and connect the tool manager.""" self.tool_manager = ToolManager( @@ -636,11 +692,7 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): - """Process RAG or knowledge search to enrich the user query with context. - - This method handles both traditional RAG and sirchmunk-based knowledge search. - For knowledge search, it also populates searching_detail and search_result - fields in the message for frontend display and next-turn LLM context. + """Process RAG to enrich the user query with context. Args: messages (List[Message]): The message list to process. @@ -654,23 +706,6 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search - if self.knowledge_search is not None: - # Perform search and get results - search_result = await self.knowledge_search.query(query) - search_details = self.knowledge_search.get_search_details() - - # Store search details in the message for frontend display - user_message.searching_detail = search_details - user_message.search_result = search_result - - # Build enriched context from search results - if search_result: - # Append search context to user query - context = search_result - user_message.content = ( - f'Relevant context retrieved from codebase search:\n\n{context}\n\n' - f'User question: {query}') async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -757,18 +792,6 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) - async def prepare_knowledge_search(self): - """Load and initialize the knowledge search component from the config.""" - if self.knowledge_search is not None: - # Already initialized (e.g. by caller before run_loop), skip to avoid - # overwriting a configured instance (e.g. one with streaming callbacks set). - return - if hasattr(self.config, 'knowledge_search'): - ks_config = self.config.knowledge_search - if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch( - self.config) - async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -931,7 +954,15 @@ async def step( self.save_history(messages) if _response_message.tool_calls: - messages = await self.parallel_tool_call(messages) + # Use the streaming variant so intermediate tool logs are yielded + # back to the caller while the tools are still running. + async for messages in self.parallel_tool_call_streaming(messages): + _lm = messages[-1] + _progress = ( + _lm.role == 'tool' and _lm.content == '' + and _lm.tool_detail is not None) + if _progress: + yield messages await self.after_tool_call(messages) @@ -1111,7 +1142,6 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() - await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 67ec8d563..1b8701695 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -136,9 +136,7 @@ def define_args(parsers: argparse.ArgumentParser): required=False, type=str, default=None, - help= - 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' - ) + help='Comma-separated list of paths for knowledge search.') parser.set_defaults(func=subparser_func) def execute(self): @@ -170,7 +168,6 @@ def execute(self): def _execute_with_config(self): Env.load_dotenv_into_environ(getattr(self.args, 'env', None)) - if not self.args.config: current_dir = os.getcwd() if os.path.exists(os.path.join(current_dir, AGENT_CONFIG_FILE)): @@ -218,31 +215,28 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) - # If knowledge_search_paths is provided, configure SirchmunkSearch + # If knowledge_search_paths is provided, configure tools.localsearch if getattr(self.args, 'knowledge_search_paths', None): paths = [ p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip() ] if paths: - if 'knowledge_search' not in config or not config.knowledge_search: - # No existing knowledge_search config, create minimal config - # LLM settings will be auto-reused from llm module by SirchmunkSearch - knowledge_search_config = { - 'name': 'SirchmunkSearch', + if not hasattr(config, 'tools') or config.tools is None: + config['tools'] = OmegaConf.create({}) + tl = getattr(config.tools, 'localsearch', None) + if tl is None or not OmegaConf.is_config(tl): + localsearch_config = { 'paths': paths, 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create( - knowledge_search_config) + config.tools['localsearch'] = OmegaConf.create( + localsearch_config) else: - # Existing knowledge_search config found, only update paths - # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container( - config.knowledge_search, resolve=True) + existing = OmegaConf.to_container(tl, resolve=True) existing['paths'] = paths - config['knowledge_search'] = OmegaConf.create(existing) + config.tools['localsearch'] = OmegaConf.create(existing) if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index 2f6175524..58e3ef26a 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -16,6 +16,20 @@ logger = get_logger() +# ``tools.`` entries implemented as in-process ToolBase classes, not MCP transports. +_BUILTIN_TOOL_SERVERS = frozenset({ + 'localsearch', + 'web_search', + 'split_task', + 'file_system', + 'code_executor', + 'financial_data_fetcher', + 'agent_tools', + 'todo_list', + 'image_generator', + 'video_generator', +}) + class ConfigLifecycleHandler: @@ -253,6 +267,10 @@ def convert_mcp_servers_to_json( for server, server_config in config.tools.items(): if server == TOOL_PLUGIN_NAME: continue - if getattr(server_config, 'mcp', True): + if server in _BUILTIN_TOOL_SERVERS: + use_mcp = getattr(server_config, 'mcp', False) + else: + use_mcp = getattr(server_config, 'mcp', True) + if use_mcp: servers['mcpServers'][server] = deepcopy(server_config) return servers diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py index 33362beee..f6c7f5143 100644 --- a/ms_agent/knowledge_search/__init__.py +++ b/ms_agent/knowledge_search/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Knowledge search module based on sirchmunk. +"""Backward-compatible re-exports for sirchmunk local search. This module provides integration between sirchmunk's AgenticSearch -and the ms_agent framework, enabling intelligent codebase search -capabilities similar to RAG. +and the ms_agent framework, enabling intelligent local path search +capabilities. """ -from .sirchmunk_search import SirchmunkSearch +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch __all__ = ['SirchmunkSearch'] diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 410aa12f0..08131dbd2 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,11 +61,10 @@ class Message: api_calls: int = 1 - # Knowledge search (sirchmunk) related fields - # searching_detail: Search process logs and metadata for frontend display - searching_detail: Dict[str, Any] = field(default_factory=dict) - # search_result: Raw search results to be used as context for next LLM turn - search_result: List[Dict[str, Any]] = field(default_factory=list) + # role=tool: UI / SSE only; omitted from LLM API via to_dict_clean(). + # During parallel_tool_call_streaming, temporary placeholder messages use this for + # one incremental log line each (content=''); completed tool messages omit it here. + tool_detail: Optional[str] = None def to_dict(self): return asdict(self) @@ -88,7 +87,11 @@ def to_dict_clean(self): } } required = ['content', 'role'] - rm = ['completion_tokens', 'prompt_tokens', 'api_calls'] + # Never send UI-only fields to model providers. + rm = [ + 'completion_tokens', 'prompt_tokens', 'api_calls', 'tool_detail', + 'searching_detail', 'search_result' + ] return { key: value for key, value in raw_dict.items() @@ -98,20 +101,33 @@ def to_dict_clean(self): @dataclass class ToolResult: + """Tool execution outcome. + + ``text`` is sent to the model as the tool message ``content``. + ``tool_detail`` is optional verbose output for frontends only (SSE, logs). + """ + text: str resources: List[str] = field(default_factory=list) extra: dict = field(default_factory=dict) + tool_detail: Optional[str] = None @staticmethod def from_raw(raw): if isinstance(raw, str): return ToolResult(text=raw) if isinstance(raw, dict): + model_text = raw.get('result') + if model_text is None: + model_text = raw.get('text', '') + td = raw.get('tool_detail') return ToolResult( - text=str(raw.get('text', '')), + text=str(model_text), resources=raw.get('resources', []), + tool_detail=None if td is None else str(td), extra={ k: v - for k, v in raw.items() if k not in ['text', 'resources'] + for k, v in raw.items() + if k not in ['text', 'resources', 'result', 'tool_detail'] }) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index e66da954d..08e9a4db7 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -4,6 +4,3 @@ rag_mapping = { 'LlamaIndexRAG': LlamaIndexRAG, } - -# Note: SirchmunkSearch is registered in knowledge_search module -# and integrated directly in LLMAgent, not through rag_mapping diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index 12ece9948..aed0b9768 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -89,3 +89,26 @@ async def call_tool(self, server_name: str, *, tool_name: str, Calling result in string format. """ pass + + async def call_tool_streaming(self, server_name: str, *, tool_name: str, + tool_args: dict): + """Streaming variant of call_tool. + + Contract for overrides: + - Emit zero or more intermediate updates as strings (UI / log lines). + - The **last** value yielded before the async generator finishes must + be the same shape as ``call_tool`` would return: a ``str`` or a + ``dict`` understood by ``ToolResult.from_raw`` (e.g. with a + ``result`` / ``text`` field). + + Callers inside this package (``ToolManager.single_call_tool_streaming``) + do **not** infer \"final vs log\" from types: they mark the terminal + emission explicitly. Short string finals (e.g. ``\"OK\"``) are therefore + valid and unambiguous. + + Default implementation: no intermediate yields; a single final yield + from ``call_tool``. + """ + result = await self.call_tool( + server_name, tool_name=tool_name, tool_args=tool_args) + yield result diff --git a/ms_agent/tools/search/localsearch_catalog.py b/ms_agent/tools/search/localsearch_catalog.py new file mode 100644 index 000000000..3ce840802 --- /dev/null +++ b/ms_agent/tools/search/localsearch_catalog.py @@ -0,0 +1,491 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Build a compact file catalog for localsearch tool descriptions using sirchmunk's DirectoryScanner.""" + +from __future__ import annotations +import hashlib +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + +import json + + +def catalog_cache_path(work_path: Path, fingerprint: str) -> Path: + work_path.mkdir(parents=True, exist_ok=True) + return work_path / f'localsearch_description_catalog.{fingerprint}.json' + + +def catalog_fingerprint( + roots: List[str], + max_files: int, + max_depth: int, + max_preview_chars: int, + max_chars: int, + exclude: Optional[List[str]], +) -> str: + payload = { + 'roots': sorted(roots), + 'max_files': max_files, + 'max_depth': max_depth, + 'max_preview_chars': max_preview_chars, + 'max_chars': max_chars, + 'exclude': sorted(exclude or []), + } + raw = json.dumps( + payload, sort_keys=True, ensure_ascii=False).encode('utf-8') + return hashlib.sha256(raw).hexdigest()[:24] + + +def load_cached_catalog( + path: Path, + ttl_seconds: float, +) -> Optional[str]: + if ttl_seconds <= 0 or not path.is_file(): + return None + try: + with open(path, 'r', encoding='utf-8') as f: + data = json.load(f) + created = float(data.get('created_at', 0)) + if time.time() - created > ttl_seconds: + return None + text = data.get('catalog') + return str(text) if text is not None else None + except (OSError, json.JSONDecodeError, TypeError, ValueError): + return None + + +def save_cached_catalog(path: Path, catalog: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + '.tmp') + payload = {'created_at': time.time(), 'catalog': catalog} + with open(tmp, 'w', encoding='utf-8') as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + tmp.replace(path) + + +def _int_from_block(block: Any, key: str, default: int) -> int: + if block is None: + return default + v = block.get(key, default) if hasattr(block, 'get') else getattr( + block, key, default) + try: + return int(v) + except (TypeError, ValueError): + return default + + +def _bool_from_block(block: Any, key: str, default: bool = False) -> bool: + if block is None: + return default + v = block.get(key, default) if hasattr(block, 'get') else getattr( + block, key, default) + if isinstance(v, str): + return v.strip().lower() in ('1', 'true', 'yes', 'on') + return bool(v) + + +def _exclude_from_block(block: Any) -> Optional[List[str]]: + if block is None: + return None + raw = block.get('description_catalog_exclude', None) if hasattr( + block, 'get') else getattr(block, 'description_catalog_exclude', None) + if raw is None: + return None + if isinstance(raw, str): + return [raw] if raw.strip() else None + if isinstance(raw, (list, tuple)): + out = [str(x).strip() for x in raw if str(x).strip()] + return out or None + return None + + +def description_catalog_settings(block: Any) -> Tuple[bool, Dict[str, Any]]: + """Parse ``tools.localsearch`` (or legacy ``knowledge_search``) catalog options.""" + enabled = _bool_from_block(block, 'description_catalog', True) + opts = { + 'max_files': + max( + 1, + min(2000, + _int_from_block(block, 'description_catalog_max_files', 120))), + 'max_depth': + max( + 1, + min(20, _int_from_block(block, 'description_catalog_max_depth', + 5))), + 'max_chars': + max( + 500, + min(100_000, + _int_from_block(block, 'description_catalog_max_chars', + 3_000))), + 'max_preview_chars': + max( + 80, + min( + 4000, + _int_from_block(block, 'description_catalog_max_preview_chars', + 400))), + 'cache_ttl_seconds': + max( + 0, + _int_from_block(block, 'description_catalog_cache_ttl_seconds', + 300)), + 'exclude_extra': + _exclude_from_block(block), + # Files larger than this are skipped during catalog scan (default 50 MB). + # Set to 0 to disable the cap. + 'max_file_size_mb': + max(0, + _int_from_block(block, 'description_catalog_max_file_size_mb', + 50)), + # Wall-clock timeout (seconds) for extracting the first page of an + # oversized PDF. Corrupt or pathological PDFs are abandoned after this + # and only their filename + size appear in the catalog. + 'oversized_pdf_timeout_s': + max( + 0.1, + _int_from_block(block, + 'description_catalog_oversized_pdf_timeout_s', 1)), + } + return enabled, opts + + +_DIR_SKIP: Set[str] = { + '.git', + '.svn', + 'node_modules', + '__pycache__', + '.idea', + '.vscode', + '.cache', + '.tox', + '.eggs', + 'dist', + 'build', + '.DS_Store', +} + +# Directories that typically contain generated/compiled artifacts, not source. +# These are skipped by default to keep the tree focused on meaningful content. +_DIR_SKIP_GENERATED: Set[str] = { + '_build', + '_static', + '_templates', + '_sphinx_design_static', + '__pycache__', + '.mypy_cache', + '.pytest_cache', + '.ruff_cache', + 'htmlcov', + 'site-packages', + 'egg-info', + '.egg-info', +} + + +def _build_dir_tree( + root: Path, + max_depth: int, + exclude: Optional[List[str]], + max_chars: int = 4000, +) -> str: + """Fast filesystem-only directory tree (no file content reads). + + Produces a compact indented listing: + 📁 ms_agent/ + 📁 tools/ + 📁 search/ (8 files) + agent_tool.py base.py ... +3 + __init__.py llm_agent.py ... + + Strategy — two passes: + 1. DFS pre-scan: collect every (depth, line) pair without any char limit. + 2. Breadth-first selection: sort collected lines by depth, take lines from + shallowest levels first until the char budget is exhausted. This + guarantees all top-level directories appear before any second-level + directories are shown, etc. + + The final output is re-sorted by original DFS order so indentation is + visually coherent. + """ + skip = set(_DIR_SKIP) | _DIR_SKIP_GENERATED + if exclude: + skip.update(exclude) + + # --- Pass 1: DFS, collect (depth, seq, line_text) --- + collected: List[tuple] = [] # (depth, seq, line_text) + seq = [0] + + def _dfs(p: Path, depth: int, indent: str) -> None: + if depth > max_depth: + return + try: + entries = sorted( + p.iterdir(), key=lambda e: (e.is_file(), e.name.lower())) + except PermissionError: + return + + dirs = [ + e for e in entries + if e.is_dir() and not e.name.startswith('.') and e.name not in skip + ] + files = [ + e for e in entries if e.is_file() and not e.name.startswith('.') + and e.name not in skip + ] + + child_indent = indent + ' ' + + for d in dirs: + try: + file_count = sum(1 for _ in d.iterdir() + if _.is_file() and not _.name.startswith('.')) + except PermissionError: + file_count = 0 + count_hint = f' ({file_count} files)' if file_count else '' + collected.append( + (depth, seq[0], f'{indent}📁 {d.name}/{count_hint}')) + seq[0] += 1 + _dfs(d, depth + 1, child_indent) + + if files: + MAX_SHOW = 5 + shown = [f.name for f in files[:MAX_SHOW]] + overflow = len(files) - MAX_SHOW + file_line = f'{indent} ' + ' '.join(shown) + if overflow > 0: + file_line += f' … +{overflow}' + # File hint lines get depth + 0.5 so they sort after their parent + # dir line but before the parent's children directories. + collected.append((depth + 0.5, seq[0], file_line)) + seq[0] += 1 + + _dfs(root, 0, '') + + # --- Pass 2: breadth-first selection within char budget --- + # Sort by (depth, seq) to process shallowest lines first. + by_depth = sorted(collected, key=lambda t: (t[0], t[1])) + budget = max_chars - 40 # reserve for truncation note + selected_seqs: set = set() + used = 0 + truncated = False + for depth, s, line in by_depth: + cost = len(line) + 1 + if used + cost > budget: + truncated = True + break + selected_seqs.add(s) + used += cost + + # Re-sort selected lines by original DFS seq to restore correct indentation + output_lines = [line for (_, s, line) in collected if s in selected_seqs] + + result = '\n'.join(output_lines) + if truncated: + result += '\n… (tree truncated — deeper directories omitted)' + return result + + +def _compact_file_summary(candidate: Any, root_dir: str, + max_preview: int) -> str: + """Single-line or two-line compact summary for a FileCandidate. + + Format: + - path/to/file.py (.py, 12KB) — Title or first-line preview + """ + from pathlib import Path as _Path + try: + rel = _Path(candidate.path).relative_to(_Path(root_dir)).as_posix() + except (ValueError, TypeError): + rel = _Path(candidate.path).as_posix() + + size = candidate.size_bytes + if size < 1024: + size_str = f'{size}B' + elif size < 1024 * 1024: + size_str = f'{size / 1024:.0f}KB' + else: + size_str = f'{size / 1024 / 1024:.1f}MB' + + label = candidate.title or '' + if not label and candidate.preview: + # First sentence / line of preview, capped + label = candidate.preview.replace('\n', ' ').strip() + label = label[:max_preview] if label else '' + + base = f'- {rel} ({candidate.extension or "?"}, {size_str})' + return f'{base} — {label}' if label else base + + +async def build_file_catalog_text( + roots: List[str], + *, + max_files: int, + max_depth: int, + max_preview_chars: int, + exclude_extra: Optional[List[str]], + max_file_size_mb: int = 50, + oversized_pdf_timeout_s: float = 1.0, + max_chars: int = 10_000, +) -> str: + """Build a two-section catalog for the localsearch tool description. + + Section 1 — Directory tree (filesystem walk, no IO beyond stat): + Gives the model the full directory structure so it understands where + to look without needing to enumerate every file. Capped at ~60% of + the max_chars budget so that file summaries always get space too. + + Section 2 — File summaries (from DirectoryScanner, capped by max_files): + Compact one-liners: relative path, size, and a short content hint. + Sorted by path so related files appear together. + + The combined output fits within max_chars (the caller may still apply + _truncate_catalog_text for final trimming of the file-summary entries). + """ + try: + from sirchmunk.scan.dir_scanner import DirectoryScanner + except ImportError as e: + raise ImportError('sirchmunk is required for description_catalog. ' + f'Import failed: {e}') from e + + # Tree gets 60% of the budget; file summaries get the remaining 40%. + # Each root shares the budget equally. + num_roots = max(1, len(roots)) + per_root_budget = max(500, max_chars // num_roots) + tree_budget = max(300, int(per_root_budget * 0.60)) + + max_file_size_bytes = (max_file_size_mb * 1024 + * 1024) if max_file_size_mb > 0 else None + # Merge the built-in skip sets with any user-provided excludes so the + # scanner also skips generated/artifact directories. + scanner_exclude: List[str] = sorted(_DIR_SKIP | _DIR_SKIP_GENERATED) + if exclude_extra: + scanner_exclude.extend(exclude_extra) + scanner = DirectoryScanner( + llm=None, + max_depth=max_depth, + max_files=max_files, + max_preview_chars=max_preview_chars, + exclude_patterns=scanner_exclude, + max_file_size_bytes=max_file_size_bytes, + oversized_pdf_timeout_s=oversized_pdf_timeout_s, + ) + + sections: List[str] = [] + for root in roots: + p = Path(root) + if not p.exists(): + sections.append(f'### `{root}`\n(missing on disk)') + continue + + # --- Section 1: directory tree (fast, no content reads) --- + tree = _build_dir_tree( + p, + max_depth=max_depth, + exclude=exclude_extra, + max_chars=tree_budget) + tree_block = f'#### Directory structure of `{p}`\n{tree}' if tree else '' + + # --- Section 2: per-file compact summaries --- + result = await scanner.scan(p) + # Sort first so round-robin within each subdir is deterministic + all_candidates = sorted(result.candidates, key=lambda c: c.path) + # Stratified reorder: root files first, then round-robin across + # subdirectories (smallest subdir first for maximum coverage). + # Always reorder regardless of count — the char-budget trim below + # does the actual limiting. + reordered = _stratified_sample(all_candidates, p, len(all_candidates)) + + # Trim to the files char budget (40% of total per-root budget). + # Trimming the *reordered* list ensures diverse coverage is preserved. + files_budget = max(200, per_root_budget - tree_budget - 80) + file_lines: List[str] = [] + omitted = 0 + if reordered: + used_f = 0 + for c in reordered: + line = _compact_file_summary(c, str(p), max_preview=100) + if used_f + len(line) + 1 > files_budget: + omitted = len(reordered) - len(file_lines) + break + file_lines.append(line) + used_f += len(line) + 1 + else: + file_lines.append('_(no scannable files in depth budget)_') + + total_scanned = len(all_candidates) + header = f'#### File summaries ({len(file_lines)} of {total_scanned} sampled)' + if omitted: + file_block = header + '\n' + '\n'.join( + file_lines) + f'\n… ({omitted} more files not shown)' + else: + file_block = header + '\n' + '\n'.join(file_lines) + + root_section = f'### Under `{p}`\n\n{tree_block}\n\n{file_block}' + sections.append(root_section.strip()) + + return '\n\n---\n\n'.join(sections).strip() + + +def _stratified_sample(candidates: List[Any], root: Path, + max_entries: int) -> List[Any]: + """Return a stratified sample of candidates so every subdirectory gets + representation, rather than simply taking the first N by path. + + Algorithm: + 1. Root-level files first (high information density — README, pyproject, etc.) + 2. One representative per unique immediate subdirectory, round-robin until + the budget is exhausted. Subdirectories with fewer files are serviced + first so smaller modules are not crowded out by large doc/data trees. + + Within each group files are ordered by path so the result is deterministic. + Always reorders — callers rely on this even when len(candidates) == max_entries. + """ + # Split into root-level vs sub-directory files + root_files: List[Any] = [] + by_subdir: dict = {} + for c in candidates: + try: + rel = Path(c.path).relative_to(root) + except ValueError: + rel = Path(c.path) + parts = rel.parts + if len(parts) == 1: + root_files.append(c) + else: + subdir = parts[0] + by_subdir.setdefault(subdir, []).append(c) + + result: List[Any] = [] + + # Always include all root-level files first (usually just a handful) + result.extend(root_files) + + if not by_subdir: + return result[:max_entries] + + # Round-robin across subdirectories. + # Sort subdirs by file count ascending so smaller directories (which are + # likely to have fewer but more targeted files) get their representative + # entry before large documentation/data directories exhaust the budget. + subdirs = sorted(by_subdir.keys(), key=lambda s: len(by_subdir[s])) + subdir_iters = {s: iter(by_subdir[s]) for s in subdirs} + while len(result) < max_entries and subdir_iters: + exhausted = [] + for s in subdirs: + if len(result) >= max_entries: + break + it = subdir_iters.get(s) + if it is None: + continue + try: + result.append(next(it)) + except StopIteration: + exhausted.append(s) + for s in exhausted: + del subdir_iters[s] + if not subdir_iters: + break + + return result[:max_entries] diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py new file mode 100644 index 000000000..f2a2b7b22 --- /dev/null +++ b/ms_agent/tools/search/localsearch_tool.py @@ -0,0 +1,489 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""On-demand local codebase search via sirchmunk (replaces pre-turn RAG injection).""" + +import asyncio +import time +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional + +import json +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.tools.search.localsearch_catalog import ( + build_file_catalog_text, catalog_cache_path, catalog_fingerprint, + description_catalog_settings, load_cached_catalog, save_cached_catalog) +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, effective_localsearch_settings) +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SERVER = 'localsearch' +_TOOL = 'localsearch' + +# Tool-facing description: aligned with sirchmunk AgenticSearch.search() capabilities. +_LOCALSEARCH_DESCRIPTION = """Search local files, codebases, and documents on disk. + +USE THIS TOOL WHEN: +- The user asks about content in local files or directories +- You need to find information in source code, config files, or documents +- The query references a local path, project structure, or codebase +- You need to search PDF, DOCX, XLSX, PPTX, CSV, JSON, YAML, Markdown, etc. +- Large files or directories should be searched by this tool. + + +DO NOT USE THIS TOOL WHEN: +- The user is asking a general knowledge question +- The user is greeting you or making casual conversation (e.g., "你好", "hello") +- You need information from the internet or recent events +- The query has no relation to local files or code + +Returns: +Search results after summarizing as formatted text with file paths, code snippets, and explanations where +available. Retrieved excerpts and meta are included in the tool output. + +Configured search roots for this agent (absolute paths; default search scope when `paths` is omitted): +{configured_roots} + +{file_catalog_section} +""" + + +def _resolved_localsearch_paths_from_config(config) -> List[str]: + """Match ``SirchmunkSearch`` path resolution for consistent tool text and checks.""" + block = effective_localsearch_settings(config) + if not block: + return [] + paths = block.get('paths', []) + if isinstance(paths, str): + paths = [paths] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(Path(str(p).strip()).expanduser().resolve())) + return out + + +def _work_path_from_config(config) -> Path: + block = effective_localsearch_settings(config) + if not block: + return Path('.sirchmunk').expanduser().resolve() + wp = block.get('work_path', './.sirchmunk') if hasattr( + block, 'get') else getattr(block, 'work_path', './.sirchmunk') + return Path(str(wp)).expanduser().resolve() + + +def _truncate_catalog_text(text: str, max_chars: int) -> str: + """Truncate catalog text to ``max_chars``, preserving the directory tree section + and truncating the file-summary section on entry boundaries. + + The catalog has two sections: + 1. Directory structure (``#### Directory structure of ...``) + 2. File summaries (``#### File summaries ...``) — entries start with ``- `` + + Strategy: always keep the full directory tree (it fits in a few hundred chars + per root); truncate only the file-summary entries within the remaining budget. + """ + if max_chars <= 0 or len(text) <= max_chars: + return text + + import re + + # Locate the first file-summary section header. + summary_header_m = re.search(r'^#### File summaries', text, re.MULTILINE) + if summary_header_m is None: + # No file summaries — just hard-truncate. + return text[:max_chars - 24].rstrip() + '\n\n… (truncated)' + + # Everything up to and including the file-summary header line is the + # "prefix" we always keep. + header_section_end = text.find('\n', summary_header_m.end()) + if header_section_end == -1: + return text[:max_chars - 24].rstrip() + '\n\n… (truncated)' + + prefix = text[:header_section_end + 1] + body = text[header_section_end + 1:] + + # Split body into individual entry lines (each starts with "- "). + parts = re.split(r'(?=^- )', body, flags=re.MULTILINE) + parts = [p for p in parts if p.strip()] + + budget = max_chars - len(prefix) - 50 # reserve space for trailing note + kept: list[str] = [] + used = 0 + for part in parts: + if used + len(part) > budget: + break + kept.append(part) + used += len(part) + + omitted = len(parts) - len(kept) + suffix = f'\n… ({omitted} more files not shown)' if omitted > 0 else '' + return prefix + ''.join(kept).rstrip() + suffix + + +def _format_configured_roots(paths: List[str]) -> str: + if not paths: + return ('(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') + return '\n'.join(f'- {p}' for p in paths) + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2) + + +def _as_str_list(value: Any, name: str) -> Optional[List[str]]: + if value is None: + return None + if isinstance(value, str): + return [value] if value.strip() else None + if isinstance(value, list): + out = [str(x).strip() for x in value if str(x).strip()] + return out or None + raise TypeError(f'{name} must be a string or list of strings') + + +class LocalSearchTool(ToolBase): + """Expose sirchmunk as a callable tool when ``tools.localsearch`` is configured.""" + + def __init__(self, config, **kwargs): + super().__init__(config) + tools_root = getattr(config, 'tools', None) + tool_cfg = getattr(tools_root, 'localsearch', + None) if tools_root else None + if tool_cfg is not None: + self.exclude_func(tool_cfg) + self._searcher: Optional[SirchmunkSearch] = None + self._configured_roots: List[str] = ( + _resolved_localsearch_paths_from_config(config)) + block = effective_localsearch_settings(config) + self._catalog_enabled, self._catalog_opts = description_catalog_settings( + block) + self._work_path = _work_path_from_config(config) + self._catalog_text: str = '' + self._catalog_build_error: Optional[str] = None + + def _file_catalog_section(self) -> str: + if not self._catalog_enabled: + return '' + err = self._catalog_build_error + if err: + return ('\n\n## Local knowledge catalog\n' + f'_(Catalog build failed: {err})_\n') + body = (self._catalog_text or '').strip() + if not body: + return ('\n\n## Local knowledge catalog\n' + '_(No scannable files or catalog empty.)_\n') + shown = _truncate_catalog_text(body, self._catalog_opts['max_chars']) + return ( + '\n\n## Local knowledge catalog (shallow scan)\n' + 'Brief previews of files under the configured roots; call this tool ' + 'with a `query` for full search.\n\n' + shown + '\n') + + def _tool_description(self) -> str: + return _LOCALSEARCH_DESCRIPTION.format( + configured_roots=_format_configured_roots(self._configured_roots), + file_catalog_section=self._file_catalog_section()) + + def _paths_param_description(self) -> str: + roots = _format_configured_roots(self._configured_roots) + return ( + 'Optional. Narrow search to specific files or directories under the ' + 'configured roots below. Each path must exist on disk and lie under ' + 'one of these roots (or be exactly one of them).\n' + f'Configured roots:\n{roots}') + + def _ensure_searcher(self) -> SirchmunkSearch: + if self._searcher is None: + self._searcher = SirchmunkSearch(self.config) + return self._searcher + + async def connect(self) -> None: + self._catalog_build_error = None + self._catalog_text = '' + if not self._catalog_enabled: + return + roots = [r for r in self._configured_roots if r] + if not roots: + self._catalog_build_error = 'no configured roots' + return + o = self._catalog_opts + fp = catalog_fingerprint( + roots, + o['max_files'], + o['max_depth'], + o['max_preview_chars'], + o['max_chars'], + o['exclude_extra'], + ) + cache_path = catalog_cache_path(self._work_path, fp) + ttl = float(o['cache_ttl_seconds']) + t0 = time.monotonic() + cached = load_cached_catalog(cache_path, ttl) + if cached is not None: + elapsed = time.monotonic() - t0 + self._catalog_text = cached + logger.info( + f'localsearch catalog: loaded from cache in {elapsed:.3f}s ' + f'({len(cached)} chars) roots={roots}') + return + try: + built = await build_file_catalog_text( + roots, + max_files=o['max_files'], + max_depth=o['max_depth'], + max_preview_chars=o['max_preview_chars'], + exclude_extra=o['exclude_extra'], + max_file_size_mb=o['max_file_size_mb'], + oversized_pdf_timeout_s=o['oversized_pdf_timeout_s'], + max_chars=o['max_chars'], + ) + elapsed = time.monotonic() - t0 + self._catalog_text = built + logger.info(f'localsearch catalog: scanned in {elapsed:.3f}s ' + f'({len(built)} chars) roots={roots}') + if ttl > 0 and built.strip(): + try: + save_cached_catalog(cache_path, built) + except OSError as exc: + logger.debug( + f'localsearch catalog cache write failed: {exc}') + except ImportError as exc: + elapsed = time.monotonic() - t0 + self._catalog_build_error = str(exc) + logger.warning( + f'localsearch description_catalog ({elapsed:.3f}s): {exc}') + except Exception as exc: + elapsed = time.monotonic() - t0 + self._catalog_build_error = str(exc) + logger.warning( + f'localsearch description_catalog scan failed ({elapsed:.3f}s): {exc}' + ) + + async def _get_tools_inner(self) -> Dict[str, List[Tool]]: + return { + _SERVER: [ + Tool( + tool_name=_TOOL, + server_name=_SERVER, + description=self._tool_description(), + parameters={ + 'type': 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'Search keywords or natural-language question about local content.', + }, + 'paths': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': self._paths_param_description(), + }, + 'mode': { + 'type': + 'string', + 'enum': ['FAST', 'DEEP', 'FILENAME_ONLY'], + 'description': + 'Search mode; omit to use agent default (usually FAST).', + }, + 'max_depth': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, + 'description': + 'Max directory depth for filesystem search.', + }, + 'top_k_files': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, + 'description': + 'Max files for evidence / filename hits.', + }, + 'include': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to include (e.g. *.py, *.md).', + }, + 'exclude': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to exclude (e.g. *.pyc).', + }, + }, + 'required': ['query'], + }, + ) + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict): + del server_name + if tool_name != _TOOL: + return f'Unknown tool: {tool_name}' + + args = tool_args or {} + query = str(args.get('query', '')).strip() + if not query: + return 'Error: `query` is required and cannot be empty.' + + try: + paths_arg = _as_str_list(args.get('paths'), 'paths') + mode = args.get('mode') + if mode is not None: + mode = str(mode).strip().upper() or None + + max_depth = args.get('max_depth') + if max_depth is not None: + max_depth = int(max_depth) + max_depth = max(1, min(20, max_depth)) + + top_k = args.get('top_k_files') + if top_k is not None: + top_k = int(top_k) + top_k = max(1, min(20, top_k)) + + include = _as_str_list(args.get('include'), 'include') + exclude = _as_str_list(args.get('exclude'), 'exclude') + + searcher = self._ensure_searcher() + resolved_paths = None + if paths_arg: + resolved_paths = searcher.resolve_tool_paths(paths_arg) + if not resolved_paths: + roots = _format_configured_roots(self._configured_roots) + return ( + 'Error: `paths` are invalid. Each path must exist on disk and lie ' + 'under one of these configured roots:\n' + roots) + + answer = await searcher.query( + query, + paths=resolved_paths, + mode=mode, + max_depth=max_depth, + top_k_files=top_k, + include=include, + exclude=exclude, + ) + details = searcher.get_search_details() + excerpts = searcher.get_last_retrieved_chunks() + + lines = ['## Local search (sirchmunk)', '', str(answer), ''] + + if excerpts: + lines.append('## Retrieved excerpts') + lines.append('') + for i, item in enumerate(excerpts[:12], 1): + meta = item.get('metadata') or {} + src = meta.get('source', '?') + text = (item.get('text') or '')[:4000] + lines.append(f'### [{i}] {src}') + lines.append(text) + lines.append('') + + summary = { + 'mode': details.get('mode'), + 'paths': details.get('paths'), + 'work_path': details.get('work_path'), + 'cluster_cache_hit': details.get('cluster_cache_hit'), + } + lines.append('## Meta') + lines.append(_json_dumps(summary)) + + full_text = '\n'.join(lines) + # Model sees answer + source paths only; UI gets full excerpts + meta. + result_parts = [str(answer).strip()] + if excerpts: + result_parts.append('\nSource paths:') + for item in excerpts[:12]: + meta = item.get('metadata') or {} + result_parts.append(f'- {meta.get("source", "?")}') + result_text = '\n'.join(result_parts) + + return { + 'result': result_text, + 'tool_detail': full_text, + } + except (TypeError, ValueError) as exc: + return f'Invalid tool arguments: {exc}' + except Exception as exc: + logger.warning(f'localsearch failed: {exc}') + return f'Local search failed: {exc}' + + async def call_tool_streaming(self, server_name: str, *, tool_name: str, + tool_args: dict): + """Streaming variant: yield log lines while searching, then yield final result. + + Intermediate yields are plain strings (log lines). + The final yield is the result dict (or error string) from call_tool. + + Timeout semantics: the caller should treat the absence of any yield + within 30 s as a hang and cancel the task. + """ + log_queue: asyncio.Queue = asyncio.Queue() + + # Register the streaming callback on the searcher so sirchmunk pushes + # log lines into our queue as they are emitted. + async def _on_log(entry: str): + await log_queue.put(entry) + + # We need the searcher to exist before we can register the callback. + # _ensure_searcher() is synchronous and cheap if already initialized. + try: + searcher = self._ensure_searcher() + searcher.enable_streaming_logs(_on_log) + except Exception as exc: + yield f'Local search failed: {exc}' + return + + # Sentinel placed in the queue by the search coroutine when done. + _DONE = object() + + async def _run_search(): + try: + result = await self.call_tool( + server_name, tool_name=tool_name, tool_args=tool_args) + except Exception as exc: + result = f'Local search failed: {exc}' + await log_queue.put(_DONE) + await log_queue.put(result) + + search_task = asyncio.create_task(_run_search()) + + try: + while True: + item = await log_queue.get() + if item is _DONE: + # Next item is the final result. + final = await log_queue.get() + yield final + break + # Intermediate log line. + yield item + finally: + search_task.cancel() + try: + await search_task + except asyncio.CancelledError: + pass diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py similarity index 64% rename from ms_agent/knowledge_search/sirchmunk_search.py rename to ms_agent/tools/search/sirchmunk_search.py index e1c76181f..7aa09b5e6 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -1,51 +1,83 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Sirchmunk-based knowledge search integration. +"""Sirchmunk backend for the ``localsearch`` tool. -This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, -providing document retrieval capabilities similar to RAG but optimized for -codebase and documentation search. +Configuration lives under ``tools.localsearch`` (same namespace as other tools). +Legacy top-level ``knowledge_search`` is still accepted for backward compatibility. """ import asyncio from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional +import json from loguru import logger -from ms_agent.rag.base import RAG from omegaconf import DictConfig -class SirchmunkSearch(RAG): - """Sirchmunk-based knowledge search class. +def _paths_from_block(block: Any) -> List[str]: + if block is None: + return [] + paths = block.get('paths', []) if hasattr(block, 'get') else [] + if isinstance(paths, str): + paths = [paths] if str(paths).strip() else [] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(p).strip()) + return out - This class wraps the sirchmunk library to provide intelligent codebase search - capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk - uses a combination of keyword search, semantic clustering, and LLM-powered - analysis to find relevant information from codebases. - The configuration needed in the config yaml: - - name: SirchmunkSearch - - paths: List of paths to search, required - - work_path: Working directory for sirchmunk cache, default './.sirchmunk' - - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' - - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 - - cluster_sim_top_k: Top K clusters to consider, default 3 - - reuse_knowledge: Whether to reuse previous search results, default True - - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' +def effective_localsearch_settings(config: DictConfig) -> Optional[Any]: + """Resolve the active localsearch / sirchmunk settings node. + + Precedence: ``tools.localsearch`` with non-empty ``paths``, else legacy + ``knowledge_search`` with non-empty ``paths``. Returns ``None`` if local + search is not configured. + """ + tools = getattr(config, 'tools', None) + tl = None + if tools is not None: + tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr( + tools, 'localsearch', None) + ks = getattr(config, 'knowledge_search', None) + + if tl is not None and _paths_from_block(tl): + return tl + if ks is not None and _paths_from_block(ks): + return ks + return None + + +class SirchmunkSearch: + """Sirchmunk-based local search (used by :class:`LocalSearchTool`). + + Configure in yaml under ``tools.localsearch`` (recommended), for example:: + + tools: + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + embedding_model: text-embedding-3-small + cluster_sim_threshold: 0.85 + cluster_sim_top_k: 3 + reuse_knowledge: true + mode: FAST + + Legacy: the same keys may be placed under top-level ``knowledge_search``. Args: - config (DictConfig): Configuration object containing sirchmunk settings. + config: Full agent config; sirchmunk options read from the effective + block returned by :func:`effective_localsearch_settings`. """ def __init__(self, config: DictConfig): - super().__init__(config) - self._validate_config(config) + rag_config = effective_localsearch_settings(config) + assert rag_config is not None - # Extract configuration parameters - rag_config = config.get('knowledge_search', {}) - - # Search paths - required paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] @@ -53,11 +85,9 @@ def __init__(self, config: DictConfig): str(Path(p).expanduser().resolve()) for p in paths ] - # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') self.work_path: Path = Path(_work_path).expanduser().resolve() - # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) @@ -66,13 +96,10 @@ def __init__(self, config: DictConfig): self.max_loops = rag_config.get('max_loops', 10) self.max_token_budget = rag_config.get('max_token_budget', 128000) - # LLM configuration for sirchmunk - # First try knowledge_search.llm_api_key, then fall back to main llm config self.llm_api_key = rag_config.get('llm_api_key', None) self.llm_base_url = rag_config.get('llm_base_url', None) self.llm_model_name = rag_config.get('llm_model_name', None) - # Fall back to main llm config if not specified in knowledge_search if (self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None): llm_config = config.get('llm', {}) @@ -87,39 +114,57 @@ def __init__(self, config: DictConfig): if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) - # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( 'embedding_model_cache_dir', None) - # Runtime state self._searcher = None self._initialized = False self._cluster_cache_hit = False self._cluster_cache_hit_time: str | None = None self._last_search_result: List[Dict[str, Any]] | None = None - # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] - # Async queue for streaming logs in real-time self._log_queue: asyncio.Queue | None = None self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): - """Validate configuration parameters.""" - if not hasattr(config, - 'knowledge_search') or config.knowledge_search is None: + block = effective_localsearch_settings(config) + if block is None: raise ValueError( - 'Missing knowledge_search configuration. ' - 'Please add knowledge_search section to your config with at least "paths" specified.' - ) - - rag_config = config.knowledge_search - paths = rag_config.get('paths', []) + 'Missing localsearch configuration. Add ' + '`tools.localsearch` with non-empty `paths` (or legacy ' + '`knowledge_search.paths`).') + paths = _paths_from_block(block) if not paths: raise ValueError( - 'knowledge_search.paths must be specified and non-empty') + 'tools.localsearch.paths (or legacy knowledge_search.paths) ' + 'must be specified and non-empty') + + def resolve_tool_paths(self, + paths: Optional[List[str]]) -> Optional[List[str]]: + """Restrict per-call paths to configured search roots.""" + if not paths: + return None + roots = [Path(p).resolve() for p in self.search_paths] + cleaned: List[str] = [] + for raw in paths: + if raw is None or not str(raw).strip(): + continue + p = Path(str(raw).strip()).expanduser().resolve() + if not p.exists(): + logger.warning( + f'localsearch: path does not exist, skipped: {p}') + continue + allowed = any(p == r or p.is_relative_to(r) for r in roots) + if not allowed: + logger.warning( + f'localsearch: path outside configured search roots, ' + f'skipped: {p}') + continue + cleaned.append(str(p)) + return cleaned or None def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -131,7 +176,6 @@ def _initialize_searcher(self): from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil - # Create LLM client llm = OpenAIChat( api_key=self.llm_api_key, base_url=self.llm_base_url, @@ -140,8 +184,6 @@ def _initialize_searcher(self): log_callback=self._log_callback_wrapper(), ) - # Create embedding util - # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( @@ -150,7 +192,6 @@ def _initialize_searcher(self): embedding = EmbeddingUtil( model_id=embedding_model_id, cache_dir=embedding_cache_dir) - # Create AgenticSearch instance self._searcher = AgenticSearch( llm=llm, embedding=embedding, @@ -191,7 +232,6 @@ def log_callback( ): log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) - # Stream log in real-time if streaming callback is set if self._streaming_callback: asyncio.create_task(self._streaming_callback(log_entry)) @@ -215,7 +255,6 @@ async def add_documents(self, documents: List[str]) -> bool: 'SirchmunkSearch does not support direct document addition. ' 'Documents should be saved to files within the configured search paths.' ) - # Trigger re-scan of the search paths if self._searcher and hasattr(self._searcher, 'knowledge_base'): try: await self._searcher.knowledge_base.refresh() @@ -274,7 +313,6 @@ async def retrieve(self, max_token_budget = filters.get('max_token_budget', self.max_token_budget) - # Perform search result = await self._searcher.search( query=query, mode=mode, @@ -283,78 +321,130 @@ async def retrieve(self, return_context=True, ) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - # If a similar cluster was found and reused, it's a cache hit - self._cluster_cache_hit = getattr(result.cluster, - '_reused_from_cache', False) - # Get the cluster cache hit time if available - if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit = any( + 'Found similar cluster' in entry + or 'Reused existing knowledge cluster' in entry + for entry in self._search_logs) + if hasattr(result.cluster, 'last_modified'): self._cluster_cache_hit_time = getattr( - result.cluster, 'updated_at', None) + result.cluster, 'last_modified', None) - # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] - async def query(self, query: str) -> str: - """Query sirchmunk and return a synthesized answer. - - This method performs a search and returns the LLM-synthesized answer - along with search details that can be used for frontend display. + async def query( + self, + query: str, + *, + paths: Optional[List[str]] = None, + mode: Optional[str] = None, + max_depth: Optional[int] = None, + top_k_files: Optional[int] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> str: + """Query sirchmunk and return a synthesized answer (or filename hits). + + Optional arguments are forwarded to ``AgenticSearch.search`` where supported. + ``paths`` must already be restricted to configured search roots (see + :meth:`resolve_tool_paths`). Args: - query (str): The search query. + query: The search query. + paths: Override search roots (subset of configured paths), or None. + mode: ``FAST``, ``DEEP``, or ``FILENAME_ONLY``; None uses config default. + max_depth: Directory depth cap for filesystem search. + top_k_files: Max files for evidence / filename ranking. + include: Glob patterns to include (e.g. ``*.py``). + exclude: Glob patterns to exclude (e.g. ``node_modules``). Returns: - str: The synthesized answer from sirchmunk. + Answer string, or JSON string for ``FILENAME_ONLY`` list results. """ self._initialize_searcher() self._search_logs.clear() try: - mode = self.search_mode - max_loops = self.max_loops - max_token_budget = self.max_token_budget - - # Single search with context so we get both the synthesized answer and - # source units in one call, avoiding a redundant second search. - result = await self._searcher.search( + mode_eff = mode if mode is not None else self.search_mode + if isinstance(mode_eff, str): + mode_eff = mode_eff.strip().upper() + allowed_modes = ('FAST', 'DEEP', 'FILENAME_ONLY') + if mode_eff not in allowed_modes: + return ( + f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.') + + kw: Dict[str, Any] = dict( query=query, - mode=mode, - max_loops=max_loops, - max_token_budget=max_token_budget, + paths=paths, + mode=mode_eff, + max_loops=self.max_loops, + max_token_budget=self.max_token_budget, return_context=True, ) + if max_depth is not None: + kw['max_depth'] = max_depth + if top_k_files is not None: + kw['top_k_files'] = top_k_files + if include is not None: + kw['include'] = include + if exclude is not None: + kw['exclude'] = exclude + + result = await self._searcher.search(**kw) + + if isinstance(result, list): + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + self._last_search_result = [] + for item in result[:20]: + if isinstance(item, dict): + src = ( + item.get('path') or item.get('file_path') + or item.get('file') or '') + self._last_search_result.append({ + 'text': + json.dumps(item, ensure_ascii=False), + 'score': + 1.0, + 'metadata': { + 'source': str(src), + 'type': 'filename_match', + }, + }) + return json.dumps(result, ensure_ascii=False, indent=2) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, - '_reused_from_cache', False) - if hasattr(result.cluster, 'updated_at'): + # Detect cluster reuse from search logs: sirchmunk emits + # "[SUCCESS] Found similar cluster: ..." or + # "[SUCCESS] Reused existing knowledge cluster" when a cached + # cluster is reused. KnowledgeCluster has no _reused_from_cache + # attribute, so log-based detection is the correct approach. + self._cluster_cache_hit = any( + 'Found similar cluster' in entry + or 'Reused existing knowledge cluster' in entry + for entry in self._search_logs) + if hasattr(result.cluster, 'last_modified'): self._cluster_cache_hit_time = getattr( - result.cluster, 'updated_at', None) + result.cluster, 'last_modified', None) - # Store parsed context for frontend display self._last_search_result = self._parse_search_result( result, score_threshold=0.7, limit=5) - # Extract the synthesized answer from the context result - if hasattr(result, 'answer'): + if hasattr(result, 'answer') and getattr(result, 'answer', + None) is not None: return result.answer - # If result is already a plain string (some modes return str directly) if isinstance(result, str): return result - # Fallback: convert to string return str(result) except Exception as e: @@ -375,14 +465,11 @@ def _parse_search_result(self, result: Any, score_threshold: float, """ results = [] - # Handle SearchContext format (returned when return_context=True) if hasattr(result, 'cluster') and result.cluster is not None: cluster = result.cluster for unit in cluster.evidences: - # Extract score from snippets if available score = getattr(cluster, 'confidence', 1.0) if score >= score_threshold: - # Extract text from snippets text_parts = [] source = str(getattr(unit, 'file_or_url', 'unknown')) for snippet in getattr(unit, 'snippets', []): @@ -406,7 +493,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) @@ -423,7 +509,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle list format elif isinstance(result, list): for item in result: if isinstance(item, dict): @@ -438,7 +523,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, item.get('metadata', {}), }) - # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: @@ -451,10 +535,13 @@ def _parse_search_result(self, result: Any, score_threshold: float, result.get('metadata', {}), }) - # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) return results[:limit] + def get_last_retrieved_chunks(self) -> List[Dict[str, Any]]: + """Parsed evidence chunks from the last `query` or `retrieve` call.""" + return list(self._last_search_result or []) + def get_search_logs(self) -> List[str]: """Get the captured search logs. diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..7a43a154e 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -7,7 +7,7 @@ import uuid from copy import copy from types import TracebackType -from typing import Any, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional import json from ms_agent.llm.utils import Tool, ToolCall @@ -17,6 +17,9 @@ from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient +from ms_agent.tools.search.localsearch_tool import LocalSearchTool +from ms_agent.tools.search.sirchmunk_search import \ + effective_localsearch_settings from ms_agent.tools.search.websearch_tool import WebSearchTool from ms_agent.tools.split_task import SplitTask from ms_agent.tools.todolist_tool import TodoListTool @@ -88,6 +91,8 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + if effective_localsearch_settings(config) is not None: + self.extra_tools.append(LocalSearchTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, @@ -243,11 +248,139 @@ async def single_call_tool(self, tool_info: ToolCall): logger.warning(traceback.format_exc()) return f'Tool calling failed: {brief_info}, details: {str(e)}' + async def single_call_tool_streaming( + self, tool_info: ToolCall) -> AsyncGenerator: + """Streaming variant of single_call_tool. + + Yields (tool_call_id, item, is_final) triples: + - is_final False: intermediate log line (str). + - is_final True: final tool result (same shape as single_call_tool), + including error/timeout strings. + + Callers must use is_final to tell logs from short string results (e.g. + \"OK\"); content-based heuristics are unreliable. + + Timeout: if no yield arrives within self.tool_call_timeout seconds the + tool is considered hung and a timeout error string is yielded as the + final result. + """ + if self._concurrent_limiter is None: + if self._init_lock is None: + self._init_lock = asyncio.Lock() + async with self._init_lock: + if self._concurrent_limiter is None: + self._concurrent_limiter = asyncio.Semaphore( + MAX_CONCURRENT_TOOLS) + + call_id = tool_info.get('id', '') + brief_info = json.dumps(tool_info, ensure_ascii=False) + if len(brief_info) > 1024: + brief_info = brief_info[:1024] + '...' + + async with self._concurrent_limiter: + try: + tool_name = tool_info['tool_name'] + tool_args = tool_info['arguments'] + while isinstance(tool_args, str): + try: + tool_args = json.loads(tool_args) + except Exception: # noqa + yield ( + call_id, + f'The input {tool_args} is not a valid JSON, fix your arguments and try again', + True) + return + assert tool_name in self._tool_index, \ + f'Tool name {tool_name} not found' + tool_ins, server_name, _ = self._tool_index[tool_name] + call_args = tool_args + if isinstance(tool_ins, AgentTool): + call_args = dict(tool_args or {}) + call_args['__call_id'] = call_id or str(uuid.uuid4()) + + # Use the streaming variant; default impl wraps call_tool. + gen = tool_ins.call_tool_streaming( + server_name, + tool_name=tool_name.split(self.TOOL_SPLITER)[1], + tool_args=call_args) + + # Enforce per-item timeout: if no item arrives within + # tool_call_timeout seconds the tool is considered hung. + timeout = self.tool_call_timeout + last_item = None + while True: + try: + item = await asyncio.wait_for( + gen.__anext__(), timeout=timeout) + except StopAsyncIteration: + # Generator exhausted normally; last_item is the result. + if last_item is not None: + yield call_id, last_item, True + return + except asyncio.TimeoutError: + import traceback + logger.warning(traceback.format_exc()) + yield call_id, f'Execute tool call timeout: {brief_info}', True + return + except Exception as e: + import traceback + logger.warning(traceback.format_exc()) + yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}', True + return + + if last_item is not None: + # Emit previous item as an intermediate log. + yield call_id, last_item, False + last_item = item + + # Once we have the first yield (any kind), relax the + # per-item timeout to avoid cutting off long-running but + # active searches. + timeout = max(timeout, 120) + + except Exception as e: + import traceback + logger.warning(traceback.format_exc()) + yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}', True + async def parallel_call_tool(self, tool_list: List[ToolCall]): tasks = [self.single_call_tool(tool) for tool in tool_list] result = await asyncio.gather(*tasks) return result + async def parallel_call_tool_streaming( + self, tool_list: List[ToolCall]) -> AsyncGenerator: + """Run all tools concurrently; yield (call_id, item, is_final) as they arrive. + + Items are interleaved in arrival order. The caller must track call_id + to associate intermediate logs with their tool. is_final distinguishes + the last result for that tool from streaming log lines. + """ + # Shared queue: producers push (call_id, item, is_final); sentinel signals done. + queue: asyncio.Queue = asyncio.Queue() + _DONE = object() + + async def _producer(tool_info: ToolCall): + async for call_id, item, is_final in self.single_call_tool_streaming( + tool_info): + await queue.put((call_id, item, is_final)) + await queue.put(_DONE) + + tasks = [asyncio.create_task(_producer(t)) for t in tool_list] + remaining = len(tasks) + + try: + while remaining > 0: + item = await queue.get() + if item is _DONE: + remaining -= 1 + else: + yield item + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + async def __aenter__(self) -> 'ToolManager': return self diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py index 5a4f43213..9fd1737eb 100644 --- a/tests/knowledge_search/test_sirschmunk.py +++ b/tests/knowledge_search/test_sirschmunk.py @@ -1,23 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. +"""Tests for SirchmunkSearch and localsearch tool integration. -These tests verify the sirchmunk-based knowledge search functionality -through the LLMAgent entry point, including verification that -search_result and searching_detail fields are properly populated. - -To run these tests, you need to set the following environment variables: - - TEST_LLM_API_KEY: Your LLM API key - - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) - - TEST_LLM_MODEL_NAME: Your LLM model name (optional) - - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) - - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) - -Example: +Example (full sirchmunk run): export TEST_LLM_API_KEY="your-api-key" - export TEST_LLM_BASE_URL="https://api.openai.com/v1" - export TEST_LLM_MODEL_NAME="gpt-4o" - export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" - export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" python -m pytest tests/knowledge_search/test_sirschmunk.py """ import asyncio @@ -26,92 +11,49 @@ import unittest from pathlib import Path -from ms_agent.knowledge_search import SirchmunkSearch -from ms_agent.agent import LLMAgent -from ms_agent.config import Config -from omegaconf import DictConfig - -from modelscope.utils.test_utils import test_level +def _sirchmunk_dir_scanner_available() -> bool: + try: + import sirchmunk.scan.dir_scanner # noqa: F401 + return True + except ImportError: + return False -class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): - """Test cases for SirchmunkSearch integration with LLMAgent. +from ms_agent.agent import LLMAgent +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch +from ms_agent.llm.utils import Message +from ms_agent.tools.tool_manager import ToolManager +from omegaconf import DictConfig - These tests verify that when LLMAgent runs a query that triggers - knowledge search, the Message objects have search_result and - searching_detail fields properly populated. - """ +class SirchmunkKnowledgeSearchTest(unittest.TestCase): + """Sirchmunk config, ToolManager registration""" @classmethod def setUpClass(cls): - """Set up test fixtures.""" - # Create test directory with sample files cls.test_dir = Path('./test_llm_agent_knowledge') cls.test_dir.mkdir(exist_ok=True) - - # Create sample documentation - (cls.test_dir / 'README.md').write_text(''' -# Test Project Documentation - -## Overview -This is a test project for knowledge search integration. - -## API Reference - -### UserManager -The UserManager class handles user operations: -- create_user: Create a new user account -- delete_user: Delete an existing user -- update_user: Update user information -- get_user: Retrieve user details - -### AuthService -The AuthService class handles authentication: -- login: Authenticate user credentials -- logout: End user session -- refresh_token: Refresh authentication token -- verify_token: Validate authentication token -''') - - (cls.test_dir / 'config.py').write_text(''' -"""Configuration module.""" - -class Config: - """Application configuration.""" - - def __init__(self): - self.database_url = "postgresql://localhost:5432/mydb" - self.secret_key = "your-secret-key" - self.debug_mode = False - - def load_from_env(self): - """Load configuration from environment variables.""" - import os - self.database_url = os.getenv("DATABASE_URL", self.database_url) - self.secret_key = os.getenv("SECRET_KEY", self.secret_key) - return self -''') + (cls.test_dir / 'README.md').write_text( + '# Demo\n\nUserManager.create_user creates a user.\n') @classmethod def tearDownClass(cls): - """Clean up test fixtures.""" if cls.test_dir.exists(): shutil.rmtree(cls.test_dir, ignore_errors=True) work_dir = Path('./.sirchmunk') if work_dir.exists(): shutil.rmtree(work_dir, ignore_errors=True) - def _get_agent_config(self): - """Create agent configuration with knowledge search.""" + def _base_config(self) -> DictConfig: llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') - llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', + 'https://api.openai.com/v1') llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') - # Read from TEST_* env vars (for test-specific config) - # These can be set from .env file which uses TEST_* prefix embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') - embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') - - config = DictConfig({ + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', + '') + return DictConfig({ + 'output_dir': + './outputs_knowledge_test', 'llm': { 'service': 'openai', 'model': llm_model_name, @@ -122,81 +64,89 @@ def _get_agent_config(self): 'temperature': 0.3, 'max_tokens': 500, }, - 'knowledge_search': { - 'name': 'SirchmunkSearch', - 'paths': [str(self.test_dir)], - 'work_path': './.sirchmunk', - 'llm_api_key': llm_api_key, - 'llm_base_url': llm_base_url, - 'llm_model_name': llm_model_name, - 'embedding_model': embedding_model_id, - 'embedding_model_cache_dir': embedding_model_cache_dir, - 'mode': 'FAST', - } + 'tools': { + 'localsearch': { + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + }, + }, }) - return config - - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_llm_agent_with_knowledge_search(self): - """Test LLMAgent using knowledge search. - - This test verifies that: - 1. LLMAgent can be initialized with SirchmunkSearch configuration - 2. Running a query produces a valid response - 3. User message has searching_detail and search_result populated - 4. searching_detail contains expected keys (logs, mode, paths) - 5. search_result is a list - """ - config = self._get_agent_config() - agent = LLMAgent(config=config, tag='test-knowledge-agent') - - # Test query that should trigger knowledge search - query = 'How do I use UserManager to create a user?' - async def run_agent(): - result = await agent.run(query) - return result - - result = asyncio.run(run_agent()) - - # Verify result - self.assertIsNotNone(result) - self.assertIsInstance(result, list) - self.assertTrue(len(result) > 0) - - # Check that assistant message exists - assistant_message = [m for m in result if m.role == 'assistant'] - self.assertTrue(len(assistant_message) > 0) - - # Check that user message has search_result and searching_detail populated - user_messages = [m for m in result if m.role == 'user'] - self.assertTrue(len(user_messages) > 0, "Expected at least one user message") - - # The first user message should have search details after do_rag processing - user_msg = user_messages[0] - self.assertTrue( - hasattr(user_msg, 'searching_detail'), - "User message should have searching_detail attribute" - ) + def test_does_not_inject_knowledge_search(self): + """Local sirchmunk search is no longer merged into the user message here.""" + config = self._base_config() + agent = LLMAgent(config=config, tag='test-knowledge-agent') + original = 'How do I use UserManager?' + + async def run(): + messages = [ + Message(role='system', content='You are a helper.'), + Message(role='user', content=original), + ] + messages = await agent.run(messages) + return messages + + messages = asyncio.run(run()) + print(f'messages: {messages}') + + def test_tool_manager_registers_localsearch(self): + """When tools.localsearch.paths is set, ToolManager exposes localsearch.""" + + async def run(): + config = self._base_config() + tm = ToolManager(config, trust_remote_code=False) + await tm.connect() + tools = await tm.get_tools() + await tm.cleanup() + return tools + + tools = asyncio.run(run()) + names = [t['tool_name'] for t in tools] self.assertTrue( - hasattr(user_msg, 'search_result'), - "User message should have search_result attribute" + any(n.endswith('localsearch') for n in names), + f'Expected localsearch in tools, got: {names}', ) - # Check that searching_detail is a dict with expected keys - self.assertIsInstance( - user_msg.searching_detail, dict, - "searching_detail should be a dictionary" - ) - self.assertIn('logs', user_msg.searching_detail) - self.assertIn('mode', user_msg.searching_detail) - self.assertIn('paths', user_msg.searching_detail) - - # Check that search_result is a list (may be empty if no relevant docs found) - self.assertIsInstance( - user_msg.search_result, list, - "search_result should be a list" - ) + @unittest.skipUnless( + _sirchmunk_dir_scanner_available(), + 'sirchmunk scan not installed', + ) + def test_localsearch_description_catalog_injects_file_preview(self): + """Optional: shallow DirectoryScanner summaries appear in tool description.""" + + async def run(): + config = self._base_config() + config.tools.localsearch['description_catalog'] = True + config.tools.localsearch['description_catalog_cache_ttl_seconds'] = 0 + tm = ToolManager(config, trust_remote_code=False) + await tm.connect() + tools = await tm.get_tools() + await tm.cleanup() + return tools + + tools = asyncio.run(run()) + loc = next(t for t in tools if t['tool_name'].endswith('localsearch')) + desc = loc.get('description') or '' + self.assertIn('Local knowledge catalog', desc) + self.assertIn('UserManager', desc) + + @unittest.skipUnless( + os.getenv('TEST_SIRCHMUNK_SMOKE', ''), + 'Set TEST_SIRCHMUNK_SMOKE=1 to run sirchmunk API smoke test', + ) + def test_sirchmunk_search_query_smoke(self): + """Optional: run sirchmunk once (needs network / valid API keys).""" + config = self._base_config() + searcher = SirchmunkSearch(config) + result = asyncio.run(searcher.query('UserManager')) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) if __name__ == '__main__':