From 09b82e9952226d8d488879582564663d4d29aa1d Mon Sep 17 00:00:00 2001 From: ulleo Date: Thu, 18 Dec 2025 19:00:25 +0800 Subject: [PATCH] feat: support command in mcp --- backend/apps/chat/api/chat.py | 111 +++++++++++++++----- backend/apps/chat/curd/chat.py | 11 ++ backend/apps/chat/task/llm.py | 157 ++++++++++++++++++++-------- backend/apps/mcp/mcp.py | 79 ++------------ backend/common/utils/data_format.py | 27 +++++ 5 files changed, 241 insertions(+), 144 deletions(-) diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 473bfad91..ce96cf232 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -8,12 +8,13 @@ from fastapi import APIRouter, HTTPException, Path from fastapi.responses import StreamingResponse from sqlalchemy import and_, select +from starlette.responses import JSONResponse from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \ format_json_data, format_json_list_data, get_chart_config, list_recent_questions from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \ - ChatInfo, Chat + ChatInfo, Chat, ChatFinishStep from apps.chat.task.llm import LLMService from apps.swagger.i18n import PLACEHOLDER_PREFIX from apps.system.schemas.permission import SqlbotPermission, require_permissions @@ -166,11 +167,18 @@ def find_base_question(record_id: int, session: SessionDep): @require_permissions(permission=SqlbotPermission(type='chat', keyExpression="request_question.chat_id")) async def question_answer(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, current_assistant: CurrentAssistant): + return await question_answer_inner(session, current_user, request_question, current_assistant, embedding=True) + + +async def question_answer_inner(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, + current_assistant: Optional[CurrentAssistant] = None, in_chat: bool = True, + stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART, embedding: bool = False): try: command, text_before_command, record_id, warning_info = parse_quick_command(request_question.question) if command: - # todo 暂不支持分析和预测,需要改造前端 - if command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA: + # todo 对话界面下,暂不支持分析和预测,需要改造前端 + if in_chat and (command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA): raise Exception(f'Command: {command.value} temporary not supported') if record_id is not None: @@ -221,53 +229,83 @@ async def question_answer(session: SessionDep, current_user: CurrentUser, reques if command == QuickCommand.REGENERATE: request_question.question = text_before_command request_question.regenerate_record_id = rec_id - return await stream_sql(session, current_user, request_question, current_assistant) + return await stream_sql(session, current_user, request_question, current_assistant, in_chat, stream, + finish_step, embedding) elif command == QuickCommand.ANALYSIS: - return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant) + return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant, in_chat, stream) elif command == QuickCommand.PREDICT_DATA: - return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant) + return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant, in_chat, stream) else: raise Exception(f'Unknown command: {command.value}') else: - return await stream_sql(session, current_user, request_question, current_assistant) + return await stream_sql(session, current_user, request_question, current_assistant, in_chat, stream, + finish_step, embedding) except Exception as e: traceback.print_exc() - def _err(_e: Exception): - yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + if stream: + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' - return StreamingResponse(_err(e), media_type="text/event-stream") + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, - current_assistant: CurrentAssistant): + current_assistant: Optional[CurrentAssistant] = None, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART, embedding: bool = False): try: llm_service = await LLMService.create(session, current_user, request_question, current_assistant, - embedding=True) + embedding=embedding) llm_service.init_record(session=session) - llm_service.run_task_async() + llm_service.run_task_async(in_chat=in_chat, stream=stream, finish_step=finish_step) except Exception as e: traceback.print_exc() - def _err(_e: Exception): - yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' - - return StreamingResponse(_err(e), media_type="text/event-stream") + if stream: + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' - return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) + if stream: + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + else: + res = llm_service.await_result() + raw_data = {} + for chunk in res: + if chunk: + raw_data = chunk + status_code = 200 + if not raw_data.get('success'): + status_code = 500 + + return JSONResponse( + content=raw_data, + status_code=status_code, + ) @router.post("/record/{chat_record_id}/{action_type}", summary=f"{PLACEHOLDER_PREFIX}analysis_or_predict") async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant, chat_record_id: int, - action_type: str = Path(..., description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")): + action_type: str = Path(..., + description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")): return await analysis_or_predict(session, current_user, chat_record_id, action_type, current_assistant) async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str, - current_assistant: CurrentAssistant): + current_assistant: CurrentAssistant, in_chat: bool = True, stream: bool = True): try: if action_type != 'analysis' and action_type != 'predict': raise Exception(f"Type {action_type} Not Found") @@ -294,16 +332,35 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch request_question = ChatQuestion(chat_id=record.chat_id, question=record.question) llm_service = await LLMService.create(session, current_user, request_question, current_assistant) - llm_service.run_analysis_or_predict_task_async(session, action_type, record) + llm_service.run_analysis_or_predict_task_async(session, action_type, record, in_chat, stream) except Exception as e: traceback.print_exc() + if stream: + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' - def _err(_e: Exception): - yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' - - return StreamingResponse(_err(e), media_type="text/event-stream") - - return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) + if stream: + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + else: + res = llm_service.await_result() + raw_data = {} + for chunk in res: + if chunk: + raw_data = chunk + status_code = 200 + if not raw_data.get('success'): + status_code = 500 + + return JSONResponse( + content=raw_data, + status_code=status_code, + ) @router.get("/record/{chat_record_id}/excel/export", summary=f"{PLACEHOLDER_PREFIX}export_chart_data") diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index f5c2c4e32..8dc391ba9 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -163,6 +163,17 @@ def format_json_list_data(origin_data: list[dict]): return data +def get_chat_chart_config(session: SessionDep, chat_record_id: int): + stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chat_record_id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.data) + except Exception: + pass + return {} + + def get_chat_chart_data(session: SessionDep, chat_record_id: int): stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id)) res = session.execute(stmt) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 3e2106b12..7db118b0a 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -30,7 +30,8 @@ save_select_datasource_answer, save_recommend_question_answer, \ get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ - get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate + get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate, get_chat_predict_data, \ + get_chat_chart_config from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ ChatFinishStep, AxisObj from apps.data_training.curd.data_training import get_training_template @@ -969,6 +970,10 @@ def run_task(self, in_chat: bool = True, stream: bool = True, 'regenerate_record_id': self.get_record().regenerate_record_id}).decode() + '\n\n' yield 'data:' + orjson.dumps( {'type': 'question', 'question': self.get_record().question}).decode() + '\n\n' + else: + if stream: + yield '> ID: ' + str(self.get_record().id) + '\n' + yield '> ' + self.get_record().question + '\n\n' if not stream: json_result['record_id'] = self.get_record().id @@ -1150,26 +1155,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True, {'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n' else: if stream: - _fields = {} - if chart.get('columns'): - for _column in chart.get('columns'): - if _column: - _fields[_column.get('value')] = _column.get('name') - if chart.get('axis'): - if chart.get('axis').get('x'): - _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name') - if chart.get('axis').get('y'): - _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name') - if chart.get('axis').get('series'): - _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get( - 'name') - _column_list = [] - for field in result.get('fields'): - _column_list.append( - AxisObj(name=field if not _fields.get(field) else _fields.get(field), value=field)) - - md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data')) - + md_data, _fields_list = DataFormat.convert_data_fields_for_pandas(chart, result.get('fields'), + result.get('data')) # data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data')) if not md_data or not _fields_list: @@ -1183,7 +1170,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, if in_chat: yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' else: - # todo generate picture + # generate picture try: if chart['type'] != 'table': yield '### generated chart picture\n\n' @@ -1255,51 +1242,123 @@ def run_recommend_questions_task(self): finally: session_maker.remove() - def run_analysis_or_predict_task_async(self, session: Session, action_type: str, base_record: ChatRecord): + def run_analysis_or_predict_task_async(self, session: Session, action_type: str, base_record: ChatRecord, + in_chat: bool = True, stream: bool = True): self.set_record(save_analysis_predict_record(session, base_record, action_type)) - self.future = executor.submit(self.run_analysis_or_predict_task_cache, action_type) + self.future = executor.submit(self.run_analysis_or_predict_task_cache, action_type, in_chat, stream) - def run_analysis_or_predict_task_cache(self, action_type: str): - for chunk in self.run_analysis_or_predict_task(action_type): + def run_analysis_or_predict_task_cache(self, action_type: str, in_chat: bool = True, stream: bool = True): + for chunk in self.run_analysis_or_predict_task(action_type, in_chat, stream): self.chunk_list.append(chunk) - def run_analysis_or_predict_task(self, action_type: str): + def run_analysis_or_predict_task(self, action_type: str, in_chat: bool = True, stream: bool = True): + json_result: Dict[str, Any] = {'success': True} _session = None try: _session = session_maker() - yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n' + if in_chat: + yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n' + else: + if stream: + yield '> ID: ' + str(self.get_record().id) + '\n' + yield '> ' + self.get_record().question + '\n\n' + if not stream: + json_result['record_id'] = self.get_record().id if action_type == 'analysis': # generate analysis analysis_res = self.generate_analysis(_session) + full_text = '' for chunk in analysis_res: - yield 'data:' + orjson.dumps( - {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), - 'type': 'analysis-result'}).decode() + '\n\n' - yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n' - - yield 'data:' + orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n' + full_text += chunk.get('content') + if in_chat: + yield 'data:' + orjson.dumps( + {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), + 'type': 'analysis-result'}).decode() + '\n\n' + else: + if stream: + yield chunk.get('content') + if in_chat: + yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n' + yield 'data:' + orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n' + else: + if stream: + yield '\n\n' + if not stream: + json_result['content'] = full_text elif action_type == 'predict': # generate predict analysis_res = self.generate_predict(_session) full_text = '' for chunk in analysis_res: - yield 'data:' + orjson.dumps( - {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), - 'type': 'predict-result'}).decode() + '\n\n' full_text += chunk.get('content') - yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n' + if in_chat: + yield 'data:' + orjson.dumps( + {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), + 'type': 'predict-result'}).decode() + '\n\n' + if in_chat: + yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n' - _data = self.check_save_predict_data(session=_session, res=full_text) - if _data: - yield 'data:' + orjson.dumps({'type': 'predict-success'}).decode() + '\n\n' - else: - yield 'data:' + orjson.dumps({'type': 'predict-failed'}).decode() + '\n\n' + has_data = self.check_save_predict_data(session=_session, res=full_text) + if has_data: + if in_chat: + yield 'data:' + orjson.dumps({'type': 'predict-success'}).decode() + '\n\n' + else: + chart = get_chat_chart_config(_session, self.record.id) + origin_data = get_chat_chart_data(_session, self.record.id) + predict_data = get_chat_predict_data(_session, self.record.id) - yield 'data:' + orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n' + if stream: + md_data, _fields_list = DataFormat.convert_data_fields_for_pandas(chart, + origin_data.get('fields'), + predict_data) + if not md_data or not _fields_list: + yield 'Predict data result is empty.\n\n' + else: + df = pd.DataFrame(md_data, columns=_fields_list) + df_safe = DataFormat.safe_convert_to_string(df) + markdown_table = df_safe.to_markdown(index=False) + yield markdown_table + '\n\n' + + else: + json_result['origin_data'] = origin_data + json_result['predict_data'] = predict_data + + # generate picture + try: + if chart['type'] != 'table': + yield '### generated chart picture\n\n' + + _data = get_chat_chart_data(_session, self.record.id) + _data['data'] = _data['data'] + predict_data + + image_url = request_picture(self.record.chat_id, self.record.id, chart, + format_json_data(_data)) + SQLBotLogUtil.info(image_url) + if stream: + yield f'![{chart["type"]}]({image_url})' + else: + json_result['image_url'] = image_url + except Exception as e: + if stream: + raise e + else: + if in_chat: + yield 'data:' + orjson.dumps({'type': 'predict-failed'}).decode() + '\n\n' + else: + if stream: + yield full_text + '\n\n' + if not stream: + json_result['success'] = False + json_result['message'] = full_text + if in_chat: + yield 'data:' + orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n' self.finish(_session) + + if not stream: + yield json_result except Exception as e: error_msg: str if isinstance(e, SingleMessageError): @@ -1308,7 +1367,15 @@ def run_analysis_or_predict_task(self, action_type: str): error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode() if _session: self.save_error(session=_session, message=error_msg) - yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' + if in_chat: + yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' + else: + if stream: + yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + else: + json_result['success'] = False + json_result['message'] = error_msg + yield json_result finally: # end session_maker.remove() diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index 0085c7fde..b6b20d2a3 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -1,22 +1,18 @@ # Author: Junjun # Date: 2025/7/1 import json -import traceback from datetime import timedelta import jwt from fastapi import HTTPException, status, APIRouter -from fastapi.responses import StreamingResponse # from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from pydantic import ValidationError from sqlmodel import select -from starlette.responses import JSONResponse -from apps.chat.api.chat import create_chat +from apps.chat.api.chat import create_chat, question_answer_inner from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion, \ ChatFinishStep -from apps.chat.task.llm import LLMService from apps.system.crud.user import authenticate from apps.system.crud.user import get_db_user from apps.system.models.system_model import UserWsModel @@ -33,7 +29,7 @@ tokenUrl=f"{settings.API_V1_STR}/login/access-token" ) -router = APIRouter(tags=["mcp"], prefix="/mcp", include_in_schema=False) +router = APIRouter(tags=["mcp"], prefix="/mcp") # @router.post("/access_token", operation_id="access_token") @@ -110,39 +106,8 @@ async def mcp_question(session: SessionDep, chat: McpQuestion): mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question) - try: - llm_service = await LLMService.create(session, session_user, mcp_chat) - llm_service.init_record(session=session) - llm_service.run_task_async(False, chat.stream) - except Exception as e: - traceback.print_exc() - - if chat.stream: - def _err(_e: Exception): - yield str(_e) + '\n\n' - - return StreamingResponse(_err(e), media_type="text/event-stream") - else: - return JSONResponse( - content={'message': str(e)}, - status_code=500, - ) - if chat.stream: - return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") - else: - res = llm_service.await_result() - raw_data = {} - for chunk in res: - if chunk: - raw_data = chunk - status_code = 200 - if not raw_data.get('success'): - status_code = 500 - - return JSONResponse( - content=raw_data, - status_code=status_code, - ) + return await question_answer_inner(session=session, current_user=session_user, request_question=mcp_chat, + in_chat=False, stream=chat.stream) @router.post("/mcp_assistant", operation_id="mcp_assistant") @@ -166,36 +131,6 @@ async def mcp_assistant(session: SessionDep, chat: McpAssistant): # assistant question mcp_chat = ChatQuestion(chat_id=c.id, question=chat.question) # ask - try: - llm_service = await LLMService.create(session, session_user, mcp_chat, mcp_assistant_header) - llm_service.init_record(session=session) - llm_service.run_task_async(False, chat.stream, ChatFinishStep.QUERY_DATA) - except Exception as e: - traceback.print_exc() - - if chat.stream: - def _err(_e: Exception): - yield str(_e) + '\n\n' - - return StreamingResponse(_err(e), media_type="text/event-stream") - else: - return JSONResponse( - content={'message': str(e)}, - status_code=500, - ) - if chat.stream: - return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") - else: - res = llm_service.await_result() - raw_data = {} - for chunk in res: - if chunk: - raw_data = chunk - status_code = 200 - if not raw_data.get('success'): - status_code = 500 - - return JSONResponse( - content=raw_data, - status_code=status_code, - ) + return await question_answer_inner(session=session, current_user=session_user, request_question=mcp_chat, + current_assistant=mcp_assistant_header, + in_chat=False, stream=chat.stream, finish_step=ChatFinishStep.QUERY_DATA) diff --git a/backend/common/utils/data_format.py b/backend/common/utils/data_format.py index 63037fa34..9dc1304a7 100644 --- a/backend/common/utils/data_format.py +++ b/backend/common/utils/data_format.py @@ -1,5 +1,8 @@ import pandas as pd +from apps.chat.models.chat_model import AxisObj + + class DataFormat: @staticmethod def safe_convert_to_string(df): @@ -77,6 +80,30 @@ def convert_object_array_for_pandas(column_list: list, data_list: list): md_data.append(_row) return md_data, _fields_list + @staticmethod + def convert_data_fields_for_pandas(chart: dict, fields: list, data: list): + _fields = {} + if chart.get('columns'): + for _column in chart.get('columns'): + if _column: + _fields[_column.get('value')] = _column.get('name') + if chart.get('axis'): + if chart.get('axis').get('x'): + _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name') + if chart.get('axis').get('y'): + _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name') + if chart.get('axis').get('series'): + _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get( + 'name') + _column_list = [] + for field in fields: + _column_list.append( + AxisObj(name=field if not _fields.get(field) else _fields.get(field), value=field)) + + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, data) + + return md_data, _fields_list + @staticmethod def format_pd_data(column_list: list, data_list: list, col_formats: dict = None): # 预处理数据并记录每列的格式类型