From 6baa994b14f103cb4d6033e756760de9586c3894 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 6 Feb 2026 15:03:04 +0800 Subject: [PATCH 01/12] fix: video gen exclude edit_file --- projects/singularity_cinema/agent.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/singularity_cinema/agent.yaml b/projects/singularity_cinema/agent.yaml index d171fc7fc..dc1756486 100644 --- a/projects/singularity_cinema/agent.yaml +++ b/projects/singularity_cinema/agent.yaml @@ -279,6 +279,7 @@ tools: mcp: false allow_read_all_files: true exclude: + - edit_file - list_files - search_file_content - search_file_name From 693d9e95d5f9d013f8af32789de5ffc381cd10cb Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 15 Mar 2026 17:48:32 +0800 Subject: [PATCH 02/12] feat: support search local paths through sirchmunk --- docs/en/Components/Config.md | 17 + docs/zh/Components/config.md | 80 ++-- examples/knowledge_search/agent.yaml.example | 86 ++++ ms_agent/agent/llm_agent.py | 70 ++- ms_agent/cli/run.py | 55 +++ ms_agent/knowledge_search/README.md | 277 ++++++++++++ ms_agent/knowledge_search/__init__.py | 11 + ms_agent/knowledge_search/sirchmunk_search.py | 401 ++++++++++++++++++ ms_agent/llm/dashscope_llm.py | 2 +- ms_agent/llm/utils.py | 6 + ms_agent/rag/utils.py | 3 + tests/knowledge_search/__init__.py | 2 + tests/knowledge_search/test_sirschmunk.py | 203 +++++++++ 13 files changed, 1179 insertions(+), 34 deletions(-) create mode 100644 examples/knowledge_search/agent.yaml.example create mode 100644 ms_agent/knowledge_search/README.md create mode 100644 ms_agent/knowledge_search/__init__.py create mode 100644 ms_agent/knowledge_search/sirchmunk_search.py create mode 100644 tests/knowledge_search/__init__.py create mode 100644 tests/knowledge_search/test_sirschmunk.py diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index d40f03898..f1253bd75 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -166,3 +166,20 @@ In addition to yaml configuration, MS-Agent also supports several additional com ``` > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. + +- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. + +### Quick Start for Knowledge Search + +Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: + +```bash +# Using default agent.yaml configuration, automatically reuses LLM settings +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" + +# Specify configuration file +ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" +``` + +LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. +If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 041820b93..12849f2a7 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -1,12 +1,12 @@ --- slug: config title: 配置与参数 -description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM配置、推理配置、system和query、callbacks、工具配置、其他、config_handler、命令行配置 +description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM 配置、推理配置、system 和 query、callbacks、工具配置、其他、config_handler、命令行配置 --- # 配置与参数 -MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名为`agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: +MS-Agent 使用一个 yaml 文件进行配置管理,通常这个文件被命名为 `agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: ## 类型配置 @@ -17,31 +17,31 @@ MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名 type: llmagent ``` -标识本配置对应的agent类型,支持`llmagent`和`codeagent`两类。默认为`llmagent`。如果yaml中包含了code_file字段,则code_file优先生效。 +标识本配置对应的 agent 类型,支持 `llmagent` 和 `codeagent` 两类。默认为 `llmagent`。如果 yaml 中包含了 code_file 字段,则 code_file 优先生效。 ## 自定义代码 -> 可选,在需要自定义LLMAgent时使用 +> 可选,在需要自定义 LLMAgent 时使用 ```yaml code_file: custom_agent ``` -可以使用一个外部agent类,该类需要继承自`LLMAgent`。可以复写其中的若干方法,如果code_file有值,则`type`字段不生效。 +可以使用一个外部 agent 类,该类需要继承自 `LLMAgent`。可以复写其中的若干方法,如果 code_file 有值,则 `type` 字段不生效。 -## LLM配置 +## LLM 配置 > 必须存在 ```yaml llm: - # 大模型服务backend + # 大模型服务 backend service: modelscope - # 模型id + # 模型 id model: Qwen/Qwen3-235B-A22B-Instruct-2507 - # 模型api_key + # 模型 api_key modelscope_api_key: - # 模型base_url + # 模型 base_url modelscope_base_url: https://api-inference.modelscope.cn/v1 ``` @@ -51,7 +51,7 @@ llm: ```yaml generation_config: - # 下面的字段均为OpenAI sdk的标准参数,你也可以配置OpenAI支持的其他参数在这里。 + # 下面的字段均为 OpenAI sdk 的标准参数,你也可以配置 OpenAI 支持的其他参数在这里。 top_p: 0.6 temperature: 0.2 top_k: 20 @@ -60,25 +60,25 @@ generation_config: enable_thinking: false ``` -## system和query +## system 和 query -> 可选,但推荐传入system +> 可选,但推荐传入 system ```yaml prompt: - # LLM system,如果不传递则使用默认的`you are a helpful assistant.` + # LLM system,如果不传递则使用默认的 `you are a helpful assistant.` system: - # LLM初始query,通常来说可以不使用 + # LLM 初始 query,通常来说可以不使用 query: ``` ## callbacks -> 可选,推荐自定义callbacks +> 可选,推荐自定义 callbacks ```yaml callbacks: - # 用户输入callback,该callback在assistant回复后自动等待用户输入 + # 用户输入 callback,该 callback 在 assistant 回复后自动等待用户输入 - input_callback ``` @@ -90,9 +90,9 @@ callbacks: tools: # 工具名称 file_system: - # 是否是mcp + # 是否是 mcp mcp: false - # 排除的function,可以为空 + # 排除的 function,可以为空 exclude: - create_directory - write_file @@ -104,20 +104,20 @@ tools: - map_geo ``` -支持的完整工具列表,以及自定义工具请参考[这里](./tools) +支持的完整工具列表,以及自定义工具请参考 [这里](./tools) ## 其他 > 可选,按需配置 ```yaml -# 自动对话轮数,默认为20轮 +# 自动对话轮数,默认为 20 轮 max_chat_round: 9999 # 工具调用超时时间,单位秒 tool_call_timeout: 30000 -# 输出artifact目录 +# 输出 artifact 目录 output_dir: output # 帮助信息,通常在运行错误后出现 @@ -127,13 +127,13 @@ help: | ## config_handler -为了便于在任务开始时对config进行定制化,MS-Agent构建了一个名为`ConfigLifecycleHandler`的机制。这是一个callback类,开发者可以在yaml文件中增加这样一个配置: +为了便于在任务开始时对 config 进行定制化,MS-Agent 构建了一个名为 `ConfigLifecycleHandler` 的机制。这是一个 callback 类,开发者可以在 yaml 文件中增加这样一个配置: ```yaml handler: custom_handler ``` -这代表和yaml文件同级有一个custom_handler.py文件,该文件的类继承自`ConfigLifecycleHandler`,分别有两个方法: +这代表和 yaml 文件同级有一个 custom_handler.py 文件,该文件的类继承自 `ConfigLifecycleHandler`,分别有两个方法: ```python def task_begin(self, config: DictConfig, tag: str) -> DictConfig: @@ -143,18 +143,18 @@ handler: custom_handler return config ``` -`task_begin`在LLMAgent类构造时生效,在该方法中可以对config进行一些修改。如果你的工作流中下游任务会继承上游的yaml配置,这个机制会有帮助。值得注意的是`tag`参数,该参数会传入当前LLMAgent的名字,方便分辨当前工作流的节点。 +`task_begin` 在 LLMAgent 类构造时生效,在该方法中可以对 config 进行一些修改。如果你的工作流中下游任务会继承上游的 yaml 配置,这个机制会有帮助。值得注意的是 `tag` 参数,该参数会传入当前 LLMAgent 的名字,方便分辨当前工作流的节点。 ## 命令行配置 -在yaml配置之外,MS-Agent还支持若干额外的命令行参数。 +在 yaml 配置之外,MS-Agent 还支持若干额外的命令行参数。 -- query: 初始query,这个query的优先级高于yaml中的prompt.query -- config: 配置文件路径,支持modelscope model-id -- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为true才会生效 -- load_cache: 从历史messages继续对话。cache会被自动存储在`output`配置中。默认为`False` -- mcp_server_file: 可以读取一个外部的mcp工具配置,格式为: +- query: 初始 query,这个 query 的优先级高于 yaml 中的 prompt.query +- config: 配置文件路径,支持 modelscope model-id +- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为 true 才会生效 +- load_cache: 从历史 messages 继续对话。cache 会被自动存储在 `output` 配置中。默认为 `False` +- mcp_server_file: 可以读取一个外部的 mcp 工具配置,格式为: ```json { "mcpServers": { @@ -165,5 +165,21 @@ handler: custom_handler } } ``` +- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 -> agent.yaml中的任意一个配置,都可以使用命令行传入新的值, 也支持从同名(大小写不敏感)环境变量中读取,例如`--llm.modelscope_api_key xxx-xxx`。 +> agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 + +### 知识搜索快速使用 + +通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: + +```bash +# 使用默认 agent.yaml 配置,自动复用 LLM 设置 +ms-agent run --query "如何实现用户认证?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 +如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example new file mode 100644 index 000000000..cc11a8a3d --- /dev/null +++ b/examples/knowledge_search/agent.yaml.example @@ -0,0 +1,86 @@ +# Sirchmunk Knowledge Search 配置示例 +# Sirchmunk Knowledge Search Configuration Example + +# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: + +llm: + service: modelscope + model: Qwen/Qwen3-235B-A22B-Instruct-2507 + modelscope_api_key: + modelscope_base_url: https://api-inference.modelscope.cn/v1 + +generation_config: + temperature: 0.3 + top_k: 20 + stream: true + +# Knowledge Search 配置(可选) +# 用于在本地代码库中搜索相关信息 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录,用于缓存 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则使用上面 llm 的配置) + llm_api_key: + llm_base_url: https://api.openai.com/v1 + llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:聚类相似度阈值 + cluster_sim_threshold: 0.85 + + # 可选:聚类 TopK + cluster_sim_top_k: 3 + + # 可选:是否重用之前的知识 + reuse_knowledge: true + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:最大循环次数 + max_loops: 10 + + # 可选:最大 token 预算 + max_token_budget: 128000 + +prompt: + system: | + You are an assistant that helps me complete tasks. + +max_chat_round: 9999 + +# 使用说明: +# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 +# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 +# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM +# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 +# +# Python 使用示例: +# ```python +# from ms_agent import LLMAgent +# from ms_agent.config import Config +# +# config = Config.from_task('path/to/agent.yaml') +# agent = LLMAgent(config=config) +# result = await agent.run('如何实现用户认证功能?') +# +# # 获取搜索详情(用于前端展示) +# for msg in result: +# if msg.role == 'user': +# print(f"Search logs: {msg.searching_detail}") +# print(f"Search results: {msg.search_result}") +# ``` +# +# CLI 测试命令: +# export LLM_API_KEY="your-api-key" +# export LLM_BASE_URL="https://api.openai.com/v1" +# export LLM_MODEL_NAME="gpt-4o-mini" +# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 740eab690..3bc40c1fc 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -19,6 +19,7 @@ from ms_agent.memory.memory_manager import SharedMemoryManager from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER @@ -104,6 +105,7 @@ def __init__(self, self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None + self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -619,8 +621,52 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): + """Process RAG or knowledge search to enrich the user query with context. + + This method handles both traditional RAG and sirchmunk-based knowledge search. + For knowledge search, it also populates searching_detail and search_result + fields in the message for frontend display and next-turn LLM context. + + Args: + messages (List[Message]): The message list to process. + """ + user_message = messages[1] if len(messages) > 1 else None + if user_message is None or user_message.role != 'user': + return + + query = user_message.content + + # Handle traditional RAG if self.rag is not None: - messages[1].content = await self.rag.query(messages[1].content) + user_message.content = await self.rag.query(query) + + # Handle sirchmunk knowledge search + if self.knowledge_search is not None: + # Perform search and get results + search_result = await self.knowledge_search.retrieve(query) + search_details = self.knowledge_search.get_search_details() + + # Store search details in the message for frontend display + user_message.searching_detail = search_details + user_message.search_result = search_result + + # Build enriched context from search results + if search_result: + context_parts = [] + for i, result in enumerate(search_result, 1): + text = result.get('text', '') + source = result.get('metadata', {}).get('source', 'unknown') + score = result.get('score', 0) + context_parts.append( + f"[Source {i}] {source} (relevance: {score:.2f})\n{text}\n" + ) + + # Append search context to user query + context = '\n'.join(context_parts) + user_message.content = ( + f"Relevant context retrieved from codebase search:\n\n{context}\n\n" + f"User question: {query}" + ) async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -706,6 +752,27 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) + async def prepare_knowledge_search(self): + """Load and initialize the knowledge search component from the config.""" + if hasattr(self.config, 'knowledge_search'): + ks_config = self.config.knowledge_search + if ks_config is not None: + # Extract LLM config for sirchmunk + if hasattr(self.config, 'llm'): + llm_config = self.config.llm + # Update knowledge_search config with LLM settings if not specified + if not hasattr(ks_config, 'llm_api_key') and hasattr(llm_config, 'modelscope_api_key'): + OmegaConf.update(self.config, 'knowledge_search.llm_api_key', + getattr(llm_config, 'modelscope_api_key', None), merge=True) + if not hasattr(ks_config, 'llm_base_url') and hasattr(llm_config, 'modelscope_base_url'): + OmegaConf.update(self.config, 'knowledge_search.llm_base_url', + getattr(llm_config, 'modelscope_base_url', None), merge=True) + if not hasattr(ks_config, 'llm_model_name') and hasattr(llm_config, 'model'): + OmegaConf.update(self.config, 'knowledge_search.llm_model_name', + getattr(llm_config, 'model', None), merge=True) + + self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) + async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -1044,6 +1111,7 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() + await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index c2df42eee..cfe387e5a 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,6 +4,8 @@ import os from importlib import resources as importlib_resources +from omegaconf import OmegaConf + from ms_agent.config import Config from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII @@ -46,6 +48,22 @@ class RunCMD(CLICommand): def __init__(self, args): self.args = args + def load_env_file(self): + """Load environment variables from .env file in current directory.""" + env_file = os.path.join(os.getcwd(), '.env') + if os.path.exists(env_file): + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, value = line.split('=', 1) + key = key.strip() + value = value.strip() + # Only set if not already set in environment + if key not in os.environ: + os.environ[key] = value + logger.debug(f'Loaded {key} from .env file') + @staticmethod def define_args(parsers: argparse.ArgumentParser): """Define args for run command.""" @@ -120,6 +138,14 @@ def define_args(parsers: argparse.ArgumentParser): help= 'Animation mode for video_generate project: auto (default) or human.' ) + parser.add_argument( + '--knowledge_search_paths', + required=False, + type=str, + default=None, + help= + 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' + ) parser.set_defaults(func=subparser_func) def execute(self): @@ -150,10 +176,18 @@ def execute(self): return self._execute_with_config() def _execute_with_config(self): + # Load environment variables from .env file if exists + self.load_env_file() + if not self.args.config: current_dir = os.getcwd() if os.path.exists(os.path.join(current_dir, AGENT_CONFIG_FILE)): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) + else: + # Use built-in default agent.yaml from package + default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file(default_config_path) as config_file: + self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download self.args.config = snapshot_download(self.args.config) @@ -190,6 +224,27 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) + # If knowledge_search_paths is provided, configure SirchmunkSearch + if getattr(self.args, 'knowledge_search_paths', None): + paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + if paths: + if 'knowledge_search' not in config or not config.knowledge_search: + # No existing knowledge_search config, create minimal config + # LLM settings will be auto-reused from llm module by SirchmunkSearch + knowledge_search_config = { + 'name': 'SirchmunkSearch', + 'paths': paths, + 'work_path': './.sirchmunk', + 'mode': 'FAST', + } + config['knowledge_search'] = OmegaConf.create(knowledge_search_config) + else: + # Existing knowledge_search config found, only update paths + # LLM settings are already handled by SirchmunkSearch internally + existing = OmegaConf.to_container(config.knowledge_search, resolve=True) + existing['paths'] = paths + config['knowledge_search'] = OmegaConf.create(existing) + if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader engine = WorkflowLoader.build( diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md new file mode 100644 index 000000000..00743601e --- /dev/null +++ b/ms_agent/knowledge_search/README.md @@ -0,0 +1,277 @@ +# Sirchmunk Knowledge Search 集成 + +本模块实现了 [sirchmunk](https://github.com/modelscope/sirchmunk) 与 ms_agent 框架的集成,提供了基于代码库的智能搜索功能。 + +## 功能特性 + +- **智能代码搜索**: 使用 LLM 和 embedding 模型对代码库进行语义搜索 +- **多模式搜索**: 支持 FAST、DEEP、FILENAME_ONLY 三种搜索模式 +- **知识复用**: 自动缓存和复用之前的搜索结果,减少 LLM 调用 +- **前端友好**: 提供详细的搜索日志和结果,方便前端展示 +- **无缝集成**: 与 LLMAgent 无缝集成,像使用 RAG 一样简单 + +## 安装 + +```bash +pip install sirchmunk +``` + +## 配置 + +在您的 `agent.yaml` 或 `workflow.yaml` 中添加以下配置: + +```yaml +llm: + service: dashscope + model: qwen3.5-plus + dashscope_api_key: + dashscope_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + +generation_config: + temperature: 0.3 + stream: true + +# Knowledge Search 配置 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则自动复用上面 llm 模块的配置) + # llm_api_key: + # llm_base_url: https://api.openai.com/v1 + # llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:是否重用之前的知识 + reuse_knowledge: true +``` + +**LLM 配置自动复用机制**: + +`SirchmunkSearch` 会自动从主配置的 `llm` 模块复用 LLM 相关参数: +- 如果 `knowledge_search.llm_api_key` 未配置,自动使用 `llm.{service}_api_key` +- 如果 `knowledge_search.llm_base_url` 未配置,自动使用 `llm.{service}_base_url` +- 如果 `knowledge_search.llm_model_name` 未配置,自动使用 `llm.model` + +其中 `service` 是 `llm.service` 的值(如 `dashscope`, `modelscope`, `openai` 等)。 + +通过 CLI 使用时,只需传入 `--knowledge_search_paths` 参数,无需额外配置 LLM 参数。 + +## 使用方式 + +### 1. 通过 CLI 使用(推荐) + +从命令行直接运行,无需编写代码: + +```bash +# 基本用法 - LLM 配置自动从 agent.yaml 的 llm 模块复用 +ms-agent run --query "如何实现用户认证功能?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +**说明**: +- `--knowledge_search_paths` 参数支持逗号分隔的多个路径 +- LLM 相关配置(api_key, base_url, model)会自动从配置文件的 `llm` 模块复用 +- 如果 `knowledge_search` 模块单独配置了 `llm_api_key` 等参数,则优先使用模块自己的配置 + +### 2. 通过 LLMAgent 使用 + +```python +from ms_agent import LLMAgent +from ms_agent.config import Config + +# 从配置文件加载 +config = Config.from_task('path/to/agent.yaml') +agent = LLMAgent(config=config) + +# 运行查询 - 会自动触发知识搜索 +result = await agent.run('如何实现用户认证功能?') + +# 获取搜索结果 +for msg in result: + if msg.role == 'user': + # 搜索详情(用于前端展示) + print(f"Search logs: {msg.searching_detail}") + # 搜索结果(作为 LLM 上下文) + print(f"Search results: {msg.search_result}") +``` + +### 2. 单独使用 SirchmunkSearch + +```python +from ms_agent.knowledge_search import SirchmunkSearch +from omegaconf import DictConfig + +config = DictConfig({ + 'knowledge_search': { + 'paths': ['./src', './docs'], + 'work_path': './.sirchmunk', + 'llm_api_key': 'your-api-key', + 'llm_model_name': 'gpt-4o-mini', + 'mode': 'FAST', + } +}) + +searcher = SirchmunkSearch(config) + +# 查询(返回合成答案) +answer = await searcher.query('如何实现用户认证?') + +# 检索(返回原始搜索结果) +results = await searcher.retrieve( + query='用户认证', + limit=5, + score_threshold=0.7 +) + +# 获取搜索日志 +logs = searcher.get_search_logs() + +# 获取搜索详情 +details = searcher.get_search_details() +``` + +## 环境变量 + +可以通过环境变量配置: + +```bash +# LLM 配置(如不设置则自动从 agent.yaml 的 llm 模块读取) +export LLM_API_KEY="your-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o-mini" + +# Embedding 模型配置 +export EMBEDDING_MODEL_ID="text-embedding-3-small" +export SIRCHMUNK_WORK_PATH="./.sirchmunk" +``` + +**注意**:通过 CLI 使用时,推荐直接在 `.env` 文件或 agent.yaml 中配置 LLM 参数,`SirchmunkSearch` 会自动复用。 + +## 测试 + +### 单元测试 + +```bash +export LLM_API_KEY="your-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o-mini" + +python -m unittest tests/knowledge_search/test_sirschmunk.py +``` + +### CLI 测试 + +```bash +# 基本测试 +python tests/knowledge_search/test_cli.py + +# 指定查询 +python tests/knowledge_search/test_cli.py -q "如何实现用户认证?" + +# 仅测试 standalone 模式 +python tests/knowledge_search/test_cli.py -m standalone + +# 仅测试 agent 模式 +python tests/knowledge_search/test_cli.py -m agent +``` + +## 配置参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| paths | List[str] | 必选 | 要搜索的目录/文件路径列表 | +| work_path | str | ./.sirchmunk | sirchmunk 工作目录,用于缓存 | +| llm_api_key | str | 从 llm 配置继承 | LLM API 密钥 | +| llm_base_url | str | 从 llm 配置继承 | LLM API 基础 URL | +| llm_model_name | str | 从 llm 配置继承 | LLM 模型名称 | +| embedding_model | str | text-embedding-3-small | Embedding 模型 ID | +| cluster_sim_threshold | float | 0.85 | 聚类相似度阈值 | +| cluster_sim_top_k | int | 3 | 聚类 TopK 数量 | +| reuse_knowledge | bool | true | 是否重用之前的知识 | +| mode | str | FAST | 搜索模式 (DEEP/FAST/FILENAME_ONLY) | +| max_loops | int | 10 | 最大搜索循环次数 | +| max_token_budget | int | 128000 | 最大 token 预算 | + +## 搜索模式 + +- **FAST**: 快速模式,使用贪婪策略,1-5 秒内返回结果,0-2 次 LLM 调用 +- **DEEP**: 深度模式,并行多路径检索 + ReAct 优化,5-30 秒,4-6 次 LLM 调用 +- **FILENAME_ONLY**: 仅文件名模式,基于模式匹配,无 LLM 调用,非常快 + +## Message 字段扩展 + +为了支持知识搜索,`Message` 类增加了两个字段: + +- **searching_detail** (Dict[str, Any]): 搜索过程日志和元数据,用于前端展示 + - `logs`: 搜索日志列表 + - `mode`: 使用的搜索模式 + - `paths`: 搜索的路径 + - `work_path`: 工作目录 + - `reuse_knowledge`: 是否重用知识 + +- **search_result** (List[Dict[str, Any]]): 搜索结果,作为下一轮 LLM 的上下文 + - `text`: 文档内容 + - `score`: 相关性分数 + - `metadata`: 元数据(如源文件、类型等) + +## 工作原理 + +1. 用户发送查询 +2. LLMAgent 调用 `prepare_knowledge_search()` 初始化 SirchmunkSearch +3. `do_rag()` 方法执行知识搜索: + - 调用 `searcher.retrieve()` 获取相关文档 + - 将搜索结果存入 `message.search_result` + - 将搜索日志存入 `message.searching_detail` + - 将搜索结果格式化为上下文,附加到用户查询 +4. LLM 接收 enriched query 并生成回答 +5. 前端可以通过 `searching_detail` 展示搜索过程 + +## 故障排除 + +### 常见问题 + +1. **ImportError: No module named 'sirchmunk'** + ```bash + pip install sirchmunk + ``` + +2. **搜索结果为空** + - 检查 `paths` 配置是否正确 + - 确保路径下有可搜索的文件 + - 尝试降低 `cluster_sim_threshold` 值 + +3. **LLM API 调用失败** + - 检查 API key 是否正确 + - 检查 base URL 是否正确 + - 查看搜索日志了解详细错误 + +### 日志查看 + +```python +# 查看搜索日志 +logs = searcher.get_search_logs() +for log in logs: + print(log) + +# 或在配置中启用 verbose +knowledge_search: + verbose: true +``` + +## 参考资源 + +- [sirchmunk GitHub](https://github.com/modelscope/sirchmunk) +- [ModelScope Agent](https://github.com/modelscope/modelscope-agent) diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py new file mode 100644 index 000000000..33362beee --- /dev/null +++ b/ms_agent/knowledge_search/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search module based on sirchmunk. + +This module provides integration between sirchmunk's AgenticSearch +and the ms_agent framework, enabling intelligent codebase search +capabilities similar to RAG. +""" + +from .sirchmunk_search import SirchmunkSearch + +__all__ = ['SirchmunkSearch'] diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py new file mode 100644 index 000000000..4e1e322a5 --- /dev/null +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -0,0 +1,401 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Sirchmunk-based knowledge search integration. + +This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, +providing document retrieval capabilities similar to RAG but optimized for +codebase and documentation search. +""" + +import asyncio +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from loguru import logger + +from ms_agent.rag.base import RAG +from omegaconf import DictConfig + + +class SirchmunkSearch(RAG): + """Sirchmunk-based knowledge search class. + + This class wraps the sirchmunk library to provide intelligent codebase search + capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk + uses a combination of keyword search, semantic clustering, and LLM-powered + analysis to find relevant information from codebases. + + The configuration needed in the config yaml: + - name: SirchmunkSearch + - paths: List of paths to search, required + - work_path: Working directory for sirchmunk cache, default './.sirchmunk' + - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' + - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 + - cluster_sim_top_k: Top K clusters to consider, default 3 + - reuse_knowledge: Whether to reuse previous search results, default True + - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' + + Args: + config (DictConfig): Configuration object containing sirchmunk settings. + """ + + def __init__(self, config: DictConfig): + super().__init__(config) + + self._validate_config(config) + + # Extract configuration parameters + rag_config = config.get('knowledge_search', {}) + + # Search paths - required + paths = rag_config.get('paths', []) + if isinstance(paths, str): + paths = [paths] + self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + + # Work path for sirchmunk cache + _work_path = rag_config.get('work_path', './.sirchmunk') + self.work_path: Path = Path(_work_path).expanduser().resolve() + + # Sirchmunk search parameters + self.reuse_knowledge = rag_config.get('reuse_knowledge', True) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) + self.search_mode = rag_config.get('mode', 'FAST') + self.max_loops = rag_config.get('max_loops', 10) + self.max_token_budget = rag_config.get('max_token_budget', 128000) + + # LLM configuration for sirchmunk + # First try knowledge_search.llm_api_key, then fall back to main llm config + self.llm_api_key = rag_config.get('llm_api_key', None) + self.llm_base_url = rag_config.get('llm_base_url', None) + self.llm_model_name = rag_config.get('llm_model_name', None) + + # Fall back to main llm config if not specified in knowledge_search + if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + llm_config = config.get('llm', {}) + if llm_config: + service = getattr(llm_config, 'service', 'dashscope') + if self.llm_api_key is None: + self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + if self.llm_base_url is None: + self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + if self.llm_model_name is None: + self.llm_model_name = getattr(llm_config, 'model', None) + + # Embedding model configuration + self.embedding_model_id = rag_config.get('embedding_model', None) + self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + + # Runtime state + self._searcher = None + self._initialized = False + + # Callback for capturing logs + self._log_callback = None + self._search_logs: List[str] = [] + + def _validate_config(self, config: DictConfig): + """Validate configuration parameters.""" + if not hasattr(config, 'knowledge_search') or config.knowledge_search is None: + raise ValueError( + 'Missing knowledge_search configuration. ' + 'Please add knowledge_search section to your config with at least "paths" specified.' + ) + + rag_config = config.knowledge_search + paths = rag_config.get('paths', []) + if not paths: + raise ValueError('knowledge_search.paths must be specified and non-empty') + + def _initialize_searcher(self): + """Initialize the sirchmunk AgenticSearch instance.""" + if self._initialized: + return + + try: + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.utils.embedding_util import EmbeddingUtil + + # Create LLM client + llm = OpenAIChat( + api_key=self.llm_api_key, + base_url=self.llm_base_url, + model=self.llm_model_name, + max_retries=3, + log_callback=self._log_callback_wrapper(), + ) + + # Create embedding util + # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) + embedding_model_id = self.embedding_model_id if self.embedding_model_id else None + embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None + embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + + # Create AgenticSearch instance + self._searcher = AgenticSearch( + llm=llm, + embedding=embedding, + work_path=str(self.work_path), + paths=self.search_paths, + verbose=True, + reuse_knowledge=self.reuse_knowledge, + cluster_sim_threshold=self.cluster_sim_threshold, + cluster_sim_top_k=self.cluster_sim_top_k, + log_callback=self._log_callback_wrapper(), + ) + + self._initialized = True + logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + + except ImportError as e: + raise ImportError( + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk' + ) + except Exception as e: + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + + def _log_callback_wrapper(self): + """Create a callback wrapper to capture search logs.""" + def log_callback(message: str, level: str = 'INFO', logger_name: str = '', is_async: bool = False): + self._search_logs.append(f'[{level}] {message}') + + return log_callback + + async def add_documents(self, documents: List[str]) -> bool: + """Add documents to the search index. + + Note: Sirchmunk works by scanning existing files in the specified paths. + This method is provided for RAG interface compatibility but doesn't + directly add documents. Instead, documents should be saved to files + within the search paths. + + Args: + documents (List[str]): List of document contents to add. + + Returns: + bool: True if successful (for interface compatibility). + """ + logger.warning( + 'SirchmunkSearch does not support direct document addition. ' + 'Documents should be saved to files within the configured search paths.' + ) + # Trigger re-scan of the search paths + if self._searcher and hasattr(self._searcher, 'knowledge_base'): + try: + await self._searcher.knowledge_base.refresh() + return True + except Exception as e: + logger.error(f'Failed to refresh knowledge base: {e}') + return False + return True + + async def add_documents_from_files(self, file_paths: List[str]) -> bool: + """Add documents from file paths. + + Args: + file_paths (List[str]): List of file paths to scan. + + Returns: + bool: True if successful. + """ + self._initialize_searcher() + + if self._searcher and hasattr(self._searcher, 'scan_directory'): + try: + for file_path in file_paths: + if Path(file_path).exists(): + await self._searcher.scan_directory(str(Path(file_path).parent)) + return True + except Exception as e: + logger.error(f'Failed to scan files: {e}') + return False + return True + + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: + """Retrieve relevant documents using sirchmunk. + + Args: + query (str): The search query. + limit (int): Maximum number of results to return. + score_threshold (float): Minimum relevance score threshold. + **filters: Additional filters (mode, max_loops, etc.). + + Returns: + List[Dict[str, Any]]: List of search results with 'text', 'score', + 'metadata' fields. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = filters.get('mode', self.search_mode) + max_loops = filters.get('max_loops', self.max_loops) + max_token_budget = filters.get('max_token_budget', self.max_token_budget) + + # Perform search + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=True, + ) + + # Parse results into standard format + return self._parse_search_result(result, score_threshold, limit) + + except Exception as e: + logger.error(f'SirschmunkSearch retrieve failed: {e}') + return [] + + async def query(self, query: str) -> str: + """Query sirchmunk and return a synthesized answer. + + This method performs a search and returns the LLM-synthesized answer + along with search details that can be used for frontend display. + + Args: + query (str): The search query. + + Returns: + str: The synthesized answer from sirchmunk. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = self.search_mode + max_loops = self.max_loops + max_token_budget = self.max_token_budget + + # Perform search and get answer + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=False, + ) + + # Result is already a synthesized answer string + if isinstance(result, str): + return result + + # If we got SearchContext or other format, extract the answer + if hasattr(result, 'answer'): + return result.answer + + # Fallback: convert to string + return str(result) + + except Exception as e: + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' + + def _parse_search_result(self, + result: Any, + score_threshold: float, + limit: int) -> List[Dict[str, Any]]: + """Parse sirchmunk search result into standard format. + + Args: + result: The raw search result from sirchmunk. + score_threshold: Minimum score threshold. + limit: Maximum number of results. + + Returns: + List[Dict[str, Any]]: Parsed results. + """ + results = [] + + # Handle SearchContext format (returned when return_context=True) + if hasattr(result, 'cluster') and result.cluster is not None: + cluster = result.cluster + for unit in cluster.evidences: + # Extract score from snippets if available + score = getattr(cluster, 'confidence', 1.0) + if score >= score_threshold: + # Extract text from snippets + text_parts = [] + source = str(getattr(unit, 'file_or_url', 'unknown')) + for snippet in getattr(unit, 'snippets', []): + if isinstance(snippet, dict): + text_parts.append(snippet.get('snippet', '')) + else: + text_parts.append(str(snippet)) + + results.append({ + 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') if hasattr(unit, 'abstraction_level') else 'text', + } + }) + + # Handle format with evidence_units attribute directly + elif hasattr(result, 'evidence_units'): + for unit in result.evidence_units: + score = getattr(unit, 'confidence', 1.0) + if score >= score_threshold: + results.append({ + 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + } + }) + + # Handle list format + elif isinstance(result, list): + for item in result: + if isinstance(item, dict): + score = item.get('score', item.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': item.get('content', item.get('text', str(item))), + 'score': score, + 'metadata': item.get('metadata', {}), + }) + + # Handle dict format + elif isinstance(result, dict): + score = result.get('score', result.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + }) + + # Sort by score and limit results + results.sort(key=lambda x: x.get('score', 0), reverse=True) + return results[:limit] + + def get_search_logs(self) -> List[str]: + """Get the captured search logs. + + Returns: + List[str]: List of log messages from the search operation. + """ + return self._search_logs.copy() + + def get_search_details(self) -> Dict[str, Any]: + """Get detailed search information including logs and metadata. + + Returns: + Dict[str, Any]: Search details including logs, mode, and paths. + """ + return { + 'logs': self._search_logs.copy(), + 'mode': self.search_mode, + 'paths': self.search_paths, + 'work_path': str(self.work_path), + 'reuse_knowledge': self.reuse_knowledge, + } diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index af766f679..b4a6ddaa8 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -12,7 +12,7 @@ class DashScope(OpenAI): def __init__(self, config: DictConfig): super().__init__( config, - base_url=config.llm.modelscope_base_url + base_url=config.llm.dashscope_base_url or get_service_config('dashscope').base_url, api_key=config.llm.dashscope_api_key) diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 6a336ca6e..410aa12f0 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,6 +61,12 @@ class Message: api_calls: int = 1 + # Knowledge search (sirchmunk) related fields + # searching_detail: Search process logs and metadata for frontend display + searching_detail: Dict[str, Any] = field(default_factory=dict) + # search_result: Raw search results to be used as context for next LLM turn + search_result: List[Dict[str, Any]] = field(default_factory=list) + def to_dict(self): return asdict(self) diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index 08e9a4db7..e66da954d 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -4,3 +4,6 @@ rag_mapping = { 'LlamaIndexRAG': LlamaIndexRAG, } + +# Note: SirchmunkSearch is registered in knowledge_search module +# and integrated directly in LLMAgent, not through rag_mapping diff --git a/tests/knowledge_search/__init__.py b/tests/knowledge_search/__init__.py new file mode 100644 index 000000000..0cc40e613 --- /dev/null +++ b/tests/knowledge_search/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search tests.""" diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py new file mode 100644 index 000000000..5a4f43213 --- /dev/null +++ b/tests/knowledge_search/test_sirschmunk.py @@ -0,0 +1,203 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. + +These tests verify the sirchmunk-based knowledge search functionality +through the LLMAgent entry point, including verification that +search_result and searching_detail fields are properly populated. + +To run these tests, you need to set the following environment variables: + - TEST_LLM_API_KEY: Your LLM API key + - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) + - TEST_LLM_MODEL_NAME: Your LLM model name (optional) + - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) + - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) + +Example: + export TEST_LLM_API_KEY="your-api-key" + export TEST_LLM_BASE_URL="https://api.openai.com/v1" + export TEST_LLM_MODEL_NAME="gpt-4o" + export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" + python -m pytest tests/knowledge_search/test_sirschmunk.py +""" +import asyncio +import os +import shutil +import unittest +from pathlib import Path + +from ms_agent.knowledge_search import SirchmunkSearch +from ms_agent.agent import LLMAgent +from ms_agent.config import Config +from omegaconf import DictConfig + +from modelscope.utils.test_utils import test_level + + +class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): + """Test cases for SirchmunkSearch integration with LLMAgent. + + These tests verify that when LLMAgent runs a query that triggers + knowledge search, the Message objects have search_result and + searching_detail fields properly populated. + """ + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + # Create test directory with sample files + cls.test_dir = Path('./test_llm_agent_knowledge') + cls.test_dir.mkdir(exist_ok=True) + + # Create sample documentation + (cls.test_dir / 'README.md').write_text(''' +# Test Project Documentation + +## Overview +This is a test project for knowledge search integration. + +## API Reference + +### UserManager +The UserManager class handles user operations: +- create_user: Create a new user account +- delete_user: Delete an existing user +- update_user: Update user information +- get_user: Retrieve user details + +### AuthService +The AuthService class handles authentication: +- login: Authenticate user credentials +- logout: End user session +- refresh_token: Refresh authentication token +- verify_token: Validate authentication token +''') + + (cls.test_dir / 'config.py').write_text(''' +"""Configuration module.""" + +class Config: + """Application configuration.""" + + def __init__(self): + self.database_url = "postgresql://localhost:5432/mydb" + self.secret_key = "your-secret-key" + self.debug_mode = False + + def load_from_env(self): + """Load configuration from environment variables.""" + import os + self.database_url = os.getenv("DATABASE_URL", self.database_url) + self.secret_key = os.getenv("SECRET_KEY", self.secret_key) + return self +''') + + @classmethod + def tearDownClass(cls): + """Clean up test fixtures.""" + if cls.test_dir.exists(): + shutil.rmtree(cls.test_dir, ignore_errors=True) + work_dir = Path('./.sirchmunk') + if work_dir.exists(): + shutil.rmtree(work_dir, ignore_errors=True) + + def _get_agent_config(self): + """Create agent configuration with knowledge search.""" + llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') + # Read from TEST_* env vars (for test-specific config) + # These can be set from .env file which uses TEST_* prefix + embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') + + config = DictConfig({ + 'llm': { + 'service': 'openai', + 'model': llm_model_name, + 'openai_api_key': llm_api_key, + 'openai_base_url': llm_base_url, + }, + 'generation_config': { + 'temperature': 0.3, + 'max_tokens': 500, + }, + 'knowledge_search': { + 'name': 'SirchmunkSearch', + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + } + }) + return config + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_llm_agent_with_knowledge_search(self): + """Test LLMAgent using knowledge search. + + This test verifies that: + 1. LLMAgent can be initialized with SirchmunkSearch configuration + 2. Running a query produces a valid response + 3. User message has searching_detail and search_result populated + 4. searching_detail contains expected keys (logs, mode, paths) + 5. search_result is a list + """ + config = self._get_agent_config() + agent = LLMAgent(config=config, tag='test-knowledge-agent') + + # Test query that should trigger knowledge search + query = 'How do I use UserManager to create a user?' + + async def run_agent(): + result = await agent.run(query) + return result + + result = asyncio.run(run_agent()) + + # Verify result + self.assertIsNotNone(result) + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + + # Check that assistant message exists + assistant_message = [m for m in result if m.role == 'assistant'] + self.assertTrue(len(assistant_message) > 0) + + # Check that user message has search_result and searching_detail populated + user_messages = [m for m in result if m.role == 'user'] + self.assertTrue(len(user_messages) > 0, "Expected at least one user message") + + # The first user message should have search details after do_rag processing + user_msg = user_messages[0] + self.assertTrue( + hasattr(user_msg, 'searching_detail'), + "User message should have searching_detail attribute" + ) + self.assertTrue( + hasattr(user_msg, 'search_result'), + "User message should have search_result attribute" + ) + + # Check that searching_detail is a dict with expected keys + self.assertIsInstance( + user_msg.searching_detail, dict, + "searching_detail should be a dictionary" + ) + self.assertIn('logs', user_msg.searching_detail) + self.assertIn('mode', user_msg.searching_detail) + self.assertIn('paths', user_msg.searching_detail) + + # Check that search_result is a list (may be empty if no relevant docs found) + self.assertIsInstance( + user_msg.search_result, list, + "search_result should be a list" + ) + + +if __name__ == '__main__': + unittest.main() From 2977ba09204169b6fe02492da5089e56ad06573e Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:58:32 +0800 Subject: [PATCH 03/12] Update ms_agent/agent/llm_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/agent/llm_agent.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 3bc40c1fc..515446381 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -757,20 +757,6 @@ async def prepare_knowledge_search(self): if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: - # Extract LLM config for sirchmunk - if hasattr(self.config, 'llm'): - llm_config = self.config.llm - # Update knowledge_search config with LLM settings if not specified - if not hasattr(ks_config, 'llm_api_key') and hasattr(llm_config, 'modelscope_api_key'): - OmegaConf.update(self.config, 'knowledge_search.llm_api_key', - getattr(llm_config, 'modelscope_api_key', None), merge=True) - if not hasattr(ks_config, 'llm_base_url') and hasattr(llm_config, 'modelscope_base_url'): - OmegaConf.update(self.config, 'knowledge_search.llm_base_url', - getattr(llm_config, 'modelscope_base_url', None), merge=True) - if not hasattr(ks_config, 'llm_model_name') and hasattr(llm_config, 'model'): - OmegaConf.update(self.config, 'knowledge_search.llm_model_name', - getattr(llm_config, 'model', None), merge=True) - self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: From 879dba4190c819ee7b70ce7d4f511aba21e44715 Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:59:07 +0800 Subject: [PATCH 04/12] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/knowledge_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md index 00743601e..ef86df0da 100644 --- a/ms_agent/knowledge_search/README.md +++ b/ms_agent/knowledge_search/README.md @@ -108,7 +108,7 @@ for msg in result: print(f"Search results: {msg.search_result}") ``` -### 2. 单独使用 SirchmunkSearch +### 3. 单独使用 SirchmunkSearch ```python from ms_agent.knowledge_search import SirchmunkSearch From 520282384734a76a4a1eec047ff1e4f962c0ca59 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 20 Mar 2026 18:27:13 +0800 Subject: [PATCH 05/12] full modify? --- ms_agent/agent/llm_agent.py | 352 +++++++++--------- ms_agent/knowledge_search/sirchmunk_search.py | 208 ++++++++--- 2 files changed, 330 insertions(+), 230 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 515446381..76289a403 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,30 +2,29 @@ import asyncio import importlib import inspect +import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy +from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping from ms_agent.memory.memory_manager import SharedMemoryManager from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger -from omegaconf import DictConfig, OmegaConf - from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -90,14 +89,17 @@ class LLMAgent(Agent): TOTAL_CACHE_CREATION_INPUT_TOKENS = 0 TOKEN_LOCK = asyncio.Lock() - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs, + ): if not hasattr(config, 'llm'): default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml' + ) llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -113,7 +115,8 @@ def __init__(self, self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {})) + kwargs.get('mcp_config', {}) + ) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -161,37 +164,34 @@ def _ensure_auto_skills(self) -> bool: use_sandbox = getattr(skills_config, 'use_sandbox', True) if use_sandbox: from ms_agent.utils.docker_utils import is_docker_daemon_running + if not is_docker_daemon_running(): - logger.warning( - 'Docker not running, disabling sandbox for skills') + logger.warning('Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container( - skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', - None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, - 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) logger.info( - f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' + f"AutoSkills initialized with {len(self._auto_skills.all_skills)} skills" ) self._auto_skills_initialized = True return True except Exception as e: - logger.warning(f'Failed to initialize AutoSkills: {e}') + logger.warning(f"Failed to initialize AutoSkills: {e}") self._auto_skills_initialized = True return False @@ -233,7 +233,7 @@ async def should_use_skills(self, query: str) -> bool: needs_skills, _, _, _ = self._auto_skills._analyze_query(query) return needs_skills except Exception as e: - logger.error(f'Skill analysis error: {e}') + logger.error(f"Skill analysis error: {e}") return False async def get_skill_dag(self, query: str): @@ -265,13 +265,15 @@ async def execute_skills(self, query: str, execution_input=None): return None skills_config = self._get_skills_config() - stop_on_failure = getattr(skills_config, 'stop_on_failure', - True) if skills_config else True + stop_on_failure = ( + getattr(skills_config, 'stop_on_failure', True) if skills_config else True + ) result = await self._auto_skills.run( query=query, execution_input=execution_input, - stop_on_failure=stop_on_failure) + stop_on_failure=stop_on_failure, + ) self._last_skill_result = result return result @@ -289,15 +291,14 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append( - Message(role='assistant', content=dag_result.chat_response)) + messages.append(Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills if not dag_result.is_complete: content = "I couldn't find suitable skills for this task." if dag_result.clarification: - content += f'\n\n{dag_result.clarification}' + content += f"\n\n{dag_result.clarification}" messages.append(Message(role='assistant', content=content)) return messages @@ -317,28 +318,30 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: stdout_preview = output.stdout[:1000] if len(output.stdout) > 1000: stdout_preview += '...' - content += f'**{skill_id} output:**\n{stdout_preview}\n\n' + content += f"**{skill_id} output:**\n{stdout_preview}\n\n" if output.output_files: - content += f'**Generated files:** {list(output.output_files.values())}\n\n' + content += f"**Generated files:** {list(output.output_files.values())}\n\n" - content += f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + content += ( + f"Total execution time: {exec_result.total_duration_ms:.2f}ms" + ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): if not result.success: - content += f'**{skill_id} failed:** {result.error}\n' + content += f"**{skill_id} failed:** {result.error}\n" messages.append(Message(role='assistant', content=content)) else: # DAG only, no execution skill_names = list(dag_result.selected_skills.keys()) - content = f'Found {len(skill_names)} relevant skill(s) for your task:\n' + content = f"Found {len(skill_names)} relevant skill(s) for your task:\n" for skill_id, skill in dag_result.selected_skills.items(): desc_preview = skill.description[:100] if len(skill.description) > 100: desc_preview += '...' - content += f'- **{skill.name}** ({skill_id}): {desc_preview}\n' - content += f'\nExecution order: {dag_result.execution_order}' + content += f"- **{skill.name}** ({skill_id}): {desc_preview}\n" + content += f"\nExecution order: {dag_result.execution_order}" messages.append(Message(role='assistant', content=content)) @@ -364,8 +367,7 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile( - self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -394,26 +396,32 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: if handler_file is not None: local_dir = self.config.local_dir assert self.config.trust_remote_code, ( - f'[External Code]A Config Lifecycle handler ' - f'registered in the config: {handler_file}. ' - f'\nThis is external code, if you trust this workflow, ' - f'please specify `--trust_remote_code true`') - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + f"[External Code]A Config Lifecycle handler " + f"registered in the config: {handler_file}. " + f"\nThis is external code, if you trust this workflow, " + f"please specify `--trust_remote_code true`" + ) + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if local_dir not in sys.path: sys.path.insert(0, local_dir) handler_module = importlib.import_module(handler_file) module_classes = { name: cls - for name, cls in inspect.getmembers(handler_module, - inspect.isclass) + for name, cls in inspect.getmembers(handler_module, inspect.isclass) } handler = None for name, handler_cls in module_classes.items(): - if handler_cls.__bases__[ - 0] is ConfigLifecycleHandler and handler_cls.__module__ == handler_file: + if ( + handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file + ): handler = handler_cls() - assert handler is not None, f'Config Lifecycle handler class cannot be found in {handler_file}' + assert ( + handler is not None + ), f"Config Lifecycle handler class cannot be found in {handler_file}" return handler return None @@ -424,13 +432,14 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, - 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: subdir = os.path.dirname(_callback) - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if subdir: subdir = os.path.join(local_dir, str(subdir)) _callback = os.path.basename(_callback) @@ -451,23 +460,22 @@ def register_callback_from_config(self): module_classes = { name: cls for name, cls in inspect.getmembers( - callback_file, inspect.isclass) + callback_file, inspect.isclass + ) } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass( - cls, Callback) and cls.__module__ == _callback: + if issubclass(cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback]( - self.config)) + self.callbacks.append(callbacks_mapping[_callback](self.config)) async def on_task_begin(self, messages: List[Message]): - self.log_output(f'Agent {self.tag} task beginning.') + self.log_output(f"Agent {self.tag} task beginning.") await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): - self.log_output(f'Agent {self.tag} task finished.') + self.log_output(f"Agent {self.tag} task finished.") await self.loop_callback('on_task_end', messages) async def on_generate_response(self, messages: List[Message]): @@ -492,8 +500,7 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, - messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -504,17 +511,20 @@ async def parallel_tool_call(self, List[Message]: Updated message list including tool responses. """ tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls) + messages[-1].tool_calls + ) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip(tool_call_result, - messages[-1].tool_calls): + for tool_call_result, tool_call_query in zip( + tool_call_result, messages[-1].tool_calls + ): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', content=tool_call_result_format.text, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], - resources=tool_call_result_format.resources) + resources=tool_call_result_format.resources, + ) if _new_message.tool_call_id is None: # If tool call id is None, add a random one @@ -530,7 +540,8 @@ async def prepare_tools(self): self.config, self.mcp_config, self.mcp_client, - trust_remote_code=self.trust_remote_code) + trust_remote_code=self.trust_remote_code, + ) await self.tool_manager.connect() async def cleanup_tools(self): @@ -539,8 +550,7 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -551,8 +561,7 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -563,8 +572,7 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) def _write_reasoning(self, text: str): @@ -580,19 +588,18 @@ def _write_reasoning(self, text: str): @property def system(self): - return getattr( - getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr( - getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query async def create_messages( - self, messages: Union[List[Message], str]) -> List[Message]: + self, messages: Union[List[Message], str] + ) -> List[Message]: """ Convert input into a standardized list of messages. @@ -604,18 +611,19 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if system is not None and messages[ - 0].role == 'system' and system != messages[0].content: + if ( + system is not None + and messages[0].role == 'system' + and system != messages[0].content + ): # Replace the existing system messages[0].content = system else: assert isinstance( messages, str - ), f'inputs can be either a list or a string, but current is {type(messages)}' + ), f"inputs can be either a list or a string, but current is {type(messages)}" messages = [ - Message( - role='system', - content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -639,11 +647,10 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search if self.knowledge_search is not None: # Perform search and get results - search_result = await self.knowledge_search.retrieve(query) + search_result = await self.knowledge_search.query(query) search_details = self.knowledge_search.get_search_details() # Store search details in the message for frontend display @@ -652,24 +659,14 @@ async def do_rag(self, messages: List[Message]): # Build enriched context from search results if search_result: - context_parts = [] - for i, result in enumerate(search_result, 1): - text = result.get('text', '') - source = result.get('metadata', {}).get('source', 'unknown') - score = result.get('score', 0) - context_parts.append( - f"[Source {i}] {source} (relevance: {score:.2f})\n{text}\n" - ) - # Append search context to user query - context = '\n'.join(context_parts) + context = search_result user_message.content = ( f"Relevant context retrieved from codebase search:\n\n{context}\n\n" f"User question: {query}" ) - async def do_skill(self, - messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -686,7 +683,9 @@ async def do_skill(self, # Extract user query from normalized messages query = ( messages[1].content - if len(messages) > 1 and messages[1].role == 'user' else None) + if len(messages) > 1 and messages[1].role == 'user' + else None + ) if not query: return None @@ -700,8 +699,9 @@ async def do_skill(self, try: skills_config = self._get_skills_config() - auto_execute = getattr(skills_config, 'auto_execute', - True) if skills_config else True + auto_execute = ( + getattr(skills_config, 'auto_execute', True) if skills_config else True + ) if auto_execute: dag_result = await self.execute_skills(query) @@ -709,8 +709,7 @@ async def do_skill(self, dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages( - dag_result) + skill_messages = self._format_skill_result_as_messages(dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -721,7 +720,8 @@ async def do_skill(self, except Exception as e: logger.warning( - f'Skill execution failed: {e}, falling back to standard agent') + f"Skill execution failed: {e}, falling back to standard agent" + ) self._skill_mode_active = False return None @@ -735,11 +735,13 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f'{mem_instance_type} not in memory_mapping, ' - f'which supports: {list(memory_mapping.keys())}') + f"{mem_instance_type} not in memory_mapping, " + f"which supports: {list(memory_mapping.keys())}" + ) shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type) + self.config, mem_instance_type + ) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -748,12 +750,17 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f'{rag.name} not in rag_mapping, ' - f'which supports: {list(rag_mapping.keys())}') + f"{rag.name} not in rag_mapping, " + f"which supports: {list(rag_mapping.keys())}" + ) self.rag: RAG = rag_mapping(rag.name)(self.config) async def prepare_knowledge_search(self): """Load and initialize the knowledge search component from the config.""" + if self.knowledge_search is not None: + # Already initialized (e.g. by caller before run_loop), skip to avoid + # overwriting a configured instance (e.g. one with streaming callbacks set). + return if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: @@ -790,7 +797,7 @@ def log_output(self, content: Union[str, list]): text_parts.append(item.get('text', '')) elif item.get('type') == 'image_url': img_url = item.get('image_url', {}).get('url', '') - text_parts.append(f'[Image: {img_url[:50]}...]') + text_parts.append(f"[Image: {img_url[:50]}...]") content = ' '.join(text_parts) # Ensure content is a string @@ -801,10 +808,9 @@ def log_output(self, content: Union[str, list]): content = content[:512] + '\n...\n' + content[-512:] for line in content.split('\n'): for _line in line.split('\\n'): - logger.info(f'[{self.tag}] {_line}') + logger.info(f"[{self.tag}] {_line}") - def handle_new_response(self, messages: List[Message], - response_message: Message): + def handle_new_response(self, messages: List[Message], response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -812,24 +818,23 @@ def handle_new_response(self, messages: List[Message], tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads( - tool_call['arguments']) + tool_call['arguments'] = json.loads(tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output( - json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if messages[-1].role == 'assistant' and not messages[ - -1].content and response_message.tool_calls: + if ( + messages[-1].role == 'assistant' + and not messages[-1].content + and response_message.tool_calls + ): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step( - self, messages: List[Message] - ) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -865,20 +870,20 @@ async def step( is_first = True _response_message = None _printed_reasoning_header = False - for _response_message in self.llm.generate( - messages, tools=tools): + for _response_message in self.llm.generate(messages, tools=tools): if is_first: messages.append(_response_message) is_first = False # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') or '' + ) # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning):] + new_reasoning = reasoning_text[len(_reasoning) :] if new_reasoning: if not _printed_reasoning_header: self._write_reasoning('[thinking]:\n') @@ -886,7 +891,7 @@ async def step( self._write_reasoning(new_reasoning) _reasoning = reasoning_text - new_content = _response_message.content[len(_content):] + new_content = _response_message.content[len(_content) :] sys.stdout.write(new_content) sys.stdout.flush() _content = _response_message.content @@ -898,8 +903,9 @@ async def step( else: _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') or '' + ) if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -927,8 +933,9 @@ async def step( prompt_tokens = _response_message.prompt_tokens completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 - cache_creation_input_tokens = getattr( - _response_message, 'cache_creation_input_tokens', 0) or 0 + cache_creation_input_tokens = ( + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 + ) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -938,20 +945,21 @@ async def step( # tokens in the current step self.log_output( - f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' + f"[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) if cached_tokens or cache_creation_input_tokens: self.log_output( - f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' + f"[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}" ) # total tokens for the process so far self.log_output( - f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' - f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') + f"[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, " + f"total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}" + ) if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( - f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' - f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}' + f"[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, " + f"total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}" ) yield messages @@ -964,8 +972,9 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history(self, messages: List[Message], - **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history( + self, messages: List[Message], **kwargs + ) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1009,9 +1018,9 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: def _get_step_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step') - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + memory_config, 'add_after_step' + ) + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1020,9 +1029,9 @@ def _get_run_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( memory_config, 'add_after_task', - default_user_id=getattr(memory_config, 'user_id', None)) - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + default_user_id=getattr(memory_config, 'user_id', None), + ) + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1033,24 +1042,29 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, - memory_config) in enumerate(self.config.memory.items()): + for idx, (mem_instance_type, memory_config) in enumerate( + self.config.memory.items() + ): if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config) + memory_config + ) else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config) + memory_config + ) if idx < tools_num: - if any(v is not None - for v in [user_id, agent_id, run_id, memory_type]): + if any( + v is not None for v in [user_id, agent_id, run_id, memory_type] + ): await self.memory_tools[idx].add( messages, user_id=user_id, agent_id=agent_id, run_id=run_id, - memory_type=memory_type) + memory_type=memory_type, + ) def save_history(self, messages: List[Message], **kwargs): """ @@ -1072,11 +1086,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history( - self.output_dir, task=self.tag, config=config, messages=messages) + save_history(self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop(self, messages: Union[List[Message], str], - **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop( + self, messages: Union[List[Message], str], **kwargs + ) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1089,8 +1103,9 @@ async def run_loop(self, messages: Union[List[Message], str], List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr(self.config, 'max_chat_round', - LLMAgent.DEFAULT_MAX_CHAT_ROUND) + self.max_chat_round = getattr( + self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND + ) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1132,8 +1147,7 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages self.runtime.round += 1 # save memory and history - await self.add_memory( - messages, add_type='add_after_step', **kwargs) + await self.add_memory(messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1142,9 +1156,10 @@ async def run_loop(self, messages: Union[List[Message], str], messages.append( Message( role='assistant', - content= - f'Task {messages[1].content} was cutted off, because ' - f'max round({self.max_chat_round}) exceeded.')) + content=f"Task {messages[1].content} was cutted off, because " + f"max round({self.max_chat_round}) exceeded.", + ) + ) self.runtime.should_stop = True yield messages @@ -1155,32 +1170,33 @@ async def run_loop(self, messages: Union[List[Message], str], def _add_memory(): asyncio.run( - self.add_memory( - messages, add_type='add_after_task', **kwargs)) + self.add_memory(messages, add_type='add_after_task', **kwargs) + ) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) except Exception as e: import traceback + logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( - f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' + f"[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}" ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True) + self.config, 'generation_config.stream', True, merge=True + ) async def stream_generator(): - async for _chunk in self.run_loop( - messages=messages, **kwargs): + async for _chunk in self.run_loop(messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py index 4e1e322a5..dd80738f9 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -7,12 +7,12 @@ """ import asyncio -from pathlib import Path -from typing import Any, Dict, List, Optional, Union from loguru import logger +from omegaconf import DictConfig +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, Callable from ms_agent.rag.base import RAG -from omegaconf import DictConfig class SirchmunkSearch(RAG): @@ -49,7 +49,9 @@ def __init__(self, config: DictConfig): paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] - self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + self.search_paths: List[str] = [ + str(Path(p).expanduser().resolve()) for p in paths + ] # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') @@ -70,28 +72,40 @@ def __init__(self, config: DictConfig): self.llm_model_name = rag_config.get('llm_model_name', None) # Fall back to main llm config if not specified in knowledge_search - if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + if ( + self.llm_api_key is None + or self.llm_base_url is None + or self.llm_model_name is None + ): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + self.llm_api_key = getattr(llm_config, f"{service}_api_key", None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + self.llm_base_url = getattr(llm_config, f"{service}_base_url", None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) - self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + self.embedding_model_cache_dir = rag_config.get( + 'embedding_model_cache_dir', None + ) # Runtime state self._searcher = None self._initialized = False + self._cluster_cache_hit = False + self._cluster_cache_hit_time: str | None = None + self._last_search_result: List[Dict[str, Any]] | None = None # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] + # Async queue for streaming logs in real-time + self._log_queue: asyncio.Queue | None = None + self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): """Validate configuration parameters.""" @@ -112,8 +126,8 @@ def _initialize_searcher(self): return try: - from sirchmunk.search import AgenticSearch from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil # Create LLM client @@ -127,9 +141,17 @@ def _initialize_searcher(self): # Create embedding util # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) - embedding_model_id = self.embedding_model_id if self.embedding_model_id else None - embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None - embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + embedding_model_id = ( + self.embedding_model_id if self.embedding_model_id else None + ) + embedding_cache_dir = ( + self.embedding_model_cache_dir + if self.embedding_model_cache_dir + else None + ) + embedding = EmbeddingUtil( + model_id=embedding_model_id, cache_dir=embedding_cache_dir + ) # Create AgenticSearch instance self._searcher = AgenticSearch( @@ -145,20 +167,35 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + logger.info(f"SirschmunkSearch initialized with paths: {self.search_paths}") except ImportError as e: raise ImportError( - f'Failed to import sirchmunk: {e}. ' + f"Failed to import sirchmunk: {e}. " 'Please install sirchmunk: pip install sirchmunk' ) except Exception as e: - raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + raise RuntimeError(f"Failed to initialize SirchmunkSearch: {e}") def _log_callback_wrapper(self): - """Create a callback wrapper to capture search logs.""" - def log_callback(message: str, level: str = 'INFO', logger_name: str = '', is_async: bool = False): - self._search_logs.append(f'[{level}] {message}') + """Create a callback wrapper to capture search logs. + + The sirchmunk LogCallback signature is: + (level: str, message: str, end: str, flush: bool) -> None + See sirchmunk/utils/log_utils.py for reference. + """ + + def log_callback( + level: str, + message: str, + end: str = '\n', + flush: bool = False, + ): + log_entry = f"[{level.upper()}] {message}" + self._search_logs.append(log_entry) + # Stream log in real-time if streaming callback is set + if self._streaming_callback: + asyncio.create_task(self._streaming_callback(log_entry)) return log_callback @@ -186,7 +223,7 @@ async def add_documents(self, documents: List[str]) -> bool: await self._searcher.knowledge_base.refresh() return True except Exception as e: - logger.error(f'Failed to refresh knowledge base: {e}') + logger.error(f"Failed to refresh knowledge base: {e}") return False return True @@ -208,15 +245,13 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: await self._searcher.scan_directory(str(Path(file_path).parent)) return True except Exception as e: - logger.error(f'Failed to scan files: {e}') + logger.error(f"Failed to scan files: {e}") return False return True - async def retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.7, - **filters) -> List[Dict[str, Any]]: + async def retrieve( + self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters + ) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -246,11 +281,21 @@ async def retrieve(self, return_context=True, ) + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + # If a similar cluster was found and reused, it's a cache hit + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + # Get the cluster cache hit time if available + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: - logger.error(f'SirschmunkSearch retrieve failed: {e}') + logger.error(f"SirschmunkSearch retrieve failed: {e}") return [] async def query(self, query: str) -> str: @@ -273,34 +318,45 @@ async def query(self, query: str) -> str: max_loops = self.max_loops max_token_budget = self.max_token_budget - # Perform search and get answer + # Single search with context so we get both the synthesized answer and + # source units in one call, avoiding a redundant second search. result = await self._searcher.search( query=query, mode=mode, max_loops=max_loops, max_token_budget=max_token_budget, - return_context=False, + return_context=True, ) - # Result is already a synthesized answer string - if isinstance(result, str): - return result + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + + # Store parsed context for frontend display + self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) - # If we got SearchContext or other format, extract the answer + # Extract the synthesized answer from the context result if hasattr(result, 'answer'): return result.answer + # If result is already a plain string (some modes return str directly) + if isinstance(result, str): + return result + # Fallback: convert to string return str(result) except Exception as e: - logger.error(f'SirschmunkSearch query failed: {e}') - return f'Query failed: {e}' + logger.error(f"SirschmunkSearch query failed: {e}") + return f"Query failed: {e}" - def _parse_search_result(self, - result: Any, - score_threshold: float, - limit: int) -> List[Dict[str, Any]]: + def _parse_search_result( + self, result: Any, score_threshold: float, limit: int + ) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -329,28 +385,38 @@ def _parse_search_result(self, else: text_parts.append(str(snippet)) - results.append({ - 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') if hasattr(unit, 'abstraction_level') else 'text', + results.append( + { + 'text': '\n'.join(text_parts) + if text_parts + else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') + else 'text', + }, } - }) + ) # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append({ - 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), + results.append( + { + 'text': str(unit.content) + if hasattr(unit, 'content') + else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, } - }) + ) # Handle list format elif isinstance(result, list): @@ -358,21 +424,27 @@ def _parse_search_result(self, if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': item.get('content', item.get('text', str(item))), - 'score': score, - 'metadata': item.get('metadata', {}), - }) + results.append( + { + 'text': item.get( + 'content', item.get('text', str(item)) + ), + 'score': score, + 'metadata': item.get('metadata', {}), + } + ) # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - }) + results.append( + { + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + } + ) # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) @@ -398,4 +470,16 @@ def get_search_details(self) -> Dict[str, Any]: 'paths': self.search_paths, 'work_path': str(self.work_path), 'reuse_knowledge': self.reuse_knowledge, + 'cluster_cache_hit': self._cluster_cache_hit, + 'cluster_cache_hit_time': self._cluster_cache_hit_time, } + + def enable_streaming_logs(self, callback: Callable): + """Enable streaming mode for search logs. + + Args: + callback: Async callback function to receive log entries in real-time. + Signature: async def callback(log_entry: str) -> None + """ + self._streaming_callback = callback + self._search_logs.clear() From 24eb500b75b79941a77eb6172d3683df848f76e6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 20 Mar 2026 20:03:13 +0800 Subject: [PATCH 06/12] fix lint --- ms_agent/agent/llm_agent.py | 294 +++++++++--------- ms_agent/cli/run.py | 20 +- ms_agent/knowledge_search/sirchmunk_search.py | 181 +++++------ 3 files changed, 253 insertions(+), 242 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 76289a403..5f2ddf2e7 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,16 +2,15 @@ import asyncio import importlib import inspect -import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy -from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.knowledge_search import SirchmunkSearch @@ -25,6 +24,8 @@ from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig, OmegaConf + from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -98,8 +99,7 @@ def __init__( ): if not hasattr(config, 'llm'): default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml' - ) + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -115,8 +115,7 @@ def __init__( self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {}) - ) + kwargs.get('mcp_config', {})) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -166,32 +165,36 @@ def _ensure_auto_skills(self) -> bool: from ms_agent.utils.docker_utils import is_docker_daemon_running if not is_docker_daemon_running(): - logger.warning('Docker not running, disabling sandbox for skills') + logger.warning( + 'Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container( + skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', + None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, + 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) logger.info( - f"AutoSkills initialized with {len(self._auto_skills.all_skills)} skills" + f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' ) self._auto_skills_initialized = True return True except Exception as e: - logger.warning(f"Failed to initialize AutoSkills: {e}") + logger.warning(f'Failed to initialize AutoSkills: {e}') self._auto_skills_initialized = True return False @@ -233,7 +236,7 @@ async def should_use_skills(self, query: str) -> bool: needs_skills, _, _, _ = self._auto_skills._analyze_query(query) return needs_skills except Exception as e: - logger.error(f"Skill analysis error: {e}") + logger.error(f'Skill analysis error: {e}') return False async def get_skill_dag(self, query: str): @@ -266,8 +269,8 @@ async def execute_skills(self, query: str, execution_input=None): skills_config = self._get_skills_config() stop_on_failure = ( - getattr(skills_config, 'stop_on_failure', True) if skills_config else True - ) + getattr(skills_config, 'stop_on_failure', True) + if skills_config else True) result = await self._auto_skills.run( query=query, @@ -291,14 +294,15 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append(Message(role='assistant', content=dag_result.chat_response)) + messages.append( + Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills if not dag_result.is_complete: content = "I couldn't find suitable skills for this task." if dag_result.clarification: - content += f"\n\n{dag_result.clarification}" + content += f'\n\n{dag_result.clarification}' messages.append(Message(role='assistant', content=content)) return messages @@ -318,30 +322,30 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: stdout_preview = output.stdout[:1000] if len(output.stdout) > 1000: stdout_preview += '...' - content += f"**{skill_id} output:**\n{stdout_preview}\n\n" + content += f'**{skill_id} output:**\n{stdout_preview}\n\n' if output.output_files: - content += f"**Generated files:** {list(output.output_files.values())}\n\n" + content += f'**Generated files:** {list(output.output_files.values())}\n\n' content += ( - f"Total execution time: {exec_result.total_duration_ms:.2f}ms" + f'Total execution time: {exec_result.total_duration_ms:.2f}ms' ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): if not result.success: - content += f"**{skill_id} failed:** {result.error}\n" + content += f'**{skill_id} failed:** {result.error}\n' messages.append(Message(role='assistant', content=content)) else: # DAG only, no execution skill_names = list(dag_result.selected_skills.keys()) - content = f"Found {len(skill_names)} relevant skill(s) for your task:\n" + content = f'Found {len(skill_names)} relevant skill(s) for your task:\n' for skill_id, skill in dag_result.selected_skills.items(): desc_preview = skill.description[:100] if len(skill.description) > 100: desc_preview += '...' - content += f"- **{skill.name}** ({skill_id}): {desc_preview}\n" - content += f"\nExecution order: {dag_result.execution_order}" + content += f'- **{skill.name}** ({skill_id}): {desc_preview}\n' + content += f'\nExecution order: {dag_result.execution_order}' messages.append(Message(role='assistant', content=content)) @@ -367,7 +371,8 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile( + self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -396,11 +401,10 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: if handler_file is not None: local_dir = self.config.local_dir assert self.config.trust_remote_code, ( - f"[External Code]A Config Lifecycle handler " - f"registered in the config: {handler_file}. " - f"\nThis is external code, if you trust this workflow, " - f"please specify `--trust_remote_code true`" - ) + f'[External Code]A Config Lifecycle handler ' + f'registered in the config: {handler_file}. ' + f'\nThis is external code, if you trust this workflow, ' + f'please specify `--trust_remote_code true`') assert ( local_dir is not None ), 'Using external py files, but local_dir cannot be found.' @@ -410,18 +414,17 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: handler_module = importlib.import_module(handler_file) module_classes = { name: cls - for name, cls in inspect.getmembers(handler_module, inspect.isclass) + for name, cls in inspect.getmembers(handler_module, + inspect.isclass) } handler = None for name, handler_cls in module_classes.items(): - if ( - handler_cls.__bases__[0] is ConfigLifecycleHandler - and handler_cls.__module__ == handler_file - ): + if (handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file): handler = handler_cls() assert ( handler is not None - ), f"Config Lifecycle handler class cannot be found in {handler_file}" + ), f'Config Lifecycle handler class cannot be found in {handler_file}' return handler return None @@ -432,7 +435,8 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, + 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: @@ -460,22 +464,23 @@ def register_callback_from_config(self): module_classes = { name: cls for name, cls in inspect.getmembers( - callback_file, inspect.isclass - ) + callback_file, inspect.isclass) } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass(cls, Callback) and cls.__module__ == _callback: + if issubclass( + cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback](self.config)) + self.callbacks.append(callbacks_mapping[_callback]( + self.config)) async def on_task_begin(self, messages: List[Message]): - self.log_output(f"Agent {self.tag} task beginning.") + self.log_output(f'Agent {self.tag} task beginning.') await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): - self.log_output(f"Agent {self.tag} task finished.") + self.log_output(f'Agent {self.tag} task finished.') await self.loop_callback('on_task_end', messages) async def on_generate_response(self, messages: List[Message]): @@ -500,7 +505,8 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, + messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -511,12 +517,10 @@ async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: List[Message]: Updated message list including tool responses. """ tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls - ) + messages[-1].tool_calls) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip( - tool_call_result, messages[-1].tool_calls - ): + for tool_call_result, tool_call_query in zip(tool_call_result, + messages[-1].tool_calls): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', @@ -550,7 +554,8 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -561,7 +566,8 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -572,7 +578,8 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) def _write_reasoning(self, text: str): @@ -588,18 +595,19 @@ def _write_reasoning(self, text: str): @property def system(self): - return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr( + getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr( + getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query async def create_messages( - self, messages: Union[List[Message], str] - ) -> List[Message]: + self, messages: Union[List[Message], str]) -> List[Message]: """ Convert input into a standardized list of messages. @@ -611,19 +619,18 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if ( - system is not None - and messages[0].role == 'system' - and system != messages[0].content - ): + if (system is not None and messages[0].role == 'system' + and system != messages[0].content): # Replace the existing system messages[0].content = system else: assert isinstance( messages, str - ), f"inputs can be either a list or a string, but current is {type(messages)}" + ), f'inputs can be either a list or a string, but current is {type(messages)}' messages = [ - Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message( + role='system', + content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -662,11 +669,11 @@ async def do_rag(self, messages: List[Message]): # Append search context to user query context = search_result user_message.content = ( - f"Relevant context retrieved from codebase search:\n\n{context}\n\n" - f"User question: {query}" - ) + f'Relevant context retrieved from codebase search:\n\n{context}\n\n' + f'User question: {query}') - async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, + messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -683,9 +690,7 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: # Extract user query from normalized messages query = ( messages[1].content - if len(messages) > 1 and messages[1].role == 'user' - else None - ) + if len(messages) > 1 and messages[1].role == 'user' else None) if not query: return None @@ -700,8 +705,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: try: skills_config = self._get_skills_config() auto_execute = ( - getattr(skills_config, 'auto_execute', True) if skills_config else True - ) + getattr(skills_config, 'auto_execute', True) + if skills_config else True) if auto_execute: dag_result = await self.execute_skills(query) @@ -709,7 +714,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages(dag_result) + skill_messages = self._format_skill_result_as_messages( + dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -720,8 +726,7 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: except Exception as e: logger.warning( - f"Skill execution failed: {e}, falling back to standard agent" - ) + f'Skill execution failed: {e}, falling back to standard agent') self._skill_mode_active = False return None @@ -735,13 +740,11 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f"{mem_instance_type} not in memory_mapping, " - f"which supports: {list(memory_mapping.keys())}" - ) + f'{mem_instance_type} not in memory_mapping, ' + f'which supports: {list(memory_mapping.keys())}') shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type - ) + self.config, mem_instance_type) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -750,9 +753,8 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f"{rag.name} not in rag_mapping, " - f"which supports: {list(rag_mapping.keys())}" - ) + f'{rag.name} not in rag_mapping, ' + f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) async def prepare_knowledge_search(self): @@ -764,7 +766,8 @@ async def prepare_knowledge_search(self): if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) + self.knowledge_search: SirchmunkSearch = SirchmunkSearch( + self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: """ @@ -797,7 +800,7 @@ def log_output(self, content: Union[str, list]): text_parts.append(item.get('text', '')) elif item.get('type') == 'image_url': img_url = item.get('image_url', {}).get('url', '') - text_parts.append(f"[Image: {img_url[:50]}...]") + text_parts.append(f'[Image: {img_url[:50]}...]') content = ' '.join(text_parts) # Ensure content is a string @@ -808,9 +811,10 @@ def log_output(self, content: Union[str, list]): content = content[:512] + '\n...\n' + content[-512:] for line in content.split('\n'): for _line in line.split('\\n'): - logger.info(f"[{self.tag}] {_line}") + logger.info(f'[{self.tag}] {_line}') - def handle_new_response(self, messages: List[Message], response_message: Message): + def handle_new_response(self, messages: List[Message], + response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -818,23 +822,24 @@ def handle_new_response(self, messages: List[Message], response_message: Message tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads(tool_call['arguments']) + tool_call['arguments'] = json.loads( + tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output( + json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if ( - messages[-1].role == 'assistant' - and not messages[-1].content - and response_message.tool_calls - ): + if (messages[-1].role == 'assistant' and not messages[-1].content + and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step( + self, messages: List[Message] + ) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -870,7 +875,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A is_first = True _response_message = None _printed_reasoning_header = False - for _response_message in self.llm.generate(messages, tools=tools): + for _response_message in self.llm.generate( + messages, tools=tools): if is_first: messages.append(_response_message) is_first = False @@ -878,12 +884,12 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') or '' - ) + getattr(_response_message, 'reasoning_content', '') + or '') # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning) :] + new_reasoning = reasoning_text[len(_reasoning):] if new_reasoning: if not _printed_reasoning_header: self._write_reasoning('[thinking]:\n') @@ -891,7 +897,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A self._write_reasoning(new_reasoning) _reasoning = reasoning_text - new_content = _response_message.content[len(_content) :] + new_content = _response_message.content[len(_content):] sys.stdout.write(new_content) sys.stdout.flush() _content = _response_message.content @@ -904,8 +910,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') or '' - ) + getattr(_response_message, 'reasoning_content', '') + or '') if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -934,8 +940,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 cache_creation_input_tokens = ( - getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 - ) + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -945,21 +950,20 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # tokens in the current step self.log_output( - f"[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' ) if cached_tokens or cache_creation_input_tokens: self.log_output( - f"[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}" + f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' ) # total tokens for the process so far self.log_output( - f"[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, " - f"total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}" - ) + f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' + f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( - f"[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, " - f"total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}" + f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' + f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}' ) yield messages @@ -972,9 +976,8 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history( - self, messages: List[Message], **kwargs - ) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history(self, messages: List[Message], + **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1018,9 +1021,9 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: def _get_step_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step' - ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + memory_config, 'add_after_step') + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1031,7 +1034,8 @@ def _get_run_memory_info(self, memory_config: DictConfig): 'add_after_task', default_user_id=getattr(memory_config, 'user_id', None), ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1042,22 +1046,18 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, memory_config) in enumerate( - self.config.memory.items() - ): + for idx, (mem_instance_type, + memory_config) in enumerate(self.config.memory.items()): if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config - ) + memory_config) else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config - ) + memory_config) if idx < tools_num: - if any( - v is not None for v in [user_id, agent_id, run_id, memory_type] - ): + if any(v is not None + for v in [user_id, agent_id, run_id, memory_type]): await self.memory_tools[idx].add( messages, user_id=user_id, @@ -1086,11 +1086,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history(self.output_dir, task=self.tag, config=config, messages=messages) + save_history( + self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop( - self, messages: Union[List[Message], str], **kwargs - ) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], + **kwargs) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1103,9 +1103,8 @@ async def run_loop( List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr( - self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND - ) + self.max_chat_round = getattr(self.config, 'max_chat_round', + LLMAgent.DEFAULT_MAX_CHAT_ROUND) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1147,7 +1146,8 @@ async def run_loop( yield messages self.runtime.round += 1 # save memory and history - await self.add_memory(messages, add_type='add_after_step', **kwargs) + await self.add_memory( + messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1156,10 +1156,10 @@ async def run_loop( messages.append( Message( role='assistant', - content=f"Task {messages[1].content} was cutted off, because " - f"max round({self.max_chat_round}) exceeded.", - ) - ) + content= + f'Task {messages[1].content} was cutted off, because ' + f'max round({self.max_chat_round}) exceeded.', + )) self.runtime.should_stop = True yield messages @@ -1170,8 +1170,8 @@ async def run_loop( def _add_memory(): asyncio.run( - self.add_memory(messages, add_type='add_after_task', **kwargs) - ) + self.add_memory( + messages, add_type='add_after_task', **kwargs)) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) @@ -1181,22 +1181,22 @@ def _add_memory(): logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( - f"[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}" + f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True - ) + self.config, 'generation_config.stream', True, merge=True) async def stream_generator(): - async for _chunk in self.run_loop(messages=messages, **kwargs): + async for _chunk in self.run_loop( + messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index cfe387e5a..a07397c95 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,11 +4,10 @@ import os from importlib import resources as importlib_resources -from omegaconf import OmegaConf - from ms_agent.config import Config from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII +from omegaconf import OmegaConf from .base import CLICommand @@ -185,8 +184,10 @@ def _execute_with_config(self): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) else: # Use built-in default agent.yaml from package - default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) - with importlib_resources.as_file(default_config_path) as config_file: + default_config_path = importlib_resources.files( + 'ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file( + default_config_path) as config_file: self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download @@ -226,7 +227,10 @@ def _execute_with_config(self): # If knowledge_search_paths is provided, configure SirchmunkSearch if getattr(self.args, 'knowledge_search_paths', None): - paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + paths = [ + p.strip() for p in self.args.knowledge_search_paths.split(',') + if p.strip() + ] if paths: if 'knowledge_search' not in config or not config.knowledge_search: # No existing knowledge_search config, create minimal config @@ -237,11 +241,13 @@ def _execute_with_config(self): 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create(knowledge_search_config) + config['knowledge_search'] = OmegaConf.create( + knowledge_search_config) else: # Existing knowledge_search config found, only update paths # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container(config.knowledge_search, resolve=True) + existing = OmegaConf.to_container( + config.knowledge_search, resolve=True) existing['paths'] = paths config['knowledge_search'] = OmegaConf.create(existing) diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py index dd80738f9..e1c76181f 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -7,12 +7,12 @@ """ import asyncio -from loguru import logger -from omegaconf import DictConfig from pathlib import Path -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Union +from loguru import logger from ms_agent.rag.base import RAG +from omegaconf import DictConfig class SirchmunkSearch(RAG): @@ -59,7 +59,8 @@ def __init__(self, config: DictConfig): # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) - self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', + 0.85) self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) self.search_mode = rag_config.get('mode', 'FAST') self.max_loops = rag_config.get('max_loops', 10) @@ -72,26 +73,24 @@ def __init__(self, config: DictConfig): self.llm_model_name = rag_config.get('llm_model_name', None) # Fall back to main llm config if not specified in knowledge_search - if ( - self.llm_api_key is None - or self.llm_base_url is None - or self.llm_model_name is None - ): + if (self.llm_api_key is None or self.llm_base_url is None + or self.llm_model_name is None): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f"{service}_api_key", None) + self.llm_api_key = getattr(llm_config, + f'{service}_api_key', None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f"{service}_base_url", None) + self.llm_base_url = getattr(llm_config, + f'{service}_base_url', None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( - 'embedding_model_cache_dir', None - ) + 'embedding_model_cache_dir', None) # Runtime state self._searcher = None @@ -109,7 +108,8 @@ def __init__(self, config: DictConfig): def _validate_config(self, config: DictConfig): """Validate configuration parameters.""" - if not hasattr(config, 'knowledge_search') or config.knowledge_search is None: + if not hasattr(config, + 'knowledge_search') or config.knowledge_search is None: raise ValueError( 'Missing knowledge_search configuration. ' 'Please add knowledge_search section to your config with at least "paths" specified.' @@ -118,7 +118,8 @@ def _validate_config(self, config: DictConfig): rag_config = config.knowledge_search paths = rag_config.get('paths', []) if not paths: - raise ValueError('knowledge_search.paths must be specified and non-empty') + raise ValueError( + 'knowledge_search.paths must be specified and non-empty') def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -142,16 +143,12 @@ def _initialize_searcher(self): # Create embedding util # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( - self.embedding_model_id if self.embedding_model_id else None - ) + self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( self.embedding_model_cache_dir - if self.embedding_model_cache_dir - else None - ) + if self.embedding_model_cache_dir else None) embedding = EmbeddingUtil( - model_id=embedding_model_id, cache_dir=embedding_cache_dir - ) + model_id=embedding_model_id, cache_dir=embedding_cache_dir) # Create AgenticSearch instance self._searcher = AgenticSearch( @@ -167,15 +164,16 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f"SirschmunkSearch initialized with paths: {self.search_paths}") + logger.info( + f'SirschmunkSearch initialized with paths: {self.search_paths}' + ) except ImportError as e: raise ImportError( - f"Failed to import sirchmunk: {e}. " - 'Please install sirchmunk: pip install sirchmunk' - ) + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk') except Exception as e: - raise RuntimeError(f"Failed to initialize SirchmunkSearch: {e}") + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') def _log_callback_wrapper(self): """Create a callback wrapper to capture search logs. @@ -191,7 +189,7 @@ def log_callback( end: str = '\n', flush: bool = False, ): - log_entry = f"[{level.upper()}] {message}" + log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) # Stream log in real-time if streaming callback is set if self._streaming_callback: @@ -223,7 +221,7 @@ async def add_documents(self, documents: List[str]) -> bool: await self._searcher.knowledge_base.refresh() return True except Exception as e: - logger.error(f"Failed to refresh knowledge base: {e}") + logger.error(f'Failed to refresh knowledge base: {e}') return False return True @@ -242,16 +240,19 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: try: for file_path in file_paths: if Path(file_path).exists(): - await self._searcher.scan_directory(str(Path(file_path).parent)) + await self._searcher.scan_directory( + str(Path(file_path).parent)) return True except Exception as e: - logger.error(f"Failed to scan files: {e}") + logger.error(f'Failed to scan files: {e}') return False return True - async def retrieve( - self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters - ) -> List[Dict[str, Any]]: + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -270,7 +271,8 @@ async def retrieve( try: mode = filters.get('mode', self.search_mode) max_loops = filters.get('max_loops', self.max_loops) - max_token_budget = filters.get('max_token_budget', self.max_token_budget) + max_token_budget = filters.get('max_token_budget', + self.max_token_budget) # Perform search result = await self._searcher.search( @@ -286,16 +288,18 @@ async def retrieve( self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: # If a similar cluster was found and reused, it's a cache hit - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) # Get the cluster cache hit time if available if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: - logger.error(f"SirschmunkSearch retrieve failed: {e}") + logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] async def query(self, query: str) -> str: @@ -332,12 +336,15 @@ async def query(self, query: str) -> str: self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) # Store parsed context for frontend display - self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) + self._last_search_result = self._parse_search_result( + result, score_threshold=0.7, limit=5) # Extract the synthesized answer from the context result if hasattr(result, 'answer'): @@ -351,12 +358,11 @@ async def query(self, query: str) -> str: return str(result) except Exception as e: - logger.error(f"SirschmunkSearch query failed: {e}") - return f"Query failed: {e}" + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' - def _parse_search_result( - self, result: Any, score_threshold: float, limit: int - ) -> List[Dict[str, Any]]: + def _parse_search_result(self, result: Any, score_threshold: float, + limit: int) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -385,38 +391,37 @@ def _parse_search_result( else: text_parts.append(str(snippet)) - results.append( - { - 'text': '\n'.join(text_parts) - if text_parts - else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') - if hasattr(unit, 'abstraction_level') - else 'text', - }, - } - ) + results.append({ + 'text': + '\n'.join(text_parts) if text_parts else getattr( + unit, 'summary', ''), + 'score': + score, + 'metadata': { + 'source': + source, + 'type': + getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') else 'text', + }, + }) # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append( - { - 'text': str(unit.content) - if hasattr(unit, 'content') - else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), - }, - } - ) + results.append({ + 'text': + str(unit.content) + if hasattr(unit, 'content') else str(unit), + 'score': + score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, + }) # Handle list format elif isinstance(result, list): @@ -424,27 +429,27 @@ def _parse_search_result( if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': item.get( - 'content', item.get('text', str(item)) - ), - 'score': score, - 'metadata': item.get('metadata', {}), - } - ) + results.append({ + 'text': + item.get('content', item.get('text', str(item))), + 'score': + score, + 'metadata': + item.get('metadata', {}), + }) # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - } - ) + results.append({ + 'text': + result.get('content', result.get('text', str(result))), + 'score': + score, + 'metadata': + result.get('metadata', {}), + }) # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) From fec8bfe4f305a08e8b540b9715e543895c083764 Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 24 Mar 2026 11:36:46 +0800 Subject: [PATCH 07/12] mv localsearch to tools --- docs/en/Components/Config.md | 16 +- docs/zh/Components/config.md | 14 +- examples/knowledge_search/agent.yaml.example | 86 ------ ms_agent/agent/llm_agent.py | 39 +-- ms_agent/cli/run.py | 25 +- ms_agent/knowledge_search/README.md | 277 ----------------- ms_agent/knowledge_search/__init__.py | 9 +- ms_agent/llm/utils.py | 30 +- ms_agent/rag/utils.py | 4 +- ms_agent/tools/search/localsearch_tool.py | 282 ++++++++++++++++++ .../search}/sirchmunk_search.py | 253 ++++++++++------ ms_agent/tools/tool_manager.py | 4 + tests/knowledge_search/test_sirschmunk.py | 223 +++++--------- 13 files changed, 586 insertions(+), 676 deletions(-) delete mode 100644 examples/knowledge_search/agent.yaml.example delete mode 100644 ms_agent/knowledge_search/README.md create mode 100644 ms_agent/tools/search/localsearch_tool.py rename ms_agent/{knowledge_search => tools/search}/sirchmunk_search.py (68%) diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index f1253bd75..2154add8e 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -102,6 +102,14 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # Local codebase / document search (sirchmunk), exposed as the `localsearch` tool + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # Optional: llm_api_key, llm_base_url, llm_model_name (else inherited from `llm`) ``` For the complete list of supported tools and custom tools, please refer to [here](./Tools.md) @@ -167,19 +175,19 @@ In addition to yaml configuration, MS-Agent also supports several additional com > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. -- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. +- knowledge_search_paths: Comma-separated local search paths. Merges into `tools.localsearch.paths` and registers the **`localsearch`** tool (sirchmunk) for on-demand use by the model—not automatic per-turn injection. LLM settings are inherited from the `llm` module unless you set `tools.localsearch.llm_*` fields. ### Quick Start for Knowledge Search -Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: +Use `--knowledge_search_paths` or define `tools.localsearch` in yaml so the model can call `localsearch` when needed: ```bash # Using default agent.yaml configuration, automatically reuses LLM settings -ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "/path/to/docs" # Specify configuration file ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" ``` LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. -If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. +For a dedicated sirchmunk LLM, set `tools.localsearch.llm_api_key`, `llm_base_url`, and `llm_model_name` in yaml. Legacy top-level `knowledge_search` with the same keys is still read for backward compatibility. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 12849f2a7..0baa4bf18 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -102,6 +102,14 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # 本地代码库/文档搜索(sirchmunk),对应模型可调用的 `localsearch` 工具 + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # 可选:llm_api_key、llm_base_url、llm_model_name(不填则从 `llm` 继承) ``` 支持的完整工具列表,以及自定义工具请参考 [这里](./tools) @@ -165,13 +173,13 @@ handler: custom_handler } } ``` -- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 +- knowledge_search_paths: 知识搜索路径,逗号分隔。会合并到 `tools.localsearch.paths` 并注册 **`localsearch`** 工具(sirchmunk),由模型按需调用,不再在每轮自动注入上下文;除非配置 `tools.localsearch.llm_*`,否则 LLM 从 `llm` 模块复用 > agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 ### 知识搜索快速使用 -通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: +通过 `--knowledge_search_paths` 或在 yaml 中配置 `tools.localsearch`,启用本地知识搜索(模型按需调用 `localsearch`): ```bash # 使用默认 agent.yaml 配置,自动复用 LLM 设置 @@ -182,4 +190,4 @@ ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_sea ``` LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 -如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 +若 sirchmunk 需独立 LLM,可在 yaml 的 `tools.localsearch` 下设置 `llm_api_key`、`llm_base_url`、`llm_model_name`。仍支持旧版顶层 `knowledge_search` 相同字段,以便迁移。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example deleted file mode 100644 index cc11a8a3d..000000000 --- a/examples/knowledge_search/agent.yaml.example +++ /dev/null @@ -1,86 +0,0 @@ -# Sirchmunk Knowledge Search 配置示例 -# Sirchmunk Knowledge Search Configuration Example - -# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: - -llm: - service: modelscope - model: Qwen/Qwen3-235B-A22B-Instruct-2507 - modelscope_api_key: - modelscope_base_url: https://api-inference.modelscope.cn/v1 - -generation_config: - temperature: 0.3 - top_k: 20 - stream: true - -# Knowledge Search 配置(可选) -# 用于在本地代码库中搜索相关信息 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录,用于缓存 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则使用上面 llm 的配置) - llm_api_key: - llm_base_url: https://api.openai.com/v1 - llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:聚类相似度阈值 - cluster_sim_threshold: 0.85 - - # 可选:聚类 TopK - cluster_sim_top_k: 3 - - # 可选:是否重用之前的知识 - reuse_knowledge: true - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:最大循环次数 - max_loops: 10 - - # 可选:最大 token 预算 - max_token_budget: 128000 - -prompt: - system: | - You are an assistant that helps me complete tasks. - -max_chat_round: 9999 - -# 使用说明: -# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 -# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 -# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM -# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 -# -# Python 使用示例: -# ```python -# from ms_agent import LLMAgent -# from ms_agent.config import Config -# -# config = Config.from_task('path/to/agent.yaml') -# agent = LLMAgent(config=config) -# result = await agent.run('如何实现用户认证功能?') -# -# # 获取搜索详情(用于前端展示) -# for msg in result: -# if msg.role == 'user': -# print(f"Search logs: {msg.searching_detail}") -# print(f"Search results: {msg.search_result}") -# ``` -# -# CLI 测试命令: -# export LLM_API_KEY="your-api-key" -# export LLM_BASE_URL="https://api.openai.com/v1" -# export LLM_MODEL_NAME="gpt-4o-mini" -# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 5f2ddf2e7..36e46a672 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -13,7 +13,6 @@ import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping @@ -107,7 +106,6 @@ def __init__( self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None - self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -528,6 +526,7 @@ async def parallel_tool_call(self, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], resources=tool_call_result_format.resources, + tool_detail=tool_call_result_format.tool_detail, ) if _new_message.tool_call_id is None: @@ -636,11 +635,7 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): - """Process RAG or knowledge search to enrich the user query with context. - - This method handles both traditional RAG and sirchmunk-based knowledge search. - For knowledge search, it also populates searching_detail and search_result - fields in the message for frontend display and next-turn LLM context. + """Process RAG to enrich the user query with context. Args: messages (List[Message]): The message list to process. @@ -654,23 +649,6 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search - if self.knowledge_search is not None: - # Perform search and get results - search_result = await self.knowledge_search.query(query) - search_details = self.knowledge_search.get_search_details() - - # Store search details in the message for frontend display - user_message.searching_detail = search_details - user_message.search_result = search_result - - # Build enriched context from search results - if search_result: - # Append search context to user query - context = search_result - user_message.content = ( - f'Relevant context retrieved from codebase search:\n\n{context}\n\n' - f'User question: {query}') async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -757,18 +735,6 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) - async def prepare_knowledge_search(self): - """Load and initialize the knowledge search component from the config.""" - if self.knowledge_search is not None: - # Already initialized (e.g. by caller before run_loop), skip to avoid - # overwriting a configured instance (e.g. one with streaming callbacks set). - return - if hasattr(self.config, 'knowledge_search'): - ks_config = self.config.knowledge_search - if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch( - self.config) - async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -1111,7 +1077,6 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() - await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 84c93a97e..2accdb40e 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -153,7 +153,7 @@ def define_args(parsers: argparse.ArgumentParser): type=str, default=None, help= - 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' + 'Comma-separated list of paths for knowledge search.' ) parser.set_defaults(func=subparser_func) @@ -233,31 +233,28 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) - # If knowledge_search_paths is provided, configure SirchmunkSearch + # If knowledge_search_paths is provided, configure tools.localsearch if getattr(self.args, 'knowledge_search_paths', None): paths = [ p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip() ] if paths: - if 'knowledge_search' not in config or not config.knowledge_search: - # No existing knowledge_search config, create minimal config - # LLM settings will be auto-reused from llm module by SirchmunkSearch - knowledge_search_config = { - 'name': 'SirchmunkSearch', + if not hasattr(config, 'tools') or config.tools is None: + config['tools'] = OmegaConf.create({}) + tl = getattr(config.tools, 'localsearch', None) + if tl is None or not OmegaConf.is_config(tl): + localsearch_config = { 'paths': paths, 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create( - knowledge_search_config) + config.tools['localsearch'] = OmegaConf.create( + localsearch_config) else: - # Existing knowledge_search config found, only update paths - # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container( - config.knowledge_search, resolve=True) + existing = OmegaConf.to_container(tl, resolve=True) existing['paths'] = paths - config['knowledge_search'] = OmegaConf.create(existing) + config.tools['localsearch'] = OmegaConf.create(existing) if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md deleted file mode 100644 index ef86df0da..000000000 --- a/ms_agent/knowledge_search/README.md +++ /dev/null @@ -1,277 +0,0 @@ -# Sirchmunk Knowledge Search 集成 - -本模块实现了 [sirchmunk](https://github.com/modelscope/sirchmunk) 与 ms_agent 框架的集成,提供了基于代码库的智能搜索功能。 - -## 功能特性 - -- **智能代码搜索**: 使用 LLM 和 embedding 模型对代码库进行语义搜索 -- **多模式搜索**: 支持 FAST、DEEP、FILENAME_ONLY 三种搜索模式 -- **知识复用**: 自动缓存和复用之前的搜索结果,减少 LLM 调用 -- **前端友好**: 提供详细的搜索日志和结果,方便前端展示 -- **无缝集成**: 与 LLMAgent 无缝集成,像使用 RAG 一样简单 - -## 安装 - -```bash -pip install sirchmunk -``` - -## 配置 - -在您的 `agent.yaml` 或 `workflow.yaml` 中添加以下配置: - -```yaml -llm: - service: dashscope - model: qwen3.5-plus - dashscope_api_key: - dashscope_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 - -generation_config: - temperature: 0.3 - stream: true - -# Knowledge Search 配置 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则自动复用上面 llm 模块的配置) - # llm_api_key: - # llm_base_url: https://api.openai.com/v1 - # llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:是否重用之前的知识 - reuse_knowledge: true -``` - -**LLM 配置自动复用机制**: - -`SirchmunkSearch` 会自动从主配置的 `llm` 模块复用 LLM 相关参数: -- 如果 `knowledge_search.llm_api_key` 未配置,自动使用 `llm.{service}_api_key` -- 如果 `knowledge_search.llm_base_url` 未配置,自动使用 `llm.{service}_base_url` -- 如果 `knowledge_search.llm_model_name` 未配置,自动使用 `llm.model` - -其中 `service` 是 `llm.service` 的值(如 `dashscope`, `modelscope`, `openai` 等)。 - -通过 CLI 使用时,只需传入 `--knowledge_search_paths` 参数,无需额外配置 LLM 参数。 - -## 使用方式 - -### 1. 通过 CLI 使用(推荐) - -从命令行直接运行,无需编写代码: - -```bash -# 基本用法 - LLM 配置自动从 agent.yaml 的 llm 模块复用 -ms-agent run --query "如何实现用户认证功能?" --knowledge_search_paths "./src,./docs" - -# 指定配置文件 -ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" -``` - -**说明**: -- `--knowledge_search_paths` 参数支持逗号分隔的多个路径 -- LLM 相关配置(api_key, base_url, model)会自动从配置文件的 `llm` 模块复用 -- 如果 `knowledge_search` 模块单独配置了 `llm_api_key` 等参数,则优先使用模块自己的配置 - -### 2. 通过 LLMAgent 使用 - -```python -from ms_agent import LLMAgent -from ms_agent.config import Config - -# 从配置文件加载 -config = Config.from_task('path/to/agent.yaml') -agent = LLMAgent(config=config) - -# 运行查询 - 会自动触发知识搜索 -result = await agent.run('如何实现用户认证功能?') - -# 获取搜索结果 -for msg in result: - if msg.role == 'user': - # 搜索详情(用于前端展示) - print(f"Search logs: {msg.searching_detail}") - # 搜索结果(作为 LLM 上下文) - print(f"Search results: {msg.search_result}") -``` - -### 3. 单独使用 SirchmunkSearch - -```python -from ms_agent.knowledge_search import SirchmunkSearch -from omegaconf import DictConfig - -config = DictConfig({ - 'knowledge_search': { - 'paths': ['./src', './docs'], - 'work_path': './.sirchmunk', - 'llm_api_key': 'your-api-key', - 'llm_model_name': 'gpt-4o-mini', - 'mode': 'FAST', - } -}) - -searcher = SirchmunkSearch(config) - -# 查询(返回合成答案) -answer = await searcher.query('如何实现用户认证?') - -# 检索(返回原始搜索结果) -results = await searcher.retrieve( - query='用户认证', - limit=5, - score_threshold=0.7 -) - -# 获取搜索日志 -logs = searcher.get_search_logs() - -# 获取搜索详情 -details = searcher.get_search_details() -``` - -## 环境变量 - -可以通过环境变量配置: - -```bash -# LLM 配置(如不设置则自动从 agent.yaml 的 llm 模块读取) -export LLM_API_KEY="your-api-key" -export LLM_BASE_URL="https://api.openai.com/v1" -export LLM_MODEL_NAME="gpt-4o-mini" - -# Embedding 模型配置 -export EMBEDDING_MODEL_ID="text-embedding-3-small" -export SIRCHMUNK_WORK_PATH="./.sirchmunk" -``` - -**注意**:通过 CLI 使用时,推荐直接在 `.env` 文件或 agent.yaml 中配置 LLM 参数,`SirchmunkSearch` 会自动复用。 - -## 测试 - -### 单元测试 - -```bash -export LLM_API_KEY="your-api-key" -export LLM_BASE_URL="https://api.openai.com/v1" -export LLM_MODEL_NAME="gpt-4o-mini" - -python -m unittest tests/knowledge_search/test_sirschmunk.py -``` - -### CLI 测试 - -```bash -# 基本测试 -python tests/knowledge_search/test_cli.py - -# 指定查询 -python tests/knowledge_search/test_cli.py -q "如何实现用户认证?" - -# 仅测试 standalone 模式 -python tests/knowledge_search/test_cli.py -m standalone - -# 仅测试 agent 模式 -python tests/knowledge_search/test_cli.py -m agent -``` - -## 配置参数说明 - -| 参数 | 类型 | 默认值 | 说明 | -|------|------|--------|------| -| paths | List[str] | 必选 | 要搜索的目录/文件路径列表 | -| work_path | str | ./.sirchmunk | sirchmunk 工作目录,用于缓存 | -| llm_api_key | str | 从 llm 配置继承 | LLM API 密钥 | -| llm_base_url | str | 从 llm 配置继承 | LLM API 基础 URL | -| llm_model_name | str | 从 llm 配置继承 | LLM 模型名称 | -| embedding_model | str | text-embedding-3-small | Embedding 模型 ID | -| cluster_sim_threshold | float | 0.85 | 聚类相似度阈值 | -| cluster_sim_top_k | int | 3 | 聚类 TopK 数量 | -| reuse_knowledge | bool | true | 是否重用之前的知识 | -| mode | str | FAST | 搜索模式 (DEEP/FAST/FILENAME_ONLY) | -| max_loops | int | 10 | 最大搜索循环次数 | -| max_token_budget | int | 128000 | 最大 token 预算 | - -## 搜索模式 - -- **FAST**: 快速模式,使用贪婪策略,1-5 秒内返回结果,0-2 次 LLM 调用 -- **DEEP**: 深度模式,并行多路径检索 + ReAct 优化,5-30 秒,4-6 次 LLM 调用 -- **FILENAME_ONLY**: 仅文件名模式,基于模式匹配,无 LLM 调用,非常快 - -## Message 字段扩展 - -为了支持知识搜索,`Message` 类增加了两个字段: - -- **searching_detail** (Dict[str, Any]): 搜索过程日志和元数据,用于前端展示 - - `logs`: 搜索日志列表 - - `mode`: 使用的搜索模式 - - `paths`: 搜索的路径 - - `work_path`: 工作目录 - - `reuse_knowledge`: 是否重用知识 - -- **search_result** (List[Dict[str, Any]]): 搜索结果,作为下一轮 LLM 的上下文 - - `text`: 文档内容 - - `score`: 相关性分数 - - `metadata`: 元数据(如源文件、类型等) - -## 工作原理 - -1. 用户发送查询 -2. LLMAgent 调用 `prepare_knowledge_search()` 初始化 SirchmunkSearch -3. `do_rag()` 方法执行知识搜索: - - 调用 `searcher.retrieve()` 获取相关文档 - - 将搜索结果存入 `message.search_result` - - 将搜索日志存入 `message.searching_detail` - - 将搜索结果格式化为上下文,附加到用户查询 -4. LLM 接收 enriched query 并生成回答 -5. 前端可以通过 `searching_detail` 展示搜索过程 - -## 故障排除 - -### 常见问题 - -1. **ImportError: No module named 'sirchmunk'** - ```bash - pip install sirchmunk - ``` - -2. **搜索结果为空** - - 检查 `paths` 配置是否正确 - - 确保路径下有可搜索的文件 - - 尝试降低 `cluster_sim_threshold` 值 - -3. **LLM API 调用失败** - - 检查 API key 是否正确 - - 检查 base URL 是否正确 - - 查看搜索日志了解详细错误 - -### 日志查看 - -```python -# 查看搜索日志 -logs = searcher.get_search_logs() -for log in logs: - print(log) - -# 或在配置中启用 verbose -knowledge_search: - verbose: true -``` - -## 参考资源 - -- [sirchmunk GitHub](https://github.com/modelscope/sirchmunk) -- [ModelScope Agent](https://github.com/modelscope/modelscope-agent) diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py index 33362beee..8746f27c9 100644 --- a/ms_agent/knowledge_search/__init__.py +++ b/ms_agent/knowledge_search/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Knowledge search module based on sirchmunk. +"""Backward-compatible re-exports for sirchmunk local search. -This module provides integration between sirchmunk's AgenticSearch -and the ms_agent framework, enabling intelligent codebase search -capabilities similar to RAG. +Implementation lives in :mod:`ms_agent.tools.search.sirchmunk_search`; prefer +importing ``SirchmunkSearch`` from there in new code. """ -from .sirchmunk_search import SirchmunkSearch +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch __all__ = ['SirchmunkSearch'] diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 410aa12f0..c108170cb 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,11 +61,8 @@ class Message: api_calls: int = 1 - # Knowledge search (sirchmunk) related fields - # searching_detail: Search process logs and metadata for frontend display - searching_detail: Dict[str, Any] = field(default_factory=dict) - # search_result: Raw search results to be used as context for next LLM turn - search_result: List[Dict[str, Any]] = field(default_factory=list) + # role=tool: extra payload for UIs / SSE only; omitted from LLM API via to_dict_clean(). + tool_detail: Optional[str] = None def to_dict(self): return asdict(self) @@ -88,7 +85,11 @@ def to_dict_clean(self): } } required = ['content', 'role'] - rm = ['completion_tokens', 'prompt_tokens', 'api_calls'] + # Never send UI-only fields to model providers. + rm = [ + 'completion_tokens', 'prompt_tokens', 'api_calls', 'tool_detail', + 'searching_detail', 'search_result' + ] return { key: value for key, value in raw_dict.items() @@ -98,20 +99,33 @@ def to_dict_clean(self): @dataclass class ToolResult: + """Tool execution outcome. + + ``text`` is sent to the model as the tool message ``content``. + ``tool_detail`` is optional verbose output for frontends only (SSE, logs). + """ + text: str resources: List[str] = field(default_factory=list) extra: dict = field(default_factory=dict) + tool_detail: Optional[str] = None @staticmethod def from_raw(raw): if isinstance(raw, str): return ToolResult(text=raw) if isinstance(raw, dict): + model_text = raw.get('result') + if model_text is None: + model_text = raw.get('text', '') + td = raw.get('tool_detail') return ToolResult( - text=str(raw.get('text', '')), + text=str(model_text), resources=raw.get('resources', []), + tool_detail=None if td is None else str(td), extra={ k: v - for k, v in raw.items() if k not in ['text', 'resources'] + for k, v in raw.items() + if k not in ['text', 'resources', 'result', 'tool_detail'] }) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index e66da954d..cf52aaa7c 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -5,5 +5,5 @@ 'LlamaIndexRAG': LlamaIndexRAG, } -# Note: SirchmunkSearch is registered in knowledge_search module -# and integrated directly in LLMAgent, not through rag_mapping +# Note: Sirchmunk local search is the ``localsearch`` tool +# (ms_agent.tools.search); it is not wired through rag_mapping. diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py new file mode 100644 index 000000000..d8ab34ef7 --- /dev/null +++ b/ms_agent/tools/search/localsearch_tool.py @@ -0,0 +1,282 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""On-demand local codebase search via sirchmunk (replaces pre-turn RAG injection).""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, + effective_localsearch_settings, +) +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SERVER = 'localsearch' +_TOOL = 'localsearch' + +# Tool-facing description: aligned with sirchmunk AgenticSearch.search() capabilities. +_LOCALSEARCH_DESCRIPTION = """Search local files, codebases, and documents on disk. + +USE THIS TOOL WHEN: +- The user asks about content in local files or directories +- You need to find information in source code, config files, or documents +- The query references a local path, project structure, or codebase +- You need to search PDF, DOCX, XLSX, PPTX, CSV, JSON, YAML, Markdown, etc. + +DO NOT USE THIS TOOL WHEN: +- The user is asking a general knowledge question +- The user is greeting you or making casual conversation (e.g., "你好", "hello") +- You need information from the internet or recent events +- The query has no relation to local files or code + +Returns: +Search results after summarizing as formatted text with file paths, code snippets, and explanations where +available. Retrieved excerpts and meta are included in the tool output. + +Configured search roots for this agent (absolute paths; default search scope when `paths` is omitted): +{configured_roots} +""" + + +def _resolved_localsearch_paths_from_config(config) -> List[str]: + """Match ``SirchmunkSearch`` path resolution for consistent tool text and checks.""" + block = effective_localsearch_settings(config) + if not block: + return [] + paths = block.get('paths', []) + if isinstance(paths, str): + paths = [paths] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(Path(str(p).strip()).expanduser().resolve())) + return out + + +def _format_configured_roots(paths: List[str]) -> str: + if not paths: + return ( + '(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') + return '\n'.join(f'- {p}' for p in paths) + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2) + + +def _as_str_list(value: Any, name: str) -> Optional[List[str]]: + if value is None: + return None + if isinstance(value, str): + return [value] if value.strip() else None + if isinstance(value, list): + out = [str(x).strip() for x in value if str(x).strip()] + return out or None + raise TypeError(f'{name} must be a string or list of strings') + + +class LocalSearchTool(ToolBase): + """Expose sirchmunk as a callable tool when ``tools.localsearch`` is configured.""" + + def __init__(self, config, **kwargs): + super().__init__(config) + tools_root = getattr(config, 'tools', None) + tool_cfg = getattr(tools_root, 'localsearch', None) if tools_root else None + if tool_cfg is not None: + self.exclude_func(tool_cfg) + self._searcher: Optional[SirchmunkSearch] = None + self._configured_roots: List[str] = ( + _resolved_localsearch_paths_from_config(config)) + + def _tool_description(self) -> str: + return _LOCALSEARCH_DESCRIPTION.format( + configured_roots=_format_configured_roots(self._configured_roots)) + + def _paths_param_description(self) -> str: + roots = _format_configured_roots(self._configured_roots) + return ( + 'Optional. Narrow search to specific files or directories under the ' + 'configured roots below. Each path must exist on disk and lie under ' + 'one of these roots (or be exactly one of them).\n' + f'Configured roots:\n{roots}') + + def _ensure_searcher(self) -> SirchmunkSearch: + if self._searcher is None: + self._searcher = SirchmunkSearch(self.config) + return self._searcher + + async def connect(self) -> None: + return None + + async def _get_tools_inner(self) -> Dict[str, List[Tool]]: + return { + _SERVER: [ + Tool( + tool_name=_TOOL, + server_name=_SERVER, + description=self._tool_description(), + parameters={ + 'type': + 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'Search keywords or natural-language question about local content.', + }, + 'paths': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + self._paths_param_description(), + }, + 'mode': { + 'type': + 'string', + 'enum': ['FAST', 'DEEP', 'FILENAME_ONLY'], + 'description': + 'Search mode; omit to use agent default (usually FAST).', + }, + 'max_depth': { + 'type': 'integer', + 'minimum': 1, + 'maximum': 20, + 'description': + 'Max directory depth for filesystem search.', + }, + 'top_k_files': { + 'type': 'integer', + 'minimum': 1, + 'maximum': 20, + 'description': + 'Max files for evidence / filename hits.', + }, + 'include': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to include (e.g. *.py, *.md).', + }, + 'exclude': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to exclude (e.g. *.pyc).', + }, + }, + 'required': ['query'], + }, + ) + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict): + del server_name + if tool_name != _TOOL: + return f'Unknown tool: {tool_name}' + + args = tool_args or {} + query = str(args.get('query', '')).strip() + if not query: + return 'Error: `query` is required and cannot be empty.' + + try: + paths_arg = _as_str_list(args.get('paths'), 'paths') + mode = args.get('mode') + if mode is not None: + mode = str(mode).strip().upper() or None + + max_depth = args.get('max_depth') + if max_depth is not None: + max_depth = int(max_depth) + max_depth = max(1, min(20, max_depth)) + + top_k = args.get('top_k_files') + if top_k is not None: + top_k = int(top_k) + top_k = max(1, min(20, top_k)) + + include = _as_str_list(args.get('include'), 'include') + exclude = _as_str_list(args.get('exclude'), 'exclude') + + searcher = self._ensure_searcher() + resolved_paths = None + if paths_arg: + resolved_paths = searcher.resolve_tool_paths(paths_arg) + if not resolved_paths: + roots = _format_configured_roots( + self._configured_roots) + return ( + 'Error: `paths` are invalid. Each path must exist on disk and lie ' + 'under one of these configured roots:\n' + roots) + + answer = await searcher.query( + query, + paths=resolved_paths, + mode=mode, + max_depth=max_depth, + top_k_files=top_k, + include=include, + exclude=exclude, + ) + details = searcher.get_search_details() + excerpts = searcher.get_last_retrieved_chunks() + + lines = ['## Local search (sirchmunk)', '', str(answer), ''] + + if excerpts: + lines.append('## Retrieved excerpts') + lines.append('') + for i, item in enumerate(excerpts[:12], 1): + meta = item.get('metadata') or {} + src = meta.get('source', '?') + text = (item.get('text') or '')[:4000] + lines.append(f'### [{i}] {src}') + lines.append(text) + lines.append('') + + summary = { + 'mode': details.get('mode'), + 'paths': details.get('paths'), + 'work_path': details.get('work_path'), + 'cluster_cache_hit': details.get('cluster_cache_hit'), + } + lines.append('## Meta') + lines.append(_json_dumps(summary)) + + full_text = '\n'.join(lines) + # Model sees answer + source paths only; UI gets full excerpts + meta. + result_parts = [str(answer).strip()] + if excerpts: + result_parts.append('\nSource paths:') + for item in excerpts[:12]: + meta = item.get('metadata') or {} + result_parts.append( + f'- {meta.get("source", "?")}') + result_text = '\n'.join(result_parts) + + return { + 'result': result_text, + 'tool_detail': full_text, + } + except (TypeError, ValueError) as exc: + return f'Invalid tool arguments: {exc}' + except Exception as exc: + logger.warning(f'localsearch failed: {exc}') + return f'Local search failed: {exc}' + diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py similarity index 68% rename from ms_agent/knowledge_search/sirchmunk_search.py rename to ms_agent/tools/search/sirchmunk_search.py index e1c76181f..cd86f9d63 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -1,51 +1,83 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Sirchmunk-based knowledge search integration. +"""Sirchmunk backend for the ``localsearch`` tool. -This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, -providing document retrieval capabilities similar to RAG but optimized for -codebase and documentation search. +Configuration lives under ``tools.localsearch`` (same namespace as other tools). +Legacy top-level ``knowledge_search`` is still accepted for backward compatibility. """ import asyncio +import json from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional from loguru import logger -from ms_agent.rag.base import RAG from omegaconf import DictConfig -class SirchmunkSearch(RAG): - """Sirchmunk-based knowledge search class. +def _paths_from_block(block: Any) -> List[str]: + if block is None: + return [] + paths = block.get('paths', []) if hasattr(block, 'get') else [] + if isinstance(paths, str): + paths = [paths] if str(paths).strip() else [] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(p).strip()) + return out - This class wraps the sirchmunk library to provide intelligent codebase search - capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk - uses a combination of keyword search, semantic clustering, and LLM-powered - analysis to find relevant information from codebases. - The configuration needed in the config yaml: - - name: SirchmunkSearch - - paths: List of paths to search, required - - work_path: Working directory for sirchmunk cache, default './.sirchmunk' - - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' - - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 - - cluster_sim_top_k: Top K clusters to consider, default 3 - - reuse_knowledge: Whether to reuse previous search results, default True - - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' +def effective_localsearch_settings(config: DictConfig) -> Optional[Any]: + """Resolve the active localsearch / sirchmunk settings node. + + Precedence: ``tools.localsearch`` with non-empty ``paths``, else legacy + ``knowledge_search`` with non-empty ``paths``. Returns ``None`` if local + search is not configured. + """ + tools = getattr(config, 'tools', None) + tl = None + if tools is not None: + tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr( + tools, 'localsearch', None) + ks = getattr(config, 'knowledge_search', None) + + if tl is not None and _paths_from_block(tl): + return tl + if ks is not None and _paths_from_block(ks): + return ks + return None + + +class SirchmunkSearch: + """Sirchmunk-based local search (used by :class:`LocalSearchTool`). + + Configure in yaml under ``tools.localsearch`` (recommended), for example:: + + tools: + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + embedding_model: text-embedding-3-small + cluster_sim_threshold: 0.85 + cluster_sim_top_k: 3 + reuse_knowledge: true + mode: FAST + + Legacy: the same keys may be placed under top-level ``knowledge_search``. Args: - config (DictConfig): Configuration object containing sirchmunk settings. + config: Full agent config; sirchmunk options read from the effective + block returned by :func:`effective_localsearch_settings`. """ def __init__(self, config: DictConfig): - super().__init__(config) - self._validate_config(config) + rag_config = effective_localsearch_settings(config) + assert rag_config is not None - # Extract configuration parameters - rag_config = config.get('knowledge_search', {}) - - # Search paths - required paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] @@ -53,11 +85,9 @@ def __init__(self, config: DictConfig): str(Path(p).expanduser().resolve()) for p in paths ] - # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') self.work_path: Path = Path(_work_path).expanduser().resolve() - # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) @@ -66,13 +96,10 @@ def __init__(self, config: DictConfig): self.max_loops = rag_config.get('max_loops', 10) self.max_token_budget = rag_config.get('max_token_budget', 128000) - # LLM configuration for sirchmunk - # First try knowledge_search.llm_api_key, then fall back to main llm config self.llm_api_key = rag_config.get('llm_api_key', None) self.llm_base_url = rag_config.get('llm_base_url', None) self.llm_model_name = rag_config.get('llm_model_name', None) - # Fall back to main llm config if not specified in knowledge_search if (self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None): llm_config = config.get('llm', {}) @@ -87,39 +114,57 @@ def __init__(self, config: DictConfig): if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) - # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( 'embedding_model_cache_dir', None) - # Runtime state self._searcher = None self._initialized = False self._cluster_cache_hit = False self._cluster_cache_hit_time: str | None = None self._last_search_result: List[Dict[str, Any]] | None = None - # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] - # Async queue for streaming logs in real-time self._log_queue: asyncio.Queue | None = None self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): - """Validate configuration parameters.""" - if not hasattr(config, - 'knowledge_search') or config.knowledge_search is None: + block = effective_localsearch_settings(config) + if block is None: raise ValueError( - 'Missing knowledge_search configuration. ' - 'Please add knowledge_search section to your config with at least "paths" specified.' - ) - - rag_config = config.knowledge_search - paths = rag_config.get('paths', []) + 'Missing localsearch configuration. Add ' + '`tools.localsearch` with non-empty `paths` (or legacy ' + '`knowledge_search.paths`).') + paths = _paths_from_block(block) if not paths: raise ValueError( - 'knowledge_search.paths must be specified and non-empty') + 'tools.localsearch.paths (or legacy knowledge_search.paths) ' + 'must be specified and non-empty') + + def resolve_tool_paths( + self, paths: Optional[List[str]]) -> Optional[List[str]]: + """Restrict per-call paths to configured search roots.""" + if not paths: + return None + roots = [Path(p).resolve() for p in self.search_paths] + cleaned: List[str] = [] + for raw in paths: + if raw is None or not str(raw).strip(): + continue + p = Path(str(raw).strip()).expanduser().resolve() + if not p.exists(): + logger.warning(f'localsearch: path does not exist, skipped: {p}') + continue + allowed = any( + p == r or p.is_relative_to(r) for r in roots) + if not allowed: + logger.warning( + f'localsearch: path outside configured search roots, ' + f'skipped: {p}') + continue + cleaned.append(str(p)) + return cleaned or None def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -131,7 +176,6 @@ def _initialize_searcher(self): from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil - # Create LLM client llm = OpenAIChat( api_key=self.llm_api_key, base_url=self.llm_base_url, @@ -140,8 +184,6 @@ def _initialize_searcher(self): log_callback=self._log_callback_wrapper(), ) - # Create embedding util - # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( @@ -150,7 +192,6 @@ def _initialize_searcher(self): embedding = EmbeddingUtil( model_id=embedding_model_id, cache_dir=embedding_cache_dir) - # Create AgenticSearch instance self._searcher = AgenticSearch( llm=llm, embedding=embedding, @@ -191,7 +232,6 @@ def log_callback( ): log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) - # Stream log in real-time if streaming callback is set if self._streaming_callback: asyncio.create_task(self._streaming_callback(log_entry)) @@ -215,7 +255,6 @@ async def add_documents(self, documents: List[str]) -> bool: 'SirchmunkSearch does not support direct document addition. ' 'Documents should be saved to files within the configured search paths.' ) - # Trigger re-scan of the search paths if self._searcher and hasattr(self._searcher, 'knowledge_base'): try: await self._searcher.knowledge_base.refresh() @@ -274,7 +313,6 @@ async def retrieve(self, max_token_budget = filters.get('max_token_budget', self.max_token_budget) - # Perform search result = await self._searcher.search( query=query, mode=mode, @@ -283,56 +321,101 @@ async def retrieve(self, return_context=True, ) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - # If a similar cluster was found and reused, it's a cache hit self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) - # Get the cluster cache hit time if available if hasattr(result.cluster, 'updated_at'): self._cluster_cache_hit_time = getattr( result.cluster, 'updated_at', None) - # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] - async def query(self, query: str) -> str: - """Query sirchmunk and return a synthesized answer. - - This method performs a search and returns the LLM-synthesized answer - along with search details that can be used for frontend display. + async def query( + self, + query: str, + *, + paths: Optional[List[str]] = None, + mode: Optional[str] = None, + max_depth: Optional[int] = None, + top_k_files: Optional[int] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> str: + """Query sirchmunk and return a synthesized answer (or filename hits). + + Optional arguments are forwarded to ``AgenticSearch.search`` where supported. + ``paths`` must already be restricted to configured search roots (see + :meth:`resolve_tool_paths`). Args: - query (str): The search query. + query: The search query. + paths: Override search roots (subset of configured paths), or None. + mode: ``FAST``, ``DEEP``, or ``FILENAME_ONLY``; None uses config default. + max_depth: Directory depth cap for filesystem search. + top_k_files: Max files for evidence / filename ranking. + include: Glob patterns to include (e.g. ``*.py``). + exclude: Glob patterns to exclude (e.g. ``node_modules``). Returns: - str: The synthesized answer from sirchmunk. + Answer string, or JSON string for ``FILENAME_ONLY`` list results. """ self._initialize_searcher() self._search_logs.clear() try: - mode = self.search_mode - max_loops = self.max_loops - max_token_budget = self.max_token_budget - - # Single search with context so we get both the synthesized answer and - # source units in one call, avoiding a redundant second search. - result = await self._searcher.search( + mode_eff = mode if mode is not None else self.search_mode + if isinstance(mode_eff, str): + mode_eff = mode_eff.strip().upper() + allowed_modes = ('FAST', 'DEEP', 'FILENAME_ONLY') + if mode_eff not in allowed_modes: + return ( + f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.') + + kw: Dict[str, Any] = dict( query=query, - mode=mode, - max_loops=max_loops, - max_token_budget=max_token_budget, + paths=paths, + mode=mode_eff, + max_loops=self.max_loops, + max_token_budget=self.max_token_budget, return_context=True, ) + if max_depth is not None: + kw['max_depth'] = max_depth + if top_k_files is not None: + kw['top_k_files'] = top_k_files + if include is not None: + kw['include'] = include + if exclude is not None: + kw['exclude'] = exclude + + result = await self._searcher.search(**kw) + + if isinstance(result, list): + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + self._last_search_result = [] + for item in result[:20]: + if isinstance(item, dict): + src = (item.get('path') or item.get('file_path') + or item.get('file') or '') + self._last_search_result.append({ + 'text': + json.dumps(item, ensure_ascii=False), + 'score': + 1.0, + 'metadata': { + 'source': str(src), + 'type': 'filename_match', + }, + }) + return json.dumps(result, ensure_ascii=False, indent=2) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: @@ -342,19 +425,16 @@ async def query(self, query: str) -> str: self._cluster_cache_hit_time = getattr( result.cluster, 'updated_at', None) - # Store parsed context for frontend display self._last_search_result = self._parse_search_result( result, score_threshold=0.7, limit=5) - # Extract the synthesized answer from the context result - if hasattr(result, 'answer'): + if hasattr(result, 'answer') and getattr(result, 'answer', + None) is not None: return result.answer - # If result is already a plain string (some modes return str directly) if isinstance(result, str): return result - # Fallback: convert to string return str(result) except Exception as e: @@ -375,14 +455,11 @@ def _parse_search_result(self, result: Any, score_threshold: float, """ results = [] - # Handle SearchContext format (returned when return_context=True) if hasattr(result, 'cluster') and result.cluster is not None: cluster = result.cluster for unit in cluster.evidences: - # Extract score from snippets if available score = getattr(cluster, 'confidence', 1.0) if score >= score_threshold: - # Extract text from snippets text_parts = [] source = str(getattr(unit, 'file_or_url', 'unknown')) for snippet in getattr(unit, 'snippets', []): @@ -406,7 +483,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) @@ -423,7 +499,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle list format elif isinstance(result, list): for item in result: if isinstance(item, dict): @@ -438,7 +513,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, item.get('metadata', {}), }) - # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: @@ -451,10 +525,13 @@ def _parse_search_result(self, result: Any, score_threshold: float, result.get('metadata', {}), }) - # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) return results[:limit] + def get_last_retrieved_chunks(self) -> List[Dict[str, Any]]: + """Parsed evidence chunks from the last `query` or `retrieve` call.""" + return list(self._last_search_result or []) + def get_search_logs(self) -> List[str]: """Get the captured search logs. diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..e7885b92d 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -17,6 +17,8 @@ from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient +from ms_agent.tools.search.localsearch_tool import LocalSearchTool +from ms_agent.tools.search.sirchmunk_search import effective_localsearch_settings from ms_agent.tools.search.websearch_tool import WebSearchTool from ms_agent.tools.split_task import SplitTask from ms_agent.tools.todolist_tool import TodoListTool @@ -88,6 +90,8 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + if effective_localsearch_settings(config) is not None: + self.extra_tools.append(LocalSearchTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py index 5a4f43213..705c7184e 100644 --- a/tests/knowledge_search/test_sirschmunk.py +++ b/tests/knowledge_search/test_sirschmunk.py @@ -1,23 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. +"""Tests for SirchmunkSearch and localsearch tool integration. -These tests verify the sirchmunk-based knowledge search functionality -through the LLMAgent entry point, including verification that -search_result and searching_detail fields are properly populated. - -To run these tests, you need to set the following environment variables: - - TEST_LLM_API_KEY: Your LLM API key - - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) - - TEST_LLM_MODEL_NAME: Your LLM model name (optional) - - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) - - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) - -Example: +Example (full sirchmunk run): export TEST_LLM_API_KEY="your-api-key" - export TEST_LLM_BASE_URL="https://api.openai.com/v1" - export TEST_LLM_MODEL_NAME="gpt-4o" - export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" - export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" python -m pytest tests/knowledge_search/test_sirschmunk.py """ import asyncio @@ -26,92 +11,41 @@ import unittest from pathlib import Path -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.agent import LLMAgent -from ms_agent.config import Config +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch +from ms_agent.llm.utils import Message +from ms_agent.tools.tool_manager import ToolManager from omegaconf import DictConfig -from modelscope.utils.test_utils import test_level - - -class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): - """Test cases for SirchmunkSearch integration with LLMAgent. - - These tests verify that when LLMAgent runs a query that triggers - knowledge search, the Message objects have search_result and - searching_detail fields properly populated. - """ +class SirchmunkKnowledgeSearchTest(unittest.TestCase): + """Sirchmunk config, ToolManager registration""" @classmethod def setUpClass(cls): - """Set up test fixtures.""" - # Create test directory with sample files cls.test_dir = Path('./test_llm_agent_knowledge') cls.test_dir.mkdir(exist_ok=True) - - # Create sample documentation - (cls.test_dir / 'README.md').write_text(''' -# Test Project Documentation - -## Overview -This is a test project for knowledge search integration. - -## API Reference - -### UserManager -The UserManager class handles user operations: -- create_user: Create a new user account -- delete_user: Delete an existing user -- update_user: Update user information -- get_user: Retrieve user details - -### AuthService -The AuthService class handles authentication: -- login: Authenticate user credentials -- logout: End user session -- refresh_token: Refresh authentication token -- verify_token: Validate authentication token -''') - - (cls.test_dir / 'config.py').write_text(''' -"""Configuration module.""" - -class Config: - """Application configuration.""" - - def __init__(self): - self.database_url = "postgresql://localhost:5432/mydb" - self.secret_key = "your-secret-key" - self.debug_mode = False - - def load_from_env(self): - """Load configuration from environment variables.""" - import os - self.database_url = os.getenv("DATABASE_URL", self.database_url) - self.secret_key = os.getenv("SECRET_KEY", self.secret_key) - return self -''') + (cls.test_dir / 'README.md').write_text( + '# Demo\n\nUserManager.create_user creates a user.\n') @classmethod def tearDownClass(cls): - """Clean up test fixtures.""" if cls.test_dir.exists(): shutil.rmtree(cls.test_dir, ignore_errors=True) work_dir = Path('./.sirchmunk') if work_dir.exists(): shutil.rmtree(work_dir, ignore_errors=True) - def _get_agent_config(self): - """Create agent configuration with knowledge search.""" + def _base_config(self) -> DictConfig: llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') - llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', + 'https://api.openai.com/v1') llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') - # Read from TEST_* env vars (for test-specific config) - # These can be set from .env file which uses TEST_* prefix embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') - embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') - - config = DictConfig({ + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', + '') + return DictConfig({ + 'output_dir': + './outputs_knowledge_test', 'llm': { 'service': 'openai', 'model': llm_model_name, @@ -122,81 +56,66 @@ def _get_agent_config(self): 'temperature': 0.3, 'max_tokens': 500, }, - 'knowledge_search': { - 'name': 'SirchmunkSearch', - 'paths': [str(self.test_dir)], - 'work_path': './.sirchmunk', - 'llm_api_key': llm_api_key, - 'llm_base_url': llm_base_url, - 'llm_model_name': llm_model_name, - 'embedding_model': embedding_model_id, - 'embedding_model_cache_dir': embedding_model_cache_dir, - 'mode': 'FAST', - } + 'tools': { + 'localsearch': { + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + }, + }, }) - return config - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_llm_agent_with_knowledge_search(self): - """Test LLMAgent using knowledge search. - - This test verifies that: - 1. LLMAgent can be initialized with SirchmunkSearch configuration - 2. Running a query produces a valid response - 3. User message has searching_detail and search_result populated - 4. searching_detail contains expected keys (logs, mode, paths) - 5. search_result is a list - """ - config = self._get_agent_config() + def test_does_not_inject_knowledge_search(self): + """Local sirchmunk search is no longer merged into the user message here.""" + config = self._base_config() agent = LLMAgent(config=config, tag='test-knowledge-agent') - - # Test query that should trigger knowledge search - query = 'How do I use UserManager to create a user?' - - async def run_agent(): - result = await agent.run(query) - return result - - result = asyncio.run(run_agent()) - - # Verify result - self.assertIsNotNone(result) - self.assertIsInstance(result, list) - self.assertTrue(len(result) > 0) - - # Check that assistant message exists - assistant_message = [m for m in result if m.role == 'assistant'] - self.assertTrue(len(assistant_message) > 0) - - # Check that user message has search_result and searching_detail populated - user_messages = [m for m in result if m.role == 'user'] - self.assertTrue(len(user_messages) > 0, "Expected at least one user message") - - # The first user message should have search details after do_rag processing - user_msg = user_messages[0] - self.assertTrue( - hasattr(user_msg, 'searching_detail'), - "User message should have searching_detail attribute" - ) + original = 'How do I use UserManager?' + + async def run(): + messages = [ + Message(role='system', content='You are a helper.'), + Message(role='user', content=original), + ] + await agent.run(messages) + return messages + + messages = asyncio.run(run()) + print(f'messages: {messages}') + + def test_tool_manager_registers_localsearch(self): + """When tools.localsearch.paths is set, ToolManager exposes localsearch.""" + + async def run(): + config = self._base_config() + tm = ToolManager(config, trust_remote_code=False) + await tm.connect() + tools = await tm.get_tools() + await tm.cleanup() + return tools + + tools = asyncio.run(run()) + names = [t['tool_name'] for t in tools] self.assertTrue( - hasattr(user_msg, 'search_result'), - "User message should have search_result attribute" + any(n.endswith('localsearch') for n in names), + f'Expected localsearch in tools, got: {names}', ) - # Check that searching_detail is a dict with expected keys - self.assertIsInstance( - user_msg.searching_detail, dict, - "searching_detail should be a dictionary" - ) - self.assertIn('logs', user_msg.searching_detail) - self.assertIn('mode', user_msg.searching_detail) - self.assertIn('paths', user_msg.searching_detail) - - # Check that search_result is a list (may be empty if no relevant docs found) - self.assertIsInstance( - user_msg.search_result, list, - "search_result should be a list" - ) + @unittest.skipUnless( + os.getenv('TEST_SIRCHMUNK_SMOKE', ''), + 'Set TEST_SIRCHMUNK_SMOKE=1 to run sirchmunk API smoke test', + ) + def test_sirchmunk_search_query_smoke(self): + """Optional: run sirchmunk once (needs network / valid API keys).""" + config = self._base_config() + searcher = SirchmunkSearch(config) + result = asyncio.run(searcher.query('UserManager')) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) if __name__ == '__main__': From 76a821e1ea2877b162bc38d8386959218bf0f118 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 27 Mar 2026 11:42:21 +0800 Subject: [PATCH 08/12] feat: local search tool --- docs/en/Components/Config.md | 10 + docs/zh/Components/config.md | 12 +- ms_agent/agent/agent.yaml | 37 +- ms_agent/agent/llm_agent.py | 79 +++- ms_agent/config/config.py | 20 +- ms_agent/knowledge_search/__init__.py | 5 +- ms_agent/llm/utils.py | 4 + ms_agent/rag/utils.py | 3 - ms_agent/tools/base.py | 14 + ms_agent/tools/search/localsearch_catalog.py | 436 +++++++++++++++++++ ms_agent/tools/search/localsearch_tool.py | 219 +++++++++- ms_agent/tools/search/sirchmunk_search.py | 27 +- ms_agent/tools/tool_manager.py | 129 +++++- tests/knowledge_search/test_sirschmunk.py | 33 +- 14 files changed, 970 insertions(+), 58 deletions(-) create mode 100644 ms_agent/tools/search/localsearch_catalog.py diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index 2154add8e..65b384c12 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -104,12 +104,22 @@ tools: - map_geo # Local codebase / document search (sirchmunk), exposed as the `localsearch` tool localsearch: + mcp: false paths: - ./src - ./docs work_path: ./.sirchmunk mode: FAST # Optional: llm_api_key, llm_base_url, llm_model_name (else inherited from `llm`) + # When true, a shallow sirchmunk DirectoryScanner run at tool connect injects file titles/previews + # into the `localsearch` tool description (default: false) + # description_catalog: false + # description_catalog_max_files: 120 + # description_catalog_max_depth: 5 + # description_catalog_max_chars: 10000 + # description_catalog_max_preview_chars: 400 + # description_catalog_cache_ttl_seconds: 300 + # description_catalog_exclude: [] # extra globs / dir names merged with sirchmunk defaults ``` For the complete list of supported tools and custom tools, please refer to [here](./Tools.md) diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 0baa4bf18..49e7a133e 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -104,12 +104,22 @@ tools: - map_geo # 本地代码库/文档搜索(sirchmunk),对应模型可调用的 `localsearch` 工具 localsearch: + mcp: false paths: - ./src - ./docs work_path: ./.sirchmunk mode: FAST # 可选:llm_api_key、llm_base_url、llm_model_name(不填则从 `llm` 继承) + # 为 true 时,在工具连接阶段用 sirchmunk DirectoryScanner 做浅层扫描,把文件标题/预览写入 + # `localsearch` 工具 description,便于模型知道本地知识库里大致有哪些内容(默认 false) + # description_catalog: false + # description_catalog_max_files: 120 + # description_catalog_max_depth: 5 + # description_catalog_max_chars: 10000 + # description_catalog_max_preview_chars: 400 + # description_catalog_cache_ttl_seconds: 300 + # description_catalog_exclude: [] # 额外 glob / 目录名,与 sirchmunk 默认排除合并 ``` 支持的完整工具列表,以及自定义工具请参考 [这里](./tools) @@ -190,4 +200,4 @@ ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_sea ``` LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 -若 sirchmunk 需独立 LLM,可在 yaml 的 `tools.localsearch` 下设置 `llm_api_key`、`llm_base_url`、`llm_model_name`。仍支持旧版顶层 `knowledge_search` 相同字段,以便迁移。 +若 sirchmunk 需独立 LLM,可在 yaml 的 `tools.localsearch` 下设置 `llm_api_key`、`llm_base_url`、`llm_model_name`。 diff --git a/ms_agent/agent/agent.yaml b/ms_agent/agent/agent.yaml index a8a65d440..21b01c22f 100644 --- a/ms_agent/agent/agent.yaml +++ b/ms_agent/agent/agent.yaml @@ -13,42 +13,7 @@ generation_config: prompt: system: | - You are an assistant that helps me complete tasks. You need to follow these instructions: - - 1. Analyze whether my requirements need tool-calling. If no tools are needed, you can think directly and provide an answer. - - 2. I will give you many tools, some of which are similar. Please carefully analyze which tool you currently need to invoke. - * If tools need to be invoked, you must call at least one tool in each round until the requirement is completed. - * If you get any useful links or images from the tool calling, output them with your answer as well. - * Check carefully the tool result, what it contains, whether it has information you need. - - 3. You DO NOT have built-in geocode/coordinates/links. Do not output any fake geocode/coordinates/links. Always query geocode/coordinates/links from tools first! - - 4. If you need to complete coding tasks, you need to carefully analyze the original requirements, provide detailed requirement analysis, and then complete the code writing. - - 5. This conversation is NOT for demonstration or testing purposes. Answer it as accurately as you can. - - 6. Do not call tools carelessly. Show your thoughts **as detailed as possible**. - - 7. Respond in the same language the user uses. If the user switches, switch accordingly. - - For requests that require performing a specific task or retrieving information, using the following format: - ``` - The user needs to ... - I have analyzed this request in detail and broken it down into the following steps: - ... - ``` - If you have tools which may help you to solve problems, follow this format to answer: - ``` - The user needs to ... - I have analyzed this request in detail and broken it down into the following steps: - ... - First, I should use the [Tool Name] because [explain relevance]. The required input parameters are: ... - ... - I have carefully reviewed the tool's output. The result does/does not fully meet my expectations. Next, I need to ... - ``` - - **Important: Always respond in the same language the user is using.** + you are a helpful assistant. max_chat_round: 9999 diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 36e46a672..606427de8 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -537,6 +537,76 @@ async def parallel_tool_call(self, self.log_output(_new_message.content) return messages + async def parallel_tool_call_streaming( + self, messages: List[Message]) -> AsyncGenerator: + """Streaming variant of parallel_tool_call. + + Yields messages list snapshots during tool execution: + - While tools are running: yields messages with the latest + searching_detail_delta set on a temporary placeholder Message so the + caller can stream intermediate logs to the frontend. + - After all tools finish: yields the final messages list (with proper + tool result Messages appended), same as parallel_tool_call. + """ + tool_calls = messages[-1].tool_calls + + # Map call_id -> tool_call_query for final message construction. + call_id_to_query = {tc['id']: tc for tc in tool_calls} + + # Accumulate final results keyed by call_id. + final_results: dict = {} + + async for call_id, item in self.tool_manager.parallel_call_tool_streaming( + tool_calls): + if isinstance(item, dict) or not isinstance(item, str) or \ + not self._is_log_line(item): + # Final result for this call_id. + final_results[call_id] = item + else: + # Intermediate log line: emit a snapshot with searching_detail_delta. + log_message = Message( + role='tool', + content='', + tool_call_id=call_id, + name=call_id_to_query.get(call_id, {}).get( + 'tool_name', ''), + searching_detail_delta=item, + ) + yield messages + [log_message] + + # All tools done — build final tool messages and yield. + for tool_call_query in tool_calls: + cid = tool_call_query['id'] + raw_result = final_results.get( + cid, f'Tool call missing result for id {cid}') + tool_call_result_format = ToolResult.from_raw(raw_result) + _new_message = Message( + role='tool', + content=tool_call_result_format.text, + tool_call_id=cid, + name=tool_call_query['tool_name'], + resources=tool_call_result_format.resources, + tool_detail=tool_call_result_format.tool_detail, + ) + if _new_message.tool_call_id is None: + _new_message.tool_call_id = str(uuid.uuid4())[:8] + tool_call_query['id'] = _new_message.tool_call_id + messages.append(_new_message) + self.log_output(_new_message.content) + + yield messages + + @staticmethod + def _is_log_line(item: str) -> bool: + """Heuristic: a plain log string from sirchmunk is not a timeout/error result.""" + if item.startswith('Execute tool call timeout:'): + return False + if item.startswith('Tool calling failed:'): + return False + if item.startswith('The input ') and 'is not a valid JSON' in item: + return False + return True + async def prepare_tools(self): """Initialize and connect the tool manager.""" self.tool_manager = ToolManager( @@ -897,7 +967,14 @@ async def step( self.save_history(messages) if _response_message.tool_calls: - messages = await self.parallel_tool_call(messages) + # Use the streaming variant so intermediate tool logs are yielded + # back to the caller while the tools are still running. + async for messages in self.parallel_tool_call_streaming(messages): + is_final = ( + not messages[-1].searching_detail_delta + if hasattr(messages[-1], 'searching_detail_delta') else True) + if not is_final: + yield messages await self.after_tool_call(messages) diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index 2f6175524..58e3ef26a 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -16,6 +16,20 @@ logger = get_logger() +# ``tools.`` entries implemented as in-process ToolBase classes, not MCP transports. +_BUILTIN_TOOL_SERVERS = frozenset({ + 'localsearch', + 'web_search', + 'split_task', + 'file_system', + 'code_executor', + 'financial_data_fetcher', + 'agent_tools', + 'todo_list', + 'image_generator', + 'video_generator', +}) + class ConfigLifecycleHandler: @@ -253,6 +267,10 @@ def convert_mcp_servers_to_json( for server, server_config in config.tools.items(): if server == TOOL_PLUGIN_NAME: continue - if getattr(server_config, 'mcp', True): + if server in _BUILTIN_TOOL_SERVERS: + use_mcp = getattr(server_config, 'mcp', False) + else: + use_mcp = getattr(server_config, 'mcp', True) + if use_mcp: servers['mcpServers'][server] = deepcopy(server_config) return servers diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py index 8746f27c9..f6c7f5143 100644 --- a/ms_agent/knowledge_search/__init__.py +++ b/ms_agent/knowledge_search/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Backward-compatible re-exports for sirchmunk local search. -Implementation lives in :mod:`ms_agent.tools.search.sirchmunk_search`; prefer -importing ``SirchmunkSearch`` from there in new code. +This module provides integration between sirchmunk's AgenticSearch +and the ms_agent framework, enabling intelligent local path search +capabilities. """ from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index c108170cb..05fdbb02b 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -64,6 +64,10 @@ class Message: # role=tool: extra payload for UIs / SSE only; omitted from LLM API via to_dict_clean(). tool_detail: Optional[str] = None + # role=tool (streaming): incremental log line emitted while the tool is still running. + # Non-empty only on intermediate yields; final yield has tool_detail set instead. + searching_detail_delta: Optional[str] = None + def to_dict(self): return asdict(self) diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index cf52aaa7c..08e9a4db7 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -4,6 +4,3 @@ rag_mapping = { 'LlamaIndexRAG': LlamaIndexRAG, } - -# Note: Sirchmunk local search is the ``localsearch`` tool -# (ms_agent.tools.search); it is not wired through rag_mapping. diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index 12ece9948..f3fd18515 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -89,3 +89,17 @@ async def call_tool(self, server_name: str, *, tool_name: str, Calling result in string format. """ pass + + async def call_tool_streaming(self, server_name: str, *, tool_name: str, + tool_args: dict): + """Streaming variant of call_tool. + + Yields incremental log strings (str) during execution, then yields the + final result (str or dict) as the last item. + + Default implementation simply delegates to call_tool (no streaming). + Override in subclasses that support real-time log emission. + """ + result = await self.call_tool( + server_name, tool_name=tool_name, tool_args=tool_args) + yield result diff --git a/ms_agent/tools/search/localsearch_catalog.py b/ms_agent/tools/search/localsearch_catalog.py new file mode 100644 index 000000000..9dd1b2bf7 --- /dev/null +++ b/ms_agent/tools/search/localsearch_catalog.py @@ -0,0 +1,436 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Build a compact file catalog for localsearch tool descriptions using sirchmunk's DirectoryScanner.""" + +from __future__ import annotations + +import hashlib +import json +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + + +def catalog_cache_path(work_path: Path, fingerprint: str) -> Path: + work_path.mkdir(parents=True, exist_ok=True) + return work_path / f'localsearch_description_catalog.{fingerprint}.json' + + +def catalog_fingerprint( + roots: List[str], + max_files: int, + max_depth: int, + max_preview_chars: int, + max_chars: int, + exclude: Optional[List[str]], +) -> str: + payload = { + 'roots': sorted(roots), + 'max_files': max_files, + 'max_depth': max_depth, + 'max_preview_chars': max_preview_chars, + 'max_chars': max_chars, + 'exclude': sorted(exclude or []), + } + raw = json.dumps(payload, sort_keys=True, ensure_ascii=False).encode('utf-8') + return hashlib.sha256(raw).hexdigest()[:24] + + +def load_cached_catalog( + path: Path, + ttl_seconds: float, +) -> Optional[str]: + if ttl_seconds <= 0 or not path.is_file(): + return None + try: + with open(path, 'r', encoding='utf-8') as f: + data = json.load(f) + created = float(data.get('created_at', 0)) + if time.time() - created > ttl_seconds: + return None + text = data.get('catalog') + return str(text) if text is not None else None + except (OSError, json.JSONDecodeError, TypeError, ValueError): + return None + + +def save_cached_catalog(path: Path, catalog: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + '.tmp') + payload = {'created_at': time.time(), 'catalog': catalog} + with open(tmp, 'w', encoding='utf-8') as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + tmp.replace(path) + + +def _int_from_block(block: Any, key: str, default: int) -> int: + if block is None: + return default + v = block.get(key, default) if hasattr(block, 'get') else getattr( + block, key, default) + try: + return int(v) + except (TypeError, ValueError): + return default + + +def _bool_from_block(block: Any, key: str, default: bool = False) -> bool: + if block is None: + return default + v = block.get(key, default) if hasattr(block, 'get') else getattr( + block, key, default) + if isinstance(v, str): + return v.strip().lower() in ('1', 'true', 'yes', 'on') + return bool(v) + + +def _exclude_from_block(block: Any) -> Optional[List[str]]: + if block is None: + return None + raw = block.get('description_catalog_exclude', None) if hasattr( + block, 'get') else getattr(block, 'description_catalog_exclude', None) + if raw is None: + return None + if isinstance(raw, str): + return [raw] if raw.strip() else None + if isinstance(raw, (list, tuple)): + out = [str(x).strip() for x in raw if str(x).strip()] + return out or None + return None + + +def description_catalog_settings(block: Any) -> Tuple[bool, Dict[str, Any]]: + """Parse ``tools.localsearch`` (or legacy ``knowledge_search``) catalog options.""" + enabled = _bool_from_block(block, 'description_catalog', True) + opts = { + 'max_files': max(1, min(2000, _int_from_block( + block, 'description_catalog_max_files', 120))), + 'max_depth': max(1, min(20, _int_from_block( + block, 'description_catalog_max_depth', 5))), + 'max_chars': max(500, min(100_000, _int_from_block( + block, 'description_catalog_max_chars', 3_000))), + 'max_preview_chars': max(80, min(4000, _int_from_block( + block, 'description_catalog_max_preview_chars', 400))), + 'cache_ttl_seconds': max(0, _int_from_block( + block, 'description_catalog_cache_ttl_seconds', 300)), + 'exclude_extra': _exclude_from_block(block), + # Files larger than this are skipped during catalog scan (default 50 MB). + # Set to 0 to disable the cap. + 'max_file_size_mb': max(0, _int_from_block( + block, 'description_catalog_max_file_size_mb', 50)), + # Wall-clock timeout (seconds) for extracting the first page of an + # oversized PDF. Corrupt or pathological PDFs are abandoned after this + # and only their filename + size appear in the catalog. + 'oversized_pdf_timeout_s': max(0.1, _int_from_block( + block, 'description_catalog_oversized_pdf_timeout_s', 1)), + } + return enabled, opts + + +_DIR_SKIP: Set[str] = { + '.git', '.svn', 'node_modules', '__pycache__', '.idea', '.vscode', + '.cache', '.tox', '.eggs', 'dist', 'build', '.DS_Store', +} + +# Directories that typically contain generated/compiled artifacts, not source. +# These are skipped by default to keep the tree focused on meaningful content. +_DIR_SKIP_GENERATED: Set[str] = { + '_build', '_static', '_templates', '_sphinx_design_static', + '__pycache__', '.mypy_cache', '.pytest_cache', '.ruff_cache', + 'htmlcov', 'site-packages', 'egg-info', '.egg-info', +} + + +def _build_dir_tree( + root: Path, + max_depth: int, + exclude: Optional[List[str]], + max_chars: int = 4000, +) -> str: + """Fast filesystem-only directory tree (no file content reads). + + Produces a compact indented listing: + 📁 ms_agent/ + 📁 tools/ + 📁 search/ (8 files) + agent_tool.py base.py ... +3 + __init__.py llm_agent.py ... + + Strategy — two passes: + 1. DFS pre-scan: collect every (depth, line) pair without any char limit. + 2. Breadth-first selection: sort collected lines by depth, take lines from + shallowest levels first until the char budget is exhausted. This + guarantees all top-level directories appear before any second-level + directories are shown, etc. + + The final output is re-sorted by original DFS order so indentation is + visually coherent. + """ + skip = set(_DIR_SKIP) | _DIR_SKIP_GENERATED + if exclude: + skip.update(exclude) + + # --- Pass 1: DFS, collect (depth, seq, line_text) --- + collected: List[tuple] = [] # (depth, seq, line_text) + seq = [0] + + def _dfs(p: Path, depth: int, indent: str) -> None: + if depth > max_depth: + return + try: + entries = sorted(p.iterdir(), key=lambda e: (e.is_file(), e.name.lower())) + except PermissionError: + return + + dirs = [e for e in entries if e.is_dir() + and not e.name.startswith('.') and e.name not in skip] + files = [e for e in entries if e.is_file() + and not e.name.startswith('.') and e.name not in skip] + + child_indent = indent + ' ' + + for d in dirs: + try: + file_count = sum( + 1 for _ in d.iterdir() + if _.is_file() and not _.name.startswith('.')) + except PermissionError: + file_count = 0 + count_hint = f' ({file_count} files)' if file_count else '' + collected.append((depth, seq[0], f'{indent}📁 {d.name}/{count_hint}')) + seq[0] += 1 + _dfs(d, depth + 1, child_indent) + + if files: + MAX_SHOW = 5 + shown = [f.name for f in files[:MAX_SHOW]] + overflow = len(files) - MAX_SHOW + file_line = f'{indent} ' + ' '.join(shown) + if overflow > 0: + file_line += f' … +{overflow}' + # File hint lines get depth + 0.5 so they sort after their parent + # dir line but before the parent's children directories. + collected.append((depth + 0.5, seq[0], file_line)) + seq[0] += 1 + + _dfs(root, 0, '') + + # --- Pass 2: breadth-first selection within char budget --- + # Sort by (depth, seq) to process shallowest lines first. + by_depth = sorted(collected, key=lambda t: (t[0], t[1])) + budget = max_chars - 40 # reserve for truncation note + selected_seqs: set = set() + used = 0 + truncated = False + for depth, s, line in by_depth: + cost = len(line) + 1 + if used + cost > budget: + truncated = True + break + selected_seqs.add(s) + used += cost + + # Re-sort selected lines by original DFS seq to restore correct indentation + output_lines = [line for (_, s, line) in collected if s in selected_seqs] + + result = '\n'.join(output_lines) + if truncated: + result += '\n… (tree truncated — deeper directories omitted)' + return result + + +def _compact_file_summary(candidate: Any, root_dir: str, max_preview: int) -> str: + """Single-line or two-line compact summary for a FileCandidate. + + Format: + - path/to/file.py (.py, 12KB) — Title or first-line preview + """ + from pathlib import Path as _Path + try: + rel = _Path(candidate.path).relative_to(_Path(root_dir)).as_posix() + except (ValueError, TypeError): + rel = _Path(candidate.path).as_posix() + + size = candidate.size_bytes + if size < 1024: + size_str = f'{size}B' + elif size < 1024 * 1024: + size_str = f'{size / 1024:.0f}KB' + else: + size_str = f'{size / 1024 / 1024:.1f}MB' + + label = candidate.title or '' + if not label and candidate.preview: + # First sentence / line of preview, capped + label = candidate.preview.replace('\n', ' ').strip() + label = label[:max_preview] if label else '' + + base = f'- {rel} ({candidate.extension or "?"}, {size_str})' + return f'{base} — {label}' if label else base + + +async def build_file_catalog_text( + roots: List[str], + *, + max_files: int, + max_depth: int, + max_preview_chars: int, + exclude_extra: Optional[List[str]], + max_file_size_mb: int = 50, + oversized_pdf_timeout_s: float = 1.0, + max_chars: int = 10_000, +) -> str: + """Build a two-section catalog for the localsearch tool description. + + Section 1 — Directory tree (filesystem walk, no IO beyond stat): + Gives the model the full directory structure so it understands where + to look without needing to enumerate every file. Capped at ~60% of + the max_chars budget so that file summaries always get space too. + + Section 2 — File summaries (from DirectoryScanner, capped by max_files): + Compact one-liners: relative path, size, and a short content hint. + Sorted by path so related files appear together. + + The combined output fits within max_chars (the caller may still apply + _truncate_catalog_text for final trimming of the file-summary entries). + """ + try: + from sirchmunk.scan.dir_scanner import DirectoryScanner + except ImportError as e: + raise ImportError( + 'sirchmunk is required for description_catalog. ' + f'Import failed: {e}') from e + + # Tree gets 60% of the budget; file summaries get the remaining 40%. + # Each root shares the budget equally. + num_roots = max(1, len(roots)) + per_root_budget = max(500, max_chars // num_roots) + tree_budget = max(300, int(per_root_budget * 0.60)) + + max_file_size_bytes = (max_file_size_mb * 1024 * 1024) if max_file_size_mb > 0 else None + # Merge the built-in skip sets with any user-provided excludes so the + # scanner also skips generated/artifact directories. + scanner_exclude: List[str] = sorted(_DIR_SKIP | _DIR_SKIP_GENERATED) + if exclude_extra: + scanner_exclude.extend(exclude_extra) + scanner = DirectoryScanner( + llm=None, + max_depth=max_depth, + max_files=max_files, + max_preview_chars=max_preview_chars, + exclude_patterns=scanner_exclude, + max_file_size_bytes=max_file_size_bytes, + oversized_pdf_timeout_s=oversized_pdf_timeout_s, + ) + + sections: List[str] = [] + for root in roots: + p = Path(root) + if not p.exists(): + sections.append(f'### `{root}`\n(missing on disk)') + continue + + # --- Section 1: directory tree (fast, no content reads) --- + tree = _build_dir_tree( + p, max_depth=max_depth, exclude=exclude_extra, max_chars=tree_budget) + tree_block = f'#### Directory structure of `{p}`\n{tree}' if tree else '' + + # --- Section 2: per-file compact summaries --- + result = await scanner.scan(p) + # Sort first so round-robin within each subdir is deterministic + all_candidates = sorted(result.candidates, key=lambda c: c.path) + # Stratified reorder: root files first, then round-robin across + # subdirectories (smallest subdir first for maximum coverage). + # Always reorder regardless of count — the char-budget trim below + # does the actual limiting. + reordered = _stratified_sample(all_candidates, p, len(all_candidates)) + + # Trim to the files char budget (40% of total per-root budget). + # Trimming the *reordered* list ensures diverse coverage is preserved. + files_budget = max(200, per_root_budget - tree_budget - 80) + file_lines: List[str] = [] + omitted = 0 + if reordered: + used_f = 0 + for c in reordered: + line = _compact_file_summary(c, str(p), max_preview=100) + if used_f + len(line) + 1 > files_budget: + omitted = len(reordered) - len(file_lines) + break + file_lines.append(line) + used_f += len(line) + 1 + else: + file_lines.append('_(no scannable files in depth budget)_') + + total_scanned = len(all_candidates) + header = f'#### File summaries ({len(file_lines)} of {total_scanned} sampled)' + if omitted: + file_block = header + '\n' + '\n'.join(file_lines) + f'\n… ({omitted} more files not shown)' + else: + file_block = header + '\n' + '\n'.join(file_lines) + + root_section = f'### Under `{p}`\n\n{tree_block}\n\n{file_block}' + sections.append(root_section.strip()) + + return '\n\n---\n\n'.join(sections).strip() + + +def _stratified_sample(candidates: List[Any], root: Path, max_entries: int) -> List[Any]: + """Return a stratified sample of candidates so every subdirectory gets + representation, rather than simply taking the first N by path. + + Algorithm: + 1. Root-level files first (high information density — README, pyproject, etc.) + 2. One representative per unique immediate subdirectory, round-robin until + the budget is exhausted. Subdirectories with fewer files are serviced + first so smaller modules are not crowded out by large doc/data trees. + + Within each group files are ordered by path so the result is deterministic. + Always reorders — callers rely on this even when len(candidates) == max_entries. + """ + # Split into root-level vs sub-directory files + root_files: List[Any] = [] + by_subdir: dict = {} + for c in candidates: + try: + rel = Path(c.path).relative_to(root) + except ValueError: + rel = Path(c.path) + parts = rel.parts + if len(parts) == 1: + root_files.append(c) + else: + subdir = parts[0] + by_subdir.setdefault(subdir, []).append(c) + + result: List[Any] = [] + + # Always include all root-level files first (usually just a handful) + result.extend(root_files) + + if not by_subdir: + return result[:max_entries] + + # Round-robin across subdirectories. + # Sort subdirs by file count ascending so smaller directories (which are + # likely to have fewer but more targeted files) get their representative + # entry before large documentation/data directories exhaust the budget. + subdirs = sorted(by_subdir.keys(), key=lambda s: len(by_subdir[s])) + subdir_iters = {s: iter(by_subdir[s]) for s in subdirs} + while len(result) < max_entries and subdir_iters: + exhausted = [] + for s in subdirs: + if len(result) >= max_entries: + break + it = subdir_iters.get(s) + if it is None: + continue + try: + result.append(next(it)) + except StopIteration: + exhausted.append(s) + for s in exhausted: + del subdir_iters[s] + if not subdir_iters: + break + + return result[:max_entries] diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py index d8ab34ef7..a0fb305b9 100644 --- a/ms_agent/tools/search/localsearch_tool.py +++ b/ms_agent/tools/search/localsearch_tool.py @@ -1,10 +1,20 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """On-demand local codebase search via sirchmunk (replaces pre-turn RAG injection).""" +import asyncio import json +import time from pathlib import Path -from typing import Any, Dict, List, Optional - +from typing import Any, AsyncGenerator, Dict, List, Optional + +from ms_agent.tools.search.localsearch_catalog import ( + build_file_catalog_text, + catalog_cache_path, + catalog_fingerprint, + description_catalog_settings, + load_cached_catalog, + save_cached_catalog, +) from ms_agent.tools.search.sirchmunk_search import ( SirchmunkSearch, effective_localsearch_settings, @@ -26,6 +36,8 @@ - You need to find information in source code, config files, or documents - The query references a local path, project structure, or codebase - You need to search PDF, DOCX, XLSX, PPTX, CSV, JSON, YAML, Markdown, etc. +- Large files or directories should be searched by this tool. + DO NOT USE THIS TOOL WHEN: - The user is asking a general knowledge question @@ -39,6 +51,8 @@ Configured search roots for this agent (absolute paths; default search scope when `paths` is omitted): {configured_roots} + +{file_catalog_section} """ @@ -58,6 +72,64 @@ def _resolved_localsearch_paths_from_config(config) -> List[str]: return out +def _work_path_from_config(config) -> Path: + block = effective_localsearch_settings(config) + if not block: + return Path('.sirchmunk').expanduser().resolve() + wp = block.get('work_path', './.sirchmunk') if hasattr( + block, 'get') else getattr(block, 'work_path', './.sirchmunk') + return Path(str(wp)).expanduser().resolve() + + +def _truncate_catalog_text(text: str, max_chars: int) -> str: + """Truncate catalog text to ``max_chars``, preserving the directory tree section + and truncating the file-summary section on entry boundaries. + + The catalog has two sections: + 1. Directory structure (``#### Directory structure of ...``) + 2. File summaries (``#### File summaries ...``) — entries start with ``- `` + + Strategy: always keep the full directory tree (it fits in a few hundred chars + per root); truncate only the file-summary entries within the remaining budget. + """ + if max_chars <= 0 or len(text) <= max_chars: + return text + + import re + + # Locate the first file-summary section header. + summary_header_m = re.search(r'^#### File summaries', text, re.MULTILINE) + if summary_header_m is None: + # No file summaries — just hard-truncate. + return text[: max_chars - 24].rstrip() + '\n\n… (truncated)' + + # Everything up to and including the file-summary header line is the + # "prefix" we always keep. + header_section_end = text.find('\n', summary_header_m.end()) + if header_section_end == -1: + return text[: max_chars - 24].rstrip() + '\n\n… (truncated)' + + prefix = text[: header_section_end + 1] + body = text[header_section_end + 1:] + + # Split body into individual entry lines (each starts with "- "). + parts = re.split(r'(?=^- )', body, flags=re.MULTILINE) + parts = [p for p in parts if p.strip()] + + budget = max_chars - len(prefix) - 50 # reserve space for trailing note + kept: list[str] = [] + used = 0 + for part in parts: + if used + len(part) > budget: + break + kept.append(part) + used += len(part) + + omitted = len(parts) - len(kept) + suffix = f'\n… ({omitted} more files not shown)' if omitted > 0 else '' + return prefix + ''.join(kept).rstrip() + suffix + + def _format_configured_roots(paths: List[str]) -> str: if not paths: return ( @@ -93,10 +165,36 @@ def __init__(self, config, **kwargs): self._searcher: Optional[SirchmunkSearch] = None self._configured_roots: List[str] = ( _resolved_localsearch_paths_from_config(config)) + block = effective_localsearch_settings(config) + self._catalog_enabled, self._catalog_opts = description_catalog_settings( + block) + self._work_path = _work_path_from_config(config) + self._catalog_text: str = '' + self._catalog_build_error: Optional[str] = None + + def _file_catalog_section(self) -> str: + if not self._catalog_enabled: + return '' + err = self._catalog_build_error + if err: + return ( + '\n\n## Local knowledge catalog\n' + f'_(Catalog build failed: {err})_\n') + body = (self._catalog_text or '').strip() + if not body: + return ( + '\n\n## Local knowledge catalog\n' + '_(No scannable files or catalog empty.)_\n') + shown = _truncate_catalog_text(body, self._catalog_opts['max_chars']) + return ( + '\n\n## Local knowledge catalog (shallow scan)\n' + 'Brief previews of files under the configured roots; call this tool ' + 'with a `query` for full search.\n\n' + shown + '\n') def _tool_description(self) -> str: return _LOCALSEARCH_DESCRIPTION.format( - configured_roots=_format_configured_roots(self._configured_roots)) + configured_roots=_format_configured_roots(self._configured_roots), + file_catalog_section=self._file_catalog_section()) def _paths_param_description(self) -> str: roots = _format_configured_roots(self._configured_roots) @@ -112,7 +210,63 @@ def _ensure_searcher(self) -> SirchmunkSearch: return self._searcher async def connect(self) -> None: - return None + self._catalog_build_error = None + self._catalog_text = '' + if not self._catalog_enabled: + return + roots = [r for r in self._configured_roots if r] + if not roots: + self._catalog_build_error = 'no configured roots' + return + o = self._catalog_opts + fp = catalog_fingerprint( + roots, + o['max_files'], + o['max_depth'], + o['max_preview_chars'], + o['max_chars'], + o['exclude_extra'], + ) + cache_path = catalog_cache_path(self._work_path, fp) + ttl = float(o['cache_ttl_seconds']) + t0 = time.monotonic() + cached = load_cached_catalog(cache_path, ttl) + if cached is not None: + elapsed = time.monotonic() - t0 + self._catalog_text = cached + logger.info( + f'localsearch catalog: loaded from cache in {elapsed:.3f}s ' + f'({len(cached)} chars) roots={roots}') + return + try: + built = await build_file_catalog_text( + roots, + max_files=o['max_files'], + max_depth=o['max_depth'], + max_preview_chars=o['max_preview_chars'], + exclude_extra=o['exclude_extra'], + max_file_size_mb=o['max_file_size_mb'], + oversized_pdf_timeout_s=o['oversized_pdf_timeout_s'], + max_chars=o['max_chars'], + ) + elapsed = time.monotonic() - t0 + self._catalog_text = built + logger.info( + f'localsearch catalog: scanned in {elapsed:.3f}s ' + f'({len(built)} chars) roots={roots}') + if ttl > 0 and built.strip(): + try: + save_cached_catalog(cache_path, built) + except OSError as exc: + logger.debug(f'localsearch catalog cache write failed: {exc}') + except ImportError as exc: + elapsed = time.monotonic() - t0 + self._catalog_build_error = str(exc) + logger.warning(f'localsearch description_catalog ({elapsed:.3f}s): {exc}') + except Exception as exc: + elapsed = time.monotonic() - t0 + self._catalog_build_error = str(exc) + logger.warning(f'localsearch description_catalog scan failed ({elapsed:.3f}s): {exc}') async def _get_tools_inner(self) -> Dict[str, List[Tool]]: return { @@ -280,3 +434,60 @@ async def call_tool(self, server_name: str, *, tool_name: str, logger.warning(f'localsearch failed: {exc}') return f'Local search failed: {exc}' + async def call_tool_streaming(self, server_name: str, *, tool_name: str, + tool_args: dict): + """Streaming variant: yield log lines while searching, then yield final result. + + Intermediate yields are plain strings (log lines). + The final yield is the result dict (or error string) from call_tool. + + Timeout semantics: the caller should treat the absence of any yield + within 30 s as a hang and cancel the task. + """ + log_queue: asyncio.Queue = asyncio.Queue() + + # Register the streaming callback on the searcher so sirchmunk pushes + # log lines into our queue as they are emitted. + async def _on_log(entry: str): + await log_queue.put(entry) + + # We need the searcher to exist before we can register the callback. + # _ensure_searcher() is synchronous and cheap if already initialized. + try: + searcher = self._ensure_searcher() + searcher.enable_streaming_logs(_on_log) + except Exception as exc: + yield f'Local search failed: {exc}' + return + + # Sentinel placed in the queue by the search coroutine when done. + _DONE = object() + + async def _run_search(): + try: + result = await self.call_tool( + server_name, tool_name=tool_name, tool_args=tool_args) + except Exception as exc: + result = f'Local search failed: {exc}' + await log_queue.put(_DONE) + await log_queue.put(result) + + search_task = asyncio.create_task(_run_search()) + + try: + while True: + item = await log_queue.get() + if item is _DONE: + # Next item is the final result. + final = await log_queue.get() + yield final + break + # Intermediate log line. + yield item + finally: + search_task.cancel() + try: + await search_task + except asyncio.CancelledError: + pass + diff --git a/ms_agent/tools/search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py index cd86f9d63..800e3f41b 100644 --- a/ms_agent/tools/search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -324,11 +324,14 @@ async def retrieve(self, self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, - '_reused_from_cache', False) - if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit = any( + 'Found similar cluster' in entry + or 'Reused existing knowledge cluster' in entry + for entry in self._search_logs + ) + if hasattr(result.cluster, 'last_modified'): self._cluster_cache_hit_time = getattr( - result.cluster, 'updated_at', None) + result.cluster, 'last_modified', None) return self._parse_search_result(result, score_threshold, limit) @@ -419,11 +422,19 @@ async def query( self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, - '_reused_from_cache', False) - if hasattr(result.cluster, 'updated_at'): + # Detect cluster reuse from search logs: sirchmunk emits + # "[SUCCESS] Found similar cluster: ..." or + # "[SUCCESS] Reused existing knowledge cluster" when a cached + # cluster is reused. KnowledgeCluster has no _reused_from_cache + # attribute, so log-based detection is the correct approach. + self._cluster_cache_hit = any( + 'Found similar cluster' in entry + or 'Reused existing knowledge cluster' in entry + for entry in self._search_logs + ) + if hasattr(result.cluster, 'last_modified'): self._cluster_cache_hit_time = getattr( - result.cluster, 'updated_at', None) + result.cluster, 'last_modified', None) self._last_search_result = self._parse_search_result( result, score_threshold=0.7, limit=5) diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index e7885b92d..b0cad08f4 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -7,7 +7,7 @@ import uuid from copy import copy from types import TracebackType -from typing import Any, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional import json from ms_agent.llm.utils import Tool, ToolCall @@ -247,11 +247,138 @@ async def single_call_tool(self, tool_info: ToolCall): logger.warning(traceback.format_exc()) return f'Tool calling failed: {brief_info}, details: {str(e)}' + async def single_call_tool_streaming( + self, tool_info: ToolCall) -> AsyncGenerator: + """Streaming variant of single_call_tool. + + Yields (tool_call_id, item) pairs where item is either: + - a str: intermediate log line + - a str/dict (final): the tool result (same as single_call_tool return value) + + The final item is always the last yield; it is distinguished from log + lines by being a dict, or by being the item after which the generator + stops. Callers can rely on the generator exhausting after the final + result. + + Timeout: if no yield arrives within self.tool_call_timeout seconds the + tool is considered hung and a timeout error string is yielded as the + final result. + """ + if self._concurrent_limiter is None: + if self._init_lock is None: + self._init_lock = asyncio.Lock() + async with self._init_lock: + if self._concurrent_limiter is None: + self._concurrent_limiter = asyncio.Semaphore( + MAX_CONCURRENT_TOOLS) + + call_id = tool_info.get('id', '') + brief_info = json.dumps(tool_info, ensure_ascii=False) + if len(brief_info) > 1024: + brief_info = brief_info[:1024] + '...' + + async with self._concurrent_limiter: + try: + tool_name = tool_info['tool_name'] + tool_args = tool_info['arguments'] + while isinstance(tool_args, str): + try: + tool_args = json.loads(tool_args) + except Exception: # noqa + yield call_id, f'The input {tool_args} is not a valid JSON, fix your arguments and try again' + return + assert tool_name in self._tool_index, \ + f'Tool name {tool_name} not found' + tool_ins, server_name, _ = self._tool_index[tool_name] + call_args = tool_args + if isinstance(tool_ins, AgentTool): + call_args = dict(tool_args or {}) + call_args['__call_id'] = call_id or str(uuid.uuid4()) + + # Use the streaming variant; default impl wraps call_tool. + gen = tool_ins.call_tool_streaming( + server_name, + tool_name=tool_name.split(self.TOOL_SPLITER)[1], + tool_args=call_args) + + # Enforce per-item timeout: if no item arrives within + # tool_call_timeout seconds the tool is considered hung. + timeout = self.tool_call_timeout + last_item = None + while True: + try: + item = await asyncio.wait_for( + gen.__anext__(), timeout=timeout) + except StopAsyncIteration: + # Generator exhausted normally; last_item is the result. + if last_item is not None: + yield call_id, last_item + return + except asyncio.TimeoutError: + import traceback + logger.warning(traceback.format_exc()) + yield call_id, f'Execute tool call timeout: {brief_info}' + return + except Exception as e: + import traceback + logger.warning(traceback.format_exc()) + yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}' + return + + if last_item is not None: + # Emit previous item as an intermediate log. + yield call_id, last_item + last_item = item + + # Once we have the first yield (any kind), relax the + # per-item timeout to avoid cutting off long-running but + # active searches. + timeout = max(timeout, 120) + + except Exception as e: + import traceback + logger.warning(traceback.format_exc()) + yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}' + async def parallel_call_tool(self, tool_list: List[ToolCall]): tasks = [self.single_call_tool(tool) for tool in tool_list] result = await asyncio.gather(*tasks) return result + async def parallel_call_tool_streaming( + self, tool_list: List[ToolCall]) -> AsyncGenerator: + """Run all tools concurrently; yield (call_id, item) as they arrive. + + Items are interleaved in arrival order. The caller must track call_id + to associate intermediate logs with their tool. Each tool's final + result is a dict or non-log string; intermediate logs are plain strings + emitted before the final result. + """ + # Shared queue: producers push (call_id, item); sentinel signals done. + queue: asyncio.Queue = asyncio.Queue() + _DONE = object() + + async def _producer(tool_info: ToolCall): + async for call_id, item in self.single_call_tool_streaming( + tool_info): + await queue.put((call_id, item)) + await queue.put(_DONE) + + tasks = [asyncio.create_task(_producer(t)) for t in tool_list] + remaining = len(tasks) + + try: + while remaining > 0: + item = await queue.get() + if item is _DONE: + remaining -= 1 + else: + yield item + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + async def __aenter__(self) -> 'ToolManager': return self diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py index 705c7184e..9fd1737eb 100644 --- a/tests/knowledge_search/test_sirschmunk.py +++ b/tests/knowledge_search/test_sirschmunk.py @@ -11,6 +11,14 @@ import unittest from pathlib import Path + +def _sirchmunk_dir_scanner_available() -> bool: + try: + import sirchmunk.scan.dir_scanner # noqa: F401 + return True + except ImportError: + return False + from ms_agent.agent import LLMAgent from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch from ms_agent.llm.utils import Message @@ -81,7 +89,7 @@ async def run(): Message(role='system', content='You are a helper.'), Message(role='user', content=original), ] - await agent.run(messages) + messages = await agent.run(messages) return messages messages = asyncio.run(run()) @@ -105,6 +113,29 @@ async def run(): f'Expected localsearch in tools, got: {names}', ) + @unittest.skipUnless( + _sirchmunk_dir_scanner_available(), + 'sirchmunk scan not installed', + ) + def test_localsearch_description_catalog_injects_file_preview(self): + """Optional: shallow DirectoryScanner summaries appear in tool description.""" + + async def run(): + config = self._base_config() + config.tools.localsearch['description_catalog'] = True + config.tools.localsearch['description_catalog_cache_ttl_seconds'] = 0 + tm = ToolManager(config, trust_remote_code=False) + await tm.connect() + tools = await tm.get_tools() + await tm.cleanup() + return tools + + tools = asyncio.run(run()) + loc = next(t for t in tools if t['tool_name'].endswith('localsearch')) + desc = loc.get('description') or '' + self.assertIn('Local knowledge catalog', desc) + self.assertIn('UserManager', desc) + @unittest.skipUnless( os.getenv('TEST_SIRCHMUNK_SMOKE', ''), 'Set TEST_SIRCHMUNK_SMOKE=1 to run sirchmunk API smoke test', From 8f503a3538059e6285b3dadf2a7f4788eb5bb0f6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 27 Mar 2026 15:08:04 +0800 Subject: [PATCH 09/12] fix comment --- ms_agent/agent/llm_agent.py | 18 +++------------ ms_agent/tools/base.py | 19 +++++++++++----- ms_agent/tools/tool_manager.py | 40 +++++++++++++++++----------------- 3 files changed, 37 insertions(+), 40 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 606427de8..04ed2ca89 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -556,11 +556,10 @@ async def parallel_tool_call_streaming( # Accumulate final results keyed by call_id. final_results: dict = {} - async for call_id, item in self.tool_manager.parallel_call_tool_streaming( + async for call_id, item, is_final in self.tool_manager.parallel_call_tool_streaming( tool_calls): - if isinstance(item, dict) or not isinstance(item, str) or \ - not self._is_log_line(item): - # Final result for this call_id. + if is_final: + # Final result for this call_id (any type; not inferred from content). final_results[call_id] = item else: # Intermediate log line: emit a snapshot with searching_detail_delta. @@ -596,17 +595,6 @@ async def parallel_tool_call_streaming( yield messages - @staticmethod - def _is_log_line(item: str) -> bool: - """Heuristic: a plain log string from sirchmunk is not a timeout/error result.""" - if item.startswith('Execute tool call timeout:'): - return False - if item.startswith('Tool calling failed:'): - return False - if item.startswith('The input ') and 'is not a valid JSON' in item: - return False - return True - async def prepare_tools(self): """Initialize and connect the tool manager.""" self.tool_manager = ToolManager( diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index f3fd18515..aed0b9768 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -94,11 +94,20 @@ async def call_tool_streaming(self, server_name: str, *, tool_name: str, tool_args: dict): """Streaming variant of call_tool. - Yields incremental log strings (str) during execution, then yields the - final result (str or dict) as the last item. - - Default implementation simply delegates to call_tool (no streaming). - Override in subclasses that support real-time log emission. + Contract for overrides: + - Emit zero or more intermediate updates as strings (UI / log lines). + - The **last** value yielded before the async generator finishes must + be the same shape as ``call_tool`` would return: a ``str`` or a + ``dict`` understood by ``ToolResult.from_raw`` (e.g. with a + ``result`` / ``text`` field). + + Callers inside this package (``ToolManager.single_call_tool_streaming``) + do **not** infer \"final vs log\" from types: they mark the terminal + emission explicitly. Short string finals (e.g. ``\"OK\"``) are therefore + valid and unambiguous. + + Default implementation: no intermediate yields; a single final yield + from ``call_tool``. """ result = await self.call_tool( server_name, tool_name=tool_name, tool_args=tool_args) diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index b0cad08f4..aecd0416b 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -251,14 +251,13 @@ async def single_call_tool_streaming( self, tool_info: ToolCall) -> AsyncGenerator: """Streaming variant of single_call_tool. - Yields (tool_call_id, item) pairs where item is either: - - a str: intermediate log line - - a str/dict (final): the tool result (same as single_call_tool return value) + Yields (tool_call_id, item, is_final) triples: + - is_final False: intermediate log line (str). + - is_final True: final tool result (same shape as single_call_tool), + including error/timeout strings. - The final item is always the last yield; it is distinguished from log - lines by being a dict, or by being the item after which the generator - stops. Callers can rely on the generator exhausting after the final - result. + Callers must use is_final to tell logs from short string results (e.g. + \"OK\"); content-based heuristics are unreliable. Timeout: if no yield arrives within self.tool_call_timeout seconds the tool is considered hung and a timeout error string is yielded as the @@ -285,7 +284,9 @@ async def single_call_tool_streaming( try: tool_args = json.loads(tool_args) except Exception: # noqa - yield call_id, f'The input {tool_args} is not a valid JSON, fix your arguments and try again' + yield (call_id, + f'The input {tool_args} is not a valid JSON, fix your arguments and try again', + True) return assert tool_name in self._tool_index, \ f'Tool name {tool_name} not found' @@ -312,22 +313,22 @@ async def single_call_tool_streaming( except StopAsyncIteration: # Generator exhausted normally; last_item is the result. if last_item is not None: - yield call_id, last_item + yield call_id, last_item, True return except asyncio.TimeoutError: import traceback logger.warning(traceback.format_exc()) - yield call_id, f'Execute tool call timeout: {brief_info}' + yield call_id, f'Execute tool call timeout: {brief_info}', True return except Exception as e: import traceback logger.warning(traceback.format_exc()) - yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}' + yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}', True return if last_item is not None: # Emit previous item as an intermediate log. - yield call_id, last_item + yield call_id, last_item, False last_item = item # Once we have the first yield (any kind), relax the @@ -338,7 +339,7 @@ async def single_call_tool_streaming( except Exception as e: import traceback logger.warning(traceback.format_exc()) - yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}' + yield call_id, f'Tool calling failed: {brief_info}, details: {str(e)}', True async def parallel_call_tool(self, tool_list: List[ToolCall]): tasks = [self.single_call_tool(tool) for tool in tool_list] @@ -347,21 +348,20 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]): async def parallel_call_tool_streaming( self, tool_list: List[ToolCall]) -> AsyncGenerator: - """Run all tools concurrently; yield (call_id, item) as they arrive. + """Run all tools concurrently; yield (call_id, item, is_final) as they arrive. Items are interleaved in arrival order. The caller must track call_id - to associate intermediate logs with their tool. Each tool's final - result is a dict or non-log string; intermediate logs are plain strings - emitted before the final result. + to associate intermediate logs with their tool. is_final distinguishes + the last result for that tool from streaming log lines. """ - # Shared queue: producers push (call_id, item); sentinel signals done. + # Shared queue: producers push (call_id, item, is_final); sentinel signals done. queue: asyncio.Queue = asyncio.Queue() _DONE = object() async def _producer(tool_info: ToolCall): - async for call_id, item in self.single_call_tool_streaming( + async for call_id, item, is_final in self.single_call_tool_streaming( tool_info): - await queue.put((call_id, item)) + await queue.put((call_id, item, is_final)) await queue.put(_DONE) tasks = [asyncio.create_task(_producer(t)) for t in tool_list] From 38eb680800d0335d5db66758563df7ee52f664da Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 27 Mar 2026 15:13:45 +0800 Subject: [PATCH 10/12] fix lint --- ms_agent/agent/llm_agent.py | 9 +- ms_agent/cli/run.py | 4 +- ms_agent/tools/search/localsearch_catalog.py | 137 +++++++++++++------ ms_agent/tools/search/localsearch_tool.py | 98 +++++++------ ms_agent/tools/search/sirchmunk_search.py | 25 ++-- ms_agent/tools/tool_manager.py | 10 +- 6 files changed, 166 insertions(+), 117 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 04ed2ca89..96760ad9e 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -567,8 +567,8 @@ async def parallel_tool_call_streaming( role='tool', content='', tool_call_id=call_id, - name=call_id_to_query.get(call_id, {}).get( - 'tool_name', ''), + name=call_id_to_query.get(call_id, + {}).get('tool_name', ''), searching_detail_delta=item, ) yield messages + [log_message] @@ -958,9 +958,8 @@ async def step( # Use the streaming variant so intermediate tool logs are yielded # back to the caller while the tools are still running. async for messages in self.parallel_tool_call_streaming(messages): - is_final = ( - not messages[-1].searching_detail_delta - if hasattr(messages[-1], 'searching_detail_delta') else True) + is_final = (not messages[-1].searching_detail_delta if hasattr( + messages[-1], 'searching_detail_delta') else True) if not is_final: yield messages diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 2accdb40e..e900bb616 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -152,9 +152,7 @@ def define_args(parsers: argparse.ArgumentParser): required=False, type=str, default=None, - help= - 'Comma-separated list of paths for knowledge search.' - ) + help='Comma-separated list of paths for knowledge search.') parser.set_defaults(func=subparser_func) def execute(self): diff --git a/ms_agent/tools/search/localsearch_catalog.py b/ms_agent/tools/search/localsearch_catalog.py index 9dd1b2bf7..3ce840802 100644 --- a/ms_agent/tools/search/localsearch_catalog.py +++ b/ms_agent/tools/search/localsearch_catalog.py @@ -2,13 +2,13 @@ """Build a compact file catalog for localsearch tool descriptions using sirchmunk's DirectoryScanner.""" from __future__ import annotations - import hashlib -import json import time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple +import json + def catalog_cache_path(work_path: Path, fingerprint: str) -> Path: work_path.mkdir(parents=True, exist_ok=True) @@ -31,7 +31,8 @@ def catalog_fingerprint( 'max_chars': max_chars, 'exclude': sorted(exclude or []), } - raw = json.dumps(payload, sort_keys=True, ensure_ascii=False).encode('utf-8') + raw = json.dumps( + payload, sort_keys=True, ensure_ascii=False).encode('utf-8') return hashlib.sha256(raw).hexdigest()[:24] @@ -102,41 +103,84 @@ def description_catalog_settings(block: Any) -> Tuple[bool, Dict[str, Any]]: """Parse ``tools.localsearch`` (or legacy ``knowledge_search``) catalog options.""" enabled = _bool_from_block(block, 'description_catalog', True) opts = { - 'max_files': max(1, min(2000, _int_from_block( - block, 'description_catalog_max_files', 120))), - 'max_depth': max(1, min(20, _int_from_block( - block, 'description_catalog_max_depth', 5))), - 'max_chars': max(500, min(100_000, _int_from_block( - block, 'description_catalog_max_chars', 3_000))), - 'max_preview_chars': max(80, min(4000, _int_from_block( - block, 'description_catalog_max_preview_chars', 400))), - 'cache_ttl_seconds': max(0, _int_from_block( - block, 'description_catalog_cache_ttl_seconds', 300)), - 'exclude_extra': _exclude_from_block(block), + 'max_files': + max( + 1, + min(2000, + _int_from_block(block, 'description_catalog_max_files', 120))), + 'max_depth': + max( + 1, + min(20, _int_from_block(block, 'description_catalog_max_depth', + 5))), + 'max_chars': + max( + 500, + min(100_000, + _int_from_block(block, 'description_catalog_max_chars', + 3_000))), + 'max_preview_chars': + max( + 80, + min( + 4000, + _int_from_block(block, 'description_catalog_max_preview_chars', + 400))), + 'cache_ttl_seconds': + max( + 0, + _int_from_block(block, 'description_catalog_cache_ttl_seconds', + 300)), + 'exclude_extra': + _exclude_from_block(block), # Files larger than this are skipped during catalog scan (default 50 MB). # Set to 0 to disable the cap. - 'max_file_size_mb': max(0, _int_from_block( - block, 'description_catalog_max_file_size_mb', 50)), + 'max_file_size_mb': + max(0, + _int_from_block(block, 'description_catalog_max_file_size_mb', + 50)), # Wall-clock timeout (seconds) for extracting the first page of an # oversized PDF. Corrupt or pathological PDFs are abandoned after this # and only their filename + size appear in the catalog. - 'oversized_pdf_timeout_s': max(0.1, _int_from_block( - block, 'description_catalog_oversized_pdf_timeout_s', 1)), + 'oversized_pdf_timeout_s': + max( + 0.1, + _int_from_block(block, + 'description_catalog_oversized_pdf_timeout_s', 1)), } return enabled, opts _DIR_SKIP: Set[str] = { - '.git', '.svn', 'node_modules', '__pycache__', '.idea', '.vscode', - '.cache', '.tox', '.eggs', 'dist', 'build', '.DS_Store', + '.git', + '.svn', + 'node_modules', + '__pycache__', + '.idea', + '.vscode', + '.cache', + '.tox', + '.eggs', + 'dist', + 'build', + '.DS_Store', } # Directories that typically contain generated/compiled artifacts, not source. # These are skipped by default to keep the tree focused on meaningful content. _DIR_SKIP_GENERATED: Set[str] = { - '_build', '_static', '_templates', '_sphinx_design_static', - '__pycache__', '.mypy_cache', '.pytest_cache', '.ruff_cache', - 'htmlcov', 'site-packages', 'egg-info', '.egg-info', + '_build', + '_static', + '_templates', + '_sphinx_design_static', + '__pycache__', + '.mypy_cache', + '.pytest_cache', + '.ruff_cache', + 'htmlcov', + 'site-packages', + 'egg-info', + '.egg-info', } @@ -170,33 +214,38 @@ def _build_dir_tree( skip.update(exclude) # --- Pass 1: DFS, collect (depth, seq, line_text) --- - collected: List[tuple] = [] # (depth, seq, line_text) + collected: List[tuple] = [] # (depth, seq, line_text) seq = [0] def _dfs(p: Path, depth: int, indent: str) -> None: if depth > max_depth: return try: - entries = sorted(p.iterdir(), key=lambda e: (e.is_file(), e.name.lower())) + entries = sorted( + p.iterdir(), key=lambda e: (e.is_file(), e.name.lower())) except PermissionError: return - dirs = [e for e in entries if e.is_dir() - and not e.name.startswith('.') and e.name not in skip] - files = [e for e in entries if e.is_file() - and not e.name.startswith('.') and e.name not in skip] + dirs = [ + e for e in entries + if e.is_dir() and not e.name.startswith('.') and e.name not in skip + ] + files = [ + e for e in entries if e.is_file() and not e.name.startswith('.') + and e.name not in skip + ] child_indent = indent + ' ' for d in dirs: try: - file_count = sum( - 1 for _ in d.iterdir() - if _.is_file() and not _.name.startswith('.')) + file_count = sum(1 for _ in d.iterdir() + if _.is_file() and not _.name.startswith('.')) except PermissionError: file_count = 0 count_hint = f' ({file_count} files)' if file_count else '' - collected.append((depth, seq[0], f'{indent}📁 {d.name}/{count_hint}')) + collected.append( + (depth, seq[0], f'{indent}📁 {d.name}/{count_hint}')) seq[0] += 1 _dfs(d, depth + 1, child_indent) @@ -238,7 +287,8 @@ def _dfs(p: Path, depth: int, indent: str) -> None: return result -def _compact_file_summary(candidate: Any, root_dir: str, max_preview: int) -> str: +def _compact_file_summary(candidate: Any, root_dir: str, + max_preview: int) -> str: """Single-line or two-line compact summary for a FileCandidate. Format: @@ -296,9 +346,8 @@ async def build_file_catalog_text( try: from sirchmunk.scan.dir_scanner import DirectoryScanner except ImportError as e: - raise ImportError( - 'sirchmunk is required for description_catalog. ' - f'Import failed: {e}') from e + raise ImportError('sirchmunk is required for description_catalog. ' + f'Import failed: {e}') from e # Tree gets 60% of the budget; file summaries get the remaining 40%. # Each root shares the budget equally. @@ -306,7 +355,8 @@ async def build_file_catalog_text( per_root_budget = max(500, max_chars // num_roots) tree_budget = max(300, int(per_root_budget * 0.60)) - max_file_size_bytes = (max_file_size_mb * 1024 * 1024) if max_file_size_mb > 0 else None + max_file_size_bytes = (max_file_size_mb * 1024 + * 1024) if max_file_size_mb > 0 else None # Merge the built-in skip sets with any user-provided excludes so the # scanner also skips generated/artifact directories. scanner_exclude: List[str] = sorted(_DIR_SKIP | _DIR_SKIP_GENERATED) @@ -331,7 +381,10 @@ async def build_file_catalog_text( # --- Section 1: directory tree (fast, no content reads) --- tree = _build_dir_tree( - p, max_depth=max_depth, exclude=exclude_extra, max_chars=tree_budget) + p, + max_depth=max_depth, + exclude=exclude_extra, + max_chars=tree_budget) tree_block = f'#### Directory structure of `{p}`\n{tree}' if tree else '' # --- Section 2: per-file compact summaries --- @@ -364,7 +417,8 @@ async def build_file_catalog_text( total_scanned = len(all_candidates) header = f'#### File summaries ({len(file_lines)} of {total_scanned} sampled)' if omitted: - file_block = header + '\n' + '\n'.join(file_lines) + f'\n… ({omitted} more files not shown)' + file_block = header + '\n' + '\n'.join( + file_lines) + f'\n… ({omitted} more files not shown)' else: file_block = header + '\n' + '\n'.join(file_lines) @@ -374,7 +428,8 @@ async def build_file_catalog_text( return '\n\n---\n\n'.join(sections).strip() -def _stratified_sample(candidates: List[Any], root: Path, max_entries: int) -> List[Any]: +def _stratified_sample(candidates: List[Any], root: Path, + max_entries: int) -> List[Any]: """Return a stratified sample of candidates so every subdirectory gets representation, rather than simply taking the first N by path. diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py index a0fb305b9..f2a2b7b22 100644 --- a/ms_agent/tools/search/localsearch_tool.py +++ b/ms_agent/tools/search/localsearch_tool.py @@ -2,25 +2,18 @@ """On-demand local codebase search via sirchmunk (replaces pre-turn RAG injection).""" import asyncio -import json import time from pathlib import Path from typing import Any, AsyncGenerator, Dict, List, Optional -from ms_agent.tools.search.localsearch_catalog import ( - build_file_catalog_text, - catalog_cache_path, - catalog_fingerprint, - description_catalog_settings, - load_cached_catalog, - save_cached_catalog, -) -from ms_agent.tools.search.sirchmunk_search import ( - SirchmunkSearch, - effective_localsearch_settings, -) +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase +from ms_agent.tools.search.localsearch_catalog import ( + build_file_catalog_text, catalog_cache_path, catalog_fingerprint, + description_catalog_settings, load_cached_catalog, save_cached_catalog) +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, effective_localsearch_settings) from ms_agent.utils.logger import get_logger logger = get_logger() @@ -101,15 +94,15 @@ def _truncate_catalog_text(text: str, max_chars: int) -> str: summary_header_m = re.search(r'^#### File summaries', text, re.MULTILINE) if summary_header_m is None: # No file summaries — just hard-truncate. - return text[: max_chars - 24].rstrip() + '\n\n… (truncated)' + return text[:max_chars - 24].rstrip() + '\n\n… (truncated)' # Everything up to and including the file-summary header line is the # "prefix" we always keep. header_section_end = text.find('\n', summary_header_m.end()) if header_section_end == -1: - return text[: max_chars - 24].rstrip() + '\n\n… (truncated)' + return text[:max_chars - 24].rstrip() + '\n\n… (truncated)' - prefix = text[: header_section_end + 1] + prefix = text[:header_section_end + 1] body = text[header_section_end + 1:] # Split body into individual entry lines (each starts with "- "). @@ -132,9 +125,8 @@ def _truncate_catalog_text(text: str, max_chars: int) -> str: def _format_configured_roots(paths: List[str]) -> str: if not paths: - return ( - '(none — set tools.localsearch.paths in agent config, ' - 'or legacy knowledge_search.paths)') + return ('(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') return '\n'.join(f'- {p}' for p in paths) @@ -159,7 +151,8 @@ class LocalSearchTool(ToolBase): def __init__(self, config, **kwargs): super().__init__(config) tools_root = getattr(config, 'tools', None) - tool_cfg = getattr(tools_root, 'localsearch', None) if tools_root else None + tool_cfg = getattr(tools_root, 'localsearch', + None) if tools_root else None if tool_cfg is not None: self.exclude_func(tool_cfg) self._searcher: Optional[SirchmunkSearch] = None @@ -177,14 +170,12 @@ def _file_catalog_section(self) -> str: return '' err = self._catalog_build_error if err: - return ( - '\n\n## Local knowledge catalog\n' - f'_(Catalog build failed: {err})_\n') + return ('\n\n## Local knowledge catalog\n' + f'_(Catalog build failed: {err})_\n') body = (self._catalog_text or '').strip() if not body: - return ( - '\n\n## Local knowledge catalog\n' - '_(No scannable files or catalog empty.)_\n') + return ('\n\n## Local knowledge catalog\n' + '_(No scannable files or catalog empty.)_\n') shown = _truncate_catalog_text(body, self._catalog_opts['max_chars']) return ( '\n\n## Local knowledge catalog (shallow scan)\n' @@ -251,22 +242,25 @@ async def connect(self) -> None: ) elapsed = time.monotonic() - t0 self._catalog_text = built - logger.info( - f'localsearch catalog: scanned in {elapsed:.3f}s ' - f'({len(built)} chars) roots={roots}') + logger.info(f'localsearch catalog: scanned in {elapsed:.3f}s ' + f'({len(built)} chars) roots={roots}') if ttl > 0 and built.strip(): try: save_cached_catalog(cache_path, built) except OSError as exc: - logger.debug(f'localsearch catalog cache write failed: {exc}') + logger.debug( + f'localsearch catalog cache write failed: {exc}') except ImportError as exc: elapsed = time.monotonic() - t0 self._catalog_build_error = str(exc) - logger.warning(f'localsearch description_catalog ({elapsed:.3f}s): {exc}') + logger.warning( + f'localsearch description_catalog ({elapsed:.3f}s): {exc}') except Exception as exc: elapsed = time.monotonic() - t0 self._catalog_build_error = str(exc) - logger.warning(f'localsearch description_catalog scan failed ({elapsed:.3f}s): {exc}') + logger.warning( + f'localsearch description_catalog scan failed ({elapsed:.3f}s): {exc}' + ) async def _get_tools_inner(self) -> Dict[str, List[Tool]]: return { @@ -276,8 +270,7 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: server_name=_SERVER, description=self._tool_description(), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'query': { 'type': @@ -286,13 +279,11 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: 'Search keywords or natural-language question about local content.', }, 'paths': { - 'type': - 'array', + 'type': 'array', 'items': { 'type': 'string' }, - 'description': - self._paths_param_description(), + 'description': self._paths_param_description(), }, 'mode': { 'type': @@ -302,21 +293,28 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: 'Search mode; omit to use agent default (usually FAST).', }, 'max_depth': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 20, + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, 'description': 'Max directory depth for filesystem search.', }, 'top_k_files': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 20, + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, 'description': 'Max files for evidence / filename hits.', }, 'include': { - 'type': 'array', + 'type': + 'array', 'items': { 'type': 'string' }, @@ -324,7 +322,8 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: 'Glob patterns to include (e.g. *.py, *.md).', }, 'exclude': { - 'type': 'array', + 'type': + 'array', 'items': { 'type': 'string' }, @@ -373,8 +372,7 @@ async def call_tool(self, server_name: str, *, tool_name: str, if paths_arg: resolved_paths = searcher.resolve_tool_paths(paths_arg) if not resolved_paths: - roots = _format_configured_roots( - self._configured_roots) + roots = _format_configured_roots(self._configured_roots) return ( 'Error: `paths` are invalid. Each path must exist on disk and lie ' 'under one of these configured roots:\n' + roots) @@ -420,8 +418,7 @@ async def call_tool(self, server_name: str, *, tool_name: str, result_parts.append('\nSource paths:') for item in excerpts[:12]: meta = item.get('metadata') or {} - result_parts.append( - f'- {meta.get("source", "?")}') + result_parts.append(f'- {meta.get("source", "?")}') result_text = '\n'.join(result_parts) return { @@ -490,4 +487,3 @@ async def _run_search(): await search_task except asyncio.CancelledError: pass - diff --git a/ms_agent/tools/search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py index 800e3f41b..7aa09b5e6 100644 --- a/ms_agent/tools/search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -6,10 +6,10 @@ """ import asyncio -import json from pathlib import Path from typing import Any, Callable, Dict, List, Optional +import json from loguru import logger from omegaconf import DictConfig @@ -142,8 +142,8 @@ def _validate_config(self, config: DictConfig): 'tools.localsearch.paths (or legacy knowledge_search.paths) ' 'must be specified and non-empty') - def resolve_tool_paths( - self, paths: Optional[List[str]]) -> Optional[List[str]]: + def resolve_tool_paths(self, + paths: Optional[List[str]]) -> Optional[List[str]]: """Restrict per-call paths to configured search roots.""" if not paths: return None @@ -154,10 +154,10 @@ def resolve_tool_paths( continue p = Path(str(raw).strip()).expanduser().resolve() if not p.exists(): - logger.warning(f'localsearch: path does not exist, skipped: {p}') + logger.warning( + f'localsearch: path does not exist, skipped: {p}') continue - allowed = any( - p == r or p.is_relative_to(r) for r in roots) + allowed = any(p == r or p.is_relative_to(r) for r in roots) if not allowed: logger.warning( f'localsearch: path outside configured search roots, ' @@ -327,8 +327,7 @@ async def retrieve(self, self._cluster_cache_hit = any( 'Found similar cluster' in entry or 'Reused existing knowledge cluster' in entry - for entry in self._search_logs - ) + for entry in self._search_logs) if hasattr(result.cluster, 'last_modified'): self._cluster_cache_hit_time = getattr( result.cluster, 'last_modified', None) @@ -405,8 +404,9 @@ async def query( self._last_search_result = [] for item in result[:20]: if isinstance(item, dict): - src = (item.get('path') or item.get('file_path') - or item.get('file') or '') + src = ( + item.get('path') or item.get('file_path') + or item.get('file') or '') self._last_search_result.append({ 'text': json.dumps(item, ensure_ascii=False), @@ -430,8 +430,7 @@ async def query( self._cluster_cache_hit = any( 'Found similar cluster' in entry or 'Reused existing knowledge cluster' in entry - for entry in self._search_logs - ) + for entry in self._search_logs) if hasattr(result.cluster, 'last_modified'): self._cluster_cache_hit_time = getattr( result.cluster, 'last_modified', None) @@ -440,7 +439,7 @@ async def query( result, score_threshold=0.7, limit=5) if hasattr(result, 'answer') and getattr(result, 'answer', - None) is not None: + None) is not None: return result.answer if isinstance(result, str): diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index aecd0416b..7a43a154e 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -18,7 +18,8 @@ from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient from ms_agent.tools.search.localsearch_tool import LocalSearchTool -from ms_agent.tools.search.sirchmunk_search import effective_localsearch_settings +from ms_agent.tools.search.sirchmunk_search import \ + effective_localsearch_settings from ms_agent.tools.search.websearch_tool import WebSearchTool from ms_agent.tools.split_task import SplitTask from ms_agent.tools.todolist_tool import TodoListTool @@ -284,9 +285,10 @@ async def single_call_tool_streaming( try: tool_args = json.loads(tool_args) except Exception: # noqa - yield (call_id, - f'The input {tool_args} is not a valid JSON, fix your arguments and try again', - True) + yield ( + call_id, + f'The input {tool_args} is not a valid JSON, fix your arguments and try again', + True) return assert tool_name in self._tool_index, \ f'Tool name {tool_name} not found' From 7fa4996e1a9f7b7248b81c3670c2f34c966eb315 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 27 Mar 2026 15:19:23 +0800 Subject: [PATCH 11/12] rm dead code --- ms_agent/cli/run.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index e900bb616..1b8701695 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -48,22 +48,6 @@ class RunCMD(CLICommand): def __init__(self, args): self.args = args - def load_env_file(self): - """Load environment variables from .env file in current directory.""" - env_file = os.path.join(os.getcwd(), '.env') - if os.path.exists(env_file): - with open(env_file, 'r') as f: - for line in f: - line = line.strip() - if line and not line.startswith('#') and '=' in line: - key, value = line.split('=', 1) - key = key.strip() - value = value.strip() - # Only set if not already set in environment - if key not in os.environ: - os.environ[key] = value - logger.debug(f'Loaded {key} from .env file') - @staticmethod def define_args(parsers: argparse.ArgumentParser): """Define args for run command.""" From 70d3bb128d288b54ebb0f62c7c6e8311e0c72014 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 27 Mar 2026 17:51:41 +0800 Subject: [PATCH 12/12] fix --- docs/zh/Components/config.md | 2 +- ms_agent/agent/llm_agent.py | 19 ++++++++++--------- ms_agent/llm/utils.py | 8 +++----- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 49e7a133e..a3031d918 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -183,7 +183,7 @@ handler: custom_handler } } ``` -- knowledge_search_paths: 知识搜索路径,逗号分隔。会合并到 `tools.localsearch.paths` 并注册 **`localsearch`** 工具(sirchmunk),由模型按需调用,不再在每轮自动注入上下文;除非配置 `tools.localsearch.llm_*`,否则 LLM 从 `llm` 模块复用 +- knowledge_search_paths: 知识搜索路径,逗号分隔。会合并到 `tools.localsearch.paths` 并注册 **`localsearch`** 工具(sirchmunk),由模型按需调用;如未配置 `tools.localsearch.llm_*`, LLM 从 `llm` 模块复用 > agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 96760ad9e..0e32cfc3a 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -542,9 +542,9 @@ async def parallel_tool_call_streaming( """Streaming variant of parallel_tool_call. Yields messages list snapshots during tool execution: - - While tools are running: yields messages with the latest - searching_detail_delta set on a temporary placeholder Message so the - caller can stream intermediate logs to the frontend. + - While tools are running: yields messages with the latest incremental + ``tool_detail`` on a temporary placeholder Message (content='') so the + caller can stream logs to the frontend. - After all tools finish: yields the final messages list (with proper tool result Messages appended), same as parallel_tool_call. """ @@ -562,14 +562,14 @@ async def parallel_tool_call_streaming( # Final result for this call_id (any type; not inferred from content). final_results[call_id] = item else: - # Intermediate log line: emit a snapshot with searching_detail_delta. + # Intermediate log line: one incremental chunk in tool_detail. log_message = Message( role='tool', content='', tool_call_id=call_id, name=call_id_to_query.get(call_id, {}).get('tool_name', ''), - searching_detail_delta=item, + tool_detail=item, ) yield messages + [log_message] @@ -585,7 +585,6 @@ async def parallel_tool_call_streaming( tool_call_id=cid, name=tool_call_query['tool_name'], resources=tool_call_result_format.resources, - tool_detail=tool_call_result_format.tool_detail, ) if _new_message.tool_call_id is None: _new_message.tool_call_id = str(uuid.uuid4())[:8] @@ -958,9 +957,11 @@ async def step( # Use the streaming variant so intermediate tool logs are yielded # back to the caller while the tools are still running. async for messages in self.parallel_tool_call_streaming(messages): - is_final = (not messages[-1].searching_detail_delta if hasattr( - messages[-1], 'searching_detail_delta') else True) - if not is_final: + _lm = messages[-1] + _progress = ( + _lm.role == 'tool' and _lm.content == '' + and _lm.tool_detail is not None) + if _progress: yield messages await self.after_tool_call(messages) diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 05fdbb02b..08131dbd2 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,13 +61,11 @@ class Message: api_calls: int = 1 - # role=tool: extra payload for UIs / SSE only; omitted from LLM API via to_dict_clean(). + # role=tool: UI / SSE only; omitted from LLM API via to_dict_clean(). + # During parallel_tool_call_streaming, temporary placeholder messages use this for + # one incremental log line each (content=''); completed tool messages omit it here. tool_detail: Optional[str] = None - # role=tool (streaming): incremental log line emitted while the tool is still running. - # Non-empty only on intermediate yields; final yield has tool_detail set instead. - searching_detail_delta: Optional[str] = None - def to_dict(self): return asdict(self)