From 97e833c486fefd873db18a842f032bbec4be4627 Mon Sep 17 00:00:00 2001 From: fit2cloud-chenyw Date: Tue, 25 Nov 2025 19:00:19 +0800 Subject: [PATCH 01/31] perf: Multiple Domain Validation for Embedded Systems #388 --- backend/apps/system/api/assistant.py | 13 ++++--- backend/apps/system/crud/assistant.py | 16 ++++++++ backend/apps/system/middleware/auth.py | 5 ++- backend/apps/system/schemas/system_schema.py | 1 + backend/common/utils/utils.py | 6 +++ frontend/src/i18n/en.json | 4 +- frontend/src/i18n/ko-KR.json | 4 +- frontend/src/i18n/zh-CN.json | 4 +- frontend/src/views/system/embedded/Page.vue | 21 ++++++---- frontend/src/views/system/embedded/iframe.vue | 39 +++++++++++++------ 10 files changed, 86 insertions(+), 27 deletions(-) diff --git a/backend/apps/system/api/assistant.py b/backend/apps/system/api/assistant.py index 9733d3f0..c6036142 100644 --- a/backend/apps/system/api/assistant.py +++ b/backend/apps/system/api/assistant.py @@ -17,7 +17,7 @@ from common.core.deps import SessionDep, Trans from common.core.security import create_access_token from common.core.sqlbot_cache import clear_cache -from common.utils.utils import get_origin_from_referer +from common.utils.utils import get_origin_from_referer, origin_match_domain router = APIRouter(tags=["system/assistant"], prefix="/system/assistant") @@ -30,13 +30,15 @@ async def info(request: Request, response: Response, session: SessionDep, trans: if not db_model: raise RuntimeError(f"assistant application not exist") db_model = AssistantModel.model_validate(db_model) - response.headers["Access-Control-Allow-Origin"] = db_model.domain + origin = request.headers.get("origin") or get_origin_from_referer(request) if not origin: raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) origin = origin.rstrip('/') - if origin != db_model.domain: + if not origin_match_domain(origin, db_model.domain): raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) + + response.headers["Access-Control-Allow-Origin"] = origin return db_model @@ -48,13 +50,14 @@ async def getApp(request: Request, response: Response, session: SessionDep, tran if not db_model: raise RuntimeError(f"assistant application not exist") db_model = AssistantModel.model_validate(db_model) - response.headers["Access-Control-Allow-Origin"] = db_model.domain origin = request.headers.get("origin") or get_origin_from_referer(request) if not origin: raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) origin = origin.rstrip('/') - if origin != db_model.domain: + if not origin_match_domain(origin, db_model.domain): raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) + + response.headers["Access-Control-Allow-Origin"] = origin return db_model diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 72168221..2196e61b 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -99,17 +99,22 @@ class AssistantOutDs: assistant: AssistantHeader ds_list: Optional[list[AssistantOutDsSchema]] = None certificate: Optional[str] = None + request_origin: Optional[str] = None def __init__(self, assistant: AssistantHeader): self.assistant = assistant self.ds_list = None self.certificate = assistant.certificate + self.request_origin = assistant.request_origin self.get_ds_from_api() # @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id") def get_ds_from_api(self): config: dict[any] = json.loads(self.assistant.configuration) endpoint: str = config['endpoint'] + endpoint = self.get_complete_endpoint(endpoint=endpoint) + if not endpoint: + raise Exception(f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]") certificateList: list[any] = json.loads(self.certificate) header = {} cookies = {} @@ -137,6 +142,17 @@ def get_ds_from_api(self): else: raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}") + def get_complete_endpoint(self, endpoint: str) -> str | None: + if endpoint.startswith("http://") or endpoint.startswith("https://"): + return endpoint + domain_text = self.assistant.domain + if not domain_text: + return None + if ',' in domain_text: + return self.request_origin.strip('/') if self.request_origin else domain_text.split(',')[0].strip('/') + endpoint + else: + return f"{domain_text}{endpoint}" + def get_simple_ds_list(self): if self.ds_list: return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list] diff --git a/backend/apps/system/middleware/auth.py b/backend/apps/system/middleware/auth.py index 3ea720c6..4aaffd58 100644 --- a/backend/apps/system/middleware/auth.py +++ b/backend/apps/system/middleware/auth.py @@ -16,7 +16,7 @@ from common.core.config import settings from common.core.schemas import TokenPayload from common.utils.locale import I18n -from common.utils.utils import SQLBotLogUtil +from common.utils.utils import SQLBotLogUtil, get_origin_from_referer from common.utils.whitelist import whiteUtils from fastapi.security.utils import get_authorization_scheme_param from common.core.deps import get_i18n @@ -40,6 +40,9 @@ async def dispatch(self, request, call_next): if validator[0]: request.state.current_user = validator[1] request.state.assistant = validator[2] + origin = request.headers.get("origin") or get_origin_from_referer(request) + if origin and validator[2]: + request.state.assistant.request_origin = origin return await call_next(request) message = trans('i18n_permission.authenticate_invalid', msg = validator[1]) return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"}) diff --git a/backend/apps/system/schemas/system_schema.py b/backend/apps/system/schemas/system_schema.py index 52dc20c9..e37bacab 100644 --- a/backend/apps/system/schemas/system_schema.py +++ b/backend/apps/system/schemas/system_schema.py @@ -116,6 +116,7 @@ class AssistantHeader(AssistantDTO): unique: Optional[str] = None certificate: Optional[str] = None online: bool = False + request_origin: Optional[str] = None class AssistantValidator(BaseModel): diff --git a/backend/common/utils/utils.py b/backend/common/utils/utils.py index ee6b0964..8d3fc808 100644 --- a/backend/common/utils/utils.py +++ b/backend/common/utils/utils.py @@ -263,6 +263,12 @@ def get_origin_from_referer(request: Request): SQLBotLogUtil.error(f"解析 Referer 出错: {e}") return referer +def origin_match_domain(origin: str, domain: str) -> bool: + if not origin or not domain: + return False + origin_text = origin.rstrip('/') + domain_list = domain.replace(" ", "").split(',') + return origin_text in [d.rstrip('/') for d in domain_list] def equals_ignore_case(str1: str, *args: str) -> bool: if str1 is None: diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index ca3d1b7a..53f2f2dc 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -572,7 +572,9 @@ "creating_advanced_applications": "Creating Advanced Applications", "configure_interface": "Configure interface", "interface_url": "Interface URL", - "format_is_incorrect": "format is incorrect", + "format_is_incorrect": "format is incorrect{msg}", + "domain_format_incorrect": ",start with http/https, no trailing slash (/), multiple domains separated by half-width commas (,)", + "interface_url_incorrect": ",enter a relative path starting with /", "aes_enable": "Enable AES encryption", "aes_enable_tips": "The fields (host, user, password, dataBase, schema) are all encrypted using the AES-CBC-PKCS5Padding encryption method", "bit": "bit", diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json index a9bf5820..0cf1a873 100644 --- a/frontend/src/i18n/ko-KR.json +++ b/frontend/src/i18n/ko-KR.json @@ -571,7 +571,9 @@ "private": "비공개", "configure_interface": "인터페이스 설정", "interface_url": "인터페이스 URL", - "format_is_incorrect": "형식이 올바르지 않습니다", + "format_is_incorrect": "형식이 올바르지 않습니다{msg}", + "domain_format_incorrect": ", http/https로 시작, 슬래시(/)로 끝나지 않음, 여러 도메인은 반각 쉼표(,)로 구분", + "interface_url_incorrect": ", 상대 경로를 입력해주세요. /로 시작합니다", "aes_enable": "AES 암호화 활성화", "aes_enable_tips": "암호화 필드 (host, user, password, dataBase, schema)는 모두 AES-CBC-PKCS5Padding 암호화 방식을 사용합니다", "bit": "비트", diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json index 0f9d6e2b..d1088070 100644 --- a/frontend/src/i18n/zh-CN.json +++ b/frontend/src/i18n/zh-CN.json @@ -571,7 +571,9 @@ "private": "私有", "configure_interface": "配置接口", "interface_url": "接口 URL", - "format_is_incorrect": "格式不对", + "format_is_incorrect": "格式不对{msg}", + "domain_format_incorrect": ",http或https开头,不能以 / 结尾,多个域名以逗号(半角)分隔", + "interface_url_incorrect": ",请填写相对路径,以/开头", "aes_enable": "开启 AES 加密", "aes_enable_tips": "加密字段 (host, user, password, dataBase, schema) 均采用 AES-CBC-PKCS5Padding 加密方式", "bit": "位", diff --git a/frontend/src/views/system/embedded/Page.vue b/frontend/src/views/system/embedded/Page.vue index a49fe5d3..42784bdc 100644 --- a/frontend/src/views/system/embedded/Page.vue +++ b/frontend/src/views/system/embedded/Page.vue @@ -209,13 +209,20 @@ const validateUrl = (_: any, value: any, callback: any) => { ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line - var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i - var objExp = new RegExp(Expression) - if (objExp.test(value) && !value.endsWith('/')) { - callback() - } else { - callback(t('embedded.format_is_incorrect')) - } + value + .trim() + .split(',') + .forEach((tempVal: string) => { + var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i + var objExp = new RegExp(Expression) + if (objExp.test(tempVal) && !tempVal.endsWith('/')) { + callback() + } else { + callback( + t('embedded.format_is_incorrect', { msg: t('embedded.domain_format_incorrect') }) + ) + } + }) } } const rules = { diff --git a/frontend/src/views/system/embedded/iframe.vue b/frontend/src/views/system/embedded/iframe.vue index eb15cabd..61becbde 100644 --- a/frontend/src/views/system/embedded/iframe.vue +++ b/frontend/src/views/system/embedded/iframe.vue @@ -157,7 +157,16 @@ const handleBaseEmbedded = (row: any) => { const handleAdvancedEmbedded = (row: any) => { advancedApplication.value = true if (row) { - Object.assign(urlForm, cloneDeep(JSON.parse(row.configuration))) + const tempData = cloneDeep(JSON.parse(row.configuration)) + if (tempData?.endpoint.startsWith('http')) { + row.domain + .trim() + .split(',') + .forEach((domain: string) => { + tempData.endpoint = tempData.endpoint.replace(domain, '') + }) + } + Object.assign(urlForm, tempData) } ruleConfigvVisible.value = true dialogTitle.value = row?.id @@ -265,13 +274,20 @@ const validateUrl = (_: any, value: any, callback: any) => { ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line - var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i - var objExp = new RegExp(Expression) - if (objExp.test(value) && !value.endsWith('/')) { - callback() - } else { - callback(t('embedded.format_is_incorrect')) - } + value + .trim() + .split(',') + .forEach((tempVal: string) => { + var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i + var objExp = new RegExp(Expression) + if (objExp.test(tempVal) && !tempVal.endsWith('/')) { + callback() + } else { + callback( + t('embedded.format_is_incorrect', { msg: t('embedded.domain_format_incorrect') }) + ) + } + }) } } const rules = { @@ -307,12 +323,13 @@ const validatePass = (_: any, value: any, callback: any) => { ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line - var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i + // var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i + var Expression = /^\/([a-zA-Z0-9_-]+\/)*[a-zA-Z0-9_-]+(\?[a-zA-Z0-9_=&-]+)?$/ var objExp = new RegExp(Expression) - if (objExp.test(value) && value.startsWith(currentEmbedded.domain)) { + if (objExp.test(value)) { callback() } else { - callback(t('embedded.format_is_incorrect')) + callback(t('embedded.format_is_incorrect', { msg: t('embedded.interface_url_incorrect') })) } } } From 12b25fa7359aa7da8a58868f843427d612976e3d Mon Sep 17 00:00:00 2001 From: ulleo Date: Tue, 25 Nov 2025 15:51:32 +0800 Subject: [PATCH 02/31] perf: improve generate Oracle SQL --- backend/apps/chat/models/chat_model.py | 27 ++-- backend/templates/sql_examples/Oracle.yaml | 137 ++++++++++++++------- backend/templates/template.yaml | 40 +++++- 3 files changed, 144 insertions(+), 60 deletions(-) diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 5e7f0136..7713270f 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -185,8 +185,12 @@ class AiModelQuestion(BaseModel): def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True): _sql_template = get_sql_example_template(db_type) - _query_limit = get_sql_template()['query_limit'] if enable_query_limit else get_sql_template()['no_query_limit'] - _base_sql_rules = _sql_template['quot_rule'] + _query_limit + _sql_template['limit_rule'] + _sql_template['other_rule'] + _base_template = get_sql_template() + _process_check = _sql_template.get('process_check') if _sql_template.get('process_check') else _base_template[ + 'process_check'] + _query_limit = _base_template['query_limit'] if enable_query_limit else _base_template['no_query_limit'] + _base_sql_rules = _sql_template['quot_rule'] + _query_limit + _sql_template['limit_rule'] + _sql_template[ + 'other_rule'] _sql_examples = _sql_template['basic_example'] _example_engine = _sql_template['example_engine'] _example_answer_1 = _sql_template['example_answer_1_with_limit'] if enable_query_limit else _sql_template[ @@ -195,15 +199,16 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T 'example_answer_2'] _example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[ 'example_answer_3'] - return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, - lang=self.lang, terminologies=self.terminologies, - data_training=self.data_training, custom_prompt=self.custom_prompt, - base_sql_rules=_base_sql_rules, - basic_sql_examples=_sql_examples, - example_engine=_example_engine, - example_answer_1=_example_answer_1, - example_answer_2=_example_answer_2, - example_answer_3=_example_answer_3) + return _base_template['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, + lang=self.lang, terminologies=self.terminologies, + data_training=self.data_training, custom_prompt=self.custom_prompt, + process_check=_process_check, + base_sql_rules=_base_sql_rules, + basic_sql_examples=_sql_examples, + example_engine=_example_engine, + example_answer_1=_example_answer_1, + example_answer_2=_example_answer_2, + example_answer_3=_example_answer_3) def sql_user_question(self, current_time: str): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, diff --git a/backend/templates/sql_examples/Oracle.yaml b/backend/templates/sql_examples/Oracle.yaml index b7293465..ba3015b2 100644 --- a/backend/templates/sql_examples/Oracle.yaml +++ b/backend/templates/sql_examples/Oracle.yaml @@ -1,4 +1,18 @@ template: + process_check: | + + 1. 分析用户问题,确定查询需求 + 2. 根据表结构生成基础SQL + 3. 强制检查:SQL是否包含GROUP BY/聚合函数? + 4. 如果是GROUP BY查询:必须使用外层查询结构包裹 + 5. 强制检查:应用数据量限制规则 + 6. 应用其他规则(引号、别名等) + 7. 最终验证:GROUP BY查询的ROWNUM位置是否正确? + 8. 强制检查:检查语法是否正确? + 9. 确定图表类型 + 10. 返回JSON结果 + + quot_rule: | 必须对数据库名、表名、字段名、别名外层加双引号(")。 @@ -10,42 +24,61 @@ template: limit_rule: | - - 当需要限制行数时: - 1. 12c以下版本必须使用ROWNUM语法 - 2. 12c+版本推荐使用FETCH FIRST语法 - - 版本适配: - - Oracle 12c以下:必须使用 WHERE ROWNUM <= N - - Oracle 12c+:推荐使用 FETCH FIRST N ROWS ONLY - - - 重要:ROWNUM必须放在正确的位置,避免语法错误 - 1. 单层查询:ROWNUM直接跟在WHERE子句后 - - 2. 多层查询:ROWNUM只能放在最外层 - - 3. 禁止的错误写法: - - SELECT ... FROM table - WHERE conditions - GROUP BY ... - ORDER BY ... - WHERE ROWNUM <= N -- 错误:不能有多个WHERE - - 4. 正确顺序:WHERE → GROUP BY → HAVING → ORDER BY → ROWNUM - 5. 括号位置:从内层SELECT开始到内层结束都要括起来 - - SELECT ... FROM ( - -- 内层完整查询(包含自己的SELECT、FROM、WHERE、GROUP BY、ORDER BY) - SELECT columns FROM table WHERE conditions GROUP BY ... ORDER BY ... - ) alias WHERE ROWNUM <= N - - + + Oracle版本语法适配 + + 如果db-engine版本号小于12 + 必须使用ROWNUM语法 + 如果db-engine版本号大于等于12 + 推荐使用FETCH FIRST语法 + + + + Oracle数据库 FETCH FIRST 语法规范 + 若使用 FETCH FIRST 语法,则必须遵循该规范 + + + + + + Oracle数据库ROWNUM语法规范 + 若使用ROWNUM语法,则必须遵循该规范 + + + 简单查询 + + + + 语法禁区 + + - 禁止多个WHERE子句 + - 禁止ROWNUM在GROUP BY内层(影响分组结果) + - 禁止括号不完整 + + + + + + GROUP BY查询的ROWNUM强制规范(必须严格遵守) + 所有包含GROUP BY或聚合函数的查询必须使用外层查询结构 + ROWNUM必须放在最外层查询的WHERE子句中 + + + 如果SQL包含GROUP BY、COUNT、SUM等聚合函数 + 必须使用:SELECT ... FROM (内层完整查询) WHERE ROWNUM <= N + 否则(简单查询) + 可以使用:SELECT ... FROM table WHERE conditions AND ROWNUM <= N + + + + -- 错误:ROWNUM在内层影响分组结果 + SELECT ... GROUP BY ... WHERE ROWNUM <= N + + + + -- 正确:ROWNUM在外层 + SELECT ... FROM (SELECT ... GROUP BY ...) WHERE ROWNUM <= N + other_rule: | @@ -73,7 +106,7 @@ template: SELECT "订单ID", "金额" FROM "TEST"."ORDERS" "t1" WHERE ROWNUM <= 100 -- 错误:缺少英文别名 SELECT COUNT("订单ID") FROM "TEST"."ORDERS" "t1" -- 错误:函数未加别名 - + SELECT "t1"."订单ID" AS "order_id", "t1"."金额" AS "amount", @@ -90,7 +123,7 @@ template: SELECT DATE, status FROM PUBLIC.USERS -- 错误:未处理关键字和引号 SELECT "DATE", ROUND(active_ratio) FROM "PUBLIC"."USERS" -- 错误:百分比格式错误 - + SELECT "u"."DATE" AS "create_date", TO_CHAR("u"."active_ratio" * 100, '990.99') || '%' AS "active_percent" @@ -98,6 +131,14 @@ template: WHERE "u"."status" = 1 AND ROWNUM <= 1000 + + SELECT + "u"."DATE" AS "create_date", + TO_CHAR("u"."active_ratio" * 100, '990.99') || '%' AS "active_percent" + FROM "PUBLIC"."USERS" "u" + WHERE "u"."status" = 1 + FETCH FIRST 1000 ROWS ONLY + @@ -108,9 +149,9 @@ template: count(*) AS "user_count" FROM "PUBLIC"."USERS" "u" WHERE "u"."status" = 1 - AND ROWNUM <= 100 + AND ROWNUM <= 100 -- 严重错误:影响分组结果! GROUP BY "u"."DEPARTMENT" - ORDER BY "department_name" -- 错误:ROWNUM 应当写在最外层,这样会导致查询结果条数比实际数据的数量少 + ORDER BY "department_name" SELECT "department_name", "user_count" FROM @@ -123,7 +164,7 @@ template: ORDER BY "department_name" WHERE ROWNUM <= 100 -- 错误:语法错误,同级内只能有一个WHERE - + SELECT "department_name", "user_count" FROM ( SELECT "u"."DEPARTMENT" AS "department_name", @@ -133,12 +174,22 @@ template: GROUP BY "u"."DEPARTMENT" ORDER BY "department_name" ) - WHERE ROWNUM <= 100 -- 外层限制(确保最终结果可控) + WHERE ROWNUM <= 100 -- 正确,在外层限制数量(确保最终结果可控) + + + SELECT + "u"."DEPARTMENT" AS "department_name", + count(*) AS "user_count" + FROM "PUBLIC"."USERS" "u" + WHERE "u"."status" = 1 + GROUP BY "u"."DEPARTMENT" + ORDER BY "department_name" + FETCH FIRST 100 ROWS ONLY - example_engine: Oracle 19c + example_engine: Oracle 11g example_answer_1: | {"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"} example_answer_1_with_limit: | diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index 7e15db29..def665e6 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -6,14 +6,37 @@ template: {data_training} sql: + process_check: | + + 1. 分析用户问题,确定查询需求 + 2. 根据表结构生成基础SQL + 3. 强制检查:应用数据量限制规则 + 4. 应用其他规则(引号、别名等) + 5. 强制检查:检查语法是否正确? + 6. 确定图表类型 + 7. 返回JSON结果 + query_limit: | - - 1. 必须遵守:所有生成的SQL必须包含数据量限制 - 2. 默认限制:1000条(除非用户明确指定其他数量) + + 数据量限制策略(必须严格遵守 - 零容忍) + + 所有生成的SQL必须包含数据量限制,这是强制要求 + 默认限制:1000条(除非用户明确指定其他数量) + 忘记添加数据量限制是不可接受的错误 + + + + 如果生成的SQL没有数据量限制,必须重新生成 + 在最终返回前必须验证限制是否存在 + no_query_limit: | - - 如果没有指定数据条数的限制,则查询的SQL默认返回全部数据 + + 数据量限制策略(必须严格遵守) + + 默认不限制数据量,返回全部数据(除非用户明确指定其他数量) + 不要臆测场景可能需要的数据量限制,以用户明确指定的数量为准 + system: | @@ -27,9 +50,10 @@ template: :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例。 若有块,它会提供一组,可能会是额外添加的背景信息,或者是额外的生成SQL的要求,请结合额外信息或要求后生成你的回答。 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 + 你必须遵守内规定的生成SQL规则 + 你必须遵守内规定的检查步骤生成你的回答 - 你必须遵守以下规则: 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 @@ -90,6 +114,8 @@ template: + {process_check} + {basic_sql_examples} @@ -369,6 +395,8 @@ template: ### 以往提问: {old_questions} + + /no_think analysis: system: | From e79a04158111915644e2cf2e78809c7a8bb0a250 Mon Sep 17 00:00:00 2001 From: ulleo Date: Tue, 25 Nov 2025 18:13:15 +0800 Subject: [PATCH 03/31] fix: "Object of type datetime is not JSON serializable" error in mcp chat --- backend/apps/chat/task/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 801cbf16..39ab5423 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -1047,7 +1047,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, if in_chat: yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n' if not stream: - json_result['data'] = result.get('data') + json_result['data'] = get_chat_chart_data(_session, self.record.id) if finish_step.value <= ChatFinishStep.QUERY_DATA.value: if stream: From 2a389067302e4618a0f1451efe0bedbbf2d7e00d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=98=89=E8=B1=AA?= <42510293+ziyujiahao@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:40:04 +0800 Subject: [PATCH 04/31] refactor:update recommended questions custom config. (#491) --- backend/apps/chat/curd/chat.py | 11 +- .../datasource/api/recommended_problem.py | 12 +- backend/apps/datasource/crud/datasource.py | 5 + .../datasource/crud/recommended_problem.py | 26 +++- backend/apps/datasource/models/datasource.py | 12 ++ frontend/src/api/recommendedApi.ts | 3 +- frontend/src/i18n/en.json | 1 + frontend/src/i18n/ko-KR.json | 1 + frontend/src/i18n/zh-CN.json | 1 + frontend/src/views/chat/index.vue | 2 +- frontend/src/views/ds/Card.vue | 4 +- frontend/src/views/ds/Datasource.vue | 1 + .../ds/RecommendedProblemConfigDialog.vue | 122 ++++++++++++------ 13 files changed, 155 insertions(+), 46 deletions(-) diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 48b3f054..034c4287 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -8,7 +8,8 @@ from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \ TypeEnum, OperationEnum, ChatRecordResult -from apps.datasource.models.datasource import CoreDatasource +from apps.datasource.crud.recommended_problem import get_datasource_recommended, get_datasource_recommended_chart +from apps.datasource.models.datasource import CoreDatasource, DsRecommendedProblem from apps.system.crud.assistant import AssistantOutDsFactory from common.core.deps import CurrentAssistant, SessionDep, CurrentUser from common.utils.utils import extract_nested_json @@ -70,6 +71,7 @@ def get_chart_config(session: SessionDep, chart_record_id: int): pass return {} + def format_chart_fields(chart_info: dict): fields = [] if chart_info.get('columns') and len(chart_info.get('columns')) > 0: @@ -88,6 +90,7 @@ def format_chart_fields(chart_info: dict): fields.append(column_str) return fields + def get_last_execute_sql_error(session: SessionDep, chart_id: int): stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by( ChatRecord.create_time.desc()).limit(1) @@ -396,6 +399,12 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: record.finish = True record.create_time = datetime.datetime.now() record.create_by = current_user.id + if ds.recommended_config == 2: + questions = get_datasource_recommended_chart(session, ds.id) + record.recommended_question = orjson.dumps(questions).decode() + record.recommended_question_answer = orjson.dumps({ + "content": questions + }).decode() _record = ChatRecord(**record.model_dump()) diff --git a/backend/apps/datasource/api/recommended_problem.py b/backend/apps/datasource/api/recommended_problem.py index 29fac41c..b7c75624 100644 --- a/backend/apps/datasource/api/recommended_problem.py +++ b/backend/apps/datasource/api/recommended_problem.py @@ -1,7 +1,10 @@ from fastapi import APIRouter -from apps.datasource.crud.recommended_problem import get_datasource_recommended -from common.core.deps import SessionDep +from apps.datasource.crud.datasource import update_ds_recommended_config +from apps.datasource.crud.recommended_problem import get_datasource_recommended, \ + save_recommended_problem +from apps.datasource.models.datasource import RecommendedProblemBase +from common.core.deps import SessionDep, CurrentUser router = APIRouter(tags=["recommended_problem"], prefix="/recommended_problem") @@ -10,3 +13,8 @@ async def datasource_recommended(session: SessionDep, ds_id: int): return get_datasource_recommended(session, ds_id) + +@router.post("/save_recommended_problem") +async def datasource_recommended(session: SessionDep, user: CurrentUser, data_info: RecommendedProblemBase): + update_ds_recommended_config(session, data_info.datasource_id, data_info.recommended_config) + return save_recommended_problem(session, user, data_info) diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index a6a867e1..e3a9a833 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -109,6 +109,11 @@ def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreData run_save_ds_embeddings([ds.id]) return ds +def update_ds_recommended_config(session: SessionDep,datasource_id: int, recommended_config:int): + record = session.exec(select(CoreDatasource).where(CoreDatasource.id == datasource_id)).first() + record.recommended_config = recommended_config + session.add(record) + session.commit() def delete_ds(session: SessionDep, id: int): term = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first() diff --git a/backend/apps/datasource/crud/recommended_problem.py b/backend/apps/datasource/crud/recommended_problem.py index e3f8ab67..21bbff71 100644 --- a/backend/apps/datasource/crud/recommended_problem.py +++ b/backend/apps/datasource/crud/recommended_problem.py @@ -1,11 +1,31 @@ +import datetime + from sqlmodel import select -from common.core.deps import SessionDep -from ..models.datasource import DsRecommendedProblem +from common.core.deps import SessionDep, CurrentUser, Trans +from ..models.datasource import DsRecommendedProblem, RecommendedProblemBase, RecommendedProblemBaseChat def get_datasource_recommended(session: SessionDep, ds_id: int): statement = select(DsRecommendedProblem).where(DsRecommendedProblem.datasource_id == ds_id) - dsRecommendedProblem = session.exec(statement) + dsRecommendedProblem = session.exec(statement).all() return dsRecommendedProblem +def get_datasource_recommended_chart(session: SessionDep, ds_id: int): + statement = select(DsRecommendedProblem.question).where(DsRecommendedProblem.datasource_id == ds_id) + dsRecommendedProblems = session.exec(statement).all() + return dsRecommendedProblems + +def save_recommended_problem(session: SessionDep,user: CurrentUser, data_info: RecommendedProblemBase): + session.query(DsRecommendedProblem).filter(DsRecommendedProblem.datasource_id == data_info.datasource_id).delete(synchronize_session=False) + problemInfo = data_info.problemInfo + if problemInfo is not None: + for problemItem in problemInfo: + problemItem.id = None + problemItem.create_time = datetime.datetime.now() + problemItem.create_by = user.id + record = DsRecommendedProblem(**problemItem.model_dump()) + session.add(record) + session.flush() + session.refresh(record) + session.commit() diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py index 2cdfed5a..dd814ea4 100644 --- a/backend/apps/datasource/models/datasource.py +++ b/backend/apps/datasource/models/datasource.py @@ -74,6 +74,18 @@ class CreateDatasource(BaseModel): tables: List[CoreTable] = [] recommended_config: int = 1 +class RecommendedProblemBase(BaseModel): + datasource_id: int = None + recommended_config: int = None + problemInfo: List[DsRecommendedProblem] = [] + + +class RecommendedProblemBaseChat: + def __init__(self, content): + self.content = content + + content: List[str] = [] + # edit local saved table and fields class TableObj(BaseModel): diff --git a/frontend/src/api/recommendedApi.ts b/frontend/src/api/recommendedApi.ts index 64f83538..7c73b62e 100644 --- a/frontend/src/api/recommendedApi.ts +++ b/frontend/src/api/recommendedApi.ts @@ -3,5 +3,6 @@ import { request } from '@/utils/request' export const recommendedApi = { get_recommended_problem: (dsId: any) => request.get(`/recommended_problem/get_datasource_recommended/${dsId}`), - save_recommended_problem: (data: any) => request.post(`/recommended_problem/save`, data), + save_recommended_problem: (data: any) => + request.post(`/recommended_problem/save_recommended_problem`, data), } diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index 53f2f2dc..7772e794 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -309,6 +309,7 @@ } }, "datasource": { + "recommended_problem_tips": "Custom configuration requires at least one problem, each problem should be 2-200 characters", "recommended_problem_configuration": "Recommended Problem Configuration", "problem_generation_method": "Problem Generation Method", "ai_automatic_generation": "AI Automatic Generation", diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json index 0cf1a873..cb149dd5 100644 --- a/frontend/src/i18n/ko-KR.json +++ b/frontend/src/i18n/ko-KR.json @@ -309,6 +309,7 @@ } }, "datasource": { + "recommended_problem_tips": "사용자 정의 구성으로 최소 한 개의 문제를 생성하세요, 각 문제는 2~200자로 작성", "recommended_problem_configuration": "추천 문제 구성", "problem_generation_method": "문제 생성 방식", "ai_automatic_generation": "AI 자동 생성", diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json index d1088070..0897aa9a 100644 --- a/frontend/src/i18n/zh-CN.json +++ b/frontend/src/i18n/zh-CN.json @@ -309,6 +309,7 @@ } }, "datasource": { + "recommended_problem_tips": "自定义配置至少一个问题,每个问题2-200个字符", "recommended_problem_configuration": "推荐问题配置", "problem_generation_method": "问题生成方式", "ai_automatic_generation": "AI 自动生成", diff --git a/frontend/src/views/chat/index.vue b/frontend/src/views/chat/index.vue index 52c1bf2e..42ac1691 100644 --- a/frontend/src/views/chat/index.vue +++ b/frontend/src/views/chat/index.vue @@ -704,7 +704,7 @@ function onChatCreatedQuick(chat: ChatInfo) { } function onChatCreated(chat: ChatInfo) { - if (chat.records.length === 1) { + if (chat.records.length === 1 && !chat.records[0].recommended_question) { getRecommendQuestions(chat.records[0].id) } } diff --git a/frontend/src/views/ds/Card.vue b/frontend/src/views/ds/Card.vue index 1a9759e8..a30da180 100644 --- a/frontend/src/views/ds/Card.vue +++ b/frontend/src/views/ds/Card.vue @@ -136,7 +136,7 @@ const onClickOutside = () => { {{ $t('datasource.edit') }} -