From e74c82a511c5a3bdfedab2c1accc858ae11ecdb4 Mon Sep 17 00:00:00 2001 From: ulleo Date: Mon, 22 Dec 2025 16:39:00 +0800 Subject: [PATCH] feat: support command in mcp --- backend/apps/chat/api/chat.py | 21 +++++++++++------- backend/apps/chat/curd/chat.py | 2 +- backend/apps/chat/task/llm.py | 39 ++++++++++++++++++++++------------ backend/common/core/config.py | 1 + 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index ce96cf23..382c3bae 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -198,11 +198,10 @@ async def question_answer_inner(session: SessionDep, current_user: CurrentUser, if rec_first_chat: raise Exception(f'Record id: {record_id} does not support this operation') - if command == QuickCommand.REGENERATE: - if rec_analysis_record_id: - raise Exception('Analysis record does not support this operation') - if rec_predict_record_id: - raise Exception('Predict data record does not support this operation') + if rec_analysis_record_id: + raise Exception('Analysis record does not support this operation') + if rec_predict_record_id: + raise Exception('Predict data record does not support this operation') else: # get last record id stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.regenerate_record_id).where( @@ -233,10 +232,12 @@ async def question_answer_inner(session: SessionDep, current_user: CurrentUser, finish_step, embedding) elif command == QuickCommand.ANALYSIS: - return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant, in_chat, stream) + return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant, in_chat, + stream) elif command == QuickCommand.PREDICT_DATA: - return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant, in_chat, stream) + return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant, in_chat, + stream) else: raise Exception(f'Unknown command: {command.value}') else: @@ -247,7 +248,11 @@ async def question_answer_inner(session: SessionDep, current_user: CurrentUser, if stream: def _err(_e: Exception): - yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + if in_chat: + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + else: + yield f'❌ **ERROR:**\n' + yield f'> {str(_e)}\n' return StreamingResponse(_err(e), media_type="text/event-stream") else: diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 8dc391ba..f93ad18e 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -168,7 +168,7 @@ def get_chat_chart_config(session: SessionDep, chat_record_id: int): res = session.execute(stmt) for row in res: try: - return orjson.loads(row.data) + return orjson.loads(row.chart) except Exception: pass return {} diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 7db118b0..b7402ab1 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -1172,15 +1172,17 @@ def run_task(self, in_chat: bool = True, stream: bool = True, else: # generate picture try: - if chart['type'] != 'table': + if chart.get('type') != 'table': yield '### generated chart picture\n\n' - image_url = request_picture(self.record.chat_id, self.record.id, chart, - format_json_data(result)) + image_url, error = request_picture(self.record.chat_id, self.record.id, chart, + format_json_data(result)) SQLBotLogUtil.info(image_url) if stream: - yield f'![{chart["type"]}]({image_url})' + yield f'![{chart.get("type")}]({image_url})' else: json_result['image_url'] = image_url + if error is not None: + raise error except Exception as e: if stream: raise e @@ -1207,7 +1209,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True, yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' else: if stream: - yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + yield f'❌ **ERROR:**\n' + yield f'> {error_msg}\n' else: json_result['success'] = False json_result['message'] = error_msg @@ -1327,19 +1330,21 @@ def run_analysis_or_predict_task(self, action_type: str, in_chat: bool = True, s # generate picture try: - if chart['type'] != 'table': + if chart.get('type') != 'table': yield '### generated chart picture\n\n' _data = get_chat_chart_data(_session, self.record.id) - _data['data'] = _data['data'] + predict_data + _data['data'] = _data.get('data') + predict_data - image_url = request_picture(self.record.chat_id, self.record.id, chart, - format_json_data(_data)) + image_url, error = request_picture(self.record.chat_id, self.record.id, chart, + format_json_data(_data)) SQLBotLogUtil.info(image_url) if stream: - yield f'![{chart["type"]}]({image_url})' + yield f'![{chart.get("type")}]({image_url})' else: json_result['image_url'] = image_url + if error is not None: + raise error except Exception as e: if stream: raise e @@ -1360,6 +1365,7 @@ def run_analysis_or_predict_task(self, action_type: str, in_chat: bool = True, s if not stream: yield json_result except Exception as e: + traceback.print_exc() error_msg: str if isinstance(e, SingleMessageError): error_msg = str(e) @@ -1371,7 +1377,8 @@ def run_analysis_or_predict_task(self, action_type: str, in_chat: bool = True, s yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' else: if stream: - yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + yield f'❌ **ERROR:**\n' + yield f'> {error_msg}\n' else: json_result['success'] = False json_result['message'] = error_msg @@ -1451,16 +1458,20 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): request_obj = { "path": os.path.join(settings.MCP_IMAGE_PATH, file_name), - "type": chart['type'], + "type": chart.get('type'), "data": orjson.dumps(data.get('data') if data.get('data') else []).decode(), "axis": orjson.dumps(axis).decode(), } - requests.post(url=settings.MCP_IMAGE_HOST, json=request_obj) + _error = None + try: + requests.post(url=settings.MCP_IMAGE_HOST, json=request_obj, timeout=settings.SERVER_IMAGE_TIMEOUT) + except Exception as e: + _error = e request_path = urllib.parse.urljoin(settings.SERVER_IMAGE_HOST, f"{file_name}.png") - return request_path + return request_path, _error def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None): diff --git a/backend/common/core/config.py b/backend/common/core/config.py index fec6832d..4e09c201 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -87,6 +87,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: EXCEL_PATH: str = '/opt/sqlbot/data/excel' MCP_IMAGE_HOST: str = 'http://localhost:3000' SERVER_IMAGE_HOST: str = 'http://YOUR_SERVE_IP:MCP_PORT/images/' + SERVER_IMAGE_TIMEOUT: int = 15 LOCAL_MODEL_PATH: str = '/opt/sqlbot/models' DEFAULT_EMBEDDING_MODEL: str = 'shibing624/text2vec-base-chinese'