From a8a2cd945c0d4357012ad0ddb2bb74a7c32c424f Mon Sep 17 00:00:00 2001 From: ulleo Date: Mon, 2 Feb 2026 16:06:11 +0800 Subject: [PATCH] feat: support specifying datasource ID in MCP question --- backend/apps/chat/models/chat_model.py | 2 +- backend/apps/mcp/mcp.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index c78646fc..83626181 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -329,7 +329,7 @@ class McpQuestion(BaseModel): token: str = Body(description='token') stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) lang: Optional[str] = Body(description='语言:zh-CN|en|ko-KR', default='zh-CN') - datasource_id: Optional[int] = Body(description='数据源ID,仅当当前对话没有确定数据源时有效', default=None) + datasource_id: Optional[int | str] = Body(description='数据源ID,仅当当前对话没有确定数据源时有效', default=None) class AxisObj(BaseModel): diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index 1d6dfd6d..5947594f 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -2,6 +2,7 @@ # Date: 2025/7/1 import json from datetime import timedelta +from typing import Optional import jwt from fastapi import HTTPException, status, APIRouter @@ -113,8 +114,22 @@ async def mcp_start(session: SessionDep, chat: ChatStart): @router.post("/mcp_question", operation_id="mcp_question") async def mcp_question(session: SessionDep, chat: McpQuestion): session_user = get_user(session, chat.token) - - mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question, datasource_id=chat.datasource_id) + ds_id: Optional[int] = None + if chat.datasource_id: + if isinstance(chat.datasource_id, str): + if chat.datasource_id.strip() == "": + ds_id = None + else: + try: + ds_id = int(chat.datasource_id.strip()) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid datasource ID") + elif isinstance(chat.datasource_id, int): + ds_id = chat.datasource_id + else: + raise HTTPException(status_code=400, detail="Invalid datasource ID") + + mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question, datasource_id=ds_id) return await question_answer_inner(session=session, current_user=session_user, request_question=mcp_chat, in_chat=False, stream=chat.stream)